
# Torchvision Tutorial: Dataset → Pretrained Model → Training

This notebook walks through a minimal, **practical** pipeline using `torch` and `torchvision`:
1) Prepare a dataset (CIFAR-10 or a `FakeData` fallback)  
2) Load a pretrained `ResNet-18` from `torchvision.models`  
3) Fine-tune (or feature-extract) on our dataset and evaluate

> Tip: If you're offline in class or the server blocks downloads, the notebook will automatically switch to `torchvision.datasets.FakeData` so everything still runs.


In [None]:

import sys, torch, torchvision
print("Python:", sys.version)
print("PyTorch:", torch.__version__)
print("Torchvision:", torchvision.__version__)
print("CUDA available:", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


## 0. Config

In [None]:

from dataclasses import dataclass

@dataclass
class CFG:
    data_root: str = "./data"
    batch_size: int = 64
    num_workers: int = 0
    num_classes: int = 10          # CIFAR-10
    epochs: int = 5                 # keep small for demo; increase for real runs
    lr: float = 3e-4
    weight_decay: float = 1e-4
    feature_extract: bool = False   # True = freeze backbone and train only classifier
    seed: int = 42

cfg = CFG()
print(cfg)


## 1. Utilities

In [None]:

import random, numpy as np, torch
def set_seed(seed: int = 42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed(cfg.seed)

try:
    from tqdm.auto import tqdm
except Exception:
    def tqdm(x, **k): 
        return x  # fallback


## 2. Dataset & Dataloaders

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import torch
import numpy as np
from collections import defaultdict

basic_train_tfms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

basic_val_tfms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

def get_cifar10_or_fakedata(root: str, train: bool, tfms):
    """Try to load CIFAR-10; if unavailable, fall back to FakeData."""
    try:
        ds = datasets.CIFAR10(root=root, train=train, download=True, transform=tfms)
        print("Loaded CIFAR-10 ✅")
        return ds, 10
    except Exception as e:
        print("CIFAR-10 unavailable, switching to FakeData. Reason:", e)
        size = (3, 224, 224)
        classes = 10
        ds = datasets.FakeData(size=1000 if train else 200, image_size=size, num_classes=classes, transform=tfms)
        return ds, classes

train_ds, nclasses_train = get_cifar10_or_fakedata(cfg.data_root, train=True, tfms=basic_train_tfms)
val_ds,   nclasses_val   = get_cifar10_or_fakedata(cfg.data_root, train=False, tfms=basic_val_tfms)
assert nclasses_train == nclasses_val
cfg.num_classes = nclasses_train


# --------- Key: subset selection function ---------
def make_subset_indices(dataset, n, stratify=True, seed=42):
    """
    Return indices for a reduced subset of the dataset.
    
    - If `stratify=True`, performs class-balanced sampling (requires `dataset.targets` or `dataset.labels`).
    - If labels are not available (e.g., FakeData), falls back to random sampling.
    """
    rng = np.random.default_rng(seed)
    N = len(dataset)
    n = int(min(n, N))
    if n <= 0:
        return np.arange(N)

    # Try to access labels
    targets = None
    for attr in ["targets", "labels"]:
        if hasattr(dataset, attr):
            t = getattr(dataset, attr)
            try:
                targets = np.asarray(t)
            except Exception:
                targets = None
            break

    if not stratify or targets is None:
        # Random sampling only
        return rng.choice(N, size=n, replace=False)

    # Stratified sampling: balance across classes
    by_cls = defaultdict(list)
    for i, y in enumerate(targets):
        by_cls[int(y)].append(i)
    per_cls = n // len(by_cls) if len(by_cls) > 0 else n
    leftover = n - per_cls * len(by_cls)

    chosen = []
    for c, idxs in by_cls.items():
        idxs = np.asarray(idxs)
        take = min(per_cls, len(idxs))
        chosen.append(rng.choice(idxs, size=take, replace=False))
    chosen = np.concatenate(chosen) if len(chosen) else np.array([], dtype=int)

    # Fill leftover slots randomly
    if leftover > 0:
        mask = np.ones(N, dtype=bool)
        mask[chosen] = False
        pool = np.where(mask)[0]
        if len(pool) > 0:
            extra = rng.choice(pool, size=min(leftover, len(pool)), replace=False)
            chosen = np.concatenate([chosen, extra])

    return chosen


# --------- Configure subset size (0 = use full dataset) ---------
num_train_samples = 128   # Smaller subset for faster demo
num_val_samples   = 256   # Validation can be larger; set 0 to use all data

train_indices = make_subset_indices(train_ds, num_train_samples, stratify=True, seed=cfg.seed)
val_indices   = make_subset_indices(val_ds,   num_val_samples,   stratify=True, seed=cfg.seed)

train_subset = Subset(train_ds, train_indices)
val_subset   = Subset(val_ds,   val_indices)

train_loader = DataLoader(train_subset, batch_size=cfg.batch_size, shuffle=True,
                          num_workers=cfg.num_workers, pin_memory=True)
val_loader   = DataLoader(val_subset,   batch_size=cfg.batch_size, shuffle=False,
                          num_workers=cfg.num_workers, pin_memory=True)

len(train_subset), len(val_subset)

### Peek at a batch

In [None]:

import math
import matplotlib.pyplot as plt

images, labels = next(iter(train_loader))
grid_cols = 8
grid_rows = min(2, math.ceil(images.size(0)/grid_cols))
plt.figure(figsize=(12, 4))
for i in range(grid_rows*grid_cols):
    if i >= images.size(0): break
    plt.subplot(grid_rows, grid_cols, i+1)
    img = images[i].permute(1,2,0).numpy()
    plt.imshow(img)
    plt.title(int(labels[i]))
    plt.axis("off")
plt.show()


## 3. Load a Pretrained Model (ResNet-18)

In [None]:

import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights

weights = ResNet18_Weights.DEFAULT
preprocess = weights.transforms()


model = resnet18(weights=weights)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, cfg.num_classes)

if cfg.feature_extract:
    for name, param in model.named_parameters():
        if not name.startswith("fc"):
            param.requires_grad = False

model = model.to(device)
model


## 4. Training Setup

In [None]:

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                        lr=cfg.lr, weight_decay=cfg.weight_decay)


