<a href="https://colab.research.google.com/github/SoubhikMajumdar/VisionTransformer/blob/main/ViTnotebookv3_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# STEP 1: Install required packages
!pip install einops torchvision --quiet


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m81.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m78.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m38.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m38.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m18.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, emb_dim=128):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, emb_dim, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, emb_dim))

    def forward(self, x):
        B = x.shape[0]
        x = self.proj(x)                         # [B, emb_dim, H/P, W/P]
        x = rearrange(x, 'b c h w -> b (h w) c') # [B, N, D]
        cls_token = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_token, x), dim=1)     # [B, N+1, D]
        return x + self.pos_embed

class TransformerEncoder(nn.Module):
    def __init__(self, emb_dim=128, num_heads=4, mlp_dim=256, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(emb_dim)
        self.attn = nn.MultiheadAttention(emb_dim, num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(emb_dim)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, emb_dim),
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=4, emb_dim=128, depth=9,
                 num_heads=4, mlp_dim=256, num_classes=10):
        super().__init__()
        self.embed = PatchEmbedding(img_size, patch_size, in_channels=3, emb_dim=emb_dim)
        self.blocks = nn.Sequential(*[
            TransformerEncoder(emb_dim, num_heads, mlp_dim) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(emb_dim)
        self.cls_head = nn.Linear(emb_dim, num_classes)

    def forward(self, x):
        x = self.embed(x)
        x = self.blocks(x)
        x = self.norm(x)
        cls_token = x[:, 0]  # First token for classification
        return self.cls_head(cls_token)


In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize(32),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Resize(32),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=128, shuffle=True)
testloader = DataLoader(testset, batch_size=128, shuffle=False)


100%|██████████| 9.91M/9.91M [00:00<00:00, 57.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.64MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.4MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.66MB/s]


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionTransformer().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(20):
    model.train()
    total_loss, total_correct = 0, 0
    for imgs, labels in trainloader:
        imgs, labels = imgs.to(device), labels.to(device)
        preds = model(imgs)
        loss = loss_fn(preds, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_correct += (preds.argmax(1) == labels).sum().item()

    acc = 100. * total_correct / len(trainset)
    print(f"Epoch {epoch+1}: Loss={total_loss:.3f}, Accuracy={acc:.2f}%")


Epoch 1: Loss=191.788, Accuracy=86.62%
Epoch 2: Loss=61.260, Accuracy=96.04%
Epoch 3: Loss=45.383, Accuracy=97.00%
Epoch 4: Loss=37.984, Accuracy=97.42%
Epoch 5: Loss=30.531, Accuracy=97.91%
Epoch 6: Loss=28.656, Accuracy=98.09%
Epoch 7: Loss=27.331, Accuracy=98.07%
Epoch 8: Loss=23.815, Accuracy=98.34%
Epoch 9: Loss=25.420, Accuracy=98.25%
Epoch 10: Loss=21.541, Accuracy=98.53%
Epoch 11: Loss=19.723, Accuracy=98.64%
Epoch 12: Loss=19.508, Accuracy=98.62%
Epoch 13: Loss=19.617, Accuracy=98.68%
Epoch 14: Loss=17.007, Accuracy=98.88%
Epoch 15: Loss=16.003, Accuracy=98.89%
Epoch 16: Loss=15.663, Accuracy=98.87%
Epoch 17: Loss=15.626, Accuracy=98.85%
Epoch 18: Loss=16.107, Accuracy=98.81%
Epoch 19: Loss=14.800, Accuracy=98.94%
Epoch 20: Loss=13.707, Accuracy=99.04%


In [None]:
model.eval()
correct = 0
with torch.no_grad():
    for imgs, labels in testloader:
        imgs, labels = imgs.to(device), labels.to(device)
        preds = model(imgs)
        correct += (preds.argmax(1) == labels).sum().item()

acc = 100. * correct / len(testset)
print(f"Test Accuracy: {acc:.2f}%")


Test Accuracy: 98.03%


In [None]:
# Count total trainable parameters
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total Trainable Parameters: {total_params:,}")

Total Trainable Parameters: 1,208,586
