<a href="https://colab.research.google.com/github/Quantamaster/ViT-SAM2-Vision-Lab-Pro-IISc-/blob/main/q1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q timm albumentations==1.3.0

# Use GPU
import torch
print("PyTorch", torch.__version__, "Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/123.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m122.9/123.5 kB[0m [31m85.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m123.5/123.5 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25hPyTorch 2.8.0+cu126 Device: CPU


In [2]:
# Cell 2: imports
import math, time, os, random
from pathlib import Path
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import InterpolationMode
import numpy as np

# for progress
from tqdm import tqdm


In [3]:
# Cell 3: config
class CFG:
    seed = 42
    device = "cuda" if torch.cuda.is_available() else "cpu"
    data_dir = "/content/data"
    batch_size = 256                   # use 512 if GPU memory allows
    epochs = 200
    lr = 3e-4
    weight_decay = 0.05
    img_size = 32
    patch_size = 4                     # 4 -> (8x8) patches => 8*8=64 patches
    in_chans = 3
    num_classes = 10
    embed_dim = 192                    # small/tiny model
    depth = 12
    num_heads = 3
    mlp_ratio = 4.0
    drop_rate = 0.0
    attn_drop = 0.0
    mixup_alpha = 0.8
    use_cutmix = True
    save_path = "/content/vit_cifar_checkpt.pth"
    grad_clip = 1.0
cfg = CFG()


In [4]:
# Cell 4: seed
def seed_all(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
seed_all(cfg.seed)


In [5]:
# Cell 5: dataloaders (RandAugment optional)
from torchvision.transforms import AutoAugmentPolicy
train_transforms = transforms.Compose([
    transforms.RandomCrop(cfg.img_size, padding=4),
    transforms.RandomHorizontalFlip(),
    # optional: stronger augmentations (RandAugment)
    transforms.ToTensor(),
    transforms.Normalize((0.4914,0.4822,0.4465),(0.247,0.243,0.261))
])
test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914,0.4822,0.4465),(0.247,0.243,0.261))
])

train_ds = datasets.CIFAR10(root=cfg.data_dir, train=True, download=True, transform=train_transforms)
test_ds  = datasets.CIFAR10(root=cfg.data_dir, train=False, download=True, transform=test_transforms)

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=4, pin_memory=True)


100%|██████████| 170M/170M [00:38<00:00, 4.38MB/s]


In [6]:
# Cell 6: Mixup/CutMix helpers
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2


In [7]:
# Cell 7: Vision Transformer (simple, clean)
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=192):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        assert img_size % patch_size == 0, "img_size must be divisible by patch_size"
        self.grid_size = img_size // patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: B,C,H,W
        x = self.proj(x)           # B,embed, H/patch, W/patch
        x = x.flatten(2).transpose(1,2)  # B, num_patches, embed_dim
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class AttentionBlock(nn.Module):
    def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=attn_drop, batch_first=True)
        self.drop_path = nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(in_features=dim, hidden_features=int(dim*cfg.mlp_ratio), drop=proj_drop)

    def forward(self, x):
        # x: B, N, D
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, num_classes=10,
                 embed_dim=192, depth=12, num_heads=3, drop_rate=0.):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        # transformer blocks
        self.blocks = nn.ModuleList([
            AttentionBlock(dim=embed_dim, num_heads=num_heads, attn_drop=cfg.attn_drop, proj_drop=cfg.drop_rate)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

        # classifier
        self.head = nn.Linear(embed_dim, num_classes)

        # init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None: nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight); nn.init.zeros_(m.bias)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)           # B, N, D
        cls_tokens = self.cls_token.expand(B, -1, -1)  # B,1,D
        x = torch.cat((cls_tokens, x), dim=1)          # B, N+1, D
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        cls = x[:,0]
        out = self.head(cls)
        return out


In [8]:
# Cell 8: instantiate
model = VisionTransformer(
    img_size=cfg.img_size,
    patch_size=cfg.patch_size,
    in_chans=cfg.in_chans,
    num_classes=cfg.num_classes,
    embed_dim=cfg.embed_dim,
    depth=cfg.depth,
    num_heads=cfg.num_heads,
    drop_rate=cfg.drop_rate
).to(cfg.device)

# Loss with label smoothing
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
    def forward(self, preds, target):
        # preds: B, C
        log_probs = F.log_softmax(preds, dim=-1)
        nll = -log_probs.gather(dim=-1, index=target.unsqueeze(1)).squeeze(1)
        smooth_loss = -log_probs.mean(dim=-1)
        return ((1.0 - self.smoothing) * nll + self.smoothing * smooth_loss).mean()

criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

# cosine annealing with warmup via lambda
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=5*len(train_loader), num_training_steps=cfg.epochs*len(train_loader))


In [9]:
# Cell 9: train & eval
def train_one_epoch(epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for step, (x,y) in pbar:
        x = x.to(cfg.device); y = y.to(cfg.device)
        # optional MixUp / CutMix could be applied here
        preds = model(x)
        loss = criterion(preds, y)
        optimizer.zero_grad()
        loss.backward()
        if cfg.grad_clip:
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        optimizer.step()
        scheduler.step()

        running_loss += loss.item()*x.size(0)
        _, predicted = preds.max(1)
        total += y.size(0)
        correct += predicted.eq(y).sum().item()
        pbar.set_description(f"Epoch {epoch} Loss {(running_loss/total):.4f} Acc {100.*correct/total:.2f}")
    return running_loss/total, 100.*correct/total

@torch.no_grad()
def evaluate():
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    for x,y in test_loader:
        x = x.to(cfg.device); y = y.to(cfg.device)
        preds = model(x)
        loss = criterion(preds, y)
        running_loss += loss.item()*x.size(0)
        _, predicted = preds.max(1)
        total += y.size(0)
        correct += predicted.eq(y).sum().item()
    return running_loss/total, 100.*correct/total


In [None]:
# Cell 10: run training
best_acc = 0.0
history = {"train_loss":[], "train_acc":[], "val_loss":[], "val_acc":[]}
for epoch in range(1, cfg.epochs+1):
    t0 = time.time()
    tr_loss, tr_acc = train_one_epoch(epoch)
    val_loss, val_acc = evaluate()
    history["train_loss"].append(tr_loss); history["train_acc"].append(tr_acc)
    history["val_loss"].append(val_loss); history["val_acc"].append(val_acc)
    print(f"Epoch {epoch} finished in {time.time()-t0:.1f}s | Train Acc {tr_acc:.2f} | Val Acc {val_acc:.2f}")
    # save best
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save({
            "model_state": model.state_dict(),
            "cfg": cfg.__dict__,
            "epoch": epoch,
            "best_acc": best_acc
        }, cfg.save_path)
        print("Saved best model:", cfg.save_path)


Epoch 1 Loss 2.0138 Acc 28.22: 100%|██████████| 196/196 [42:14<00:00, 12.93s/it]


Epoch 1 finished in 2696.6s | Train Acc 28.22 | Val Acc 34.24
Saved best model: /content/vit_cifar_checkpt.pth


Epoch 2 Loss 1.8714 Acc 35.74:   1%|          | 2/196 [00:44<1:08:25, 21.16s/it]

In [None]:
# Cell 11: final eval & quick plot
val_loss, val_acc = evaluate()
print("Final Test Acc:", val_acc)
# Optional: save history as numpy
import matplotlib.pyplot as plt
plt.plot(history["train_acc"], label="train")
plt.plot(history["val_acc"], label="val")
plt.xlabel("epoch"); plt.ylabel("accuracy"); plt.legend()
plt.show()
