Replication of the paper Image is Worth 16x16 Words: Transformers for Image Recognition at Scale in pytorch

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
from torchsummary import summary
from torch import nn
from PIL import Image
from torchvision import transforms , datasets
from torch.utils.data import DataLoader

In [None]:
torch.__version__

In [None]:
torchvision.__version__

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

## Dataset Prep

In [None]:
train_dir = "../datasets/pizza_steak_sushi/train"
test_dir = "../datasets/pizza_steak_sushi/test"

In [None]:
image = Image.open("../datasets/pizza_steak_sushi/train/pizza/5764.jpg")
print("image_size" , image.size)
image

In [None]:
img_size = 224

transform = transforms.Compose([
    transforms.Resize((img_size , img_size)) , 
    transforms.ToTensor(),
])

In [None]:
batch_size = 32

#create datasets
train_data = datasets.ImageFolder(
    "../datasets/pizza_steak_sushi/train" , 
    transform = transform
)

test_data = datasets.ImageFolder(
    "../datasets/pizza_steak_sushi/test" , 
    transform = transform
)

In [None]:
train_dataloader = DataLoader(
    dataset = train_data , 
    batch_size = batch_size , 
    shuffle = True , 
    num_workers = 4 , 
    pin_memory= True
)

test_dataloader = DataLoader(
    dataset = test_data , 
    batch_size = batch_size , 
    shuffle = False , 
    num_workers = 4 , 
    pin_memory = True
)

train_dataloader , test_dataloader 

In [None]:
classes = train_data.classes
classes

In [None]:
# batch of images
image_batch , label_batch = next(iter(train_dataloader))
print("image_batch" , image_batch.shape)
print("label_batch" , label_batch.shape)
# get single image from batch
image , label = image_batch[0] , label_batch[0]

image.shape , label

In [None]:
img = []
for image in image_batch:
    image = image.permute(1 , 2 , 0)
    img.append(image)
    
grid_size = int(np.ceil(np.sqrt(batch_size))) 
#np.ceil --> rounding
fig, axs = plt.subplots(grid_size , grid_size , figsize=(20, 20))

for i in range(grid_size):
    for j in range(grid_size):
        ax = axs[i , j ]
        if i * grid_size +j <batch_size:
            ax.imshow(img[i * grid_size+ j], cmap='gray' , aspect='auto')
            ax.axis('off')
        else:
            ax.axis('off')
plt.subplots_adjust(wspace = 0.1  , hspace= 0.1)
plt.show()

## Model

### Patch Embedding

In [None]:
# images are 224 x 224 x 3
height = 224 # H
width = 224 # W
channels = 3 # C
patch_size = 16 # P

numb_patches = int((height * width) / patch_size**2)
print(f"Number of Patches (N) with image resolution {height}x{width} is {numb_patches} patches")

In [None]:
embedding_layer_input_shape = (height , width , channels)

# 196 patches each of size 16 x 16
# output will be for each image --> 192 patches , P**2 * C
embedding_layer_output_shape = (numb_patches , patch_size**2 * channels) 

print(f"Input Image shape : {embedding_layer_input_shape}")
print(f"Output embedded shape flattened to patches : {embedding_layer_output_shape}")

In [None]:
image , label = image_batch[5] , label_batch[5]
plt.imshow(image.permute(1,2,0))
plt.title(classes[label])
plt.axis("off")
plt.show()

In [None]:
# visualizing top row patched pixels

permuted_image = image.permute(1,2,0) # H W C

plt.figure(figsize=(patch_size , patch_size))
plt.imshow(permuted_image[:patch_size , : , :])

In [None]:
img_size = 224 
patch_size = 16
num_patches = img_size / patch_size

assert img_size % patch_size ==0 , "image size must be divisible by patch_size"
print(f"number of patches per row: {num_patches}\nPatch size is {patch_size} x {patch_size} pixels")

fig,axs = plt.subplots(nrows = 1, 
                       ncols = img_size // patch_size ,
                       figsize = (num_patches , num_patches) , 
                       sharex = True , 
                       sharey = True)

