<a href="https://colab.research.google.com/github/LokeshSreenathJ/Transformers/blob/main/ViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.utils.data as dataloader
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

Data transformation for converting PIL to tensor format

In [2]:
data_transformation = transforms.Compose([transforms.ToTensor()])

In [3]:
train_dataset = torchvision.datasets.MNIST(root="./data", train = True, download=True, transform = data_transformation)
val_dataset = torchvision.datasets.MNIST(root="./data", train = False, download=True, transform = data_transformation)

100%|██████████| 9.91M/9.91M [00:00<00:00, 12.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 344kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.18MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.48MB/s]


Variables info

In [4]:
img_size = 28
num_channels = 1 #since it's grey scale image
patch_size = 7
num_patches = (img_size//patch_size)**2
token_dim = 32 # can be a any values, this represents how we are projecting a patch of a image into numeric vectors
num_heads = 4
transfomer_blocks = 4
batch_size = 64
num_clases = 10
mlp_hidden_dim = 64 # can be tweakable
learning_rate = 3e-4
epochs = 5

In [5]:
train_loader = dataloader.DataLoader(train_dataset, batch_size = 64, shuffle = True)
val_loader = dataloader.DataLoader(val_dataset, batch_size=64, shuffle=False)

Part 1 of ViT : Patch Embedding

In [6]:
class PatchEmbedding(nn.Module):
  def __init__(self):
    super().__init__() # this is initialize the parent class (nn.module) instance variables
    self.patch_embed = nn.Conv2d(num_channels, token_dim, kernel_size=patch_size, stride=patch_size)

  def forward(self,x):
    x = self.patch_embed(x)
    x = x.flatten(2)
    x = x.transpose(1,2)
    return x

Part 2 of ViT : Transformers Encoder

In [7]:
class TransformerEncoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.layernorm1 = nn.LayerNorm(token_dim)
    self.layernorm2 = nn.LayerNorm(token_dim)
    self.multihead_attention = nn.MultiheadAttention(token_dim, num_heads=num_heads)
    self.mlp = nn.Sequential(
        nn.Linear(token_dim, mlp_hidden_dim),
        nn.GELU(),
        nn.Linear(mlp_hidden_dim, token_dim)
    )

  def forward(self,x):

    residual1 = x
    x = self.layernorm1(x)
    x = self.multihead_attention(x,x,x)[0]
    x = x+ residual1

    residual2 = x
    x = self.layernorm2(x)
    x = self.mlp(x)
    x = x + residual2

    return x




Part 3 of ViT : MLP Classification head

In [8]:
class MLPHead(nn.Module):
  def __init__(self):
    super().__init__()
    self.layernorm = nn.LayerNorm(token_dim)
    self.mlp = nn.Linear(token_dim, num_clases)

  def forward(self,x):
    x = self.layernorm(x)
    x = self.mlp(x)

    return x  #output are logits


Part 1,2,3 combined

In [None]:
class VisionTransformer(nn.Module):
  def __init__(self):
    super().__init__()
    self.patch_embedding = PatchEmbedding()
    self.cls_token = nn.Parameter(torch.randn(1,1,token_dim))
    self.position_encoding = nn.Parameter(torch.randn(1,num_patches+1,token_dim))
    self.transformer_blocks = nn.Sequential(*(TransformerEncoder() for _ in range(transfomer_blocks)))
    self.mlphead = MLPHead()

  def forward(self,x):
    x = self.patch_embedding(x)
    x =