# Q1 — Vision Transformer (ViT) on CIFAR-10 (Patched v3)

Speed patches + explicit ViT class + results table & JSON at the end.

In [1]:
!pip -q install torch torchvision torchmetrics tqdm --upgrade

import os, sys, json, math, time, random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.transforms import InterpolationMode
from torchmetrics.classification import MulticlassAccuracy
from tqdm import tqdm

print("PyTorch:", torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
if device.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
    print("VRAM (GB):", round(torch.cuda.get_device_properties(0).total_memory/1e9, 2))
try:
    torch.set_float32_matmul_precision("medium")
except Exception as e:
    print("TF32 setting skipped:", e)

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/983.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m983.0/983.2 kB[0m [31m40.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m26.3 MB/s[0m eta [36m0:00:00[0m
[?25hPyTorch: 2.8.0+cu126
Device: cuda
GPU: Tesla T4
VRAM (GB): 15.83


## Config (single source of truth)

In [2]:
#CONFIG
def now_str():
    return time.strftime("%Y%m%d-%H%M%S")

config = {
    "run_id": f"vit_cifar10_v3_{now_str()}",
    "seed": 42,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "input_size": 32,
    "patch_size": 4,
    "num_classes": 10,
    "val_split": 5000,
    "randaugment": {"enabled": True, "N": 2, "M": 10},
    "mixup": {"p": 0.5, "alpha": 0.2},
    "cutmix": {"p": 0.0, "alpha": 1.0},
    "label_smoothing": 0.1,
    "embed_dim": 384,
    "depth": 12,
    "num_heads": 6,
    "mlp_ratio": 4.0,
    "drop_path": 0.1,
    "dropout": 0.0,
    "epochs": 100,
    "batch_size_target": 512,
    "grad_accum_steps": 1,
    "optimizer": {"name":"AdamW","lr":6e-4,"weight_decay":0.1,"betas":(0.9,0.999),"eps":1e-8},
    "scheduler": {"type":"cosine","warmup_epochs":10},
    "ema_decay": 0.2,
    "grad_clip": 1.0,
    "use_amp": True,
    "out_dir": "./outputs",
    "eval_every": 5
}
CFG = config
os.makedirs(CFG["out_dir"], exist_ok=True)
print(json.dumps(CFG, indent=2))

{
  "run_id": "vit_cifar10_v3_20251004-101042",
  "seed": 42,
  "device": "cuda",
  "input_size": 32,
  "patch_size": 4,
  "num_classes": 10,
  "val_split": 5000,
  "randaugment": {
    "enabled": true,
    "N": 2,
    "M": 10
  },
  "mixup": {
    "p": 0.5,
    "alpha": 0.2
  },
  "cutmix": {
    "p": 0.0,
    "alpha": 1.0
  },
  "label_smoothing": 0.1,
  "embed_dim": 384,
  "depth": 12,
  "num_heads": 6,
  "mlp_ratio": 4.0,
  "drop_path": 0.1,
  "dropout": 0.0,
  "epochs": 100,
  "batch_size_target": 512,
  "grad_accum_steps": 1,
  "optimizer": {
    "name": "AdamW",
    "lr": 0.0006,
    "weight_decay": 0.1,
    "betas": [
      0.9,
      0.999
    ],
    "eps": 1e-08
  },
  "scheduler": {
    "type": "cosine",
    "warmup_epochs": 10
  },
  "ema_decay": 0.2,
  "grad_clip": 1.0,
  "use_amp": true,
  "out_dir": "./outputs",
  "eval_every": 5
}


## Reproducibility

In [3]:
#@title Seeds & cuDNN
def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True
set_seed(CFG["seed"])

## Data & Augmentations (fast DataLoader)

In [4]:
#Data pipeline
CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR_STD  = (0.2470, 0.2435, 0.2616)

def build_transforms(input_size: int, ra_cfg: dict):
    if input_size == 32:
        train_tfms = [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
        ]
    else:
        train_tfms = [
            transforms.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
            transforms.RandomResizedCrop(input_size, scale=(0.8, 1.0), interpolation=InterpolationMode.BICUBIC),
            transforms.RandomHorizontalFlip(),
        ]
    if ra_cfg.get("enabled", False):
        N, M = ra_cfg.get("N", 2), ra_cfg.get("M", 10)
        train_tfms.append(transforms.RandAugment(num_ops=N, magnitude=M))
    train_tfms += [transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD)]
    test_tfms = [
        transforms.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ]
    return transforms.Compose(train_tfms), transforms.Compose(test_tfms)

train_tfms, test_tfms = build_transforms(CFG["input_size"], CFG["randaugment"])
data_root = "./data"
full_train = datasets.CIFAR10(root=data_root, train=True, download=True, transform=train_tfms)
test_set   = datasets.CIFAR10(root=data_root, train=False, download=True, transform=test_tfms)

val_size = CFG["val_split"]
train_size = len(full_train) - val_size
train_set, val_set = random_split(full_train, [train_size, val_size], generator=torch.Generator().manual_seed(CFG["seed"]))

def make_loader(ds, batch_size, shuffle=False):
    import os
    num_workers = max(2, min(8, os.cpu_count() - 1 if os.cpu_count() else 2))
    return DataLoader(
        ds, batch_size=batch_size, shuffle=shuffle,
        num_workers=num_workers, pin_memory=True,
        prefetch_factor=2, persistent_workers=True,
    )

print(f"Train size: {len(train_set)} | Val size: {len(val_set)} | Test size: {len(test_set)}")

100%|██████████| 170M/170M [00:04<00:00, 42.5MB/s]


Train size: 45000 | Val size: 5000 | Test size: 10000


## MixUp & CutMix helpers

In [5]:
#@title MixUp & CutMix
import numpy as np, random

def rand_bbox(W, H, lam):
    cut_rat = (1. - lam) ** 0.5
    cut_w = int(W * cut_rat); cut_h = int(H * cut_rat)
    cx = np.random.randint(W); cy = np.random.randint(H)
    x1 = np.clip(cx - cut_w // 2, 0, W); y1 = np.clip(cy - cut_h // 2, 0, H)
    x2 = np.clip(cx + cut_w // 2, 0, W); y2 = np.clip(cy + cut_h // 2, 0, H)
    return x1, y1, x2, y2

def apply_mixup_cutmix(x, y, num_classes, mixup_cfg, cutmix_cfg):
    B, C, H, W = x.shape
    y_onehot = torch.zeros(B, num_classes, device=x.device, dtype=x.dtype)
    y_onehot.scatter_(1, y.view(-1,1), 1.0)

    r = random.random()
    if r < mixup_cfg.get("p", 0.0):
        alpha = mixup_cfg.get("alpha", 0.2)
        lam = np.random.beta(alpha, alpha)
        index = torch.randperm(B, device=x.device)
        x = lam * x + (1 - lam) * x[index, :]
        y_soft = lam * y_onehot + (1 - lam) * y_onehot[index, :]
        return x, y_soft, True

    if r < (mixup_cfg.get("p", 0.0) + cutmix_cfg.get("p", 0.0)):
        alpha = cutmix_cfg.get("alpha", 1.0)
        lam = np.random.beta(alpha, alpha)
        index = torch.randperm(B, device=x.device)
        x1, y1, x2, y2 = rand_bbox(W, H, lam)
        x[:, :, y1:y2, x1:x2] = x[index, :, y1:y2, x1:x2]
        lam = 1 - ((x2 - x1) * (y2 - y1) / (W * H))
        y_soft = lam * y_onehot + (1 - lam) * y_onehot[index, :]
        return x, y_soft, True

    return x, y_onehot, False

## ViT (SDPA/Flash) + explicit `ViT` wrapper

In [6]:
#ViT (with SDPA/Flash)
class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0.):
        super().__init__(); self.drop_prob = drop_prob
    def forward(self, x):
        if self.drop_prob == 0. or not self.training: return x
        keep = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        rand = keep + torch.rand(shape, dtype=x.dtype, device=x.device)
        rand.floor_()
        return x.div(keep) * rand

class MLP(nn.Module):
    def __init__(self, dim, mlp_ratio=4.0, drop=0.0):
        super().__init__()
        hidden = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden); self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, dim); self.drop = nn.Dropout(drop)
    def forward(self, x):
        x = self.fc1(x); x = self.act(x); x = self.drop(x)
        x = self.fc2(x); x = self.drop(x); return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads; self.head_dim = dim // num_heads
        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)
        self.attn_drop = nn.Dropout(attn_drop); self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
            out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
        x = out.transpose(1,2).reshape(B, N, C)
        x = self.proj(x); x = self.proj_drop(x); return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, drop=0.0, drop_path=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim); self.attn = Attention(dim, num_heads=num_heads)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = nn.LayerNorm(dim); self.mlp = MLP(dim, mlp_ratio=mlp_ratio, drop=drop)
    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x))); return x

