# Vision Transformer

In this notebook we will be replicating the ["An Image is Worth 16x16 Words: Transformers for Image Recognition"](https://arxiv.org/abs/2010.11929) Vision Transformer from scratch.

Imports

In [49]:
import einops 
from tqdm.notebook import tqdm
import torch
from torch import nn
import torchvision
import torch.optim as optim
from torchvision.transforms import Compose,Resize,ToTensor,Normalize,RandomHorizontalFlip,RandomCrop

In [50]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f'device : {device}')
#------------------------Hyperparams--------------#
patch_size = 16
latent_size = 768 
n_channels = 3
num_heads = 12
num_encoders = 12
dropout = 0.1
num_classes = 10
size = 224

epochs = 10
base_lr = 10e-3
weight_decay = 0.03
batch_size = 1

device : cuda:0


In [51]:
#-------------------------------Linear Inupt Projection---------------------------------#
class InputEmbedding(nn.Module):
    def __init__(self,patch_size=patch_size,n_channels=n_channels,device=device,latent_size=latent_size,batch_size=batch_size):
        """Initialize all the variables"""

        super(InputEmbedding,self).__init__()
        self.latent_size = latent_size
        self.batch_size = batch_size
        self.device = device
        self.patch_size = patch_size
        self.n_channels = n_channels
        self.input_size = self.patch_size*self.patch_size*self.n_channels   # width*height*channels
        
        # Linear Projection
        self.LinearProjection = nn.Linear(self.input_size,self.latent_size)

        # Class Token
        self.class_token = nn.Parameter(torch.rand(self.batch_size,1,self.latent_size)).to(self.device)

        # Positional Embedding
        self.pos_embedding = nn.Parameter(torch.rand(self.batch_size,1,self.latent_size)).to(self.device)

    def forward(self,input_data):
        """
        Takes the input data in batch_size and creates patches. Applys Linear Projection on it and adds class_token and pos_embedding to each patch.

        Args : 
            Input : input_data.
            Output: input_patches and pos_embedding.
        """
        input_data = input_data.to(self.device)
        patches = einops.rearrange(
            input_data, 'b c (h h1) (w w1) -> b (h w) (h1 w1 c)',
            h1 = self.patch_size , w1 = self.patch_size
            )
        
        print(f'Input data : {input_data.shape}')
        print(f'Patches : {patches.shape}')

        linear_projection = self.LinearProjection(patches).to(self.device)
        b,n,_ = linear_projection.shape 
        linear_projection = torch.cat((self.class_token,linear_projection),dim=1)
        pos_embed = einops.repeat(self.pos_embedding, ' b 1 d -> b m d' , m = n+1)
        print(f'Linear Projection : {linear_projection.size()}')
        print(f'Positional : {pos_embed.size()}')
        
        linear_projection = linear_projection+pos_embed

        return linear_projection

In [52]:
# Understanding Code
test_input = torch.randn((1,3,224,224))
test_class = InputEmbedding().to(device)
embed_test = test_class(test_input)

Input data : torch.Size([1, 3, 224, 224])
Patches : torch.Size([1, 196, 768])
Linear Projection : torch.Size([1, 197, 768])
Positional : torch.Size([1, 197, 768])


In [53]:
#-----------------------------Encoder----------------------#
class EncoderBlock(nn.Module):
    def __init__(self,latent_size=latent_size,num_heads=num_heads,device=device,dropout= dropout):
        """
        Initialize variables for Encoder Layer.
        """
        super(EncoderBlock,self).__init__()

        self.latent_size = latent_size
        self.num_heads = num_heads
        self.device = device
        self.dropout = dropout

        # Normalization 
        self.normalize = nn.LayerNorm(self.latent_size) 
        self.multihhead = nn.MultiheadAttention(self.latent_size,self.num_heads,dropout=self.dropout)
        self.enc_MLP = nn.Sequential(
            nn.Linear(self.latent_size,self.latent_size*4),          # Given in the paper (output of linear is 4 times)
            nn.GELU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.latent_size*4,self.latent_size),          # Given in the paper (output is back to latent_size)  
            nn.Dropout(self.dropout)
        )

    def forward(self,embedded_patches):
        """
        Define : 
        Args: 
            Input : 
            Output : 
        """
        firstnorm_out = self.normalize(embedded_patches)
        attention_output = self.multihhead(firstnorm_out,firstnorm_out,firstnorm_out)[0]          # return a tuple from which we take only the 1st value.
        first_resudial = attention_output + embedded_patches

        secondnorm_out = self.normalize(first_resudial)
        mlp_output = self.enc_MLP(secondnorm_out)
        

        return mlp_output + first_resudial


In [54]:
test_encoder = EncoderBlock().to(device)
test_encoder(embed_test)

tensor([[[ 1.4165,  1.2501,  0.9870,  ...,  1.0960,  0.9106,  0.8004],
         [ 0.0430,  1.0837, -0.2006,  ...,  1.2386,  1.0093,  2.0220],
         [ 0.7353,  0.8299,  1.2798,  ..., -0.0305,  1.8974, -0.5246],
         ...,
         [ 0.1577,  0.4087,  0.3084,  ...,  0.5530,  1.5896,  0.0213],
         [ 0.3268,  0.6979, -0.2313,  ...,  2.1076,  0.3836, -0.2035],
         [-0.3212, -0.8197, -0.2331,  ...,  0.7070, -0.3289,  1.0611]]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [61]:
#-------------------------ViT--------------------------#
class VisionTransfomer(nn.Module):
    def __init__(self, num_encoders=num_encoders, latent_size=latent_size, device=device, num_classes=num_classes, dropout=dropout):
        super(VisionTransfomer, self).__init__()
        self.num_encoder = num_encoders
        self.latent_size = latent_size
        self.device = device
        self.num_classes = num_classes
        self.dropout = dropout

        self.embedding = InputEmbedding()

        # Create the stack of encoders
        self.encStack = nn.ModuleList([EncoderBlock() for i in range(self.num_encoder)])

        self.MLP_head = nn.Sequential(
            nn.LayerNorm(self.latent_size),
            nn.Linear(self.latent_size, self.latent_size),
            nn.Linear(self.latent_size, self.num_classes)
        )

    def forward(self, test_input):
        enc_output = self.embedding(test_input)

        for enc_layer in self.encStack:
            enc_output = enc_layer(enc_output)

        cls_token_embed = enc_output[:, 0]

        return self.MLP_head(cls_token_embed)

In [62]:
model = VisionTransfomer().to(device)
vit_output = model(test_input)
print(vit_output)
print(vit_output.size())

Input data : torch.Size([1, 3, 224, 224])
Patches : torch.Size([1, 196, 768])
Linear Projection : torch.Size([1, 197, 768])
Positional : torch.Size([1, 197, 768])
tensor([[-0.0344,  0.0879, -0.1353,  0.2790, -0.4280,  0.4329,  0.2900,  0.1555,
         -0.3983, -0.0051]], device='cuda:0', grad_fn=<AddmmBackward0>)
torch.Size([1, 10])
