
# CNN on MNIST — Teaching Notebook (Statistics PhD)

This notebook demonstrates:
- Loading MNIST and visualizing samples
- Training a **baseline MLP** vs a **CNN**
- Tracking loss/accuracy
- Confusion matrix visualization
- Simple **data augmentation** experiment
- Inspecting learned **convolutional filters**

> **Note:** The first time you run this, `torchvision` will download MNIST automatically.


In [None]:

# --- Setup & configuration ---
import os, sys, math, time, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, utils

import matplotlib.pyplot as plt

# Reproducibility
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# Hyperparameters (feel free to tweak)
BATCH_SIZE = 128
LR = 1e-3
EPOCHS_BASELINE = 3   # quick baseline
EPOCHS_CNN = 5        # slightly longer for CNN
VAL_SPLIT = 0.1       # 10% validation
AUGMENT = False       # turn on to try augmentation

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device


In [None]:

# --- Data: MNIST ---
transform_list = [transforms.ToTensor()]
if AUGMENT:
    transform_list = [
        transforms.RandomRotation(10),
        transforms.RandomAffine(0, translate=(0.05, 0.05)),
        transforms.ToTensor()
    ]

transform = transforms.Compose(transform_list)

data_root = "./data"
train_full = datasets.MNIST(root=data_root, train=True, transform=transform, download=True)
test_set  = datasets.MNIST(root=data_root, train=False, transform=transforms.ToTensor(), download=True)

# Train/Val split
val_size = int(len(train_full) * VAL_SPLIT)
train_size = len(train_full) - val_size
train_set, val_set = random_split(train_full, [train_size, val_size], generator=torch.Generator().manual_seed(SEED))

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

len(train_set), len(val_set), len(test_set)


In [None]:

# --- Visualize a mini-batch ---
imgs, labels = next(iter(train_loader))
grid = utils.make_grid(imgs[:36], nrow=6, padding=2)
plt.figure(figsize=(6,6))
plt.imshow(grid.permute(1,2,0).squeeze())
plt.axis('off')
plt.title("MNIST samples")
plt.show()
labels[:36]


## Baseline: MLP (fully connected)

In [None]:

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
    def forward(self, x):
        return self.net(x)

mlp = MLP().to(device)
sum(p.numel() for p in mlp.parameters())


## Convolutional Neural Network (CNN)

In [None]:

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2,2)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(64*7*7, 128)
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        x = F.relu(self.conv1(x))   # [B, 32, 28, 28]
        x = self.pool(x)            # [B, 32, 14, 14]
        x = F.relu(self.conv2(x))   # [B, 64, 14, 14]
        x = self.pool(x)            # [B, 64, 7, 7]
        x = self.dropout(x)
        x = torch.flatten(x, 1)     # [B, 64*7*7]
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

cnn = SimpleCNN().to(device)
sum(p.numel() for p in cnn.parameters())


## Training Utilities

In [None]:

def accuracy(logits, y):
    preds = logits.argmax(dim=1)
    return (preds == y).float().mean().item()

@torch.no_grad()
def evaluate(model, loader, loss_fn):
    model.eval()
    total_loss, total_acc, total_n = 0.0, 0.0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = loss_fn(logits, y)
        bs = y.size(0)
        total_loss += loss.item() * bs
        total_acc  += (logits.argmax(1) == y).float().sum().item()
        total_n    += bs
    return total_loss/total_n, total_acc/total_n

def train_model(model, train_loader, val_loader, epochs=5, lr=1e-3):
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
    for ep in range(1, epochs+1):
        model.train()
        running_loss, running_correct, running_total = 0.0, 0, 0
        t0 = time.time()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad(set_to_none=True)
            logits = model(x)
            loss = loss_fn(logits, y)
            loss.backward()
            opt.step()
            running_loss += loss.item() * y.size(0)
            running_correct += (logits.argmax(1) == y).float().sum().item()
            running_total += y.size(0)
        tr_loss = running_loss / running_total
        tr_acc  = running_correct / running_total
        va_loss, va_acc = evaluate(model, val_loader, loss_fn)
        history["train_loss"].append(tr_loss)
        history["train_acc"].append(tr_acc)
        history["val_loss"].append(va_loss)
        history["val_acc"].append(va_acc)
        t1 = time.time()
        print(f"Epoch {ep:02d} | "
              f"train loss {tr_loss:.4f} acc {tr_acc:.4f} | "
              f"val loss {va_loss:.4f} acc {va_acc:.4f} | "
              f"{t1-t0:.1f}s")
    return history


