In [1]:
"""
ENGR 413 — HW 3 Template (PyTorch + CIFAR-10)
Fall 2025 • Starter code covering Q1–Q13
Run cell-by-cell in a Jupyter Notebook or as a .py script.
You may trim epochs to fit your hardware. Plots and logs are saved to ./outputs.
"""

# ===============
# 0) Imports & Setup
# ===============
import os, math, time, random
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR

import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt

DEVICE = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
os.makedirs("outputs", exist_ok=True)

# =====================
# 1) Data Preparation (Q1, Q2)
# =====================
CIFAR10_CLASSES = ('airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck')

@dataclass
class DataCfg:
    batch_size: int = 128  # (Q8) you may change in experiments
    num_workers: int = 2
    data_root: str = "./data"

CFG = DataCfg()

# Q1.1: Load CIFAR10, compute global mean/std, and normalize

def compute_cifar10_mean_std(root: str = CFG.data_root) -> Tuple[Tuple[float,...], Tuple[float,...]]:
    tmp = torchvision.datasets.CIFAR10(root, train=True, download=True, transform=T.ToTensor())
    loader = DataLoader(tmp, batch_size=512, shuffle=False, num_workers=CFG.num_workers)
    n = 0
    mean = torch.zeros(3)
    M2 = torch.zeros(3)  # for online variance (Welford)
    for x, _ in loader:
        b = x.shape[0]
        x = x.view(b, 3, -1)
        batch_mean = x.mean(dim=(0, 2))
        batch_var = x.var(dim=(0, 2), unbiased=False)
        if n == 0:
            mean = batch_mean
            M2 = batch_var
            n = b
        else:
            new_n = n + b
            delta = batch_mean - mean
            mean = mean + delta * (b / new_n)
            M2 = (n * (M2 + (delta**2) * (b / new_n))) / new_n + (b * batch_var) / new_n
            n = new_n
    std = torch.sqrt(M2)
    return tuple(mean.tolist()), tuple(std.tolist())


# Build datasets and dataloaders

def make_dataloaders(batch_size: int = CFG.batch_size) -> Tuple[DataLoader, DataLoader, Tuple[float,...], Tuple[float,...]]:
    mean, std = compute_cifar10_mean_std()
    # Q1.1: use transforms.Normalize with computed mean/std
    train_tf = T.Compose([
        T.RandomCrop(32, padding=4),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    test_tf = T.Compose([
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    train_set = torchvision.datasets.CIFAR10(CFG.data_root, train=True, download=False, transform=train_tf)
    test_set  = torchvision.datasets.CIFAR10(CFG.data_root, train=False, download=True, transform=test_tf)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=CFG.num_workers, pin_memory=True)
    test_loader  = DataLoader(test_set,  batch_size=256, shuffle=False, num_workers=CFG.num_workers, pin_memory=True)
    return train_loader, test_loader, mean, std

# Q1.2 + Q1.3 (basic inspection helper)

def describe_dataset(train_loader: DataLoader, test_loader: DataLoader):
    n_train = len(train_loader.dataset)
    n_test = len(test_loader.dataset)
    print(f"Train size: {n_train}, Test size: {n_test}")
    print("Classes:", CIFAR10_CLASSES)

# =====================
# 2) Visualization (Q3)
# =====================

def show_grid(loader: DataLoader, nrow: int = 8, ncol: int = 8, unnormalize: Optional[Tuple[Tuple[float,...], Tuple[float,...]]] = None, savepath: str = "outputs/visual_grid.png"):
    x, y = next(iter(loader))
    x = x[: nrow*ncol]
    y = y[: nrow*ncol]
    if unnormalize is not None:
        mean, std = [torch.tensor(v).view(1, -1, 1, 1) for v in unnormalize]
        x = x * std + mean
    grid = torchvision.utils.make_grid(x, nrow=ncol)
    npimg = grid.permute(1, 2, 0).detach().cpu().numpy()
    plt.figure(figsize=(7,7))
    plt.imshow(npimg)
    plt.axis('off')
    plt.title("CIFAR-10 sample grid")
    plt.savefig(savepath, bbox_inches='tight')
    plt.close()
    print(f"Saved visualization to {savepath}")

# ======================================
# 3) Models (Q4–Q7): CNN (+BN toggle) and MLP
# ======================================

class SimpleMLP(nn.Module):
    # For Q7 (initialize and move to device)
    def __init__(self, hidden: int = 512, num_classes: int = 10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(3*32*32, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, num_classes)
        )
    def forward(self, x):
        return self.net(x)

class ConvBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, use_bn: bool = True):
        super().__init__()
        layers = [nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True)]
        if use_bn:
            layers.insert(1, nn.BatchNorm2d(out_ch))
        layers += [nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True)]
        if use_bn:
            layers.insert(-1, nn.BatchNorm2d(out_ch))
        self.block = nn.Sequential(*layers)
    def forward(self, x):
        return self.block(x)

