# MNIST Handwritten Digits Classification (PyTorch CNN)

This notebook trains a Convolutional Neural Network (CNN) on the MNIST dataset to classify handwritten digits. It includes model architecture, training loop, evaluation, and visualization of sample predictions. The target is **>95% test accuracy**.

Run the cells in order. If you have a GPU available, PyTorch will use it automatically.

## 1. Imports and device setup

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import random
print('PyTorch version:', torch.__version__)

# Reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)


## 2. Prepare MNIST dataset and DataLoaders

In [None]:

# Transform: normalize to mean=0.1307, std=0.3081 (standard for MNIST)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Download datasets (will store in ./data)
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST('./data', train=False, download=True, transform=transform)

batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

print('Train samples:', len(train_dataset))
print('Test samples:', len(test_dataset))


## 3. Define the CNN model

In [None]:

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # conv layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)  # 28x28 -> 28x28
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # 28x28 -> 28x28
        self.pool = nn.MaxPool2d(2, 2)                           # 28x28 -> 14x14 after pool
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout(0.5)
        # fully connected
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)               # shape: [batch, 64, 14, 14] after second conv+pool
        x = self.dropout1(x)
        x = x.view(x.size(0), -1)      # flatten
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

model = SimpleCNN().to(device)
print(model)


## 4. Training and evaluation helper functions

In [None]:

def train_one_epoch(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * data.size(0)
        preds = outputs.argmax(dim=1)
        correct += preds.eq(target).sum().item()
        total += data.size(0)
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    print(f"Train Epoch: {epoch} \tLoss: {epoch_loss:.4f} \tAccuracy: {epoch_acc:.4f}")
    return epoch_loss, epoch_acc

def evaluate(model, device, data_loader, criterion=None):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            if criterion is not None:
                running_loss += criterion(outputs, target).item() * data.size(0)
            preds = outputs.argmax(dim=1)
            all_preds.append(preds.cpu().numpy())
            all_targets.append(target.cpu().numpy())
            correct += preds.eq(target).sum().item()
            total += data.size(0)
    loss = (running_loss / total) if (criterion is not None) else None
    acc = correct / total
    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)
    return loss, acc, all_preds, all_targets


## 5. Train the model

We'll train for up to 10 epochs with early stopping by monitoring test accuracy. Use Adam optimizer and CrossEntropyLoss.

In [None]:

# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10

best_acc = 0.0
patience = 3
epochs_no_improve = 0

train_history = {'loss': [], 'acc': []}
test_history = {'loss': [], 'acc': []}

for epoch in range(1, num_epochs+1):
    train_loss, train_acc = train_one_epoch(model, device, train_loader, optimizer, criterion, epoch)
    test_loss, test_acc, _, _ = evaluate(model, device, test_loader, criterion)
    print(f"Test  Epoch: {epoch} \tLoss: {test_loss:.4f} \tAccuracy: {test_acc:.4f}\n")
    
    train_history['loss'].append(train_loss)
    train_history['acc'].append(train_acc)
    test_history['loss'].append(test_loss)
    test_history['acc'].append(test_acc)
    
    # Early stopping / save best
    if test_acc > best_acc:
        best_acc = test_acc
        epochs_no_improve = 0
        # Save best model state dict
        torch.save(model.state_dict(), 'best_mnist_cnn.pth')
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping after {epoch} epochs. Best test acc: {best_acc:.4f}")
            break

print('Training finished. Best test accuracy:', best_acc)


## 6. Load best model and evaluate on test set

In [None]:

# Load best model (if saved)
best_model = SimpleCNN().to(device)
try:
    best_model.load_state_dict(torch.load('best_mnist_cnn.pth', map_location=device))
    print('Loaded best model weights.')
except Exception as e:
    print('Could not load saved model, using current model. Error:', e)

test_loss, test_acc, test_preds, test_targets = evaluate(best_model, device, test_loader, criterion)
print(f"Final Test Accuracy: {test_acc:.4f}")


## 7. Visualize predictions on 5 random test images

In [None]:

# Pick 5 random samples from test set and show predictions
import matplotlib.pyplot as plt
idxs = np.random.choice(len(test_dataset), size=5, replace=False)
fig, axes = plt.subplots(1, 5, figsize=(12,3))
for ax, idx in zip(axes, idxs):
    img, label = test_dataset[idx]
    # img is normalized tensor; unnormalize for display
    img_disp = img.squeeze().numpy() * 0.3081 + 0.1307
    ax.imshow(img_disp, cmap='gray')
    ax.axis('off')
    # model prediction
    best_model.eval()
    with torch.no_grad():
        input_tensor = img.unsqueeze(0).to(device)
        out = best_model(input_tensor)
        pred = out.argmax(dim=1).item()
    ax.set_title(f"true: {label}\npred: {pred}")
plt.show()


## 8. Plot training curves

In [None]:

plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot(train_history['loss'], label='train loss')
plt.plot(test_history['loss'], label='test loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss curves')

plt.subplot(1,2,2)
plt.plot(train_history['acc'], label='train acc')
plt.plot(test_history['acc'], label='test acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy curves')

plt.tight_layout()
plt.show()


----

**Notes:**
- Achieving >95% test accuracy on MNIST is straightforward with this model; if accuracy is low, increase `num_epochs`, tweak model size, or add learning rate scheduling.
- The notebook saves the best model to `best_mnist_cnn.pth`.