## 5. Train & Evaluate

In [None]:

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss, total_correct, total = 0.0, 0, 0
    for images, labels in tqdm(loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            logits = model(images)
            loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        total_correct += (logits.argmax(1) == labels).sum().item()
        total += images.size(0)
    return total_loss/total, total_correct/total

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total_loss, total_correct, total = 0.0, 0, 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        logits = model(images)
        loss = criterion(logits, labels)
        total_loss += loss.item() * images.size(0)
        total_correct += (logits.argmax(1) == labels).sum().item()
        total += images.size(0)
    return total_loss/total, total_correct/total

history = []
for epoch in range(1, cfg.epochs+1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, device)
    val_loss, val_acc = evaluate(model, val_loader, device)
    history.append((epoch, train_loss, train_acc, val_loss, val_acc))
    print(f"Epoch {epoch:02d}/{cfg.epochs} | "
          f"train loss {train_loss:.4f} acc {train_acc:.4f} | "
          f"val loss {val_loss:.4f} acc {val_acc:.4f}")
history


### Plot training curves

In [None]:

import matplotlib.pyplot as plt
if history:
    epochs, tr_l, tr_a, va_l, va_a = zip(*history)
    plt.figure(figsize=(6,4))
    plt.plot(epochs, tr_l, label="train loss")
    plt.plot(epochs, va_l, label="val loss")
    plt.xlabel("epoch"); plt.ylabel("loss"); plt.legend(); plt.title("Loss")
    plt.show()

    plt.figure(figsize=(6,4))
    plt.plot(epochs, tr_a, label="train acc")
    plt.plot(epochs, va_a, label="val acc")
    plt.xlabel("epoch"); plt.ylabel("accuracy"); plt.legend(); plt.title("Accuracy")
    plt.show()
else:
    print("No history yet. Run training first.")


## 6. Save & Load the Model

In [None]:

save_path = "resnet18_finetuned.pth"
import torch
torch.save({
    "state_dict": model.state_dict(),
    "cfg": vars(cfg),
}, save_path)
print("Saved to:", save_path)

# Example of how to load later:
# state = torch.load(save_path, map_location="cpu")
# model.load_state_dict(state["state_dict"], strict=False)



## 7. Extensions (For Students)
- **Switch to feature extraction:** set `CFG.feature_extract = True` and re-run model cell.
- **Try a different backbone:** `torchvision.models.efficientnet_b0`, `mobilenet_v3_small`, etc.
- **Augmentations:** add `transforms.AutoAugment` or change `RandAugment` policy.
- **Schedulers:** try `OneCycleLR` or `StepLR` and compare.
- **Per-class accuracy:** compute confusion matrix with `sklearn.metrics.confusion_matrix`.
- **Early stopping** and **model checkpointing** on best validation accuracy.
