# Applying a CNN to classify handwritten digits in the MNIST Digits dataset

Quickstart guide:

1. Create a new virtual environment:
    ```bash
    python3 -m venv venv
    ```
2. Activate the virtual environment (macOS):
    ```bash
    source venv/bin/activate
    ```
3. Install required packages:
    ```bash
    pip install -r requirements.txt
    ```
4. Create a [wandb](https://wandb.ai) account (if none exists) and log into wandb. For detailed documentation, visit the [official wandb documentation](https://docs.wandb.ai).

Ready to start!

In [None]:
import numpy as np
import struct
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch
import wandb
from torch.utils.data import TensorDataset, DataLoader
import plotly.express as px
from sklearn.metrics import confusion_matrix
from sklearn.dummy import DummyClassifier
from sklearn.metrics import accuracy_score
import pandas as pd

## Data preprocessing

Data preprocessing involves the following steps:
1. Load data from idx-ubyte files into numpy arrays.
2. Check the balance of the data.
3. Visualize five digits from the training dataset.
4. Normalize and standardize all pixel values.
5. Concatenate and re-split all data into 60% train, 20% validation and 20% test sets.

In [None]:
# Read an idx-ubyte file and return data as a numpy array
def read_idx(filename):
    with open(filename, 'rb') as f:
        zero, data_type, dims = struct.unpack('>HBB', f.read(4))
        shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(shape)


In [None]:
# File paths
training_images_filepath = 'mnist_digits_data/train-images.idx3-ubyte'
training_labels_filepath = 'mnist_digits_data/train-labels.idx1-ubyte'
test_images_filepath = 'mnist_digits_data/t10k-images.idx3-ubyte'
test_labels_filepath = 'mnist_digits_data/t10k-labels.idx1-ubyte'

# Loading and reading the data
train_images = read_idx(training_images_filepath)
train_labels = read_idx(training_labels_filepath)
test_images = read_idx(test_images_filepath)
test_labels = read_idx(test_labels_filepath)

In [None]:
"""
Create and print talble showing label distribution
Optional: Uncomment code below to show different data distribution for training, test and all labels
"""
all_labels = train_labels
# all_labels = test_labels
# all_labels = np.concatenate((train_labels, test_labels))

labels, counts = np.unique(all_labels, return_counts=True)
percentages = counts / counts.sum() * 100
df = pd.DataFrame({'Digit Label': labels, 'Frequency': counts, 'Percentage': percentages})
print(df.to_string(index=False, float_format='%.2f'))

In [None]:
"""
Visualize first 5 images of the training set
"""
fig, axes = plt.subplots(1, 5, figsize=(10, 2))  
for i, ax in enumerate(axes):
    ax.imshow(train_images[i], cmap='gray')  
    ax.set_title(f'Label: {train_labels[i]}')  
    ax.axis('off')  

plt.tight_layout()  
plt.show()

In [None]:
"""
Calculate mean and standard deviation of the normalized pixel values
"""
images = np.concatenate((train_images, test_images), axis=0) 

# Calculate mean and standard deviation on normalized pixel values
images = images.astype(np.float32) / 255
mean = np.mean(images)
std = np.std(images)

In [None]:
"""
Perform normalization and standarization of all pixel values
"""
# Normalize pixel values to [0, 1] 
train_images = torch.tensor(train_images).unsqueeze(1).float() / 255
test_images = torch.tensor(test_images).unsqueeze(1).float() / 255

# Without standardization, the sweep takes longer and the accuracy is lower
train_images = (train_images - mean) / std
test_images = (test_images - mean) / std


In [None]:
"""
Create TensorDataset 
Concatenate and re-split all data into 60% train, 20% validation and 20% test sets
"""
train_dataset = TensorDataset(train_images, torch.tensor(train_labels).long())
test_dataset = TensorDataset(test_images, torch.tensor(test_labels).long())
full_dataset = torch.utils.data.ConcatDataset([train_dataset, test_dataset])

# Split the dataset into train, validation, and test sets
total_size = len(full_dataset)
train_size = int(0.6 * total_size)
val_size = int(0.2 * total_size)
test_size = total_size - train_size - val_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size, test_size])

