# Best in Class?

This playbook is concerned with trying to build a MNist numerical classifier which can be used to identify images not seen in the training dataset as well as possible.

This means that we will make use of a lot of different techniques in the course.

1. We will first train an Auto-Encoder to extract key features from our dataset in a positionally independent format.
2. The encoder from this will be used to give us extracted key features from the input data in a positionally independent format.
3. We will then train a combined CNN/Transformer model.
4. We will evaluate model performance after training on unseen test data.
5. Now there is a stable model, fine-tune the encoder for this task of classification.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
# Hyperparameters
batch_size = 64
learning_rate = 0.001
num_epochs = 10

## Load the mnist dataset

This time we will use a modified version of both the mnist training and test dataset.

This means that we should get a model which is able to extract positionally invariant input features and then work out the class of the image based on these features.

In [None]:
# Device configuration
device = torch.device("mps")#"cuda" if torch.cuda.is_available() else "cpu")

# Load MNIST dataset without normalization
mnist_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
# Load entire dataset in one batch
data_loader = DataLoader(mnist_dataset, batch_size=len(mnist_dataset), shuffle=False)
# Get all images in a single batch
images, _ = next(iter(data_loader))  # Shape: (60000, 1, 28, 28)
# Calculate mean and standard deviation
mean = images.mean().item()
std = images.std().item()

# Transform: Convert images to tensors and normalize
transform = transforms.Compose([
    transforms.RandomRotation(degrees=10, expand=False),       # Random rotations up to 10deg
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # Random shifts (up to 10% of image size)
    transforms.RandomResizedCrop(size=28, scale=(0.9, 1.1)),   # Random scaling (90% to 110% of original size)
    transforms.ColorJitter(brightness=0.1, contrast=0.1),      # Adjust brightness & contrast slightly
    transforms.ToTensor(),  # Converts to [0, 1]
    transforms.Normalize((mean,), (std,))
])

# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Transform: Convert images to tensors and normalize
transform_test = transforms.Compose([
    transforms.RandomRotation(degrees= ## FINISH_ME ## , expand=False),         # Random rotations up to 20deg
    transforms.RandomAffine(degrees=0, translate= ## FINISH_ME ## ),  # Random shifts (up to 15% of image size)
    transforms.RandomResizedCrop(size=28, scale= ## FINISH_ME ## ),   # Random scaling (95% to 115% of original size)
    transforms.ColorJitter(brightness= ## FINISH_ME ##, contrast= ),      # Adjust brightness & contrast slightly by a factor of 0.2
    transforms.ToTensor(),  # Converts to [0, 1]
    transforms.Normalize((mean,), (std,))
])

test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

## First Build and train a simple AutoEncoder

This AutoEncoder will encode images into a latent-space and then decode them back into 'input-like' images. 

In [None]:
# Define the Autoencoder
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        # Encoder: Conv Layers + FC Layers merged into a single Sequential
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),  # Output: [batch, 32, 14, 14]
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # Output: [batch, 64, 7, 7]
            nn.ReLU(),
            nn.Flatten(),  # Flatten before FC layers
            nn.Linear( ## FINISH_ME ##
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32)  # Latent space
        )

        # Decoder: FC Layers + Transposed Conv Layers merged into a single Sequential
        self.decoder = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear( ## FINISH_ME ##
            nn.ReLU(),
            nn.Unflatten(1, (64, 7, 7)),  # Reshape before transposed conv
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()  # Output in range [0,1] for images
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

### We will train our auto-encoder for 5 epochs.

This model has to be 'good enough', but doesn't need to be extesively trained. The better it is the better it will behave later, but for now 5-epochs is good enough.

In [None]:
# Pre-train the Autoencoder
autoencoder = Autoencoder().to(device)
criterion_ae = nn.MSELoss()
optimizer_ae = optim.Adam(autoencoder.parameters(), lr=learning_rate)

for epoch in range(5):  # Pre-train for fewer epochs
    autoencoder.train()
    total_loss = 0

    # Per-batch
    for images, _ in train_loader:
        images = images.to(device)

        # Forward, Loss, Backwards, Step
        optimizer_ae.zero_grad()
        outputs = autoencoder(images)
        loss = ## FINISH_ME ##
        loss.backward()
        optimizer_ae.step()

        # Track loss, is our model improving?
        total_loss += loss.item()
    # Print stats per-epoch to show we're doing something
    print(f'Autoencoder Epoch [{epoch+1}/5], Loss: {total_loss/len(train_loader):.4f}')

## Build Our Positionally-Invarient Transformer based Classifier

In [None]:
# Freeze the encoder when training the Transformer
# This isn't structly needed, but makes the model more stable when training
for param in autoencoder.encoder.parameters():
    param.requires_grad = False

In [None]:
# Define the Transformer Classifier
class TransformerClassifier(nn.Module):
    def __init__(self, encoder, num_heads=4, num_classes=10):
        super(TransformerClassifier, self).__init__()
        
        # We know the output from the encoder is a 32-dim LS
        # Capture these for now
        self.encoder = encoder
        self.embedding_dim = 32
        self.frozen_encoder = True

        # Transformer Encoder Layer
        self.transformer_layer = nn.TransformerEncoderLayer(
            # The model takes the encoded input and effectively 'learns' relationships between them
            d_model=self.embedding_dim, nhead=num_heads, dim_feedforward=128, dropout=0.1
        )
        # This class wraps our individual transformer into a full 'model' with multiple layers
        self.transformer_encoder = nn.TransformerEncoder(self.transformer_layer, num_layers=5)

        # Classification Head
        # These pieces encode down from the Transformer back to an image class
        self.fc1 = nn.Linear(self.embedding_dim, 64)
        self.fc2 = nn.Linear( ## FINISH_ME ## )
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):

        # We want to use a pre-trained positionally-invarient feature encoder
        if self.frozen_encoder:
            with torch.no_grad():
                x = self.encoder(x)  # Use pre-trained encoder
        else:
            x = self.encoder(x)

        # Adding a sequence dimension (sequence length = 1)
        x = x.unsqueeze(1)
        x = self.transformer_encoder(x)
        x = x.squeeze(1)  # Remove sequence dimension

        # 'Project down' to a decsision
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = ## FINISH_ME ##
        return x

In [None]:
# Instantiate the Transformer-based classifier
model = TransformerClassifier(autoencoder.encoder).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    # Per-batch
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        # Forwards, Loss, Backwards, Step
        optimizer.zero_grad()
        outputs = model(images)
        loss = ## FINISH_ME ##
        loss.backward()
        optimizer.step()

        # Track Loss
        total_loss += loss.item()

        # Print some extra data as this can be slow to train
        if i % 100 == 0:
            print(f'Average loss for this Batch: {loss.item()/len(images):.4f}')

    print(f'Transformer Classifier Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}')

In [None]:
# Evaluation on the test set
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += ## FINISH_ME ##

accuracy = 100 * correct / total
print(f'Accuracy on the MNIST test set: {accuracy:.2f}%')

## Push the model to it's limits

Now we have a 'stable' model, lets let the model further guide 'fine-tuning' the encoder.

In [None]:
# Un-Freeze the encoder when training the Transformer
for param in model.encoder.parameters():
    param.requires_grad = True
model.frozen_encoder = False

In [None]:
# Instantiate the Transformer-based classifier
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(5):
    model.train()
    total_loss = 0

    # Per-batch
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        # Forwards, Loss, Backwards, Step
        optimizer.zero_grad()
        outputs = model(images)
        loss =  ## FINISH_ME ##
        loss.backward()
        optimizer.step()

        # Track Loss
        total_loss += loss.item()

        # Print some extra data as this can be slow to train
        if i % 100 == 0:
            print(f'Average loss for this Batch: {loss.item()/len(images):.4f}')

    print(f'Transformer Classifier Epoch [{epoch+1}/5], Loss: {total_loss/len(train_loader):.4f}')

In [None]:
# Evaluation on the test set
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += ## FINISH_ME ##

accuracy = 100 * correct / total
print(f'Accuracy on the MNIST test set: {accuracy:.2f}%')

## Outputs

You should find that this model is capable of ~90% acuuracy when evaluated over the noisy version of the mnist test dataset.
This shows this model is better at classifying images not present in the original dataset, therefore it's more generic.

You should find that once the model 'un-freezes' the encoder component of the classifier and trains that in combination with the transformer layer it should improve in terms of accuracy.