## Training a Simple ViT using Pytorch [20 marks]

In this question, you will implement a vision transformer based image classification model using pytorch. 
Implement a basic version of vision transformer (https://arxiv.org/pdf/2010.11929.pdf), that first divides an image into patches and then passes them through a set of multihead self attention modules to perform classification. Please check out the details in the paper. Note that classification happens from the CLS token of the final transformer layer.

[Experiment 1] Train this model on the CIFAR-10 dataset for 10-class classification. Keep the number of attention heads to be 4 for all the experiments. [6 points]

[Experiment 2] Try out different patch sizes (like 4x4, 8x8, 16x16). You can divide the image into both overlapping and non-overlapping patches. [4 points]

[Experiment 3] How does model performance change if you vary the number of attention heads? [4 points]

[Experiment 4] Perform classification by using the CLS token from different layers of the model. [6 points]
Create a detailed report of all the experiments and analyses.


1. Convert image into patches
2. Vectorize the patches d1 X d2 X d3 --> d1d2d3 X 1 (One vector per patch)
3. Dense layer to these vectors (no non-linear activation) All have same W and same b ; z1=w1Tx1 + b and postions ; x1 is say vector 1 (patch 1).Dense layer takes input of  postions to create z1,z2,..zn postional embeddings
4. CLS token input to an embeddinglayer to create z0 vector (same shape as other z's)..o/pof transformer here used for classification
5. z0,...zn are i/p to multiheaded self attention (n+1 vectors o/p)
6. Apply a dense layer ((n+1 vectors o/p))
7. add as many as Multi-headed self attention plus dense layers as u want ( jointly called transformer encoder network)
8. At the last layer we focus on c0 vector and feed to a softmax classifier. o/p(say p) has shape equal to number of classes
9. During training Loss is CE of p and GT
10. Compute gradient descent. 

Potential Hyperparamters
1. patch-sizes
2. overlapping and non-overlapping counter
3. number of attention heads





In [1]:
import matplotlib.pyplot as pyplot
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import os
from torch.utils.data import DataLoader as dataloader
import torchvision.datasets as datasets
import torch.optim as optim
import gc




In [2]:
device="cuda" if torch.cuda.is_available() else "cpu"
print(device)


cuda


In [3]:
    
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

train_data=datasets.CIFAR10(root="/root/DLVA Assig-2/Assignment2/Assignment2/dataset", train=True, download=True,transform=transform)
test_data=datasets.CIFAR10(root="/root/DLVA Assig-2/Assignment2/Assignment2/dataset", train=False, download=True,transform=transform)

train_dataloader=dataloader(train_data, batch_size=300, shuffle=True)
test_dataloader=dataloader(test_data, batch_size=64, shuffle=True)



Files already downloaded and verified
Files already downloaded and verified


In [4]:
#Hyperprameters
epochs=25
patch_size=4
overlapping=0 #indicates how many layers of overlapping between patches
num_heads=4


In [5]:
class patch(nn.Module):
    def __init__(self,overlapping,patch_size,in_channels:int=3):
        super(patch,self).__init__()        
        self.patcher=nn.Conv2d(in_channels=in_channels,out_channels=patch_size*patch_size*3,kernel_size=patch_size,stride=patch_size-overlapping,padding=0)
        self.Flatten=nn.Flatten(start_dim=2,end_dim=3)

    def forward(self,x):
    
        x_patch=self.patcher(x)
        x_flat=self.Flatten(x_patch)
        return torch.permute(x_flat,(0,2,1))
        
class multi_head_attention(nn.Module):
    def __init__(self,patch_size,num_heads):
        super(multi_head_attention,self).__init__()

        self.layer_norm=nn.LayerNorm(normalized_shape=patch_size*patch_size*3)
        self.multihead_attn=nn.MultiheadAttention(embed_dim=patch_size*patch_size*3,num_heads=num_heads,batch_first=True)

    def forward(self,x):
        x=self.layer_norm(x)
        attn_ouput,_=self.multihead_attn(query=x,key=x,value=x,need_weights=False)
        return attn_ouput

class MLP(nn.Module):
    def __init__(self,patch_size,mlp_size:int=3072,dropout:float=0.1):
        super(MLP,self).__init__()
        self.layer_norm=nn.LayerNorm(normalized_shape=patch_size*patch_size*3)
        self.mlp=nn.Sequential(
            nn.Linear(patch_size*patch_size*3,mlp_size),
            nn.GELU(),
            nn.Dropout(p=dropout,inplace=True),
            nn.Linear(mlp_size,patch_size*patch_size*3),
            nn.Dropout(p=dropout,inplace=True)
        )

    def forward(self,x):
        x=self.layer_norm(x)
        x=self.mlp(x)
        return x
        
class Transformer_Encoder(nn.Module):
    def __init__(self,patch_size=patch_size,num_heads=num_heads,mlp_size=3072,mlp_dropout:float=0.1):
        super(Transformer_Encoder,self).__init__()
        self.msa=multi_head_attention(patch_size,num_heads)
        self.mlp=MLP(patch_size,mlp_size=mlp_size,dropout=mlp_dropout)

    def forward(self,x):
        x=self.msa(x) + x
        x=self.mlp(x) + x
        return x


class ViT(nn.Module):
    def __init__(self,patch_size,num_heads,img_size:int=32,in_channels:int=3,num_t_layers=12,mlp_droput:float=0.1,
                 mlp_size=3072,embedding_dropout:float=0.1,num_transformer_layers:int=10):
        super(ViT,self).__init__()

        self.num_patches= (img_size*img_size)//patch_size**2
        self.class_embedding=nn.Parameter(data=torch.randn(1,1,patch_size*patch_size*3),requires_grad=True)
        self.position_emdedding=nn.Parameter(data=torch.randn(1,self.num_patches+1,patch_size*patch_size*3),requires_grad=True)
        self.embedding_drop=nn.Dropout(p=embedding_dropout)
        self.patch_embedding=patch(in_channels=in_channels,overlapping=overlapping,patch_size=patch_size)
        self.transformer_encoder=nn.Sequential(*[Transformer_Encoder(patch_size=patch_size,num_heads=num_heads,mlp_size=mlp_size,mlp_dropout=0.1) for _ in range(num_transformer_layers)])
        self.classifier=nn.Sequential(nn.LayerNorm(normalized_shape=patch_size*patch_size*3),nn.Linear(in_features=patch_size*patch_size*3,out_features=10))

    def forward(self,x):
        batch_size=x.shape[0]
        cls_token=self.class_embedding.expand(batch_size,-1,-1)
        x=self.patch_embedding(x)
        x=torch.cat((cls_token,x),dim=1)
        x=self.position_emdedding+x
        x=self.embedding_drop(x)
        x=self.transformer_encoder(x)
        x=self.classifier(x[:,0])
      
        return x
                                                                            
import imp

file, pathname, description = imp.find_module('engine', ['./pytorch_deep_learning/going_modular/going_modular'])
my_module = imp.load_module('engine', file, pathname, description)

import engine

  import imp


In [6]:


vit=ViT(patch_size,num_heads)
optimizer = optim.Adam(vit.parameters(), lr=5e-3)
loss=nn.CrossEntropyLoss()

torch.cuda.empty_cache()
gc.collect()

results =engine.train( model=vit,
                       train_dataloader=train_dataloader,
                       test_dataloader=test_dataloader,
                       optimizer=optimizer,
                       loss_fn=loss,
                       epochs=epochs,
                       device=device)
    

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 2.0756 | train_acc: 0.2007 | test_loss: 1.8880 | test_acc: 0.2722
Epoch: 2 | train_loss: 1.7820 | train_acc: 0.3341 | test_loss: 1.6573 | test_acc: 0.3883
Epoch: 3 | train_loss: 1.6065 | train_acc: 0.4092 | test_loss: 1.5219 | test_acc: 0.4399
Epoch: 4 | train_loss: 1.5299 | train_acc: 0.4386 | test_loss: 1.4715 | test_acc: 0.4576
Epoch: 5 | train_loss: 1.4756 | train_acc: 0.4580 | test_loss: 1.4297 | test_acc: 0.4827
Epoch: 6 | train_loss: 1.4272 | train_acc: 0.4767 | test_loss: 1.4056 | test_acc: 0.4924
Epoch: 7 | train_loss: 1.3971 | train_acc: 0.4912 | test_loss: 1.3623 | test_acc: 0.5015
Epoch: 8 | train_loss: 1.3630 | train_acc: 0.5028 | test_loss: 1.3274 | test_acc: 0.5164
Epoch: 9 | train_loss: 1.3448 | train_acc: 0.5108 | test_loss: 1.3094 | test_acc: 0.5209
Epoch: 10 | train_loss: 1.3184 | train_acc: 0.5211 | test_loss: 1.3084 | test_acc: 0.5239
Epoch: 11 | train_loss: 1.3016 | train_acc: 0.5271 | test_loss: 1.3015 | test_acc: 0.5244
Epoch: 12 | train_l