class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=192):
        super().__init__()
        assert img_size % patch_size == 0
        self.grid = img_size // patch_size; self.num_patches = self.grid * self.grid
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        x = self.proj(x); x = x.flatten(2).transpose(1, 2); return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, num_classes=10,
                 embed_dim=192, depth=8, num_heads=3, mlp_ratio=4.0,
                 drop_rate=0.0, drop_path=0.0):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + num_patches, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)
        dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
        self.blocks = nn.ModuleList([Block(embed_dim, num_heads, mlp_ratio, drop_rate, dpr[i]) for i in range(depth)])
        self.norm = nn.LayerNorm(embed_dim); self.head = nn.Linear(embed_dim, num_classes)
        nn.init.trunc_normal_(self.pos_embed, std=0.02); nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02);
            if m.bias is not None: nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
    def forward_features(self, x):
        x = self.patch_embed(x); B, N, C = x.shape
        x = torch.cat([self.cls_token.expand(B, -1, -1), x], dim=1) + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks: x = blk(x)
        x = self.norm(x); return x[:, 0]
    def forward(self, x): return self.head(self.forward_features(x))

class ViT(nn.Module):
    def __init__(self, img_size, patch_size, num_classes, embed_dim, depth, num_heads,
                 mlp_ratio=4.0, drop_rate=0.0, drop_path=0.0):
        super().__init__()
        self.core = VisionTransformer(img_size, patch_size, 3, num_classes, embed_dim, depth, num_heads, mlp_ratio, drop_rate, drop_path)
    def forward(self, x): return self.core(x)