class SimpleCNN(nn.Module):
    """Q4/Q5/Q6: customizable CNN with optional BatchNorm.
    Args:
      channels: list like [64, 128, 256]
      use_bn: add BatchNorm after conv layers (Q5)
    """
    def __init__(self, channels: List[int] = [64,128,256], use_bn: bool = True, num_classes: int = 10):
        super().__init__()
        c = [3] + channels
        blocks = []
        for i in range(len(channels)):
            blocks += [ConvBlock(c[i], c[i+1], use_bn), nn.MaxPool2d(2)]
        self.features = nn.Sequential(*blocks)
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(channels[-1], num_classes)
        )
    def forward(self, x):
        x = self.features(x)
        return self.head(x)

# Hardware info (Q3.4) + Q7: initialize and move models to device

def print_hardware_info():
    print("Device:", DEVICE)
    if DEVICE.type == 'cuda':
        print("GPU:", torch.cuda.get_device_name(0))
    elif DEVICE.type == 'mps':
        print("Apple Silicon GPU via MPS backend")
    else:
        print("Using CPU")

# ================================
# 4) Training Utilities (Q8–Q9)
# ================================

def make_optimizer(params, name: str = "sgd", lr: float = 0.1, momentum: float = 0.9, weight_decay: float = 5e-4):
    name = name.lower()
    if name == "sgd":
        return SGD(params, lr=lr, weight_decay=weight_decay)
    elif name in ("sgdm", "sgd_momentum", "sgd+momentum"):
        return SGD(params, lr=lr, momentum=momentum, weight_decay=weight_decay)
    else:
        raise ValueError(f"Unknown optimizer: {name}")


