
# MNIST Digit Classification with a CNN (PyTorch, CUDA/MPS Ready)

- 5×5 sample grid
- Small CNN (2 conv blocks → linear layers)
- Training loop with accuracy
- Optional: visualize first-layer kernels


In [None]:

import os, random, numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

SEED = 0
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

DEVICE = get_device()
print("Using device:", DEVICE)
if DEVICE.type == "cuda":
    print("CUDA:", torch.cuda.get_device_name(0))
elif DEVICE.type == "mps":
    print("Apple Metal (MPS) backend active")


In [None]:

transform = transforms.Compose([transforms.ToTensor()])
root = os.path.join('.', 'data')
train_ds = datasets.MNIST(root=root, train=True, download=True, transform=transform)
test_ds  = datasets.MNIST(root=root, train=False, download=True, transform=transform)

BATCH_SIZE = 256
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False)

def show_grid(ds, rows=5, cols=5, seed=SEED):
    rng = np.random.default_rng(seed)
    idxs = rng.choice(len(ds), size=rows*cols, replace=False)
    fig, axes = plt.subplots(rows, cols, figsize=(6,6), dpi=140)
    fig.suptitle("25 Random Training Digits (labels in titles)")
    for ax, idx in zip(axes.ravel(), idxs):
        img, y = ds[idx]
        ax.imshow(img.squeeze(0), cmap="gray", interpolation="nearest")
        ax.set_title(str(y)); ax.axis("off")
    plt.tight_layout(); plt.show()

show_grid(train_ds, 5, 5)


In [None]:

class SmallCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(64*5*5, 128)
        self.drop = nn.Dropout(0.2)
        self.fc2 = nn.Linear(128, num_classes)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.drop(x)
        return self.fc2(x)

model = SmallCNN().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)


In [None]:

def accuracy(loader, model):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            pred = model(xb).argmax(dim=1)
            correct += (pred == yb).sum().item(); total += yb.size(0)
    return correct/total

EPOCHS = 8
for epoch in range(1, EPOCHS+1):
    model.train()
    for xb, yb in train_loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        logits = model(xb)
        loss = criterion(logits, yb)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
    if DEVICE.type == "mps":
        torch.mps.synchronize()
    tr = accuracy(train_loader, model)
    te = accuracy(test_loader, model)
    print(f"Epoch {epoch:02d}/{EPOCHS} | train_acc={tr*100:.2f}% | test_acc={te*100:.2f}%")


In [None]:

import math
@torch.no_grad()
def show_conv_kernels(conv, title="Conv kernels", up=48):
    W = conv.weight.detach().cpu().numpy()
    OC, IC, kH, kW = W.shape
    if IC > 1:
        W = W[:,0,:,:]
    else:
        W = W[:,0,:,:]
    vmax = float(np.percentile(np.abs(W), 99)); vmax = max(vmax, 1e-6)
    cols = int(math.ceil(np.sqrt(OC))); rows = int(math.ceil(OC/cols))
    fig, axes = plt.subplots(rows, cols, figsize=(cols*1.2, rows*1.2), dpi=160)
    fig.suptitle(title)
    for ax, w in zip(axes.ravel(), W):
        wt = torch.from_numpy(w).float().unsqueeze(0).unsqueeze(0)
        w_up = F.interpolate(wt, size=(up, up), mode="bicubic", align_corners=False).squeeze().numpy()
        ax.imshow(w_up, cmap="gray", vmin=-vmax, vmax=vmax, interpolation="nearest")
        ax.axis("off")
    for ax in axes.ravel()[len(W):]:
        ax.axis("off")
    plt.tight_layout(); plt.show()

show_conv_kernels(model.conv1, "Conv1 kernels")
show_conv_kernels(model.conv2, "Conv2 kernels")


In [None]:

# Quick prediction demo
i = 207
x, y = test_ds[i]
with torch.no_grad():
    pred = model(x.unsqueeze(0).to(DEVICE)).argmax(dim=1).item()
plt.figure(figsize=(2.5,2.5), dpi=140)
plt.imshow(x.squeeze(0), cmap="gray")
plt.title(f"pred={pred}, true={y}")
plt.axis("off"); plt.show()