## Optimizer / Scheduler

In [7]:
# Optimizer / Scheduler
def split_weight_decay_params(model: nn.Module):
    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad: continue
        if p.ndimension() == 1 or name.endswith(".bias") or "norm" in name.lower():
            no_decay.append(p)
        else:
            decay.append(p)
    return [
        {"params": decay, "weight_decay": CFG["optimizer"]["weight_decay"]},
        {"params": no_decay, "weight_decay": 0.0},
    ]

def build_optimizer(model):
    groups = split_weight_decay_params(model)
    return torch.optim.AdamW(groups, lr=CFG["optimizer"]["lr"], betas=CFG["optimizer"]["betas"], eps=CFG["optimizer"]["eps"])

def build_scheduler(optimizer, steps_per_epoch):
    total_epochs = CFG["epochs"]
    warmup_epochs = CFG["scheduler"]["warmup_epochs"]
    warmup_steps = warmup_epochs * steps_per_epoch
    total_steps  = total_epochs * steps_per_epoch
    def lr_lambda(step):
        if step < warmup_steps: return float(step) / float(max(1, warmup_steps))
        progress = (step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return 0.5 * (1.0 + math.cos(math.pi * progress))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

## Train / Eval (eval-every-N, optional EMA)

In [8]:
# Train / Eval
def try_loader_oom(ds, target_bs, shuffle):
    bs = target_bs
    while bs >= 8:
        try:
            loader = make_loader(ds, bs, shuffle=shuffle)
            xb, yb = next(iter(loader))
            xb = xb.to(device, non_blocking=True); yb = yb.to(device, non_blocking=True)
            del xb, yb
            return loader, bs
        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                if torch.cuda.is_available(): torch.cuda.empty_cache()
                bs //= 2
            else:
                raise e
    return make_loader(ds, 8, shuffle=shuffle), 8

def build_model_and_optimizer():
    model = ViT(
        img_size=CFG["input_size"], patch_size=CFG["patch_size"], num_classes=CFG["num_classes"],
        embed_dim=CFG["embed_dim"], depth=CFG["depth"], num_heads=CFG["num_heads"],
        mlp_ratio=CFG["mlp_ratio"], drop_rate=CFG["dropout"], drop_path=CFG["drop_path"],
    ).to(device)
    opt = build_optimizer(model)
    return model, opt

def smoothed_cross_entropy(logits, targets, smoothing, num_classes):
    log_probs = F.log_softmax(logits, dim=-1)
    with torch.no_grad():
        true_dist = torch.zeros_like(log_probs)
        true_dist.fill_(smoothing / (num_classes - 1))
        true_dist.scatter_(1, targets.data.unsqueeze(1), 1.0 - smoothing)
    return torch.mean(torch.sum(-true_dist * log_probs, dim=-1))

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    acc = MulticlassAccuracy(num_classes=CFG["num_classes"]).to(device)
    total_loss, n = 0.0, 0
    for x, y in loader:
        x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        acc.update(logits, y)
        total_loss += loss.item() * x.size(0); n += x.size(0)
    return total_loss / n, acc.compute().item()

def train():
    model, opt = build_model_and_optimizer()
    train_loader, train_bs = try_loader_oom(train_set, CFG["batch_size_target"], shuffle=True)
    val_loader, _ = try_loader_oom(val_set, min(train_bs, CFG["batch_size_target"]), shuffle=False)
    test_loader,_ = try_loader_oom(test_set, min(train_bs, CFG["batch_size_target"]), shuffle=False)

    effective_bs = train_bs * CFG["grad_accum_steps"]
    scaled_lr = CFG["optimizer"]["lr"] * (effective_bs / 256.0)
    for pg in opt.param_groups: pg["lr"] = scaled_lr

    scaler = torch.cuda.amp.GradScaler(enabled=CFG["use_amp"] and device.type=="cuda")
    scheduler = build_scheduler(opt, steps_per_epoch=len(train_loader))

    use_ema = CFG.get("ema_decay", 0.0) and CFG["ema_decay"] > 0.0
    if use_ema:
        import copy
        ema = copy.deepcopy(model).to(device)
        for p in ema.parameters(): p.requires_grad_(False)
    else:
        ema = None

    best_val = 0.0
    history = []

    for epoch in range(CFG["epochs"]):
        model.train()
        tbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CFG['epochs']} (bs={train_bs}, eff={effective_bs})", leave=False)
        running_loss, running_acc, seen = 0.0, 0.0, 0
        acc = MulticlassAccuracy(num_classes=CFG["num_classes"]).to(device)

        opt.zero_grad(set_to_none=True)
        for step, (x, y) in enumerate(tbar):
            x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
            x_in, soft_targets, used_soft = apply_mixup_cutmix(x, y, CFG["num_classes"], CFG["mixup"], CFG["cutmix"])

            with torch.cuda.amp.autocast(enabled=CFG["use_amp"] and device.type=="cuda"):
                logits = model(x_in)
                if used_soft:
                    loss = torch.mean(torch.sum(-soft_targets * F.log_softmax(logits, dim=-1), dim=-1))
                else:
                    ls = CFG["label_smoothing"]
                    loss = smoothed_cross_entropy(logits, y, ls, CFG["num_classes"]) if ls and ls>0 else F.cross_entropy(logits, y)

            scaler.scale(loss / CFG["grad_accum_steps"]).backward()

            if (step + 1) % CFG["grad_accum_steps"] == 0:
                if CFG["grad_clip"] and CFG["grad_clip"] > 0:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), CFG["grad_clip"])
                scaler.step(opt); scaler.update(); opt.zero_grad(set_to_none=True); scheduler.step()

                if use_ema:
                    with torch.no_grad():
                        d = CFG["ema_decay"]
                        for p_ema, p in zip(ema.parameters(), model.parameters()):
                            p_ema.copy_(p_ema * d + p * (1.0 - d))

            with torch.no_grad():
                acc.update(logits, y)
                running_loss += loss.item() * x.size(0)
                seen += x.size(0); running_acc = acc.compute().item()
                tbar.set_postfix(loss=f"{running_loss/seen:.4f}", acc=f"{running_acc*100:.2f}%")

        do_eval = ((epoch + 1) % CFG["eval_every"] == 0) or ((epoch + 1) == CFG["epochs"])
        val_loss = val_acc = test_loss = test_acc = None
        if do_eval:
            m_eval = ema if use_ema else model
            val_loss, val_acc = evaluate(m_eval, val_loader)
            test_loss, test_acc = evaluate(m_eval, test_loader)
            print(f"[Epoch {epoch+1}] val_acc={val_acc*100:.2f}% | test_acc={test_acc*100:.2f}% | lr={scheduler.get_last_lr()[0]:.2e}")
            if val_acc > best_val:
                best_val = val_acc
                ckpt = {"epoch": epoch+1, "model": "ViT", "state_dict": m_eval.state_dict(), "cfg": CFG, "best_val_acc": best_val}
                torch.save(ckpt, os.path.join(CFG["out_dir"], f"{CFG['run_id']}_best.pt"))

        history.append({
            "epoch": epoch+1, "train_loss": running_loss/seen, "train_acc": running_acc,
            "val_loss": val_loss, "val_acc": val_acc, "test_loss": test_loss, "test_acc": test_acc,
            "lr": scheduler.get_last_lr()[0], "batch_size": train_bs, "effective_bs": effective_bs,
        })

    results = {"run_id": CFG["run_id"], "config": CFG, "best_val_acc": best_val,
               "last_epoch": history[-1] if history else None,
               "history_tail": history[-5:] if len(history) > 5 else history}
    with open(os.path.join(CFG["out_dir"], "results.json"), "w") as f: json.dump(results, f, indent=2)
    with open(os.path.join(CFG["out_dir"], "best_config.json"), "w") as f: json.dump(CFG, f, indent=2)
    return results, history

