<a href="https://colab.research.google.com/github/OjasKhandelwal/Vision-Transformer-from-scratch/blob/main/Vision_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [80]:
import torch
import torchvision
import torch.utils.data as dataloader
import torch.nn as nn
import torch.nn.functional as F

In [81]:
#tranformation from PIL to tensor format
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

In [82]:
#import dataset
train_dataset = dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True , transform = transform)
validation_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True)

In [83]:
no_of_channels = 1 #as its a blackNwhite image  #if it was rgb image the n = 3
num_classes = 10
img_size = 28
patch_size = 7
batch_size = 64
num_patches = (img_size // patch_size) ** 2
embedding_dim = 64
attention_heads = 4
transformer_blocks = 4
learning_rate = 0.001
num_epochs = 5
mlp_hidden_nodes = 128

In [84]:
#define batches

train_loader = dataloader.DataLoader(train_dataset , batch_size = 64 , shuffle = True)
val_loader = dataloader.DataLoader(validation_dataset , batch_size = 64 , shuffle = True)

In [85]:
#Part 1: Patch Embedding
class PatchEmbedding(nn.Module):
  def __init__(self):
    super().__init__()
    self.patch_embedding = nn.Conv2d(in_channels = no_of_channels , out_channels = embedding_dim , kernel_size = patch_size , stride = patch_size)

  def forward(self , x):
    x = self.patch_embedding(x)
    #flattening
    x = x.flatten(2)

    x = x.transpose(1,2)
    return x

In [86]:
#Part 2: Transformer Encoder

class TransformerEncoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.layer_norm1 = nn.LayerNorm(embedding_dim)
    self.layer_norm2 = nn.LayerNorm(embedding_dim)
    self.multihead_attention = nn.MultiheadAttention(embedding_dim , num_heads = attention_heads , batch_first=True)
    self.mlp = nn.Sequential(
        nn.Linear(embedding_dim , mlp_hidden_nodes),
        nn.GELU(),
        nn.Linear(mlp_hidden_nodes , embedding_dim)
        )

  def forward(self,x):
    residual1 = x
    x = self.layer_norm1(x)
    x = self.multihead_attention(x,x,x)[0]
    x = residual1 + x
    residual2 = x
    x = self.layer_norm2(x)
    x = self.mlp(x)
    x = x+residual2
    return x


In [87]:
#Part 3: MLP Head for Classification

class MLP_Head(nn.Module):
  def __init__(self):
    super().__init__()
    self.layer_norm1 = nn.LayerNorm(embedding_dim)
    self.mlp_head = nn.Linear(embedding_dim, num_classes)

  def forward(self,x):
    x = self.layer_norm1(x)
    x = self.mlp_head(x)
    return x



In [88]:
class VisionTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embedding = PatchEmbedding()
        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, embedding_dim))
        self.transformer_encoder = nn.Sequential(*[TransformerEncoder() for _ in range(transformer_blocks)])
        self.mlp_head = MLP_Head()

    def forward(self, x):
        b = x.size(0)
        x = self.patch_embedding(x)
        cls_token = self.cls_token.expand(b, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = x + self.pos_embedding
        x = self.transformer_encoder(x)
        x = x[:, 0]
        x = self.mlp_head(x)
        return x


In [89]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionTransformer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()


In [90]:
#training loop

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct_epoch = 0
    total_epoch = 0
    print(f"Epoch {epoch+1}/{num_epochs}")

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()


        total_loss += loss.item()


        pred = output.argmax(dim=1)
        correct_epoch += (pred == target).sum().item()
        total_epoch += target.size(0)


        batch_acc = 100.0 * (pred == target).sum().item() / target.size(0)

        if batch_idx % 100 == 0:
            print(f"Batch {batch_idx}/{len(train_loader)} - Loss: {loss.item():.4f}, Accuracy: {batch_acc:.2f}%")


    epoch_acc = 100.0 * correct_epoch / total_epoch
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} - Avg Loss: {avg_loss:.4f}, Accuracy: {epoch_acc:.2f}%\n")


Epoch 1/5
Batch 0/938 - Loss: 2.6119, Accuracy: 12.50%
Batch 100/938 - Loss: 0.6576, Accuracy: 76.56%
Batch 200/938 - Loss: 0.3924, Accuracy: 89.06%
Batch 300/938 - Loss: 0.1959, Accuracy: 90.62%
Batch 400/938 - Loss: 0.2048, Accuracy: 92.19%
Batch 500/938 - Loss: 0.0532, Accuracy: 100.00%
Batch 600/938 - Loss: 0.1011, Accuracy: 98.44%
Batch 700/938 - Loss: 0.2106, Accuracy: 92.19%
Batch 800/938 - Loss: 0.1835, Accuracy: 95.31%
Batch 900/938 - Loss: 0.1226, Accuracy: 95.31%
Epoch 1 - Avg Loss: 0.3668, Accuracy: 88.22%

Epoch 2/5
Batch 0/938 - Loss: 0.1167, Accuracy: 96.88%
Batch 100/938 - Loss: 0.1421, Accuracy: 96.88%
Batch 200/938 - Loss: 0.1058, Accuracy: 95.31%
Batch 300/938 - Loss: 0.0437, Accuracy: 98.44%
Batch 400/938 - Loss: 0.1506, Accuracy: 93.75%
Batch 500/938 - Loss: 0.1313, Accuracy: 96.88%
Batch 600/938 - Loss: 0.0459, Accuracy: 100.00%
Batch 700/938 - Loss: 0.1374, Accuracy: 95.31%
Batch 800/938 - Loss: 0.0578, Accuracy: 96.88%
Batch 900/938 - Loss: 0.2944, Accuracy: 92.