## MNIST Revisited

Let's now revisit our MNIST. Knowing that the data contains 2-dimensional images of handwritten digits, we should be able to apply what we've learned about convolutions. Thus, in this section, we will create a convolutional neural network (CNN or convnet) for this data set.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Subset
from torchvision import datasets, transforms
import numpy as np

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
# Load MNIST data
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

This time we are going to use a **validation set** to monitor our training progress. We can also use this validation set for *hyperparameter tuning*. Remember, using the validation set allows us to keep the *test set* to gauge how well our final model should do in the real world; that is, the final model only sees the test data once.

In [None]:
# Use the first 10,000 samples of our training data as our validation set
val_indices = list(range(10000))
train_indices = list(range(10000, len(train_dataset)))

# Create subset datasets
val_dataset = Subset(train_dataset, val_indices)
partial_train_dataset = Subset(train_dataset, train_indices)

Note that in PyTorch, the data will have the shape `(batch_size, channels, height, width)` when passed through the network. PyTorch uses **channels first** convention by default. The MNIST images are automatically converted to this format by the `ToTensor()` transform.

In [None]:
# Create data loaders
batch_size = 256

train_loader = DataLoader(partial_train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Let's check the shape of a batch
for images, labels in train_loader:
    print(f"Batch shape: {images.shape}")
    print(f"Labels shape: {labels.shape}")
    break

We will now define our convolutional neural network using PyTorch's `nn.Module` class. `nn.Conv2d` creates the convolutional layers we have been discussing in the lectures. `nn.Flatten` is used to create a 1 dimensional vector so we can feed the output of our convolutional layers to the fully-connected layers. And `nn.Linear` is PyTorch's equivalent of Keras's Dense layer.

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        # Note: in_channels=1 for grayscale images
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, 
                               kernel_size=3, stride=1, padding='same')
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, 
                               kernel_size=3, stride=2, padding="valid")  # padding='valid' is default
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, 
                               kernel_size=3, stride=1, padding='same')
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, 
                               kernel_size=3, stride=2, padding="valid")
        
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 6 * 6, 128)  # 64 channels, 6x6 spatial dimensions
        self.fc2 = nn.Linear(128, 10)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)  # No softmax here - it's included in CrossEntropyLoss
        return x

model = CNN().to(device)

We are still tackling the same type of problem (multi-class classification) so the same loss and metrics will work for us here. PyTorch's `CrossEntropyLoss` combines softmax activation and negative log-likelihood loss. The optimizer `RMSprop` is the same as we used before and can be taken as the default method (or recipe) to try out for updating the model parameters.

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.RMSprop(model.parameters())  #, lr=0.001) if you want to specify learning rate, default is 0.01

We now train our model on the training data. For each batch of the training data, we compute the forward pass, calculate the loss, perform backpropagation, and update the parameters. After each *epoch* (going through all samples in our training data), we evaluate the model on the validation set. Note that the *validation data* is not being used to train the model.

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [None]:
# Training loop
num_epochs = 5
history = {
    'loss': [],
    'accuracy': [],
    'val_loss': [],
    'val_accuracy': []
}

for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    # Store history
    history['loss'].append(train_loss)
    history['accuracy'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_accuracy'].append(val_acc)
    
    print(f'Epoch {epoch+1}/{num_epochs}:')
    print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
    print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

The values for the training loss and accuracy, as well as the validation loss and accuracy, are stored in the `history` dictionary. 

We will now use this information to visualize the progress our network makes on the loss and accuracy as the number of epochs increases.

In [None]:
import matplotlib.pyplot as plt

loss_values = history['loss']
val_loss_values = history['val_loss']

epochs = range(1, len(loss_values) + 1)

# Code to plot the results
plt.plot(epochs, loss_values, 'b', label="Training Loss")
plt.plot(epochs, val_loss_values, 'r', label="Validation Loss")
plt.title("Training and Validation Loss")
plt.xlabel("Epochs")
plt.xticks(epochs)
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
# As above, but this time we want to visualize the training and validation accuracy
acc_values = history['accuracy']
val_acc_values = history['val_accuracy']

plt.plot(epochs, acc_values, 'b', label="Training Accuracy")
plt.plot(epochs, val_acc_values, 'r', label="Validation Accuracy")
plt.title("Training and Validation Accuracy")
plt.xlabel("Epochs")
plt.xticks(epochs)
plt.ylabel("Accuracy")
plt.legend()
plt.show()

### Exercise 1

For the model we have created, calculate the number of parameters by hand for each layer and compare to the output of the model summary below.

In [None]:
# Print model summary
print(model)

In [None]:
#!pip install torchinfo
from torchinfo import summary

summary(model, input_size=(1, 1, 28, 28))


### Exercise 2

Recreate a similar model, except this time reduce the height and width by using pooling layers (e.g., `nn.MaxPool2d`) instead of convolution layers with stride 2.