# Vision Transformer on CIFAR-10 (PyTorch)


# Setup and Advanced Training Utilities for ViT

# This section handles the necessary imports, hardware setup, reproducibility settings, and introduces three critical utility modules (**DropPath, WarmupCosine Scheduler, Label Smoothing**) required for robust Vision Transformer (ViT) training, particularly on smaller datasets like CIFAR-10.


In [1]:
!nvidia-smi -L

import math, os, random, json, time
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# Repro
seed = 1337
random.seed(seed); np.random.seed(seed)
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = True

# --- DropPath (stochastic depth) ---
class DropPath(nn.Module):
    def __init__(self, p=0.0):
        super().__init__(); self.p = float(p)
    def forward(self, x):
        if self.p == 0.0 or not self.training:
            return x
        keep = 1.0 - self.p
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        mask = x.new_empty(shape).bernoulli_(keep).div_(keep)
        return x * mask

# --- Warmup + Cosine scheduler (step every batch) ---
class WarmupCosine:
    def __init__(self, optimizer, warmup_steps, total_steps, min_lr=1e-5):
        self.opt = optimizer
        self.warm = int(warmup_steps)
        self.total = int(total_steps)
        self.min_lr = float(min_lr)
        self.base = [g['lr'] for g in optimizer.param_groups]
        self.t = 0
    def step(self):
        self.t += 1
        for i, g in enumerate(self.opt.param_groups):
            base = self.base[i]
            if self.t <= self.warm:
                lr = base * self.t / max(1, self.warm)
            else:
                p = (self.t - self.warm) / max(1, self.total - self.warm)
                lr = self.min_lr + 0.5*(base - self.min_lr)*(1 + math.cos(math.pi * p))
            g['lr'] = lr
    def get_last_lr(self):
        return [g['lr'] for g in self.opt.param_groups]

# --- Label smoothing CE (works well for ViT) ---
class LabelSmoothingCE(nn.Module):
    def __init__(self, eps=0.1):
        super().__init__()
        self.eps = eps
    def forward(self, logits, target):
        n = logits.size(-1)
        logp = F.log_softmax(logits, dim=-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(logits).fill_(self.eps/(n-1))
            true_dist.scatter_(1, target.unsqueeze(1), 1 - self.eps)
        return torch.mean(torch.sum(-true_dist * logp, dim=-1))


GPU 0: Tesla T4 (UUID: GPU-47c9d731-ee0a-8b3e-5aca-0f092b4b3042)
GPU 1: Tesla T4 (UUID: GPU-1a08dde5-652d-14c6-0fdd-8f08473c8822)
Device: cuda


# ==== ViT implementation ====


In [2]:
# ==== ViT implementation ====
from torch import nn

class NewGELUActivation(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

class PatchEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.image_size = config["image_size"]
        self.patch_size = config["patch_size"]
        self.num_channels = config["num_channels"]
        self.hidden_size = config["hidden_size"]
        assert self.image_size % self.patch_size == 0
        self.num_patches = (self.image_size // self.patch_size) ** 2
        self.projection = nn.Conv2d(self.num_channels, self.hidden_size,
                                    kernel_size=self.patch_size, stride=self.patch_size)
    def forward(self, x):
        x = self.projection(x)                    # (B, D, H/ps, W/ps)
        x = x.flatten(2).transpose(1, 2)          # (B, N, D)
        return x

class Embeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.patch_embeddings = PatchEmbeddings(config)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config["hidden_size"]))
        self.position_embeddings = nn.Parameter(
            torch.zeros(1, self.patch_embeddings.num_patches + 1, config["hidden_size"])
        )
        self.dropout = nn.Dropout(config["hidden_dropout_prob"])
    def forward(self, x):
        x = self.patch_embeddings(x)              # (B, N, D)
        B = x.size(0)
        cls = self.cls_token.expand(B, -1, -1)    # (B,1,D)
        x = torch.cat([cls, x], dim=1)            # prepend CLS
        x = x + self.position_embeddings          # learnable pos
        return self.dropout(x)

class FasterMultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config["hidden_size"]
        self.num_heads = config["num_attention_heads"]
        self.head_dim = self.hidden_size // self.num_heads
        self.qkv = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=config["qkv_bias"])
        self.attn_drop = nn.Dropout(config["attention_probs_dropout_prob"])
        self.proj = nn.Linear(self.hidden_size, self.hidden_size)
        self.out_drop = nn.Dropout(config["hidden_dropout_prob"])
        self.scale = self.head_dim ** -0.5
    def forward(self, x, output_attentions=False):
        B, N, C = x.shape
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = [t.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) for t in qkv]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        out = attn @ v
        out = out.transpose(1, 2).contiguous().view(B, N, C)
        out = self.out_drop(self.proj(out))
        if output_attentions:
            return out, attn
        return out, None

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.fc1 = nn.Linear(config["hidden_size"], config["intermediate_size"])
        self.act = NewGELUActivation()
        self.fc2 = nn.Linear(config["intermediate_size"], config["hidden_size"])
        self.drop = nn.Dropout(config["hidden_dropout_prob"])
    def forward(self, x):
        x = self.drop(self.act(self.fc1(x)))
        x = self.drop(self.fc2(x))
        return x

