In [None]:
"""
CIFAR‑100 ─ Vision Transformer scalability study + ResNet‑18 baseline
refactored version (logic preserved, style changed)
"""

# ───────────────────────────── Imports & Globals ──────────────────────────────
import time, random, numpy as np
import torch, torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from torchsummary import summary           # <- pip install torchsummary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SEED          = 42
BATCH         = 64
EPOCHS_VIT    = 20
EPOCHS_RESNET = 10
LR            = 0.001
NUM_CLASSES   = 100

torch.manual_seed(SEED);  np.random.seed(SEED);  random.seed(SEED)

# ────────────────────────────── Data pipeline ─────────────────────────────────
C100_MEAN = (0.5071, 0.4867, 0.4408)
C100_STD  = (0.2675, 0.2565, 0.2761)

transf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(C100_MEAN, C100_STD)
])

train_set = torchvision.datasets.CIFAR100('./data', train=True, download=True,  transform=transf)
test_set  = torchvision.datasets.CIFAR100('./data', train=False,               transform=transf)
train_loader = DataLoader(train_set, batch_size=BATCH, shuffle=True,  num_workers=2)
test_loader  = DataLoader(test_set,  batch_size=BATCH, shuffle=False, num_workers=2)

# ───────────────────────────── ViT building blocks ───────────────────────────
class Patchify(nn.Module):
    def __init__(self, img=32, patch=4, ch=3, dim=256):
        super().__init__()
        self.n = (img // patch) ** 2
        self.to_patch = nn.Conv2d(ch, dim, patch, patch)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
        self.pos = nn.Parameter(torch.zeros(1, self.n + 1, dim))
        nn.init.trunc_normal_(self.pos, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        B = x.size(0)
        x = self.to_patch(x).flatten(2).transpose(1, 2)            # B N D
        cls = self.cls_token.expand(B, -1, -1)
        return torch.cat([cls, x], 1) + self.pos


class MHSA(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()
        assert dim % heads == 0
        self.h = heads
        self.dk = dim // heads
        self.proj_qkv = nn.Linear(dim, dim * 3)
        self.out = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, D = x.shape
        qkv = self.proj_qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2,0,3,1,4)
        q, k, v = qkv
        att = (q @ k.transpose(-1,-2)) * (self.dk ** -0.5)
        x = (att.softmax(-1) @ v).transpose(1,2).reshape(B, N, D)
        return self.out(x)


class FeedForward(nn.Module):
    def __init__(self, dim, ratio=4, p=0.):
        super().__init__()
        hid = dim * ratio
        self.net = nn.Sequential(
            nn.Linear(dim, hid), nn.GELU(), nn.Dropout(p),
            nn.Linear(hid, dim), nn.Dropout(p)
        )
    def forward(self, x): return self.net(x)


class EncoderBlock(nn.Module):
    def __init__(self, dim, heads, mlp_ratio):
        super().__init__()
        self.norm1, self.att, self.norm2, self.ffn = (
            nn.LayerNorm(dim), MHSA(dim, heads),
            nn.LayerNorm(dim), FeedForward(dim, mlp_ratio)
        )
    def forward(self, x):
        x = x + self.att(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x


class ViT(nn.Module):
    def __init__(self, img=32, patch=4, dim=256, depth=4, heads=4,
                 mlp_ratio=4, classes=100):
        super().__init__()
        self.patch = Patchify(img, patch, 3, dim)
        self.body  = nn.Sequential(*[EncoderBlock(dim, heads, mlp_ratio)
                                     for _ in range(depth)])
        self.norm  = nn.LayerNorm(dim)
        self.head  = nn.Linear(dim, classes)
        self.apply(self._init)

    @staticmethod
    def _init(m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.patch(x)
        x = self.body(x)
        return self.head(self.norm(x[:,0]))

# ────────────────────────────── Train / Evaluate ─────────────────────────────
def loop(model, loader, opt=None):
    train = opt is not None
    model.train() if train else model.eval()
    crit = nn.CrossEntropyLoss()
    hits = tots = 0; t0 = time.time()
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        out = model(xb)
        loss = crit(out, yb)
        if train:
            opt.zero_grad(); loss.backward(); opt.step()
        hits += (out.argmax(1) == yb).sum().item()
        tots += yb.size(0)
    return hits/tots*100, time.time()-t0

# ──────────────────────────── Experiment catalogue ───────────────────────────
variants = [
    dict(tag='ViT‑Tiny',   p=4, d=256, L=4,  H=2, R=2),
    dict(tag='ViT‑Small',  p=8, d=256, L=8,  H=2, R=2),
    dict(tag='ViT‑Medium', p=4, d=512, L=4,  H=4, R=4),
    dict(tag='ViT‑Large',  p=8, d=512, L=8,  H=4, R=4),
]

log = []
for cfg in variants:
    print(f"\n🟢 Training {cfg['tag']}")
    net = ViT(patch=cfg['p'], dim=cfg['d'], depth=cfg['L'],
              heads=cfg['H'], mlp_ratio=cfg['R']).to(device)
    opt = torch.optim.Adam(net.parameters(), lr=LR)
    summary(net, input_size=(3,32,32), batch_size=BATCH, device=str(device))
    epoch_times=[]
    for ep in range(1, EPOCHS_VIT+1):
        _, sec = loop(net, train_loader, opt)
        epoch_times.append(sec)
        print(f"  epoch {ep}/{EPOCHS_VIT} ─ {sec:.2f}s")
    acc,_    = loop(net, test_loader)
    params   = sum(p.numel() for p in net.parameters())/1e6
    flops_ap = sum(p.numel() for p in net.parameters() if p.requires_grad)*2*32*32/1e9
    log.append((cfg['tag'], cfg['p'], cfg['d'], cfg['L'], cfg['H'], cfg['R'],
                params, flops_ap, np.mean(epoch_times), acc))

# ───────────────────────────── ResNet‑18 baseline ────────────────────────────
print("\n🟢 Training ResNet‑18 baseline")
res = torchvision.models.resnet18(num_classes=NUM_CLASSES).to(device)
summary(res, input_size=(3,32,32), batch_size=BATCH, device=str(device))
opt = torch.optim.Adam(res.parameters(), lr=LR)
epoch_times=[]
for ep in range(1, EPOCHS_RESNET+1):
    _, sec = loop(res, train_loader, opt)
    epoch_times.append(sec)
    print(f"  epoch {ep}/{EPOCHS_RESNET} ─ {sec:.2f}s")
acc,_   = loop(res, test_loader)
params  = sum(p.numel() for p in res.parameters())/1e6
flops   = sum(p.numel() for p in res.parameters() if p.requires_grad)*2*32*32/1e9
log.append(('ResNet‑18','N/A','N/A',18,'N/A','N/A',params,flops,np.mean(epoch_times),acc))

# ──────────────────────────────── Final table ────────────────────────────────
print("\n" + "="*118)
hdr = ("Model","Patch","Embed","Depth","Heads","MLP",
       "Params(M)","FLOPs(G)","Time/Epoch(s)","Accuracy")
print("{:<15}{:<8}{:<8}{:<8}{:<8}{:<8}{:<15}{:<15}{:<15}{:<10}".format(*hdr))
print("-"*118)
for r in log:
    print("{:<15}{:<8}{:<8}{:<8}{:<8}{:<8}{:<15.2f}{:<15.2f}{:<15.2f}{:<10.2f}".format(*r))
print("="*118)



🟢 Training ViT‑Tiny
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [64, 256, 8, 8]          12,544
          Patchify-2              [64, 65, 256]               0
         LayerNorm-3              [64, 65, 256]             512
            Linear-4              [64, 65, 768]         197,376
            Linear-5              [64, 65, 256]          65,792
              MHSA-6              [64, 65, 256]               0
         LayerNorm-7              [64, 65, 256]             512
            Linear-8              [64, 65, 512]         131,584
              GELU-9              [64, 65, 512]               0
          Dropout-10              [64, 65, 512]               0
           Linear-11              [64, 65, 256]         131,328
          Dropout-12              [64, 65, 256]               0
      FeedForward-13              [64, 65, 256]               0
     EncoderBlock-

In [1]:
# ==============================================================
# CIFAR‑100 • Swin‑Tiny / Swin‑Small fine‑tune + scratch baseline
# ==============================================================

import time
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, Tuple

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.auto import tqdm
from transformers import (
    AutoImageProcessor,
    SwinForImageClassification,
    SwinConfig,
)

# ------------------- 0. hyper‑parameters -----------------------
IMG_RES      = 224
BATCH        = 32
EPOCHS       = 5
LR           = 2e-5
NUM_CLASSES  = 100
DEVICE       = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CHECKPOINTS: Dict[str, str | None] = {
    "tiny":   "microsoft/swin-tiny-patch4-window7-224",
    "small":  "microsoft/swin-small-patch4-window7-224",
    "scratch": None,       # this one will be built from config
}

# ------------------- 1. dataset --------------------------------
proc   = AutoImageProcessor.from_pretrained(CHECKPOINTS["tiny"])
aug    = transforms.Compose(
    [transforms.Resize((IMG_RES, IMG_RES)),
     transforms.ToTensor(),
     transforms.Normalize(mean=proc.image_mean, std=proc.image_std)]
)

root   = Path("./data")
train_ds = datasets.CIFAR100(root, train=True,  transform=aug, download=True)
test_ds  = datasets.CIFAR100(root, train=False, transform=aug, download=True)

train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True,  num_workers=2)
test_loader  = DataLoader(test_ds,  batch_size=BATCH, shuffle=False, num_workers=2)

# ------------------- 2. helpers --------------------------------
def detach_backbone(model: SwinForImageClassification) -> None:
    for p in model.swin.parameters():
        p.requires_grad = False

@torch.no_grad()
def eval_top1(model: nn.Module) -> float:
    model.eval()
    correct = total = 0
    for x, y in test_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        pred = model(x).logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total   += y.numel()
    return 100.0 * correct / total

@dataclass
class Score:
    acc: float
    epoch_time: float

# ------------------- 3. training loop --------------------------
criterion = nn.CrossEntropyLoss()
scores: Dict[str, Score] = {}

for tag, ckpt in CHECKPOINTS.items():
    print(f"\n▶️  Running: {tag}")
    # model build
    if ckpt is None:                                          # scratch
        tiny_cfg = SwinConfig(
            image_size = IMG_RES,
            patch_size = 4,
            num_channels = 3,
            embed_dim   = 96,
            depths      = [2, 2, 6, 2],
            num_heads   = [3, 6, 12, 24],
            window_size = 7,
            num_labels  = NUM_CLASSES,
        )
        net = SwinForImageClassification(tiny_cfg)
        train_params = net.parameters()
    else:                                                     # pretrained
        net = SwinForImageClassification.from_pretrained(
            ckpt,
            num_labels = NUM_CLASSES,
            ignore_mismatched_sizes = True,
        )
        detach_backbone(net)
        train_params = net.classifier.parameters()

    net.to(DEVICE)
    optim = torch.optim.Adam(train_params, lr=LR)

    # epochs
    times: list[float] = []
    for ep in range(1, EPOCHS + 1):
        net.train()
        tic = time.perf_counter()
        for img, lbl in tqdm(train_loader, leave=False,
                             desc=f"{tag} | epoch {ep}/{EPOCHS}"):
            img, lbl = img.to(DEVICE), lbl.to(DEVICE)
            loss = criterion(net(img).logits, lbl)
            optim.zero_grad()
            loss.backward()
            optim.step()
        toc = time.perf_counter() - tic
        times.append(toc)
        print(f"  • epoch {ep:2d}: {toc:6.2f}s")

    # test + record
    top1 = eval_top1(net)
    mean_t = sum(times) / EPOCHS
    scores[tag] = Score(top1, mean_t)
    print(f"  ✔ acc={top1:5.2f}% | avg epoch time={mean_t:6.2f}s")

# ------------------- 4. summary --------------------------------
print("\n================  RESULTS  ================")
print(f"{'Model':<10}{'Accuracy':>12}{'Avg Epoch (s)':>16}")
print("-------------------------------------------")
for k, v in scores.items():
    print(f"{k:<10}{v.acc:>12.2f}{v.epoch_time:>16.2f}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/255 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
100%|██████████| 169M/169M [00:13<00:00, 12.6MB/s]



▶️  Running: tiny


config.json:   0%|          | 0.00/71.8k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/113M [00:00<?, ?B/s]

Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-tiny-patch4-window7-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([100, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tiny | epoch 1/5:   0%|          | 0/1563 [00:00<?, ?it/s]

  • epoch  1:  60.10s


tiny | epoch 2/5:   0%|          | 0/1563 [00:00<?, ?it/s]

  • epoch  2:  58.75s


tiny | epoch 3/5:   0%|          | 0/1563 [00:00<?, ?it/s]

  • epoch  3:  58.79s


tiny | epoch 4/5:   0%|          | 0/1563 [00:00<?, ?it/s]

  • epoch  4:  58.82s


tiny | epoch 5/5:   0%|          | 0/1563 [00:00<?, ?it/s]

  • epoch  5:  58.74s
  ✔ acc=66.52% | avg epoch time= 59.04s

▶️  Running: small


config.json:   0%|          | 0.00/71.8k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/199M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/199M [00:00<?, ?B/s]

Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-small-patch4-window7-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([100, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([100]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


small | epoch 1/5:   0%|          | 0/1563 [00:00<?, ?it/s]

  • epoch  1:  96.12s


small | epoch 2/5:   0%|          | 0/1563 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b17f7fb6c00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b17f7fb6c00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 16

  • epoch  2:  96.28s


small | epoch 3/5:   0%|          | 0/1563 [00:00<?, ?it/s]

  • epoch  3:  96.12s


small | epoch 4/5:   0%|          | 0/1563 [00:00<?, ?it/s]

  • epoch  4:  96.09s


small | epoch 5/5:   0%|          | 0/1563 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b17f7fb6c00>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^Exception ignored in: ^^<function _MultiProcessingDataLoaderIter.__del__ at 0x7b17f7fb6c00>^
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
^^    ^self._shutdown_workers()^
^  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
^^    ^if w.is_alive():^
^ ^ ^ ^ ^ ^

  • epoch  5:  96.38s
  ✔ acc=70.68% | avg epoch time= 96.20s

▶️  Running: scratch


scratch | epoch 1/5:   0%|          | 0/1563 [00:00<?, ?it/s]

  • epoch  1: 167.04s


scratch | epoch 2/5:   0%|          | 0/1563 [00:00<?, ?it/s]

  • epoch  2: 166.75s


scratch | epoch 3/5:   0%|          | 0/1563 [00:00<?, ?it/s]

  • epoch  3: 166.88s


scratch | epoch 4/5:   0%|          | 0/1563 [00:00<?, ?it/s]

  • epoch  4: 166.92s


scratch | epoch 5/5:   0%|          | 0/1563 [00:00<?, ?it/s]

  • epoch  5: 166.86s
  ✔ acc=37.34% | avg epoch time=166.89s

Model         Accuracy   Avg Epoch (s)
-------------------------------------------
tiny             66.52           59.04
small            70.68           96.20
scratch          37.34          166.89
