# MNIST CNN with PyTorch (torchvision.datasets.MNIST)

**Author:** Auto-generated by M365 Copilot  
**Date:** 2025-12-18

This notebook implements a **basic Convolutional Neural Network (CNN)** on the **MNIST** handwritten digits dataset using **PyTorch** and **torchvision**. It includes:

- Clean, well-commented training & evaluation code
- Plots for loss and accuracy
- Confusion matrix visualization
- Notes and practical guidance on **hyperparameter tuning**
- (Optional) a small hyperparameter sweep snippet

> Reference: torchvision MNIST dataset docs — https://docs.pytorch.org/vision/main/generated/torchvision.datasets.MNIST.html


## 1. Setup

Run the following cell **once** if PyTorch/torchvision are not installed in your environment. If you run in Google Colab, uncomment the pip installs.


In [None]:
# If needed, uncomment to install
# !pip install --upgrade pip
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

# For CUDA-enabled machines (pick the correct version for your GPU)
# See: https://pytorch.org/get-started/locally/
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118


## 2. Imports & Configuration

In [None]:
import os
import random
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# Ensure plots render inline
%matplotlib inline

print('PyTorch version:', torch.__version__)
print('CUDA available:', torch.cuda.is_available())

# Select device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', DEVICE)

# Reproducibility: set seeds and deterministic flags
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Deterministic behavior (slower but reproducible)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


## 3. Data Loading & Visualization

In [None]:
# Transforms: convert to tensor & normalize to mean=0.1307, std=0.3081 (standard MNIST normalization)
# These values are commonly used and help with stable training
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# (Optional) light augmentation for robustness; keep disabled for baseline
# aug_transform = transforms.Compose([
#     transforms.RandomRotation(degrees=10),
#     transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
#     transforms.ToTensor(),
#     transforms.Normalize((0.1307,), (0.3081,))
# ])

# Download/Load MNIST
DATA_DIR = './data'
train_dataset = datasets.MNIST(root=DATA_DIR, train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST(root=DATA_DIR, train=False, download=True, transform=transform)

# Create validation split from training data (e.g., 55k train / 5k val)
VAL_SIZE = 5000
train_size = len(train_dataset) - VAL_SIZE
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, VAL_SIZE],
                                                          generator=torch.Generator().manual_seed(SEED))

# DataLoaders
BATCH_SIZE = 128
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

# Visualize a grid of samples
batch_imgs, batch_labels = next(iter(train_loader))
imgs = batch_imgs[:16].squeeze(1).numpy()  # (N, 28, 28)
labels = batch_labels[:16].numpy()

fig, axes = plt.subplots(4, 4, figsize=(6, 6))
for i, ax in enumerate(axes.flat):
    ax.imshow(imgs[i], cmap='gray')
    ax.set_title(f'label: {labels[i]}')
    ax.axis('off')
plt.tight_layout()
plt.show()


## 4. Model: A Simple CNN

