In [11]:
import torch
import torchvision
from torch import nn

device = "cuda" if torch.cuda.is_available() else "cpu"


height = 224        #H
width = 224         #W
color_channels = 3  #C
patch_size = 16     #P
number_of_patches = int(height*width/patch_size**2) #N = HW/P**2

input_shape = (height , width , color_channels)
output_shape = (number_of_patches , patch_size**2*color_channels)

print(f"The input shape (single 2D Image) : {input_shape}")
print(f"The output shape (flattened into patches) : {output_shape}")

The input shape (single 2D Image) : (224, 224, 3)
The output shape (flattened into patches) : (196, 768)


In [2]:
#Create Patch Embedding Layer
class PatchEmbedding(nn.Module):
    def __init__(self , patch_size = 16 , in_channels = 3 , out_channels = 768):
        super(PatchEmbedding , self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.patch = nn.Conv2d(in_channels = in_channels , 
                              out_channels = out_channels , 
                              kernel_size = patch_size  ,
                              stride = patch_size , padding = 0)
        #Flatten
        self.flatten = nn.Flatten(start_dim = 2 , end_dim = 3)
    def forward(self , x):
        return self.flatten(self.patch(x)).permute(0,2,1) #-->[batch_size , N , P**2*C]

#Testing
input_image = torch.Tensor(1 , 3 , 224 , 224)
patch = PatchEmbedding(patch_size = 16 , in_channels = 3 , out_channels = 768)
print(f'The input image has the shape of {input_image.shape} -> [Batch_size , Color_channels , Height , Width]')
print("The image gets converted to 1D sequence of flattened 2d patches with size N x (P^2 x C)")
patch = PatchEmbedding(patch_size = 16 , in_channels = 3 , out_channels = 768)
patched_img = patch(input_image)
print(f'Output Shape : {patched_img.shape}')

The input image has the shape of torch.Size([1, 3, 224, 224]) -> [Batch_size , Color_channels , Height , Width]
The image gets converted to 1D sequence of flattened 2d patches with size N x (P^2 x C)
Output Shape : torch.Size([1, 196, 768])


In [13]:
#Create class token
batch_size = patched_img.shape[0]
embedding_dimension = patched_img.shape[-1]
class_token = nn.Parameter(torch.randn(batch_size, 1 , embedding_dimension),requires_grad = True)
print(class_token[:,:,:10])
print(f'Class token shape : {class_token.shape}')

#Add class token to the beginning of patch embedding
patch_embedding_with_class_token = torch.cat((class_token , patched_img), dim = 1)
print(f"Embedding patches with class token : {patch_embedding_with_class_token.shape} -> [Batch_size , num_patches , embedding_size]")

tensor([[[ 0.9788, -0.0973,  0.7067, -0.2775, -0.1617,  0.8443, -1.6049,
          -0.5992, -1.5365,  0.4258]]], grad_fn=<SliceBackward0>)
Class token shape : torch.Size([1, 1, 768])
Embedding patches with class token : torch.Size([1, 197, 768]) -> [Batch_size , num_patches , embedding_size]


In [14]:
#According to the paper position embedding has the shape of (N+1)*D
position_embedding = nn.Parameter(torch.randn(1 , number_of_patches+1 , embedding_dimension),requires_grad = True)
print(position_embedding[:,:,:10])
print(f"Position embedding shape : {position_embedding.shape}")

#Add position embedding to the beginning of patch embedding with class_token
class_patch_embedding = position_embedding + patch_embedding_with_class_token
print(f'Position Embedding with class token {class_patch_embedding}')
print(f'Shape {class_patch_embedding.shape}')

tensor([[[-0.8642, -0.0402, -0.0373,  ...,  0.6045,  0.9637,  1.0230],
         [-0.5235,  0.9575, -1.4920,  ..., -1.0307,  0.4456, -0.6138],
         [-0.4988,  0.9267,  0.0124,  ...,  1.2057, -0.6846, -1.5696],
         ...,
         [-0.1430,  1.2397, -0.8184,  ..., -1.2074, -1.8293,  0.9498],
         [ 1.6196, -2.5626,  0.9129,  ...,  0.9080,  0.4245,  1.2222],
         [ 1.1143, -2.7331, -1.4910,  ..., -0.9634,  1.6950,  0.3367]]],
       grad_fn=<SliceBackward0>)
Position embedding shape : torch.Size([1, 197, 768])
Position Embedding with class token tensor([[[ 1.1466e-01, -1.3756e-01,  6.6942e-01,  ...,  2.0557e+00,
          -3.0662e-02,  1.3205e+00],
         [-4.8788e-01,  9.2873e-01, -1.5038e+00,  ...,  5.5978e-01,
          -3.0382e-01, -7.6650e-01],
         [-4.6325e-01,  8.9795e-01,  5.0689e-04,  ..., -1.0950e+00,
          -1.8274e-01,  2.6580e+00],
         ...,
         [-1.0736e-01,  1.2109e+00, -8.3023e-01,  ..., -1.3148e+00,
          -6.5303e-01, -1.0779e-01],
  

In [15]:
#Equation-2 : Multi-head Attention Block
class Msa(nn.Module):
    def __init__(self , embedding_dimension = 768 , num_heads = 12):
        super(Msa , self).__init__()
        self.embedding_dimension = embedding_dimension
        self.num_heads = num_heads
        self.layer_norm = nn.LayerNorm(normalized_shape = embedding_dimension)
        self.msa = nn.MultiheadAttention(embed_dim = embedding_dimension , 
                                        num_heads = num_heads , batch_first = True)
    def forward(self , x):
        x = self.layer_norm(x)
        attn_output , _ = self.msa(query = x , key = x , value = x , need_weights = False)
        return attn_output


msa = Msa(embedding_dimension = 768 , num_heads = 12)
patched_img_with_msa = msa(class_patch_embedding)
print(f'Input shape of MSA : {class_patch_embedding.shape}')
print(f'Output shape of MSA : {patched_img_with_msa.shape}')

Input shape of MSA : torch.Size([1, 197, 768])
Output shape of MSA : torch.Size([1, 197, 768])


In [16]:
#Equation-3 : Multi layer Perceptron Block --- Contains two layers wth a GELU non-linearity
class MLPBlock(nn.Module):
    def __init__(self, embedding_dimension = 768 , mlp_size = 3072):
        super(MLPBlock , self).__init__()
        self.embedding_dimension = embedding_dimension
        self.mlp_size = mlp_size
        self.layer_norm = nn.LayerNorm(normalized_shape = embedding_dimension)
        self.mlp = nn.Sequential(
            nn.Linear(in_features = 768 , out_features = 3072),
            nn.GELU(),
            nn.Linear(in_features = 3072 , out_features = embedding_dimension)
        )
    def forward(self , x):
        return self.mlp(self.layer_norm(x))


mlp = MLPBlock(embedding_dimension = 768 , mlp_size = 3072)
patched_image_through_mlp = mlp(patched_img_with_msa)
print(f'Input for MLP block : {patched_img_with_msa.shape}')
print(f'Output for MLP block : {patched_image_through_mlp.shape}')

Input for MLP block : torch.Size([1, 197, 768])
Output for MLP block : torch.Size([1, 197, 768])


In [17]:
#Create transformer encoder
class TransformerEncoder(nn.Module):
    def __init__(self , embedding_dimension = 768 , num_heads = 12 , mlp_size = 3072):
        super(TransformerEncoder , self).__init__()
        self.embedding_dimension = embedding_dimension 
        self.num_heads = num_heads
        self.mlp_size = mlp_size
        
        self.msa_block = Msa(embedding_dimension = embedding_dimension , num_heads = num_heads)
        self.mlp_block = MLPBlock(embedding_dimension = embedding_dimension , mlp_size = mlp_size)
    def forward(self , x):
        #Create Residual Connection (input = output + input of previous layer)
        x = self.msa_block(x) + x
        x = self.mlp_block(x) + x
        return x

In [10]:
#Assembling ViT
class Vision_transformer(nn.Module):
    def __init__(self , patch_size = 16 , embedding_dimension = 768 , num_heads = 12 ,
                mlp_size = 3072 , num_layers = 12 , img_size = 224 , num_classes = 1000):
        super().__init__()
        
        num_patches = int((img_size*img_size)//patch_size**2)

        self.class_embedding = nn.Parameter(torch.randn(1,1,embedding_dimension),requires_grad = True)
        self.position_embedding = nn.Parameter(torch.randn(1 , num_patches+1 , embedding_dimension),requires_grad = True)
        self.embedding_dropout = nn.Dropout(p = 0.1)
        self.patch_embedding = PatchEmbedding(patch_size = patch_size , in_channels = 3 , out_channels = embedding_dimension)
        self.transformer_encoder = nn.Sequential(*[TransformerEncoder(embedding_dimension = embedding_dimension,
                                                                     num_heads = num_heads,
                                                                     mlp_size = mlp_size) for i in range(num_layers)])
        #MLP head
        self.classifier = nn.Sequential(
            nn.LayerNorm(normalized_shape = embedding_dimension),
            nn.Linear(in_features = embedding_dimension , out_features = num_classes)
        )
    def forward(self , x):
        batch_size = x.shape[0]
        class_token = self.class_embedding.expand(batch_size , -1 ,-1)
        x = self.patch_embedding(x)
        #Concat class token with patch embedding (equation 1)
        x = torch.cat((class_token , x) , dim = 1)
        #Add position embedding with class token to the beginning (equation 1)
        x = self.position_embedding+x
        #Run embedding Dropout (Appendix B.1)
        x = self.embedding_dropout(x)
        #Equation 2 , 3
        x = self.transformer_encoder(x)
        #Equation 4
        x = self.classifier(x[:,0])
        return x