for i,patch in enumerate(range(0 , img_size , patch_size)): #start from zero move by patch size in img_size
    axs[i].imshow(permuted_image[:patch_size , patch:patch+patch_size , :])
    axs[i].set_xlabel(i+1)
    axs[i].set_xticks([])
    axs[i].set_yticks([])
    

In [None]:
img_size = 224
patch_size = 16
num_patches = img_size / patch_size

assert img_size % patch_size  == 0 , "the img size should be divisible by patch size"
print(f"Number of patches per row: {num_patches}\
        \nNumber of patches per column: {num_patches}\
        \nTotal patches: {num_patches*num_patches}\
        \nPatch size: {patch_size} pixels x {patch_size} pixels")

fig , axs = plt.subplots(
    nrows = img_size // patch_size , 
    ncols = img_size // patch_size , 
    figsize = (num_patches , num_patches),
    sharex = True , 
    sharey = True
)

for i , patch_height in enumerate(range(0 , img_size ,patch_size )):
   for j , patch_width in enumerate(range(0 , img_size , patch_size)):
       
    axs[i , j].imshow(permuted_image[patch_height:patch_height+patch_size , 
                                        patch_width:patch_width+patch_size , :])
    axs[i, j].set_ylabel(i+1,
                             rotation="horizontal",
                             horizontalalignment="right",
                             verticalalignment="center")
    axs[i, j].set_xlabel(j+1)
    axs[i, j].set_xticks([])
    axs[i, j].set_yticks([])
    axs[i, j].label_outer()
        
fig.suptitle(f"{classes[label]} -> Patchified", fontsize=16)
plt.show()

In [None]:
patch_size = 16
embed_dim = 768 #D: number of feature / activation maps

conv2d = nn.Conv2d(
    in_channels = 3 , 
    out_channels = embed_dim , 
    kernel_size = patch_size , 
    stride = patch_size , 
    padding = 0
)

In [None]:
plt.imshow(image.permute(1, 2, 0))
plt.title(classes[label])
plt.axis("off")

In [None]:
print("image shape before: " , image.shape)
# expected shape for the Conv2d is (N , C , H , W)
image = image.unsqueeze(0) 
print("image shape after" , image.shape)

img_conv = conv2d(image)
print("image shape after Conv2D : " , img_conv.shape)
# output shape [batch_size , embed_dim , feature_map_height , feature_map_width]

In [None]:
import random 

random_indices = random.sample(range(0 , 758) , k=10)

print(f"showing random convolutional feature maps from indices : {random_indices}" )

fig , axs = plt.subplots(nrows = 1 , ncols=len(random_indices) , figsize=(12,12))

for i,idx in enumerate(random_indices):
    image_conv_feat_map = img_conv[: , idx , : , :]
    
    axs[i].imshow(image_conv_feat_map.squeeze().detach().numpy())
    # detach takes a copy of the tensor that disconnected from the gradient graph 
    axs[i].set(xticklabels = [] , yticklabels = [] , xticks = [] ,
               yticks = []) 

In [None]:
# After turning the image into patch embedding , its time to flatten it

print(f"shape of the output of the conv : {img_conv.shape} -> [batch , embedding_dim , feature_map_height, feature_map_width]")
# what we want to flatten is the spatial dimension of the feature map

flatten = nn.Flatten()
t = flatten(img_conv)
print(t.shape) # this is flattening the whole tensor (1 , 768 * 14 * 14)

flatten = nn.Flatten(start_dim= 2 , end_dim = 3 )
t = flatten(img_conv)

print(t.shape)

In [None]:
plt.imshow(image.permute(1 , 2, 0))
plt.title(classes[label])
plt.axis(False)
print(f"original image shape: {image.shape}")

image_conv = conv2d(image.unsqueeze(0)) # (N , C, H , W)
print(f"shape after convolution: {image_conv.shape}")# 768 feature maps each of size 14 x 14