class Block(nn.Module):
    def __init__(self, config, drop_path=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(config["hidden_size"])
        self.attn = FasterMultiHeadAttention(config) if config.get("use_faster_attention", True) else MultiHeadAttention(config)
        self.drop_path1 = DropPath(drop_path)
        self.norm2 = nn.LayerNorm(config["hidden_size"])
        self.mlp = MLP(config)
        self.drop_path2 = DropPath(drop_path)
    def forward(self, x, output_attentions=False):
        attn_out, attn_probs = self.attn(self.norm1(x), output_attentions=output_attentions)
        x = x + self.drop_path1(attn_out)          # residual
        mlp_out = self.mlp(self.norm2(x))
        x = x + self.drop_path2(mlp_out)           # residual
        if output_attentions: return x, attn_probs
        return x, None

class Encoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        depth = config["num_hidden_layers"]
        # stochastic depth schedule (0 -> drop_path)
        dpr = torch.linspace(0, config.get("drop_path", 0.1), steps=depth).tolist()
        self.blocks = nn.ModuleList([Block(config, drop_path=dpr[i]) for i in range(depth)])
    def forward(self, x, output_attentions=False):
        atts = []
        for blk in self.blocks:
            x, att = blk(x, output_attentions=output_attentions)
            if output_attentions: atts.append(att)
        if output_attentions: return x, atts
        return x, None

class ViTForClassification(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embedding = Embeddings(config)
        self.encoder = Encoder(config)
        self.norm = nn.LayerNorm(config["hidden_size"])
        self.head = nn.Linear(config["hidden_size"], config["num_classes"])
        self.apply(self._init_weights)
    def forward(self, x, output_attentions=False):
        x = self.embedding(x)
        x, atts = self.encoder(x, output_attentions=output_attentions)
        cls = self.norm(x)[:, 0]                   # classify from CLS
        logits = self.head(cls)
        if output_attentions: return logits, atts
        return logits, None
    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            nn.init.trunc_normal_(m.weight, std=self.config["initializer_range"])
            if getattr(m, "bias", None) is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
        elif isinstance(m, Embeddings):
            nn.init.trunc_normal_(m.position_embeddings, std=self.config["initializer_range"])
            nn.init.trunc_normal_(m.cls_token, std=self.config["initializer_range"])


# ==== Data: CIFAR-10 with strong aug ====


In [3]:
# ==== Data: CIFAR-10 with strong aug ====
# CIFAR-10 stats
MEAN, STD = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)

train_tfms = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])

test_tfms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD),
])

