# 🤖 Intro to Machine Learning Week 3

By Ellie Zhou, August 2025
Build a CNN to classify handwritten digits (0-9) using PyTorch!

**What we'll do:**
- Load MNIST dataset
- Build CNN architecture  
- Train and evaluate
- Visualize results

In [None]:
# Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import warnings
warnings.filterwarnings('ignore')

# Setup device and random seeds
torch.manual_seed(42)  # For reproducible results
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## Load Data

In [None]:
# Data preprocessing - prepare images for CNN
transform = transforms.Compose([
    transforms.ToTensor(),                      # Convert PIL images to tensors (0-1 range)
    transforms.Normalize((0.1307,), (0.3081,)) # Normalize using MNIST mean & std
])

# Load MNIST dataset - 28x28 grayscale images of handwritten digits
train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)   # 60,000 training images
test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)   # 10,000 test images

# Data loaders - load data in batches for efficient training
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)   # Shuffle training data each epoch
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)    # No need to shuffle test data

print(f"Training samples: {len(train_dataset):,}")
print(f"Test samples: {len(test_dataset):,}")

# Show sample images to understand our data
plt.figure(figsize=(12, 3))
for i in range(8):
    image, label = train_dataset[i]              # Get image tensor and label
    plt.subplot(1, 8, i+1)                       # Create subplot
    plt.imshow(image.squeeze(), cmap='gray')     # Display as grayscale (.squeeze() removes channel dim)
    plt.title(f'{label}')                        # Show true digit label
    plt.axis('off')                              # Hide axes for cleaner look
plt.show()

## Build CNN

**Architecture:** Conv → Pool → Conv → Pool → Conv → FC

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # Convolutional layers - extract features from images
        self.conv1 = nn.Conv2d(1, 32, 3)     # 28x28x1 → 26x26x32
        self.conv2 = nn.Conv2d(32, 64, 3)    # 13x13x32 → 11x11x64
        self.conv3 = nn.Conv2d(64, 64, 3)    # 5x5x64 → 3x3x64
        self.pool = nn.MaxPool2d(2, 2)       # Reduces spatial dimensions by half

        # Fully connected layers - for final classification
        self.fc1 = nn.Linear(64 * 3 * 3, 128)  # Flatten conv output to 128 neurons
        self.fc2 = nn.Linear(128, 10)          # 128 → 10 classes (digits 0-9)
        self.dropout = nn.Dropout(0.5)         # Prevent overfitting

    def forward(self, x):
        # Conv blocks: convolution → ReLU → pooling
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))

        # Flatten and fully connected layers
        x = x.view(-1, 64 * 3 * 3)  # Flatten 2D feature maps to 1D
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Create model and count parameters
model = CNN().to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(model)

## Train Model

In [None]:
# Training setup
criterion = nn.CrossEntropyLoss()  # Loss function for multi-class classification
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam optimizer

# Training function - one pass through training data
def train_epoch():
    model.train()  # Set to training mode (enables dropout)
    running_loss, correct, total = 0, 0, 0

    for data, target in train_loader:
        data, target = data.to(device), target.to(device)

        # Forward pass
        optimizer.zero_grad()  # Clear previous gradients
        output = model(data)
        loss = criterion(output, target)

        # Backward pass
        loss.backward()  # Calculate gradients
        optimizer.step()  # Update weights

        # Track statistics
        running_loss += loss.item()
        _, pred = torch.max(output, 1)
        total += target.size(0)
        correct += (pred == target).sum().item()

    return running_loss/len(train_loader), 100.*correct/total

# Validation function - test without updating weights
def validate():
    model.eval()  # Set to evaluation mode (disables dropout)
    test_loss, correct, total = 0, 0, 0

    with torch.no_grad():  # Don't calculate gradients
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            _, pred = torch.max(output, 1)
            total += target.size(0)
            correct += (pred == target).sum().item()

    return test_loss/len(test_loader), 100.*correct/total

# Train for 8 epochs
train_accs, val_accs = [], []

for epoch in range(8):
    train_loss, train_acc = train_epoch()
    val_loss, val_acc = validate()

    train_accs.append(train_acc)
    val_accs.append(val_acc)

    print(f'Epoch {epoch+1}: Train {train_acc:.1f}% | Val {val_acc:.1f}%')

print(f'\nFinal Test Accuracy: {val_accs[-1]:.2f}%')

## Results

In [None]:
# Plot training progress
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(train_accs, 'b-', label='Training')
plt.plot(val_accs, 'r-', label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.title('Training Progress')
plt.grid(True)

# Get all predictions for detailed analysis
model.eval()
y_pred, y_true = [], []

with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        _, pred = torch.max(output, 1)
        y_pred.extend(pred.cpu().numpy())
        y_true.extend(target.cpu().numpy())

# Plot confusion matrix - shows which digits get confused
plt.subplot(1, 2, 2)
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.show()

print('Classification Report:')
print(classification_report(y_true, y_pred))

## Sample Predictions

In [None]:
# Show sample predictions with confidence scores
model.eval()
plt.figure(figsize=(12, 8))

with torch.no_grad():
    data, target = next(iter(test_loader))  # Get one batch
    data, target = data.to(device), target.to(device)
    output = model(data)
    prob = F.softmax(output, dim=1)  # Convert to probabilities
    _, pred = torch.max(output, 1)

    # Display first 12 images with predictions
    for i in range(12):
        plt.subplot(3, 4, i+1)
        plt.imshow(data[i].cpu().squeeze(), cmap='gray')

        true_label = target[i].cpu().item()
        pred_label = pred[i].cpu().item()
        confidence = prob[i][pred_label].cpu().item()

        # Color code: green for correct, red for wrong
        color = 'green' if true_label == pred_label else 'red'
        plt.title(f'True: {true_label}, Pred: {pred_label}\n({confidence:.3f})', color=color)
        plt.axis('off')

plt.suptitle('Sample Predictions (Green=Correct, Red=Wrong)')
plt.tight_layout()
plt.show()

## Learned Features

In [None]:
# Visualize what the CNN learned - first layer filters
filters = model.conv1.weight.data.cpu().numpy()

plt.figure(figsize=(12, 3))
for i in range(16):  # Show first 16 of 32 filters
    plt.subplot(2, 8, i+1)
    plt.imshow(filters[i, 0], cmap='viridis')  # Each filter is 3x3
    plt.title(f'Filter {i+1}', fontsize=8)
    plt.axis('off')
plt.suptitle('Learned Filters (First Layer)')
plt.tight_layout()
plt.show()

# Save the trained model
torch.save(model.state_dict(), 'mnist_cnn.pth')
print('✅ Model saved as mnist_cnn.pth')

print(f'\n🎉 Training Complete!')
print(f'Final Accuracy: {val_accs[-1]:.2f}%')
print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')