In [None]:
class SimpleCNN(nn.Module):
    # A compact CNN suitable for MNIST (28x28 grayscale).
    # Architecture:
    # - Conv(1->32, 3x3) + ReLU
    # - Conv(32->64, 3x3) + ReLU + MaxPool(2x2)
    # - Dropout(0.25)
    # - Flatten
    # - FC(64*14*14 -> 128) + ReLU
    # - Dropout(0.5)
    # - FC(128 -> 10)
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool  = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        # After two convs + one pool: input 28x28 -> pool halves to 14x14; padding keeps size after convs
        # conv->conv->pool gives (64, 14, 14)
        self.fc1 = nn.Linear(64 * 14 * 14, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

# Instantiate model
model = SimpleCNN().to(DEVICE)

# Count parameters
param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Model has {param_count:,} trainable parameters')


## 5. Training Utilities

In [None]:
def accuracy_from_logits(logits, targets):
    preds = logits.argmax(dim=1)
    return (preds == targets).float().mean().item()


def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    running_acc = 0.0
    for images, labels in loader:
        images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
        running_acc += (outputs.argmax(dim=1) == labels).sum().item()
    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = running_acc / len(loader.dataset)
    return epoch_loss, epoch_acc


def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_acc = 0.0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            running_acc += (outputs.argmax(dim=1) == labels).sum().item()
    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = running_acc / len(loader.dataset)
    return epoch_loss, epoch_acc


## 6. Train Loop (Baseline)

In [None]:
# Hyperparameters (baseline)
LR = 1e-3          # learning rate
WEIGHT_DECAY = 1e-4
EPOCHS = 10

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
# Optional scheduler to reduce LR over time
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

train_losses, train_accs = [], []
val_losses, val_accs = [], []

best_val_acc = 0.0
best_state = None

for epoch in range(1, EPOCHS+1):
    tr_loss, tr_acc = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE)
    va_loss, va_acc = evaluate(model, val_loader, criterion, DEVICE)
    scheduler.step()

    train_losses.append(tr_loss)
    train_accs.append(tr_acc)
    val_losses.append(va_loss)
    val_accs.append(va_acc)

    print(f"Epoch {epoch:02d}/{EPOCHS} | "
          f"Train Loss: {tr_loss:.4f}, Train Acc: {tr_acc*100:.2f}% | "
          f"Val Loss: {va_loss:.4f}, Val Acc: {va_acc*100:.2f}%")

    # Save best model
    if va_acc > best_val_acc:
        best_val_acc = va_acc
        best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}

# Load best model state before final evaluation
if best_state is not None:
    model.load_state_dict(best_state)


## 7. Training Curves: Loss & Accuracy

In [None]:
epochs = range(1, EPOCHS+1)
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss over Epochs')
plt.legend()

plt.subplot(1,2,2)
plt.plot(epochs, [a*100 for a in train_accs], label='Train Acc')
plt.plot(epochs, [a*100 for a in val_accs], label='Val Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Accuracy over Epochs')
plt.legend()
plt.tight_layout()
plt.show()


## 8. Test Evaluation & Confusion Matrix

In [None]:
test_loss, test_acc = evaluate(model, test_loader, criterion, DEVICE)
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc*100:.2f}%')

# Confusion matrix
num_classes = 10
cm = np.zeros((num_classes, num_classes), dtype=int)
model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images)
        preds = outputs.argmax(dim=1)
        for t, p in zip(labels.view(-1), preds.view(-1)):
            cm[t.long().item(), p.long().item()] += 1

fig, ax = plt.subplots(figsize=(6,6))
im = ax.imshow(cm, cmap='Blues')
ax.figure.colorbar(im, ax=ax)
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_xticks(range(num_classes))
ax.set_yticks(range(num_classes))
ax.set_title('MNIST Confusion Matrix (Test)')

# Annotate counts
for i in range(num_classes):
    for j in range(num_classes):
        ax.text(j, i, cm[i, j], ha='center', va='center', color='black', fontsize=8)
plt.tight_layout()
plt.show()


## 9. Save & Load Model

In [None]:
SAVE_PATH = 'mnist_cnn.pth'
torch.save(model.state_dict(), SAVE_PATH)
print(f'Saved best model weights to {SAVE_PATH}')

# Example: re-load
loaded_model = SimpleCNN().to(DEVICE)
loaded_model.load_state_dict(torch.load(SAVE_PATH, map_location=DEVICE))
loaded_model.eval()


## 10. Inference Demo

In [None]:
# Show predictions on a small batch
images, labels = next(iter(test_loader))
images, labels = images.to(DEVICE), labels.to(DEVICE)
outputs = loaded_model(images)
preds = outputs.argmax(dim=1)

# Plot first 16
imgs = images[:16].cpu().squeeze(1).numpy()
true = labels[:16].cpu().numpy()
est  = preds[:16].cpu().numpy()

fig, axes = plt.subplots(4, 4, figsize=(6, 6))
for i, ax in enumerate(axes.flat):
    ax.imshow(imgs[i], cmap='gray')
    ax.set_title(f'T:{true[i]} P:{est[i]}', fontsize=10)
    ax.axis('off')
plt.tight_layout()
plt.show()


## 11. Hyperparameter Tuning Notes

**Goal:** Achieve strong validation/test accuracy with efficient training.