train_ds = datasets.CIFAR10(root='/content/data', train=True, download=True, transform=train_tfms)
test_ds  = datasets.CIFAR10(root='/content/data', train=False, download=True, transform=test_tfms)

# T4 safe batch sizes (AMP helps). Try 128 first; bump to 256 if memory is comfy.
train_bs = 128
test_bs  = 256

train_loader = DataLoader(train_ds, batch_size=train_bs, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=test_bs, shuffle=False, num_workers=2, pin_memory=True)
num_classes = 10


100%|██████████| 170M/170M [00:02<00:00, 78.9MB/s]


# ==== Trainer with AMP + scheduler + smoothing ====


In [4]:
# ==== Trainer with AMP + scheduler + smoothing ====
class Trainer:
    def __init__(self, model, optimizer, loss_fn, device, scheduler=None, exp_name="vit-exp"):
        self.model = model.to(device)
        self.opt = optimizer
        self.loss_fn = loss_fn
        self.device = device
        self.sched = scheduler
        self.scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))
        self.exp_name = exp_name
        os.makedirs(f"/content/{exp_name}", exist_ok=True)

    def train_epoch(self, loader):
        self.model.train()
        total_loss, total, correct = 0.0, 0, 0
        for imgs, labels in loader:
            imgs, labels = imgs.to(self.device), labels.to(self.device)
            self.opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=(self.device=="cuda")):
                logits, _ = self.model(imgs)
                loss = self.loss_fn(logits, labels)
            self.scaler.scale(loss).backward()
            self.scaler.step(self.opt)
            self.scaler.update()
            if self.sched: self.sched.step()
            total_loss += loss.item() * imgs.size(0)
            total += imgs.size(0)
            correct += (logits.argmax(1) == labels).sum().item()
        return total_loss/total, correct/total

    @torch.no_grad()
    def evaluate(self, loader):
        self.model.eval()
        total_loss, total, correct = 0.0, 0, 0
        for imgs, labels in loader:
            imgs, labels = imgs.to(self.device), labels.to(self.device)
            logits, _ = self.model(imgs)
            loss = F.cross_entropy(logits, labels)
            total_loss += loss.item() * imgs.size(0)
            total += imgs.size(0)
            correct += (logits.argmax(1) == labels).sum().item()
        return total_loss/total, correct/total

    def fit(self, train_loader, test_loader, epochs, save_best=True):
        best_acc = 0.0
        history = {"train_loss": [], "train_acc": [], "test_loss": [], "test_acc": []}
        for ep in range(1, epochs+1):
            tr_loss, tr_acc = self.train_epoch(train_loader)
            te_loss, te_acc = self.evaluate(test_loader)
            history["train_loss"].append(tr_loss); history["train_acc"].append(tr_acc)
            history["test_loss"].append(te_loss);  history["test_acc"].append(te_acc)
            print(f"Epoch {ep:03d} | train_acc={tr_acc:.4f} loss={tr_loss:.4f} | test_acc={te_acc:.4f} loss={te_loss:.4f}")
            if save_best and te_acc > best_acc:
                best_acc = te_acc
                torch.save(self.model.state_dict(), f"/content/{self.exp_name}/best.pth")
        print("Best test acc:", best_acc)
        with open(f"/content/{self.exp_name}/metrics.json", "w") as f:
            json.dump(history, f, indent=2)
        return history


# ==== Config + train ====


In [5]:
# ==== Config + train ====
exp_name = "vit_cifar10_colab_t4"

config = {
    "image_size": 32,
    "patch_size": 4,                 # 32/4 -> 8x8 = 64 tokens
    "num_channels": 3,
    "hidden_size": 384,              # good capacity on T4
    "num_hidden_layers": 8,          # if VRAM allows, try 10 later
    "num_attention_heads": 6,        # 384 / 6 = 64
    "intermediate_size": 4 * 384,
    "hidden_dropout_prob": 0.10,     # ↑ regularization
    "attention_probs_dropout_prob": 0.10,
    "initializer_range": 0.02,
    "num_classes": 10,
    "qkv_bias": True,
    "use_faster_attention": True,
    "drop_path": 0.20,               # ↑ stochastic depth
}
# Model
model = ViTForClassification(config)

