<a href="https://colab.research.google.com/github/AryanJadhao/Vision-Transformer/blob/main/chatGPT_sol_ViT(cifar10).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# ===============================
# 1. DATASET (aug + normalization)
# ===============================
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010]
    )
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010]
    )
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=2)


# ===============================
# 2. VISION TRANSFORMER MODEL
# ===============================

img_size = 32
patch_size = 4
num_patches = (img_size // patch_size) ** 2  # 8x8 = 64 patches

num_classes = 10
embedding_dim = 192
num_heads = 3               # 192/3 = 64 per head
mlp_hidden_dim = 768
num_layers = 6
dropout = 0.1


# Patch Embedding
class PatchEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.proj = nn.Conv2d(3, embedding_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        x = self.proj(x)              # B, 192, 8, 8
        x = x.flatten(2).transpose(1, 2)  # B, 64, 192
        x = self.norm(x)
        return x


# Transformer Encoder Block
class TransformerEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.attn  = nn.MultiheadAttention(embedding_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(embedding_dim)

        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embedding_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # Self-attention block
        h = x
        x = self.norm1(x)
        x = self.attn(x, x, x)[0]
        x = x + h

        # MLP block
        h = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = x + h

        return x


# Final MLP Head
class ViT(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embed = PatchEmbedding()

        self.class_token = nn.Parameter(torch.zeros(1, 1, embedding_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + num_patches, embedding_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

        self.encoder = nn.Sequential(*[TransformerEncoder() for _ in range(num_layers)])
        self.norm = nn.LayerNorm(embedding_dim)
        self.mlp_head = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        B = x.size(0)

        x = self.patch_embed(x)

        cls_tok = self.class_token.expand(B, -1, -1)
        x = torch.cat((cls_tok, x), dim=1)

        x = x + self.pos_embed

        x = self.encoder(x)
        x = self.norm(x)

        cls_out = x[:, 0]
        return self.mlp_head(cls_out)


# ===============================
# 3. TRAINING SETUP
# ===============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViT().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

epochs = 30


# ===============================
# 4. TRAINING LOOP
# ===============================
for epoch in range(epochs):
    model.train()
    total, correct = 0, 0
    running_loss = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

    train_acc = 100 * correct / total
    print(f"Epoch {epoch+1}/{epochs} | Loss: {running_loss/len(train_loader):.4f} | Acc: {train_acc:.2f}%")

    scheduler.step()


# ===============================
# 5. TEST ACCURACY
# ===============================
model.eval()
correct, total = 0, 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

print(f"\nTest Accuracy: {100 * correct / total:.2f}%")


100%|██████████| 170M/170M [00:19<00:00, 8.78MB/s]


Epoch 1/30 | Loss: 1.7530 | Acc: 34.88%
Epoch 2/30 | Loss: 1.4351 | Acc: 47.66%
Epoch 3/30 | Loss: 1.2867 | Acc: 53.40%
Epoch 4/30 | Loss: 1.2010 | Acc: 56.63%
Epoch 5/30 | Loss: 1.1153 | Acc: 59.71%
Epoch 6/30 | Loss: 1.0600 | Acc: 61.96%
Epoch 7/30 | Loss: 1.0085 | Acc: 64.06%
Epoch 8/30 | Loss: 0.9597 | Acc: 65.79%
Epoch 9/30 | Loss: 0.9156 | Acc: 67.39%
Epoch 10/30 | Loss: 0.8764 | Acc: 68.79%
Epoch 11/30 | Loss: 0.8415 | Acc: 70.09%
Epoch 12/30 | Loss: 0.8061 | Acc: 71.18%
Epoch 13/30 | Loss: 0.7689 | Acc: 72.63%
Epoch 14/30 | Loss: 0.7405 | Acc: 73.60%
Epoch 15/30 | Loss: 0.7136 | Acc: 74.62%
Epoch 16/30 | Loss: 0.6832 | Acc: 75.74%
Epoch 17/30 | Loss: 0.6576 | Acc: 76.55%
Epoch 18/30 | Loss: 0.6278 | Acc: 77.71%
Epoch 19/30 | Loss: 0.6045 | Acc: 78.41%
Epoch 20/30 | Loss: 0.5803 | Acc: 79.34%
Epoch 21/30 | Loss: 0.5629 | Acc: 80.05%
Epoch 22/30 | Loss: 0.5387 | Acc: 80.87%
Epoch 23/30 | Loss: 0.5243 | Acc: 81.31%
Epoch 24/30 | Loss: 0.5080 | Acc: 81.78%
Epoch 25/30 | Loss: 0.490