def make_scheduler(optimizer, name: Optional[str] = None, epochs: int = 30):
    if name is None:
        return None
    name = name.lower()
    if name in ("cosine", "cosineannealinglr"):
        return CosineAnnealingLR(optimizer, T_max=epochs)
    elif name in ("step", "steplr"):
        return StepLR(optimizer, step_size=max(1, epochs//3), gamma=0.1)
    else:
        raise ValueError(f"Unknown scheduler: {name}")


def accuracy(logits, y):
    return (logits.argmax(dim=1) == y).float().mean().item()


@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, criterion: nn.Module):
    model.eval()
    tot_loss, tot_acc, n = 0.0, 0.0, 0
    for x, y in loader:
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        logits = model(x)
        loss = criterion(logits, y)
        b = y.size(0)
        tot_loss += loss.item() * b
        tot_acc  += (logits.argmax(1) == y).float().sum().item()
        n += b
    return tot_loss/n, tot_acc/n


def train(model: nn.Module, train_loader: DataLoader, test_loader: DataLoader,
          epochs: int = 10, lr: float = 0.1, optimizer_name: str = "sgdm",
          scheduler_name: Optional[str] = "cosine"):
    """Q4.3/4.4 + logging for Q5.1. Returns history dict."""
    model = model.to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optim = make_optimizer(model.parameters(), optimizer_name, lr=lr)
    sched = make_scheduler(optim, scheduler_name, epochs)

    history = {"train_loss": [], "train_acc": [], "test_loss": [], "test_acc": []}
    for ep in range(1, epochs+1):
        model.train()
        ep_loss, ep_acc, n = 0.0, 0.0, 0
        for x, y in train_loader:
            x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
            logits = model(x)
            loss = criterion(logits, y)
            optim.zero_grad(set_to_none=True)
            loss.backward()
            optim.step()
            with torch.no_grad():
                b = y.size(0)
                ep_loss += loss.item() * b
                ep_acc  += (logits.argmax(1) == y).float().sum().item()
                n += b
        if sched is not None:
            sched.step()
        tr_loss, tr_acc = ep_loss/n, ep_acc/n
        te_loss, te_acc = evaluate(model, test_loader, criterion)
        history["train_loss"].append(tr_loss)
        history["train_acc"].append(tr_acc)
        history["test_loss"].append(te_loss)
        history["test_acc"].append(te_acc)
        print(f"Epoch {ep:02d}/{epochs} | train loss {tr_loss:.3f} acc {tr_acc*100:.1f}% | test loss {te_loss:.3f} acc {te_acc*100:.1f}%")
    return history

# ==========================
# 5) Plotting & Comparison
# ==========================

def plot_curves(histories: Dict[str, Dict[str, List[float]]], title: str, save_prefix: str):
    # Q5.1: plot loss and accuracy per epoch
    # One plot per metric to keep them readable
    e = len(next(iter(histories.values()))["train_loss"]) if histories else 0
    # Loss
    plt.figure()
    for name, h in histories.items():
        plt.plot(range(1, e+1), h["train_loss"], label=f"{name} - train")
        plt.plot(range(1, e+1), h["test_loss"], label=f"{name} - test")
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title(title + " — Loss"); plt.legend();
    lp = f"outputs/{save_prefix}_loss.png"; plt.savefig(lp, bbox_inches='tight'); plt.close(); print("Saved:", lp)
    # Accuracy
    plt.figure()
    for name, h in histories.items():
        plt.plot(range(1, e+1), [a*100 for a in h["train_acc"]], label=f"{name} - train")
        plt.plot(range(1, e+1), [a*100 for a in h["test_acc"]], label=f"{name} - test")
    plt.xlabel("Epoch"); plt.ylabel("Accuracy (%)"); plt.title(title + " — Accuracy"); plt.legend();
    ap = f"outputs/{save_prefix}_acc.png"; plt.savefig(ap, bbox_inches='tight'); plt.close(); print("Saved:", ap)

# =====================================
# 6) Experiments for Questions 10–12
# =====================================

def run_all(epochs: int = 10, batch_size: int = 128, lr: float = 0.1):
    # Make data
    train_loader, test_loader, mean, std = make_dataloaders(batch_size)
    describe_dataset(train_loader, test_loader)  # (Q2)
    show_grid(train_loader, unnormalize=(mean, std))  # (Q3)
    print_hardware_info()  # (Q3.4)

    # Q7: Initialize MLP and move to device
    mlp = SimpleMLP().to(DEVICE)

    # Q4/5/6: CNNs with and without BatchNorm (and customizable channels)
    cnn_bn    = SimpleCNN(channels=[64,128,256], use_bn=True).to(DEVICE)
    cnn_nobn  = SimpleCNN(channels=[64,128,256], use_bn=False).to(DEVICE)

    # ---- Train a lightweight set of epochs so the assignment can run quickly ----
    # You can increase epochs for better accuracy if you have GPU time.
    histories = {}

    # (Q10) MLP vs CNN
    print("\n[Q10] Training MLP vs CNN (BN)")
    histories['MLP']    = train(mlp,    train_loader, test_loader, epochs=epochs, lr=lr, optimizer_name='sgdm', scheduler_name='cosine')
    histories['CNN+BN'] = train(cnn_bn, train_loader, test_loader, epochs=epochs, lr=lr, optimizer_name='sgdm', scheduler_name='cosine')
    plot_curves({k: histories[k] for k in ['MLP','CNN+BN']}, title='MLP vs CNN', save_prefix='q10_mlp_vs_cnn')

    # (Q11) CNN with vs without BatchNorm
    print("\n[Q11] Training CNN with vs without BatchNorm")
    # re-init models so they start fresh
    cnn_bn    = SimpleCNN(channels=[64,128,256], use_bn=True).to(DEVICE)
    cnn_nobn  = SimpleCNN(channels=[64,128,256], use_bn=False).to(DEVICE)
    histories['CNN+BN_2']   = train(cnn_bn,   train_loader, test_loader, epochs=epochs, lr=lr, optimizer_name='sgdm', scheduler_name='cosine')
    histories['CNN_noBN_2'] = train(cnn_nobn, train_loader, test_loader, epochs=epochs, lr=lr, optimizer_name='sgdm', scheduler_name='cosine')
    plot_curves({'CNN+BN': histories['CNN+BN_2'], 'CNN no BN': histories['CNN_noBN_2']}, title='CNN: BN vs no BN', save_prefix='q11_bn_vs_nobn')

    # (Q12) Optimizers comparison on the same CNN architecture
    print("\n[Q12] Optimizer comparison (SGD vs SGD+Momentum)")
    cnn_opt_a = SimpleCNN(channels=[64,128,256], use_bn=True).to(DEVICE)
    cnn_opt_b = SimpleCNN(channels=[64,128,256], use_bn=True).to(DEVICE)
    histories['CNN_SGD']  = train(cnn_opt_a, train_loader, test_loader, epochs=epochs, lr=lr, optimizer_name='sgd',  scheduler_name='cosine')
    histories['CNN_SGDM'] = train(cnn_opt_b, train_loader, test_loader, epochs=epochs, lr=lr, optimizer_name='sgdm', scheduler_name='cosine')
    plot_curves({'CNN SGD': histories['CNN_SGD'], 'CNN SGD+Momentum': histories['CNN_SGDM']}, title='Optimizers on CNN', save_prefix='q12_optims')

    print("\nDone. Figures saved under ./outputs. Use these when answering Q13 (overfitting signs, factors, etc.).")


if __name__ == "__main__":
    # Example fast run for CPU users; raise epochs to 20–30 on GPU for nicer curves
    run_all(epochs=5, batch_size=CFG.batch_size, lr=0.1)


Train size: 50000, Test size: 10000
Classes: ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


: 

In [None]:
"""
ENGR 413 — HW 3 Template (PyTorch + CIFAR-10)
Fall 2025 • Starter code covering Q1–Q13
Run cell-by-cell in a Jupyter Notebook or as a .py script.
You may trim epochs to fit your hardware. Plots and logs are saved to ./outputs.
"""

# ===============
# 0) Imports & Setup
# ===============
import os, math, time, random
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR

import torchvision
import torchvision.transforms as T
import matplotlib
matplotlib.use("Agg")  # safer in some notebook/Headless envs to prevent crashes
import matplotlib.pyplot as plt
import platform

DEVICE = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
os.makedirs("outputs", exist_ok=True)

# Stability tweaks to avoid kernel crashes on some setups (e.g., macOS/MPS, Windows)
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")  # lets ops fall back to CPU if unsupported
torch.set_num_threads(max(1, (os.cpu_count() or 2)//2))  # keep CPU threads modest to reduce contention
IS_MAC = platform.system() == "Darwin"

# =====================
# 1) Data Preparation (Q1, Q2)
# =====================
CIFAR10_CLASSES = ('airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck')

@dataclass
class DataCfg:
    batch_size: int = 128  # (Q8) you may change in experiments
    num_workers: int = 0   # SAFER default; we'll raise if CUDA is available
    pin_memory: bool = False
    data_root: str = "./data"

CFG = DataCfg()
# Bump workers/pin_memory only when CUDA is available (avoids macOS/Windows crashes)
if DEVICE.type == "cuda":
    CFG.num_workers = min(4, (os.cpu_count() or 2))
    CFG.pin_memory = True
else:
    CFG.num_workers = 0
    CFG.pin_memory = False

# Q1.1: Load CIFAR10, compute global mean/std, and normalize

def compute_cifar10_mean_std(root: str = CFG.data_root) -> Tuple[Tuple[float,...], Tuple[float,...]]:
    tmp = torchvision.datasets.CIFAR10(root, train=True, download=True, transform=T.ToTensor())
    loader = DataLoader(tmp, batch_size=512, shuffle=False, num_workers=CFG.num_workers)
    n = 0vc vfbcxzZ
    mean = torch.zeros(3)
    M2 = torch.zeros(3)  # for online variance (Welford)
    for x, _ in loader:
        b = x.shape[0]
        x = x.view(b, 3, -1)
        batch_mean = x.mean(dim=(0, 2))
        batch_var = x.var(dim=(0, 2), unbiased=False)
        if n == 0:
            mean = batch_mean
            M2 = batch_var
            n = b
        else:
            new_n = n + b
            delta = batch_mean - mean
            mean = mean + delta * (b / new_n)
            M2 = (n * (M2 + (delta**2) * (b / new_n))) / new_n + (b * batch_var) / new_n
            n = new_n
    std = torch.sqrt(M2)
    return tuple(mean.tolist()), tuple(std.tolist())


# Build datasets and dataloaders

def make_dataloaders(batch_size: int = CFG.batch_size) -> Tuple[DataLoader, DataLoader, Tuple[float,...], Tuple[float,...]]:
    mean, std = compute_cifar10_mean_std()
    # Q1.1: use transforms.Normalize with computed mean/std
    train_tf = T.Compose([
        T.RandomCrop(32, padding=4),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    test_tf = T.Compose([
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    train_set = torchvision.datasets.CIFAR10(CFG.data_root, train=True, download=False, transform=train_tf)
    test_set  = torchvision.datasets.CIFAR10(CFG.data_root, train=False, download=True, transform=test_tf)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=CFG.num_workers, pin_memory=CFG.pin_memory)
    test_loader  = DataLoader(test_set,  batch_size=256, shuffle=False, num_workers=CFG.num_workers, pin_memory=CFG.pin_memory)
    return train_loader, test_loader, mean, std

# Q1.2 + Q1.3 (basic inspection helper)

def describe_dataset(train_loader: DataLoader, test_loader: DataLoader):
    n_train = len(train_loader.dataset)
    n_test = len(test_loader.dataset)
    print(f"Train size: {n_train}, Test size: {n_test}")
    print("Classes:", CIFAR10_CLASSES)

# =====================
# 2) Visualization (Q3)
# =====================

def show_grid(loader: DataLoader, nrow: int = 8, ncol: int = 8, unnormalize: Optional[Tuple[Tuple[float,...], Tuple[float,...]]] = None, savepath: str = "outputs/visual_grid.png"):
    x, y = next(iter(loader))
    x = x[: nrow*ncol]
    y = y[: nrow*ncol]
    if unnormalize is not None:
        mean, std = [torch.tensor(v).view(1, -1, 1, 1) for v in unnormalize]
        x = x * std + mean
    grid = torchvision.utils.make_grid(x, nrow=ncol)
    npimg = grid.permute(1, 2, 0).detach().cpu().numpy()
    plt.figure(figsize=(7,7))
    plt.imshow(npimg)
    plt.axis('off')
    plt.title("CIFAR-10 sample grid")
    plt.savefig(savepath, bbox_inches='tight')
    plt.close()
    print(f"Saved visualization to {savepath}")

# ======================================
# 3) Models (Q4–Q7): CNN (+BN toggle) and MLP
# ======================================

class SimpleMLP(nn.Module):
    # For Q7 (initialize and move to device)
    def __init__(self, hidden: int = 512, num_classes: int = 10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(3*32*32, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, num_classes)
        )
    def forward(self, x):
        return self.net(x)

class ConvBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, use_bn: bool = True):
        super().__init__()
        layers = [nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True)]
        if use_bn:
            layers.insert(1, nn.BatchNorm2d(out_ch))
        layers += [nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU(inplace=True)]
        if use_bn:
            layers.insert(-1, nn.BatchNorm2d(out_ch))
        self.block = nn.Sequential(*layers)
    def forward(self, x):
        return self.block(x)

class SimpleCNN(nn.Module):
    """Q4/Q5/Q6: customizable CNN with optional BatchNorm.
    Args:
      channels: list like [64, 128, 256]
      use_bn: add BatchNorm after conv layers (Q5)
    """
    def __init__(self, channels: List[int] = [64,128,256], use_bn: bool = True, num_classes: int = 10):
        super().__init__()
        c = [3] + channels
        blocks = []
        for i in range(len(channels)):
            blocks += [ConvBlock(c[i], c[i+1], use_bn), nn.MaxPool2d(2)]
        self.features = nn.Sequential(*blocks)
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(channels[-1], num_classes)
        )
    def forward(self, x):
        x = self.features(x)
        return self.head(x)

# Hardware info (Q3.4) + Q7: initialize and move models to device

def print_hardware_info():
    print("Device:", DEVICE)
    if DEVICE.type == 'cuda':
        print("GPU:", torch.cuda.get_device_name(0))
    elif DEVICE.type == 'mps':
        print("Apple Silicon GPU via MPS backend")
    else:
        print("Using CPU")

# ================================
# 4) Training Utilities (Q8–Q9)
# ================================

def make_optimizer(params, name: str = "sgd", lr: float = 0.1, momentum: float = 0.9, weight_decay: float = 5e-4):
    name = name.lower()
    if name == "sgd":
        return SGD(params, lr=lr, weight_decay=weight_decay)
    elif name in ("sgdm", "sgd_momentum", "sgd+momentum"):
        return SGD(params, lr=lr, momentum=momentum, weight_decay=weight_decay)
    else:
        raise ValueError(f"Unknown optimizer: {name}")


def make_scheduler(optimizer, name: Optional[str] = None, epochs: int = 30):
    if name is None:
        return None
    name = name.lower()
    if name in ("cosine", "cosineannealinglr"):
        return CosineAnnealingLR(optimizer, T_max=epochs)
    elif name in ("step", "steplr"):
        return StepLR(optimizer, step_size=max(1, epochs//3), gamma=0.1)
    else:
        raise ValueError(f"Unknown scheduler: {name}")


def accuracy(logits, y):
    return (logits.argmax(dim=1) == y).float().mean().item()


@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, criterion: nn.Module):
    model.eval()
    tot_loss, tot_acc, n = 0.0, 0.0, 0
    for x, y in loader:
        x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
        logits = model(x)
        loss = criterion(logits, y)
        b = y.size(0)
        tot_loss += loss.item() * b
        tot_acc  += (logits.argmax(1) == y).float().sum().item()
        n += b
    return tot_loss/n, tot_acc/n


def train(model: nn.Module, train_loader: DataLoader, test_loader: DataLoader,
          epochs: int = 10, lr: float = 0.1, optimizer_name: str = "sgdm",
          scheduler_name: Optional[str] = "cosine"):
    """Q4.3/4.4 + logging for Q5.1. Returns history dict."""
    model = model.to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optim = make_optimizer(model.parameters(), optimizer_name, lr=lr)
    sched = make_scheduler(optim, scheduler_name, epochs)

    history = {"train_loss": [], "train_acc": [], "test_loss": [], "test_acc": []}
    for ep in range(1, epochs+1):
        model.train()
        ep_loss, ep_acc, n = 0.0, 0.0, 0
        for x, y in train_loader:
            x, y = x.to(DEVICE, non_blocking=True), y.to(DEVICE, non_blocking=True)
            logits = model(x)
            loss = criterion(logits, y)
            optim.zero_grad(set_to_none=True)
            loss.backward()
            optim.step()
            with torch.no_grad():
                b = y.size(0)
                ep_loss += loss.item() * b
                ep_acc  += (logits.argmax(1) == y).float().sum().item()
                n += b
        if sched is not None:
            sched.step()
        tr_loss, tr_acc = ep_loss/n, ep_acc/n
        te_loss, te_acc = evaluate(model, test_loader, criterion)
        history["train_loss"].append(tr_loss)
        history["train_acc"].append(tr_acc)
        history["test_loss"].append(te_loss)
        history["test_acc"].append(te_acc)
        print(f"Epoch {ep:02d}/{epochs} | train loss {tr_loss:.3f} acc {tr_acc*100:.1f}% | test loss {te_loss:.3f} acc {te_acc*100:.1f}%")
    return history

# ==========================
# 5) Plotting & Comparison
# ==========================

def plot_curves(histories: Dict[str, Dict[str, List[float]]], title: str, save_prefix: str):
    # Q5.1: plot loss and accuracy per epoch
    # One plot per metric to keep them readable
    e = len(next(iter(histories.values()))["train_loss"]) if histories else 0
    # Loss
    plt.figure()
    for name, h in histories.items():
        plt.plot(range(1, e+1), h["train_loss"], label=f"{name} - train")
        plt.plot(range(1, e+1), h["test_loss"], label=f"{name} - test")
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title(title + " — Loss"); plt.legend();
    lp = f"outputs/{save_prefix}_loss.png"; plt.savefig(lp, bbox_inches='tight'); plt.close(); print("Saved:", lp)
    # Accuracy
    plt.figure()
    for name, h in histories.items():
        plt.plot(range(1, e+1), [a*100 for a in h["train_acc"]], label=f"{name} - train")
        plt.plot(range(1, e+1), [a*100 for a in h["test_acc"]], label=f"{name} - test")
    plt.xlabel("Epoch"); plt.ylabel("Accuracy (%)"); plt.title(title + " — Accuracy"); plt.legend();
    ap = f"outputs/{save_prefix}_acc.png"; plt.savefig(ap, bbox_inches='tight'); plt.close(); print("Saved:", ap)

# =====================================
# 6) Experiments for Questions 10–12
# =====================================

def run_all(epochs: int = 10, batch_size: int = 128, lr: float = 0.1):
    # Make data
    train_loader, test_loader, mean, std = make_dataloaders(batch_size)
    describe_dataset(train_loader, test_loader)  # (Q2)
    show_grid(train_loader, unnormalize=(mean, std))  # (Q3)
    print_hardware_info()  # (Q3.4)

    # Q7: Initialize MLP and move to device
    mlp = SimpleMLP().to(DEVICE)

    # Q4/5/6: CNNs with and without BatchNorm (and customizable channels)
    cnn_bn    = SimpleCNN(channels=[64,128,256], use_bn=True).to(DEVICE)
    cnn_nobn  = SimpleCNN(channels=[64,128,256], use_bn=False).to(DEVICE)

    # ---- Train a lightweight set of epochs so the assignment can run quickly ----
    # You can increase epochs for better accuracy if you have GPU time.
    histories = {}

    # (Q10) MLP vs CNN
    print("\n[Q10] Training MLP vs CNN (BN)")
    histories['MLP']    = train(mlp,    train_loader, test_loader, epochs=epochs, lr=lr, optimizer_name='sgdm', scheduler_name='cosine')
    histories['CNN+BN'] = train(cnn_bn, train_loader, test_loader, epochs=epochs, lr=lr, optimizer_name='sgdm', scheduler_name='cosine')
    plot_curves({k: histories[k] for k in ['MLP','CNN+BN']}, title='MLP vs CNN', save_prefix='q10_mlp_vs_cnn')

    # (Q11) CNN with vs without BatchNorm
    print("\n[Q11] Training CNN with vs without BatchNorm")
    # re-init models so they start fresh
    cnn_bn    = SimpleCNN(channels=[64,128,256], use_bn=True).to(DEVICE)
    cnn_nobn  = SimpleCNN(channels=[64,128,256], use_bn=False).to(DEVICE)
    histories['CNN+BN_2']   = train(cnn_bn,   train_loader, test_loader, epochs=epochs, lr=lr, optimizer_name='sgdm', scheduler_name='cosine')
    histories['CNN_noBN_2'] = train(cnn_nobn, train_loader, test_loader, epochs=epochs, lr=lr, optimizer_name='sgdm', scheduler_name='cosine')
    plot_curves({'CNN+BN': histories['CNN+BN_2'], 'CNN no BN': histories['CNN_noBN_2']}, title='CNN: BN vs no BN', save_prefix='q11_bn_vs_nobn')

    # (Q12) Optimizers comparison on the same CNN architecture
    print("\n[Q12] Optimizer comparison (SGD vs SGD+Momentum)")
    cnn_opt_a = SimpleCNN(channels=[64,128,256], use_bn=True).to(DEVICE)
    cnn_opt_b = SimpleCNN(channels=[64,128,256], use_bn=True).to(DEVICE)
    histories['CNN_SGD']  = train(cnn_opt_a, train_loader, test_loader, epochs=epochs, lr=lr, optimizer_name='sgd',  scheduler_name='cosine')
    histories['CNN_SGDM'] = train(cnn_opt_b, train_loader, test_loader, epochs=epochs, lr=lr, optimizer_name='sgdm', scheduler_name='cosine')
    plot_curves({'CNN SGD': histories['CNN_SGD'], 'CNN SGD+Momentum': histories['CNN_SGDM']}, title='Optimizers on CNN', save_prefix='q12_optims')

    print("\nDone. Figures saved under ./outputs. Use these when answering Q13 (overfitting signs, factors, etc.).")


if __name__ == "__main__":
    # Example fast run for CPU users; raise epochs to 20–30 on GPU for nicer curves
    run_all(epochs=5, batch_size=CFG.batch_size, lr=0.1)


Train size: 50000, Test size: 10000
Classes: ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


: 