## Train Baseline MLP

In [None]:

mlp = MLP().to(device)
hist_mlp = train_model(mlp, train_loader, val_loader, epochs=EPOCHS_BASELINE, lr=LR)

plt.figure()
plt.plot(hist_mlp["train_loss"], label="train loss")
plt.plot(hist_mlp["val_loss"], label="val loss")
plt.xlabel("epoch"); plt.ylabel("loss"); plt.legend(); plt.title("MLP Loss"); plt.show()

plt.figure()
plt.plot(hist_mlp["train_acc"], label="train acc")
plt.plot(hist_mlp["val_acc"], label="val acc")
plt.xlabel("epoch"); plt.ylabel("accuracy"); plt.legend(); plt.title("MLP Accuracy"); plt.show()


## Train CNN

In [None]:

cnn = SimpleCNN().to(device)
hist_cnn = train_model(cnn, train_loader, val_loader, epochs=EPOCHS_CNN, lr=LR)

plt.figure()
plt.plot(hist_cnn["train_loss"], label="train loss")
plt.plot(hist_cnn["val_loss"], label="val loss")
plt.xlabel("epoch"); plt.ylabel("loss"); plt.legend(); plt.title("CNN Loss"); plt.show()

plt.figure()
plt.plot(hist_cnn["train_acc"], label="train acc")
plt.plot(hist_cnn["val_acc"], label="val acc")
plt.xlabel("epoch"); plt.ylabel("accuracy"); plt.legend(); plt.title("CNN Accuracy"); plt.show()


## Evaluate on Test Set & Confusion Matrix

In [None]:

from sklearn.metrics import confusion_matrix
import itertools

@torch.no_grad()
def predict_all(model, loader):
    model.eval()
    y_true, y_pred = [], []
    for x, y in loader:
        x = x.to(device)
        logits = model(x)
        preds = logits.argmax(1).cpu().numpy().tolist()
        y_pred.extend(preds)
        y_true.extend(y.numpy().tolist())
    return np.array(y_true), np.array(y_pred)

y_true, y_pred = predict_all(cnn, test_loader)
cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(6,6))
plt.imshow(cm, interpolation='nearest')
plt.title("Confusion Matrix (CNN)")
plt.colorbar()
tick_marks = np.arange(10)
plt.xticks(tick_marks, tick_marks)
plt.yticks(tick_marks, tick_marks)
thresh = cm.max() / 2
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j, i, format(cm[i, j], 'd'),
             horizontalalignment="center",
             color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
plt.show()

(test_acc := (y_true == y_pred).mean())


## Inspect First-Layer Filters

In [None]:

with torch.no_grad():
    w = cnn.conv1.weight.cpu().clone()  # [32, 1, 3, 3]
w = (w - w.min()) / (w.max() - w.min() + 1e-8)

grid = utils.make_grid(w, nrow=8, padding=1, normalize=False)
plt.figure(figsize=(4,4))
plt.imshow(grid.permute(1,2,0).squeeze())
plt.axis('off')
plt.title("Conv1 filters (normalized)")
plt.show()


## Save & Load Model Checkpoints

In [None]:

ckpt_path = "cnn_mnist_ckpt.pt"
torch.save({"model_state": cnn.state_dict()}, ckpt_path)
print(f"Saved to {ckpt_path}")

# Example: reload
cnn2 = SimpleCNN().to(device)
state = torch.load(ckpt_path, map_location=device)
cnn2.load_state_dict(state["model_state"])



## Try Data Augmentation
To try augmentation, set `AUGMENT = True` in the setup cell and re-run the notebook.  
Then compare validation/test accuracy curves.


## Parameter Count Comparison

In [None]:

def count_params(model):
    return sum(p.numel() for p in model.parameters())

print("MLP params:", count_params(MLP()))
print("CNN params:", count_params(SimpleCNN()))


## Playground: Quick Re-run Helper

In [None]:

# Change LR / epochs here and rerun this cell and the training cells
LR = 5e-4
EPOCHS_CNN = 3
print("LR:", LR, "EPOCHS_CNN:", EPOCHS_CNN)
