# Q1 — Vision Transformer (ViT) on CIFAR-10

Implementation of ViT in PyTorch.

In [1]:
# ==============================
# 1. Setup & Imports
# ==============================
import os, random, math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import LambdaLR

# Reproducibility
seed = 0
random.seed(seed); np.random.seed(seed)
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")


Device: Tesla T4


In [2]:
# ==============================
# 2. Data & Augmentations
# ==============================
mean = (0.4914, 0.4822, 0.4465)
std  = (0.2470, 0.2435, 0.2616)

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=train_transform)
val_ds   = datasets.CIFAR10('./data', train=False, download=True, transform=val_transform)

batch_size = 128
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=256, shuffle=False, num_workers=4, pin_memory=True)


100%|██████████| 170M/170M [00:06<00:00, 27.4MB/s]


In [3]:
# ==============================
# 3. MixUp helper
# ==============================
def mixup_data(x, y, alpha=0.8, device='cuda'):
    if alpha <= 0:
        return x, y, 1.0, None
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


In [4]:
# ==============================
# 4. Vision Transformer (ViT)
# ==============================
class DropPath(nn.Module):
    def __init__(self, drop_prob=0.0):
        super().__init__()
        self.drop_prob = drop_prob
    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        return x.div(keep_prob) * random_tensor

class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=256):
        super().__init__()
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        x = self.proj(x)             # B, embed_dim, H/ps, W/ps
        x = x.flatten(2).transpose(1,2)  # B, N, C
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, drop=0.):
        super().__init__()
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.drop = nn.Dropout(drop)
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2,-1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1,2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads, qkv_bias=True, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path>0. else nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, int(dim*mlp_ratio), drop=drop)
    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class ViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, num_classes=10,
                 embed_dim=256, depth=6, num_heads=4, mlp_ratio=4.,
                 drop_rate=0., drop_path_rate=0.1):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        n_patches = self.patch_embed.n_patches
        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches+1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)
        # stochastic depth schedule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, drop=drop_rate, attn_drop=drop_rate, drop_path=dpr[i])
            for i in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        # init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        cls = x[:,0]
        out = self.head(cls)
        return out

# Instantiate model
model = ViT(img_size=32, patch_size=4, embed_dim=256, depth=6, num_heads=4,
            mlp_ratio=4.0, drop_rate=0.1, drop_path_rate=0.1).to(device)



In [5]:
# ==============================
# 5. Optimizer, Scheduler, EMA
# ==============================
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
opt = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)

epochs = 200
steps_per_epoch = len(train_loader)
total_steps = epochs * steps_per_epoch
warmup_steps = int(0.05 * total_steps)

def lr_lambda(current_step):
    if current_step < warmup_steps:
        return float(current_step) / float(max(1, warmup_steps))
    progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    return 0.5 * (1.0 + math.cos(math.pi * progress))

scheduler = LambdaLR(opt, lr_lambda)

class ModelEMA:
    def __init__(self, model, decay=0.9999, device=None):
        self.ema = {k: v.detach().clone().to(device) for k,v in model.state_dict().items()}
        self.decay = decay; self.device = device
    def update(self, model):
        for k,v in model.state_dict().items():
            self.ema[k].mul_(self.decay).add_(v.detach().to(self.device), alpha=1.0-self.decay)
    def state_dict(self): return self.ema

ema = ModelEMA(model, decay=0.9999, device=device)



In [6]:
# ==============================
# 6. Training & Evaluation Loops
# ==============================
scaler = GradScaler()

def train_one_epoch(model, loader, opt, scheduler, epoch, device, ema=None, mixup_alpha=0.8):
    model.train()
    total_loss, total_correct, total = 0.0, 0, 0
    for x,y in loader:
        x,y = x.to(device), y.to(device)
        if mixup_alpha > 0:
            x, y_a, y_b, lam = mixup_data(x, y, mixup_alpha, device)
        opt.zero_grad()
        with autocast():
            logits = model(x)
            if mixup_alpha > 0:
                loss = mixup_criterion(criterion, logits, y_a, y_b, lam)
            else:
                loss = criterion(logits, y)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        scheduler.step()
        if ema is not None: ema.update(model)
        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=1)
        total_correct += (preds == y).sum().item()
        total += x.size(0)
    return total_loss/total, total_correct/total

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total, correct = 0, 0
    for x,y in loader:
        x,y = x.to(device), y.to(device)
        logits = model(x)
        preds = logits.argmax(dim=1)
        total += x.size(0)
        correct += (preds == y).sum().item()
    return correct/total


  scaler = GradScaler()


In [7]:
# ==============================
# 7. Train Loop with checkpointing
# ==============================
best = 0.0
for epoch in range(1, epochs+1):
    tr_loss, tr_acc = train_one_epoch(model, train_loader, opt, scheduler, epoch, device, ema, mixup_alpha=0.8)
    val_acc = evaluate(model, val_loader, device)
    print(f"Epoch {epoch}/{epochs} - Train Loss {tr_loss:.4f} | Train Acc {tr_acc:.4f} | Val Acc {val_acc:.4f}")
    if val_acc > best:
        best = val_acc
        torch.save(model.state_dict(), "best_vit.pt")
        print(f"  ✅ Saved new best model @ epoch {epoch} with val_acc {val_acc:.4f}")


  with autocast():


Epoch 1/200 - Train Loss 2.2584 | Train Acc 0.1455 | Val Acc 0.2902
  ✅ Saved new best model @ epoch 1 with val_acc 0.2902
Epoch 2/200 - Train Loss 2.1540 | Train Acc 0.1779 | Val Acc 0.3111
  ✅ Saved new best model @ epoch 2 with val_acc 0.3111
Epoch 3/200 - Train Loss 2.0996 | Train Acc 0.2052 | Val Acc 0.3747
  ✅ Saved new best model @ epoch 3 with val_acc 0.3747
Epoch 4/200 - Train Loss 2.0582 | Train Acc 0.2057 | Val Acc 0.4200
  ✅ Saved new best model @ epoch 4 with val_acc 0.4200
Epoch 5/200 - Train Loss 2.0041 | Train Acc 0.2438 | Val Acc 0.4616
  ✅ Saved new best model @ epoch 5 with val_acc 0.4616
Epoch 6/200 - Train Loss 1.9567 | Train Acc 0.2629 | Val Acc 0.4889
  ✅ Saved new best model @ epoch 6 with val_acc 0.4889
Epoch 7/200 - Train Loss 1.9228 | Train Acc 0.2710 | Val Acc 0.5186
  ✅ Saved new best model @ epoch 7 with val_acc 0.5186
Epoch 8/200 - Train Loss 1.9004 | Train Acc 0.2778 | Val Acc 0.5527
  ✅ Saved new best model @ epoch 8 with val_acc 0.5527
Epoch 9/200 - Tr