results, history = train()

  scaler = torch.cuda.amp.GradScaler(enabled=CFG["use_amp"] and device.type=="cuda")
  with torch.cuda.amp.autocast(enabled=CFG["use_amp"] and device.type=="cuda"):
  self.gen = func(*args, **kwds)


[Epoch 5] val_acc=45.73% | test_acc=51.35% | lr=6.00e-04




[Epoch 10] val_acc=47.60% | test_acc=54.42% | lr=1.20e-03




[Epoch 15] val_acc=45.51% | test_acc=51.10% | lr=1.19e-03




[Epoch 20] val_acc=47.18% | test_acc=54.39% | lr=1.16e-03




[Epoch 25] val_acc=46.41% | test_acc=53.80% | lr=1.12e-03




[Epoch 30] val_acc=50.52% | test_acc=57.21% | lr=1.06e-03




[Epoch 35] val_acc=51.39% | test_acc=56.48% | lr=9.86e-04




[Epoch 40] val_acc=51.44% | test_acc=55.13% | lr=9.00e-04




[Epoch 45] val_acc=51.89% | test_acc=59.46% | lr=8.05e-04




[Epoch 50] val_acc=53.81% | test_acc=60.67% | lr=7.04e-04




[Epoch 55] val_acc=54.83% | test_acc=61.08% | lr=6.00e-04




