<a href="https://colab.research.google.com/github/ArtDowdy/deep-learning-engagement/blob/main/vit_cifar10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================================================
# Vision Transformer (from scratch) on CIFAR-10 (PyTorch, Colab-ready)
# - Patch embedding via Conv2d
# - Class token + learnable positional embeddings
# - TransformerEncoder (multihead self-attention)
# - AMP training, cosine schedule, label smoothing, early stopping
# - Metrics: Top-1 accuracy; Confusion Matrix
# - Exports: best checkpoint + TorchScript
# - Attention rollout visualization for 1 sample
# ============================================================

import math, os, time, random
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as tvutils
from sklearn.metrics import accuracy_score, confusion_matrix

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---------------- Config ----------------
@dataclass
class CFG:
    img_size: int = 32
    patch_size: int = 4          # 32x32 -> 8x8 = 64 patches
    in_chans: int = 3
    num_classes: int = 10
    embed_dim: int = 192         # ViT-Tiny-ish
    depth: int = 6
    num_heads: int = 3
    mlp_ratio: float = 4.0
    dropout: float = 0.1
    attn_dropout: float = 0.1
    batch_size: int = 256
    epochs: int = 30
    lr: float = 3e-4
    weight_decay: float = 0.05
    early_patience: int = 5
    label_smoothing: float = 0.05
    num_workers: int = 2
    seed: int = 42

cfg = CFG()

def seed_everything(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
seed_everything(cfg.seed)

# ---------------- Data ----------------
mean = (0.4914, 0.4822, 0.4465)
std  = (0.2470, 0.2435, 0.2616)

train_tfms = transforms.Compose([
    transforms.RandomCrop(cfg.img_size, padding=4, padding_mode="reflect"),
    transforms.RandomHorizontalFlip(),
    transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
    transforms.RandomErasing(p=0.25, scale=(0.02, 0.15), ratio=(0.3, 3.3))
])

test_tfms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

data_root = "./data"
train_set = datasets.CIFAR10(root=data_root, train=True, download=True, transform=train_tfms)
test_set  = datasets.CIFAR10(root=data_root, train=False, download=True, transform=test_tfms)

train_loader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True,
                          num_workers=cfg.num_workers, pin_memory=True)
test_loader  = DataLoader(test_set,  batch_size=cfg.batch_size, shuffle=False,
                          num_workers=cfg.num_workers, pin_memory=True)

# ---------------- Model ----------------
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=192):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = img_size // patch_size
        self.num_patches = self.grid_size * self.grid_size
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)  # (B, E, H/P, W/P)
        x = x.flatten(2).transpose(1, 2)  # (B, N, E)
        return x

class MLP(nn.Module):
    def __init__(self, dim, mlp_ratio=4.0, drop=0.0):
        super().__init__()
        hidden = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, dim)
        self.drop = nn.Dropout(drop)
    def forward(self, x):
        x = self.fc1(x); x = self.act(x); x = self.drop(x)
        x = self.fc2(x); x = self.drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, drop=0.0, attn_drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=attn_drop, batch_first=True)
        self.drop = nn.Dropout(drop)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_ratio, drop)
    def forward(self, x, need_attn=False):
        h = self.norm1(x)
        attn_out, attn_map = self.attn(h, h, h, need_weights=need_attn, average_attn_weights=False)
        x = x + self.drop(attn_out)
        h = self.norm2(x)
        x = x + self.drop(self.mlp(h))
        return (x, attn_map) if need_attn else x

class ViT(nn.Module):
    def __init__(self, img_size, patch_size, in_chans, num_classes,
                 embed_dim, depth, num_heads, mlp_ratio, drop, attn_drop):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim) * 0.02)
        self.pos_drop = nn.Dropout(drop)
        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, drop, attn_drop) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.head.weight, std=0.02)
        nn.init.zeros_(self.head.bias)

    def forward_features(self, x, collect_attn=False):
        B = x.size(0)
        x = self.patch_embed(x)                    # (B, N, E)
        cls = self.cls_token.expand(B, -1, -1)     # (B, 1, E)
        x = torch.cat([cls, x], dim=1)             # (B, 1+N, E)
        x = x + self.pos_embed[:, :x.size(1), :]
        x = self.pos_drop(x)
        attn_maps = []
        for blk in self.blocks:
            if collect_attn:
                x, a = blk(x, need_attn=True)
                attn_maps.append(a)  # (B, heads, T, T)
            else:
                x = blk(x)
        x = self.norm(x)
        return (x[:, 0], attn_maps) if collect_attn else x[:, 0]

    def forward(self, x):
        x = self.forward_features(x)
        return self.head(x)

