In [7]:
"""
CIFAR‑100 ─ Vision Transformer scalability study + ResNet‑18 baseline
refactored version (logic preserved, style changed)
"""

# ───────────────────────────── Imports & Globals ──────────────────────────────
import time, random, numpy as np
import torch, torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from torchsummary import summary           # <- pip install torchsummary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SEED          = 42
BATCH         = 64
EPOCHS_VIT    = 20
EPOCHS_RESNET = 10
LR            = 0.001
NUM_CLASSES   = 100

torch.manual_seed(SEED);  np.random.seed(SEED);  random.seed(SEED)

# ────────────────────────────── Data pipeline ─────────────────────────────────
C100_MEAN = (0.5071, 0.4867, 0.4408)
C100_STD  = (0.2675, 0.2565, 0.2761)

transf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(C100_MEAN, C100_STD)
])

train_set = torchvision.datasets.CIFAR100('./data', train=True, download=True,  transform=transf)
test_set  = torchvision.datasets.CIFAR100('./data', train=False,               transform=transf)
train_loader = DataLoader(train_set, batch_size=BATCH, shuffle=True,  num_workers=2)
test_loader  = DataLoader(test_set,  batch_size=BATCH, shuffle=False, num_workers=2)

# ───────────────────────────── ViT building blocks ───────────────────────────
class Patchify(nn.Module):
    def __init__(self, img=32, patch=4, ch=3, dim=256):
        super().__init__()
        self.n = (img // patch) ** 2
        self.to_patch = nn.Conv2d(ch, dim, patch, patch)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
        self.pos = nn.Parameter(torch.zeros(1, self.n + 1, dim))
        nn.init.trunc_normal_(self.pos, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        B = x.size(0)
        x = self.to_patch(x).flatten(2).transpose(1, 2)            # B N D
        cls = self.cls_token.expand(B, -1, -1)
        return torch.cat([cls, x], 1) + self.pos


class MHSA(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()
        assert dim % heads == 0
        self.h = heads
        self.dk = dim // heads
        self.proj_qkv = nn.Linear(dim, dim * 3)
        self.out = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, D = x.shape
        qkv = self.proj_qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2,0,3,1,4)
        q, k, v = qkv
        att = (q @ k.transpose(-1,-2)) * (self.dk ** -0.5)
        x = (att.softmax(-1) @ v).transpose(1,2).reshape(B, N, D)
        return self.out(x)


class FeedForward(nn.Module):
    def __init__(self, dim, ratio=4, p=0.):
        super().__init__()
        hid = dim * ratio
        self.net = nn.Sequential(
            nn.Linear(dim, hid), nn.GELU(), nn.Dropout(p),
            nn.Linear(hid, dim), nn.Dropout(p)
        )
    def forward(self, x): return self.net(x)


class EncoderBlock(nn.Module):
    def __init__(self, dim, heads, mlp_ratio):
        super().__init__()
        self.norm1, self.att, self.norm2, self.ffn = (
            nn.LayerNorm(dim), MHSA(dim, heads),
            nn.LayerNorm(dim), FeedForward(dim, mlp_ratio)
        )
    def forward(self, x):
        x = x + self.att(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x


class ViT(nn.Module):
    def __init__(self, img=32, patch=4, dim=256, depth=4, heads=4,
                 mlp_ratio=4, classes=100):
        super().__init__()
        self.patch = Patchify(img, patch, 3, dim)
        self.body  = nn.Sequential(*[EncoderBlock(dim, heads, mlp_ratio)
                                     for _ in range(depth)])
        self.norm  = nn.LayerNorm(dim)
        self.head  = nn.Linear(dim, classes)
        self.apply(self._init)

    @staticmethod
    def _init(m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.patch(x)
        x = self.body(x)
        return self.head(self.norm(x[:,0]))

# ────────────────────────────── Train / Evaluate ─────────────────────────────
def loop(model, loader, opt=None):
    train = opt is not None
    model.train() if train else model.eval()
    crit = nn.CrossEntropyLoss()
    hits = tots = 0; t0 = time.time()
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        out = model(xb)
        loss = crit(out, yb)
        if train:
            opt.zero_grad(); loss.backward(); opt.step()
        hits += (out.argmax(1) == yb).sum().item()
        tots += yb.size(0)
    return hits/tots*100, time.time()-t0

# ──────────────────────────── Experiment catalogue ───────────────────────────
variants = [
    dict(tag='ViT‑Tiny',   p=4, d=256, L=4,  H=2, R=2),
    dict(tag='ViT‑Small',  p=8, d=256, L=8,  H=2, R=2),
    dict(tag='ViT‑Medium', p=4, d=512, L=4,  H=4, R=4),
    dict(tag='ViT‑Large',  p=8, d=512, L=8,  H=4, R=4),
]

log = []
for cfg in variants:
    print(f"\n🟢 Training {cfg['tag']}")
    net = ViT(patch=cfg['p'], dim=cfg['d'], depth=cfg['L'],
              heads=cfg['H'], mlp_ratio=cfg['R']).to(device)
    opt = torch.optim.Adam(net.parameters(), lr=LR)
    summary(net, input_size=(3,32,32), batch_size=BATCH, device=str(device))
    epoch_times=[]
    for ep in range(1, EPOCHS_VIT+1):
        _, sec = loop(net, train_loader, opt)
        epoch_times.append(sec)
        print(f"  epoch {ep}/{EPOCHS_VIT} ─ {sec:.2f}s")
    acc,_    = loop(net, test_loader)
    params   = sum(p.numel() for p in net.parameters())/1e6
    flops_ap = sum(p.numel() for p in net.parameters() if p.requires_grad)*2*32*32/1e9
    log.append((cfg['tag'], cfg['p'], cfg['d'], cfg['L'], cfg['H'], cfg['R'],
                params, flops_ap, np.mean(epoch_times), acc))

# ───────────────────────────── ResNet‑18 baseline ────────────────────────────
print("\n🟢 Training ResNet‑18 baseline")
res = torchvision.models.resnet18(num_classes=NUM_CLASSES).to(device)
summary(res, input_size=(3,32,32), batch_size=BATCH, device=str(device))
opt = torch.optim.Adam(res.parameters(), lr=LR)
epoch_times=[]
for ep in range(1, EPOCHS_RESNET+1):
    _, sec = loop(res, train_loader, opt)
    epoch_times.append(sec)
    print(f"  epoch {ep}/{EPOCHS_RESNET} ─ {sec:.2f}s")
acc,_   = loop(res, test_loader)
params  = sum(p.numel() for p in res.parameters())/1e6
flops   = sum(p.numel() for p in res.parameters() if p.requires_grad)*2*32*32/1e9
log.append(('ResNet‑18','N/A','N/A',18,'N/A','N/A',params,flops,np.mean(epoch_times),acc))

# ──────────────────────────────── Final table ────────────────────────────────
print("\n" + "="*118)
hdr = ("Model","Patch","Embed","Depth","Heads","MLP",
       "Params(M)","FLOPs(G)","Time/Epoch(s)","Accuracy")
print("{:<15}{:<8}{:<8}{:<8}{:<8}{:<8}{:<15}{:<15}{:<15}{:<10}".format(*hdr))
print("-"*118)
for r in log:
    print("{:<15}{:<8}{:<8}{:<8}{:<8}{:<8}{:<15.2f}{:<15.2f}{:<15.2f}{:<10.2f}".format(*r))
print("="*118)



🟢 Training ViT‑Tiny
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [64, 256, 8, 8]          12,544
          Patchify-2              [64, 65, 256]               0
         LayerNorm-3              [64, 65, 256]             512
            Linear-4              [64, 65, 768]         197,376
            Linear-5              [64, 65, 256]          65,792
              MHSA-6              [64, 65, 256]               0
         LayerNorm-7              [64, 65, 256]             512
            Linear-8              [64, 65, 512]         131,584
              GELU-9              [64, 65, 512]               0
          Dropout-10              [64, 65, 512]               0
           Linear-11              [64, 65, 256]         131,328
          Dropout-12              [64, 65, 256]               0
      FeedForward-13              [64, 65, 256]               0
     EncoderBlock-

In [8]:
# ================================================================
# Fine‑tune or train Swin Transformers on CIFAR‑100 (PyTorch)
# ================================================================
# ➜ Copy everything into one Colab cell and hit “Run”.
# ------------------------------------------------
# Imports
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict

import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from tqdm.auto import tqdm

from transformers import (
    AutoImageProcessor,
    SwinForImageClassification,
    SwinConfig,
)

# ------------------------------------------------
# Config section
DEVICE      = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE    = 224
BATCH_SIZE  = 32
EPOCHS      = 5
LR          = 2e-5
NUM_CLASSES = 100

CKPTS = {
    "swin_tiny":  "microsoft/swin-tiny-patch4-window7-224",
    "swin_small": "microsoft/swin-small-patch4-window7-224",
    "scratch":    None,                 # will build from config
}

# ------------------------------------------------
# 1. Data
processor = AutoImageProcessor.from_pretrained(CKPTS["swin_tiny"])
transform = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=processor.image_mean,
                             std=processor.image_std),
    ]
)
data_root = Path("./data")
train_set = torchvision.datasets.CIFAR100(
    root=data_root, train=True,  download=True, transform=transform
)
test_set = torchvision.datasets.CIFAR100(
    root=data_root, train=False, download=True, transform=transform
)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2)
test_loader  = DataLoader(test_set,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# ------------------------------------------------
# 2. Utilities
def freeze_backbone(model: SwinForImageClassification):
    """Disable gradients for Swin backbone only (leave classifier trainable)."""
    for p in model.swin.parameters():
        p.requires_grad = False

@torch.no_grad()
def accuracy(model, loader):
    model.eval()
    correct = 0
    total   = 0
    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        logits = model(x).logits
        preds  = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total   += y.numel()
    return 100.0 * correct / total

@dataclass
class Metrics:
    acc: float
    epoch_time: float

# ------------------------------------------------
# 3. Train / evaluate loop
results: Dict[str, Metrics] = {}

criterion = nn.CrossEntropyLoss()

for alias, ckpt in CKPTS.items():
    print(f"\n⏩  Experiment: {alias}")
    # ---- build model
    if ckpt is None:                           # train from scratch
        config = SwinConfig(
            image_size = IMG_SIZE,
            patch_size = 4,
            num_channels = 3,
            embed_dim   = 96,
            depths      = [2, 2, 6, 2],
            num_heads   = [3, 6, 12, 24],
            window_size = 7,
            num_labels  = NUM_CLASSES,
        )
        model = SwinForImageClassification(config)
    else:                                      # start from HF weights
        model = SwinForImageClassification.from_pretrained(
            ckpt,
            num_labels = NUM_CLASSES,
            ignore_mismatched_sizes = True,
        )
        freeze_backbone(model)                 # only the head learns

    model.to(DEVICE)
    trainable = [p for p in model.parameters() if p.requires_grad]

    optimizer = torch.optim.Adam(trainable, lr=LR)

    # ---- training
    epoch_times = []
    for epoch in range(1, EPOCHS + 1):
        model.train()
        start = time.perf_counter()
        loop = tqdm(train_loader, desc=f"[{alias}] epoch {epoch}/{EPOCHS}", leave=False)
        for images, labels in loop:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            loss = criterion(model(images).logits, labels)
            loss.backward()
            optimizer.step()
            loop.set_postfix(loss=f"{loss.item():4.2f}")
        elapsed = time.perf_counter() - start
        epoch_times.append(elapsed)
        print(f"    Epoch {epoch:>2}: {elapsed:6.1f}s")

    # ---- evaluation
    test_acc = accuracy(model, test_loader)
    mean_time = sum(epoch_times) / EPOCHS
    results[alias] = Metrics(acc=test_acc, epoch_time=mean_time)
    print(f"    ✅  Test accuracy: {test_acc:5.2f}% | avg epoch time: {mean_time:6.1f}s")

# ------------------------------------------------
# 4. Pretty‑print final scorecard
print("\n==================  Summary  ==================")
print(f"{'Model':<12}{'Accuracy (%)':>14}{'Avg Epoch Time (s)':>20}")
print("-" * 46)
for k, v in results.items():
    print(f"{k:<12}{v.acc:>14.2f}{v.epoch_time:>20.2f}")



⏩  Experiment: swin_tiny


Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-tiny-patch4-window7-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([100, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[swin_tiny] epoch 1/5:   0%|          | 0/1563 [00:00<?, ?it/s]

    Epoch  1:   69.8s


[swin_tiny] epoch 2/5:   0%|          | 0/1563 [00:00<?, ?it/s]

    Epoch  2:   69.3s


[swin_tiny] epoch 3/5:   0%|          | 0/1563 [00:00<?, ?it/s]

    Epoch  3:   69.1s


[swin_tiny] epoch 4/5:   0%|          | 0/1563 [00:00<?, ?it/s]

    Epoch  4:   69.2s


[swin_tiny] epoch 5/5:   0%|          | 0/1563 [00:00<?, ?it/s]

    Epoch  5:   69.3s
    ✅  Test accuracy: 66.25% | avg epoch time:   69.3s

⏩  Experiment: swin_small


config.json:   0%|          | 0.00/71.8k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/199M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/199M [00:00<?, ?B/s]

Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-small-patch4-window7-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([100, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[swin_small] epoch 1/5:   0%|          | 0/1563 [00:00<?, ?it/s]

    Epoch  1:  105.3s


[swin_small] epoch 2/5:   0%|          | 0/1563 [00:00<?, ?it/s]

    Epoch  2:  115.4s


[swin_small] epoch 3/5:   0%|          | 0/1563 [00:00<?, ?it/s]

    Epoch  3:  105.2s


[swin_small] epoch 4/5:   0%|          | 0/1563 [00:00<?, ?it/s]

    Epoch  4:  105.3s


[swin_small] epoch 5/5:   0%|          | 0/1563 [00:00<?, ?it/s]

    Epoch  5:  105.3s
    ✅  Test accuracy: 70.33% | avg epoch time:  107.3s

⏩  Experiment: scratch


[scratch] epoch 1/5:   0%|          | 0/1563 [00:00<?, ?it/s]

    Epoch  1:  178.4s


[scratch] epoch 2/5:   0%|          | 0/1563 [00:00<?, ?it/s]

    Epoch  2:  178.3s


[scratch] epoch 3/5:   0%|          | 0/1563 [00:00<?, ?it/s]

    Epoch  3:  178.2s


[scratch] epoch 4/5:   0%|          | 0/1563 [00:00<?, ?it/s]

    Epoch  4:  178.0s


[scratch] epoch 5/5:   0%|          | 0/1563 [00:00<?, ?it/s]

    Epoch  5:  178.2s
    ✅  Test accuracy: 36.04% | avg epoch time:  178.2s

Model         Accuracy (%)  Avg Epoch Time (s)
----------------------------------------------
swin_tiny            66.25               69.34
swin_small           70.33              107.31
scratch              36.04              178.21