## Initialize model architecture and define hyperparameters

This section implements the following steps:
1. Initialize the `sweep_config` for the Weights & Biases sweep with selected hyperparameter combinations.
2. Construct the CNN architecture.
3. Initialize the network and global variables.

In [None]:
"""
Create TensorDataset 
Concatenate and re-split all data into 60% train, 20% validation and 20% test sets
"""
# Set fixed number of epochs due to computational limitations
n_epochs = 5

# Configuration of the wandb sweep
sweep_config = {
    'method': 'grid',  # Search method: Grid search
    'metric': {
      'name': 'accuracy',  
      'goal': 'maximize'
    },
    'parameters': {
        'learning_rate': {
            'values': [0.0001, 0.001, 0.01] # Possible values for learning rate
        },
        'batch_size': {
            'values': [16, 32, 64, 128] # Possible values for batch size
        }
    },
}
sweep_id = wandb.sweep(sweep_config, project="mnist-digits-cnn") # Initialize the sweep

In [None]:
"""
Define CNN architecture in Pytorch
"""
class simple_cnn(nn.Module):
    def __init__(self):
        super(simple_cnn, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5) # Default padding = 0, stride = 1
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d() # Default dropout value = 50% (0.5)
        # Fully connected layers
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2)) 
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training) # Dropout is applied to output of first fully connected layer
        x = self.fc2(x)
        return F.log_softmax(x) # Softmax to create probability distribution, required for nll_loss (not required for CrossEntropyLoss)

In [None]:
"""
Initialise network according to previously defined architecture and global variables
"""
random_seed = 1
torch.backends.cudnn.enabled = False # No GPU available
torch.manual_seed(random_seed)

global best_accuracy
global best_batch_size
global best_learning_rate

best_accuracy = 0.0  # Initialize with minimum starting value
best_batch_size = None  
best_learning_rate = None  

## Define and execute train and validation cycle 

This section implements the following steps:
1. Define train and validation functions
2. Create train_and_validate wrapper for wandb run
3. Execute wandb run with wandb agent to tune hyperparameters