# label smoothing CE
class LabelSmoothingCE(nn.Module):
    def __init__(self, smoothing=0.0):
        super().__init__()
        self.smoothing = smoothing
    def forward(self, logits, target):
        n = logits.size(-1)
        logp = F.log_softmax(logits, dim=-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(logp)
            true_dist.fill_(self.smoothing / (n - 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1 - self.smoothing)
        return torch.mean(torch.sum(-true_dist * logp, dim=-1))

model = ViT(cfg.img_size, cfg.patch_size, cfg.in_chans, cfg.num_classes,
            cfg.embed_dim, cfg.depth, cfg.num_heads, cfg.mlp_ratio,
            cfg.dropout, cfg.attn_dropout).to(device)

# ---------------- Optim & Sched ----------------
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.epochs)
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
criterion = LabelSmoothingCE(cfg.label_smoothing)

# ---------------- Train/Eval ----------------
best_acc = 0.0
best_state = None
pat = 0

def evaluate(model, loader):
    model.eval()
    ys, ps = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                logits = model(x)
            pred = logits.argmax(dim=1)
            ys.append(y.cpu().numpy()); ps.append(pred.cpu().numpy())
    ys = np.concatenate(ys); ps = np.concatenate(ps)
    acc = accuracy_score(ys, ps)
    cm = confusion_matrix(ys, ps)
    return acc, cm

for epoch in range(1, cfg.epochs+1):
    model.train()
    running = 0.0
    t0 = time.time()
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            logits = model(x)
            loss = criterion(logits, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running += loss.item() * x.size(0)
    scheduler.step()

    train_loss = running / len(train_set)
    acc, cm = evaluate(model, test_loader)
    dt = time.time() - t0
    print(f"Epoch {epoch:02d} | loss={train_loss:.4f} | test_acc={acc:.3f} | time={dt:.1f}s")

    if acc > best_acc + 1e-4:
        best_acc = acc
        best_state = {k: v.detach().cpu() for k,v in model.state_dict().items()}
        pat = 0
    else:
        pat += 1
        if pat >= cfg.early_patience:
            print("Early stopping.")
            break

# restore best
if best_state is not None:
    model.load_state_dict({k: v.to(device) for k,v in best_state.items()})
print(f"Best test accuracy: {best_acc:.3f}")

# ---------------- Save & Export ----------------
os.makedirs("artifacts_vit", exist_ok=True)
ckpt_path = "artifacts_vit/vit_cifar10_best.pt"
torch.save({"state_dict": model.state_dict(), "cfg": cfg.__dict__}, ckpt_path)
print("Saved checkpoint:", ckpt_path)

# TorchScript export
model.eval()
example = torch.randn(1, 3, cfg.img_size, cfg.img_size).to(device)
scripted = torch.jit.trace(model, example)
ts_path = "artifacts_vit/vit_cifar10_scripted.pt"
scripted.save(ts_path)
print("Saved TorchScript:", ts_path)

# ---------------- Attention Rollout Viz ----------------
# Grab one test image; compute attention maps and rollout
inv_norm = transforms.Normalize(
    mean=[-m/s for m,s in zip(mean, std)],
    std=[1/s for s in std]
)

def attention_rollout(attn_maps):
    # attn_maps: list of (B, heads, T, T), T=1+N
    # Average heads, apply residual & multiply across layers
    with torch.no_grad():
        attn = torch.stack(attn_maps, dim=0).mean(2)  # (L, B, T, T)
        I = torch.eye(attn.size(-1), device=attn.device).unsqueeze(0).unsqueeze(1)
        attn = attn + I  # add residual
        attn = attn / attn.sum(-1, keepdim=True)
        joint = attn[0]
        for i in range(1, attn.size(0)):
            joint = torch.bmm(attn[i], joint)
        # influence from CLS to patches (exclude CLS itself)
        cls_to_patch = joint[:, 0, 1:]  # (B, N)
        return cls_to_patch

# one sample
x_img, _ = next(iter(test_loader))
x_img = x_img[:1].to(device)
with torch.no_grad():
    _ = model.patch_embed(x_img)  # warm run
    _, attn_maps = model.forward_features(x_img, collect_attn=True)

roll = attention_rollout(attn_maps).reshape(1, model.patch_embed.grid_size, model.patch_embed.grid_size)
roll = F.interpolate(roll.unsqueeze(1), size=(cfg.img_size, cfg.img_size), mode="bilinear", align_corners=False)
roll = roll.squeeze().cpu().numpy()
roll = (roll - roll.min()) / (roll.max() - roll.min() + 1e-8)

# save overlay
img_vis = inv_norm(x_img[0].cpu()).clamp(0,1)
grid = tvutils.make_grid(img_vis, nrow=1)
grid_np = grid.permute(1,2,0).numpy()
heat = np.uint8(255 * roll)
import PIL.Image as Image
import matplotlib.cm as cm
heatmap = cm.jet(heat/255.0)[...,:3]
overlay = (0.55*grid_np + 0.45*heatmap).clip(0,1)
Image.fromarray(np.uint8(overlay*255)).save("artifacts_vit/attention_rollout.png")
print("Saved attention rollout: artifacts_vit/attention_rollout.png")