[Epoch 60] val_acc=57.72% | test_acc=63.68% | lr=4.96e-04




[Epoch 65] val_acc=59.55% | test_acc=64.97% | lr=3.95e-04




[Epoch 70] val_acc=61.26% | test_acc=66.61% | lr=3.00e-04




[Epoch 75] val_acc=62.87% | test_acc=69.22% | lr=2.14e-04




[Epoch 80] val_acc=64.48% | test_acc=70.15% | lr=1.40e-04




[Epoch 85] val_acc=65.49% | test_acc=71.01% | lr=8.04e-05




[Epoch 90] val_acc=65.30% | test_acc=72.01% | lr=3.62e-05




[Epoch 95] val_acc=66.15% | test_acc=71.93% | lr=9.12e-06




[Epoch 100] val_acc=66.04% | test_acc=71.96% | lr=0.00e+00


## Final Test evaluation + tiny results table + final JSON

In [9]:
#Final evaluation & table
best_ckpts = [f for f in os.listdir(CFG["out_dir"]) if f.endswith("_best.pt")]
best_path = max([os.path.join(CFG["out_dir"], f) for f in best_ckpts], key=os.path.getmtime) if best_ckpts else None

if best_path:
    ckpt = torch.load(best_path, map_location=device)
    model_eval = ViT(
        img_size=CFG["input_size"], patch_size=CFG["patch_size"], num_classes=CFG["num_classes"],
        embed_dim=CFG["embed_dim"], depth=CFG["depth"], num_heads=CFG["num_heads"],
        mlp_ratio=CFG["mlp_ratio"], drop_rate=CFG["dropout"], drop_path=CFG["drop_path"],
    ).to(device)
    model_eval.load_state_dict(ckpt["state_dict"])