In [None]:
"""
Define the training function with logging of the training loss in wandb
"""
def train(network, optimizer, loader):
    network.train()
    train_loss = 0
    for batch_idx, (data, target) in enumerate(loader):
        optimizer.zero_grad()
        output = network(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
        # Accumulate the loss
        train_loss += loss.item()
        
        # Log the training loss every 100 batches
        if (batch_idx + 1) % 100 == 0:
            avg_train_loss = train_loss / 100
            wandb.log({"train_loss": avg_train_loss})
            train_loss = 0  # Reset the train loss for the next 100 batches

In [None]:
"""
Define the validation function with logging of the validation loss and accuracy in wandb
"""
def validation(epoch, network, loader):
    network.eval()
    validation_loss = 0
    correct = 0
    total_batches = 0
    total_samples = 0  

    # Training the model
    with torch.no_grad():
        for data, target in loader:
            output = network(data)
            validation_loss += F.nll_loss(output, target, reduction='mean').item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
            total_batches += 1  
            total_samples += target.size(0)

    # Calculate average validation loss across all batches
    validation_loss /= total_batches

    # Calculate accuracy percentage
    accuracy = 100. * correct / total_samples
    accuracy = float(accuracy)  

    # Print for monitoring
    print("Validation Loss: ", validation_loss)
    print("Accuracy: ", accuracy)

    # Logging in wandb
    wandb.log({"validation_loss": validation_loss, "epoch": epoch, "accuracy": accuracy})

    # Save the current best model to best_model.pth for final evaluation on test set
    global best_accuracy, best_batch_size, best_learning_rate
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(network.state_dict(), 'best_model.pth')  # Save the best model based on accuracy
        best_batch_size = wandb.config.batch_size  # Save current batch size from the sweep config
        best_learning_rate = wandb.config.learning_rate  # Save current learning rate from the sweep config
        print(f"Saved new best model with accuracy {accuracy}%")


In [None]:
"""
Define train_and_validate wrapper for wandb run 
"""
def train_and_validate():
    # Initialize a new W&B run
    with wandb.init() as run:
        config = run.config

        # Setup model and optimizer
        network = simple_cnn()  
        optimizer = torch.optim.Adam(network.parameters(), lr=config.learning_rate)

        # Prepare DataLoaders
        train_loader = DataLoader(train_dataset, batch_size=int(config.batch_size), shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=1000, shuffle=False)  
        validation(0, network, val_loader)

        # Training loop (5 Epochs)
        for epoch in range(1, n_epochs + 1):
            print('Epoch:', epoch)
            train(network, optimizer, train_loader)
            validation(epoch, network, val_loader)

        wandb.finish()

In [None]:
"""
Execute run to perform the sweep
"""
wandb.agent(sweep_id, train_and_validate)

## Final testing and evaluation

This section implements the following steps:
1. Evaluate best model on test set
2. Compare best model agains dummy classifier
3. Plot confusion matrix

In [None]:
"""
Create DataLoader for test data, evaluate the best model and print the final accuracy
"""
all_preds = []
all_targets = []

# Create DataLoader for test data
test_loader = DataLoader(
    test_dataset,
    batch_size=1000,
    shuffle=False 
)

def test():
    network = simple_cnn()
    network.load_state_dict(torch.load('best_model.pth'))
    network.eval()  # Set the model to evaluation mode
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            output = network(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # Sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # Get the index of the max log-probability (argmax) 
            correct += pred.eq(target.view_as(pred)).sum().item()
            all_preds.extend(pred.squeeze().tolist())
            all_targets.extend(target.tolist())

    test_loss /= len(test_loader.dataset)
    accuracy = 100.0 * correct / len(test_loader.dataset)
    print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)')

test()

wandb.finish()


In [None]:
"""
Implement DummyClassifier from sklearn and evaluate on test set 
"""
# Implement DummyClassifier
dummy_clf = DummyClassifier(strategy="most_frequent")
dummy_clf.fit([[0]] * len(all_targets), all_targets)  # dummy_clf.fit to get the most frequent class
dummy_preds = dummy_clf.predict([[0]] * len(all_targets))  # Predict the most frequent class

# Calculate and print the accuracy of the dummy classifier
dummy_accuracy = accuracy_score(all_targets, dummy_preds)
print(f'Dummy classifier accuracy: {dummy_accuracy * 100:.2f}%')

In [None]:
"""
Optional: print the previously saved optimal hyperparameters
"""
print(best_batch_size)
print(best_learning_rate)

In [None]:
"""
Plot and show confusion matrix
"""
cm = confusion_matrix(all_targets, all_preds)
fig = px.imshow(cm,
                labels=dict(x="Predicted Labels", y="True Labels", color="Count"),
                x=list(range(10)),
                y=list(range(10)),
                color_continuous_scale='Blues', 
                text_auto=True)  

fig.update_layout(
    title_text='Confusion Matrix',  
    title_x=0.5,  
    width=620,  
    height=600,  
    xaxis=dict(
        tickmode='linear', tick0=0, dtick=1,
        showgrid=True,  
        tickfont=dict(size=15),  
        title_font=dict(size=16)  
    ),
    yaxis=dict(
        tickmode='linear', tick0=0, dtick=1,
        showgrid=True,  
        tickfont=dict(size=15),  
        title_font=dict(size=16)  
    ),
    title_font=dict(size=24),  
    title_pad=dict(b=10),  
    margin=dict(l=10, r=10, t=40, b=10)  
)

fig.update_xaxes(
    showline=True, linewidth=2, linecolor='black'
)
fig.update_yaxes(
    showline=True, linewidth=2, linecolor='black'
)

fig.show()