### Key Hyperparameters & Practical Ranges
- **Learning Rate (LR):** Start with `1e-3` for Adam. Try `{3e-4, 1e-3, 3e-3}`. For SGD, begin around `0.05`–`0.2` with momentum.
- **Batch Size:** MNIST trains well with `64`–`256`. Larger batches speed up on GPU but may need LR adjustment (linear scaling rule).
- **Weight Decay (L2):** Helps generalization. Try `{0, 1e-5, 1e-4, 5e-4}`.
- **Dropout:** `0.25`–`0.5` is typical; lower for small models if underfitting.
- **Optimizer:** `Adam` is a reliable baseline; `SGD(momentum=0.9)` can match/beat Adam with tuned LR and decay.
- **LR Schedule:** StepLR (decay by 0.5–0.1 every 5 epochs) or CosineAnnealing; improves final accuracy.
- **Data Augmentation:** Light rotations (±10°) and translations (≤10%) can add robustness; avoid heavy transforms that distort digits.

### Tuning Strategy (Time-Effective)
1. **Baseline run** (as provided) to establish reference accuracy.
2. **LR sweep**: Keep other params fixed; test 3–5 LRs for 3–5 epochs, pick best.
3. **Regularization sweep**: Try a couple of weight_decay and dropout combinations.
4. **Optimizer comparison**: Adam vs SGD+momentum with tuned LR.
5. **Schedule on**: Enable StepLR/Cosine once core hyperparams are set.

### Diagnostics
- If **train acc >> val acc**: decrease LR, increase weight_decay or dropout, add light augmentation.
- If **train acc ≈ val acc but both low**: increase capacity (more filters), train longer, or try a slightly higher LR.
- If **loss plateaus early**: try LR warmup or restart with higher LR.

### Recommended Defaults (MNIST)
- Adam, `lr=1e-3`, `weight_decay=1e-4`
- Batch size `128`
- 10–15 epochs with StepLR (step_size=5, gamma=0.5)
- Optional light augmentation for extra robustness

> The MNIST dataset and transforms are provided by `torchvision.datasets.MNIST` and `torchvision.transforms`. See the official docs: https://docs.pytorch.org/vision/main/generated/torchvision.datasets.MNIST.html


## 12. (Optional) Tiny Hyperparameter Sweep

In [None]:
# WARNING: This will train multiple short runs and can take extra time.
# It runs 2 epochs per config to quickly compare validation accuracy.

from copy import deepcopy

def run_short_experiment(lr, weight_decay, epochs=2):
    model = SimpleCNN().to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    best_va = 0.0
    for ep in range(epochs):
        train_one_epoch(model, train_loader, criterion, optimizer, DEVICE)
        va_loss, va_acc = evaluate(model, val_loader, criterion, DEVICE)
        best_va = max(best_va, va_acc)
    return best_va

candidates = [
    {"lr": 3e-4, "wd": 0.0},
    {"lr": 1e-3, "wd": 1e-5},
    {"lr": 1e-3, "wd": 1e-4},
    {"lr": 3e-3, "wd": 1e-4},
]

results = []
for cfg in candidates:
    acc = run_short_experiment(cfg["lr"], cfg["wd"], epochs=2)
    results.append({"lr": cfg["lr"], "wd": cfg["wd"], "val_acc": acc})
    print(f"Config lr={cfg['lr']:.1e}, wd={cfg['wd']:.1e} -> best val acc: {acc*100:.2f}%")

# Pick best
best = max(results, key=lambda r: r["val_acc"]) if results else None
print('Best (short run):', best)


## 13. Appendix: Notes & Tips

- **Reproducibility:** We set seeds and deterministic flags; exact reproducibility may still vary across hardware and CUDA versions.
- **Runtime:** On CPU, expect ~1–5 minutes for 10 epochs; on GPU, seconds. Times vary by hardware.
- **Model Capacity:** You can add another conv+pool block or increase filters (e.g., 64→128) to push accuracy closer to ~99%.
- **Evaluation Protocol:** Use a held-out validation split (as we did) to tune hyperparameters, then report final **test** performance with best weights.
- **Saving Artifacts:** Models (`.pth`) and plots can be saved to disk for reporting.