# Optimizer / schedule
epochs = 150
base_lr = 3e-4
wd = 5e-2
opt = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=wd)
total_steps = epochs * len(train_loader)
sched = WarmupCosine(opt, warmup_steps=5*len(train_loader), total_steps=total_steps, min_lr=1e-5)

# Loss (label smoothing)
loss_fn = LabelSmoothingCE(eps=0.1)

trainer = Trainer(model, opt, loss_fn, device=device, scheduler=sched, exp_name=exp_name)
history = trainer.fit(train_loader, test_loader, epochs=epochs, save_best=True)


  self.scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))
  with torch.cuda.amp.autocast(enabled=(self.device=="cuda")):


Epoch 001 | train_acc=0.2175 loss=2.1790 | test_acc=0.3073 loss=1.8423
Epoch 002 | train_acc=0.3216 loss=1.9417 | test_acc=0.4051 loss=1.6341
Epoch 003 | train_acc=0.3798 loss=1.8366 | test_acc=0.4682 loss=1.4862
Epoch 004 | train_acc=0.4102 loss=1.7725 | test_acc=0.4658 loss=1.4663
Epoch 005 | train_acc=0.4333 loss=1.7332 | test_acc=0.4935 loss=1.4099
Epoch 006 | train_acc=0.4489 loss=1.6983 | test_acc=0.5256 loss=1.3226
Epoch 007 | train_acc=0.4664 loss=1.6665 | test_acc=0.5539 loss=1.2643
Epoch 008 | train_acc=0.4774 loss=1.6443 | test_acc=0.5541 loss=1.2623
Epoch 009 | train_acc=0.4860 loss=1.6251 | test_acc=0.5650 loss=1.2471
Epoch 010 | train_acc=0.4964 loss=1.6045 | test_acc=0.5652 loss=1.2348
Epoch 011 | train_acc=0.5103 loss=1.5773 | test_acc=0.5644 loss=1.2059
Epoch 012 | train_acc=0.5161 loss=1.5635 | test_acc=0.5746 loss=1.2144
Epoch 013 | train_acc=0.5232 loss=1.5483 | test_acc=0.6008 loss=1.1183
Epoch 014 | train_acc=0.5294 loss=1.5388 | test_acc=0.6076 loss=1.1192
Epoch 

# ==== Optional: reload best checkpoint & eval ====


In [6]:
# ==== Optional: reload best checkpoint & eval ====
best_path = f"/content/{exp_name}/best.pth"
if os.path.exists(best_path):
    model.load_state_dict(torch.load(best_path, map_location=device))
    model.to(device)
    test_loss, test_acc = trainer.evaluate(test_loader)
    print("Reloaded best checkpoint -> test_acc:", f"{test_acc:.4f}")
else:
    print("No best checkpoint found yet.")


Reloaded best checkpoint -> test_acc: 0.8760


In [1]:
pip install torch weightwatcher

Collecting torch
  Using cached torch-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (30 kB)
Collecting weightwatcher
  Downloading weightwatcher-0.7.5.5-py3-none-any.whl.metadata (26 kB)
Collecting filelock (from torch)
  Using cached filelock-3.19.1-py3-none-any.whl.metadata (2.1 kB)
Collecting sympy>=1.13.3 (from torch)
  Using cached sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Using cached networkx-3.5-py3-none-any.whl.metadata (6.3 kB)
Collecting fsspec (from torch)
  Using cached fsspec-2025.9.0-py3-none-any.whl.metadata (10 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.8.93 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cuda-runtime-cu12==12.8.90 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cuda-cupti-cu12==12.8.90 (from torch)
  U