image_flatten = flatten(image_conv)
print(f"shape after flatten: {image_flatten.shape}") # 768 feature maps each of flattend size of 196

print("desired shape : N x(P^2 * C) --> (196 , 768)")

image_final = image_flatten.permute(0 , 2 , 1)
print(f"Patch embedding final shape : {image_final.shape} ")

In [None]:
single_sample = image_final[: , : , 0]
print("shape of this sample : " , single_sample.shape) # 2D image to 1D embedding vector
plt.figure(figsize=(20 , 20))
plt.imshow(single_sample.detach().numpy())
plt.title(f"Flattened feature map shape: {image_final.shape}")
plt.axis("off")

Summing up everything till now in a module

In [None]:
class PatchEmbedding(nn.Module):
    """
    turns 2d input image into a 1D sequence learnable embedding vector
    
    Args:
    in_channels(int) : Number of channels in the input image. (Default set to 3)
    patch_size (int) : size of the patch (Defaults to 16)
    embed_dim (int) : number of feature maps (Defaults to 768) 
    """
    def __init__(self ,
                 in_channels: int = 3 ,
                 patch_size : int = 16 , 
                 embed_dim: int = 768 ):
        super().__init__()
        self.conv2d = nn.Conv2d(in_channels= in_channels , 
                           out_channels = embed_dim , 
                           kernel_size= patch_size , 
                           stride = patch_size , 
                           padding = 0)
        
        self.flatten = nn.Flatten(start_dim = 2 , end_dim = 3)
    
    def forward(self , image):
        height = image.shape[-1]
        assert height % patch_size == 0, f"Input image size must be divisble by patch size, image shape: {height}, patch size: {patch_size}"
        patched_image = self.conv2d(image)
        flattened_image = self.flatten(patched_image)
        
        return flattened_image.permute(0 , 2 , 1)

In [None]:
patchify = PatchEmbedding(in_channels=3,
                          patch_size=16,
                          embed_dim=768)
# Pass a single image through
print(f"Input image shape: {image.unsqueeze(0).shape}")
patch_embedded_image = patchify(image.unsqueeze(0)) # add an extra batch dimension on the 0th index, otherwise will error
print(f"Output patch embedding shape: {patch_embedded_image.shape}")

In [None]:
# Create random input sizes
random_input_image = ( 3, 224, 224)
random_input_image_error = ( 3, 250, 250) # will error because image size is incompatible with patch_size

# Get a summary of the input and outputs of PatchEmbedding (uncomment for full output)
summary(PatchEmbedding().to(device),
        input_size=random_input_image, # try swapping this for "random_input_image_error"
        batch_size = batch_size)

In [None]:
print("patch_embedded_image shape " , patch_embedded_image.shape)

batch_size , embedding_dim = patch_embedded_image.shape[0] , patch_embedded_image.shape[-1]

class_token = nn.Parameter(torch.ones((batch_size , 1 , embedding_dim)) ,
                           requires_grad= True)# [batch_size, number_of_tokens, embedding_dimension]

print(f"class_token: {class_token[: , : , :10]}")
print(f"class_token shape: {class_token.shape}")

In [None]:
updated_patch = torch.cat((class_token , patch_embedded_image) , 
                          dim = 1)

print(f"shape of updated patch embedding with class token: {updated_patch.shape}")

# before each feature map was represented with 196 tokens  , however now after adding the class token each will be represented by 197

print(updated_patch)

## Position Embedding

Epos --> (N+1 X D) where N is the number of patches and the 1 is for class token

In [None]:
updated_patch , updated_patch.shape


In [None]:
numb_patches = int((height * width) / patch_size **2)

embedding_dim = updated_patch.shape[2]

position_embedding = nn.Parameter(
    torch.ones( 1 , numb_patches +1 , embedding_dim),
                        requires_grad= True        ) # learnable

position_embedding[: , :10 , :10] , position_embedding.shape


In [None]:
patch_with_position = updated_patch + position_embedding
patch_with_position , patch_with_position.shape