else:
    model_eval = None

def make_test_loader():
    _, test_bs = try_loader_oom(test_set, CFG["batch_size_target"], shuffle=False)
    return make_loader(test_set, test_bs, shuffle=False)

@torch.no_grad()
def test_acc_only(model, loader):
    model.eval()
    acc = MulticlassAccuracy(num_classes=CFG["num_classes"]).to(device)
    for x, y in loader:
        x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
        logits = model(x); acc.update(logits, y)
    return acc.compute().item()

from textwrap import dedent

if model_eval is not None:
    test_loader = make_test_loader()
    acc = test_acc_only(model_eval, test_loader)
    test_acc_pct = round(acc*100, 2)
    row = {"input": CFG["input_size"], "patch": CFG["patch_size"], "dim": CFG["embed_dim"],
           "depth": CFG["depth"], "heads": CFG["num_heads"], "drop_path": CFG["drop_path"],
           "epochs": CFG["epochs"], "test_acc%": test_acc_pct}
    print("Results row:", row)

    final_json = {"run_id": CFG["run_id"], "config": CFG, "best_val_acc": float(ckpt.get("best_val_acc", 0.0)), "final_test_acc_pct": float(test_acc_pct)}
    with open(os.path.join(CFG["out_dir"], "results_final.json"), "w") as f: json.dump(final_json, f, indent=2)

    table = dedent(f"""
    | Config | Input | Patch | Dim | Depth | Heads | DropPath | Epochs | Test Acc (%) |
    |---|---:|---:|---:|---:|---:|---:|---:|---:|
    | {CFG['run_id']} | {CFG['input_size']}×{CFG['input_size']} | {CFG['patch_size']} | {CFG['embed_dim']} | {CFG['depth']} | {CFG['num_heads']} | {CFG['drop_path']} | {CFG['epochs']} | **{test_acc_pct}** |
    """)
    print(table)
    with open(os.path.join(CFG["out_dir"], "results_table.md"), "w") as f: f.write(table)
else:
    print("No best checkpoint found; run training first.")

Results row: {'input': 32, 'patch': 4, 'dim': 384, 'depth': 12, 'heads': 6, 'drop_path': 0.1, 'epochs': 100, 'test_acc%': 71.93}

| Config | Input | Patch | Dim | Depth | Heads | DropPath | Epochs | Test Acc (%) |
|---|---:|---:|---:|---:|---:|---:|---:|---:|
| vit_cifar10_v3_20251004-101042 | 32×32 | 4 | 384 | 12 | 6 | 0.1 | 100 | **71.93** |



## (Optional) Sanity check — tokens & forward

In [10]:
#Sanity Check
N = CFG["input_size"] // CFG["patch_size"]
tokens = N*N + 1
print("Tokens incl. CLS:", tokens)
x = torch.randn(2, 3, CFG["input_size"], CFG["input_size"]).to(device)
m = ViT(img_size=CFG["input_size"], patch_size=CFG["patch_size"], num_classes=CFG["num_classes"],
        embed_dim=CFG["embed_dim"], depth=CFG["depth"], num_heads=CFG["num_heads"],
        mlp_ratio=CFG["mlp_ratio"], drop_rate=CFG["dropout"], drop_path=CFG["drop_path"]).to(device)
with torch.inference_mode(): out = m(x)
print("Forward ok. Logits shape:", out.shape)

Tokens incl. CLS: 65
Forward ok. Logits shape: torch.Size([2, 10])


## Bonus: Short, crisp analysis (for README)

- **Patch size (32×)**: `p=4` (65 tokens incl. CLS) > `p=8` (17 tokens) for CIFAR-10 detail.
- **Depth/Width**: Small (dim=384, L=12, H=6) > Tiny (dim=192) on Colab.
- **Regularization**: RandAug + MixUp + LS + DropPath≈0.1 generalizes well.
- **Schedule**: AdamW + warmup→cosine; **EMA 0.9998** can add ~0.1–0.3% test.
- **Throughput**: 224×/p=16 (197 tokens) ≈ 9× attention compute vs 32×/p=4 (65).