In [81]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms

In [82]:
%%capture

def load_data():
    # Data loading and preprocessing
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    train_dataset = datasets.MNIST(
        root='./data', 
        train=True, 
        transform=transform,
        download=True)
    
    test_dataset = datasets.MNIST(
        root='./data', 
        train=False, 
        transform=transform,
        download=True)
    
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

    return train_loader, test_loader

# Create DataLoaders for batch processing
train_loader, test_loader = load_data() 

In [83]:
class MultiheadMLP(nn.Module): 
    def __init__(self): 
        super().__init__()

        # Fully connected layers learning a shared representation for both tasks
        self.shared_encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )

        # Classification head for the 10 digits
        self.digits_head = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 10),
            # nn.Softmax() - This didn't help
        )

        # Classification head to determine parity
        self.parity_head = nn.Sequential(
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, 2),
            # nn.Softmax()
        )

    def forward(self, x):
        x = x.view(-1, 784)  # Flatten input
        
        # Shared base forward pass
        shared_output = self.shared_encoder(x)
        
        # Task-specific heads
        digits_output = self.digits_head(shared_output)
        parity_output = self.parity_head(shared_output)
        
        return digits_output, parity_output
        

In [94]:
def train(model, train_loader, a=1, b=1):
    model.train()  # Set the model to training mode
    num_epochs = 10
    loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(num_epochs):  # Iterate over epochs
        total_epoch_loss = 0 
        
        for images, labels in train_loader:  # Iterate over batches
            # Create odd/even labels for the parity task
            parity_labels = (labels % 2).long()  # 0 for even, 1 for odd
            
            # Forward pass
            digits_output, parity_output = model(images)
            
            # Compute losses for each task
            digits_loss = loss(digits_output, labels)  # Loss for digit classification
            parity_loss = loss(parity_output, parity_labels)  # Loss for parity classification
            
            # Combine losses with equal weights (can adjust weights if necessary)
            total_loss = a*digits_loss + b*parity_loss
            
            # Backpropagation
            optimizer.zero_grad()  # Zero out gradients
            total_loss.backward()  # Compute gradients
            optimizer.step()  # Update model parameters
            
            # Accumulate the batch loss for tracking
            total_epoch_loss += total_loss.item()
        
        # Print epoch-level summary
        if (epoch) % 5 == 0: 
            print(f"Epoch {epoch}/{num_epochs}, "
                  f"Total Loss: {total_epoch_loss:.4f}")

In [85]:
def evaluate(model, loader):
    model.eval()
    digits_correct, parity_correct, total = 0, 0, 0
    
    with torch.no_grad():
        for images, labels in loader:
            # Generate odd/even labels from `labels`
            parity_labels = (labels % 2).long()  # 0 for even, 1 for odd

            # Forward pass
            digits_output, parity_output = model(images)

            digits_preds = torch.argmax(digits_output, dim=1)
            parity_preds = torch.argmax(parity_output, dim=1)
            
            digits_batch_correct = (digits_preds == labels).sum()
            parity_batch_correct = (parity_preds == parity_labels).sum()

            # Update counters
            digits_correct += digits_batch_correct.item()
            parity_correct += parity_batch_correct.item()
            total += labels.size(0)

    # Print the accuracy for each task
    print(f"Digit Classification Accuracy: {100 * digits_correct / total:.2f}%")
    print(f"Odd/Even Classification Accuracy: {100 * parity_correct / total:.2f}%")
    print("------------------------------")


In [86]:
def experiment_pipeline(train_loader):
    model = MultiheadMLP()

    # Equal weights
    train(model, train_loader)
    evaluate(model, train_loader)

    # Prioritize digits
    train(model, train_loader, b=0)
    evaluate(model, train_loader)

    # Prioritize parity 
    train(model, train_loader, a=0)
    evaluate(model, train_loader) 

    # Intermediate weights
    train(model, train_loader, a=0.7, b=0.3)
    evaluate(model, train_loader)

In [95]:
experiment_pipeline(train_loader)

Epoch 0/10, Total Loss: 616.3267
Epoch 5/10, Total Loss: 132.9142
Digit Classification Accuracy: 98.10%
Odd/Even Classification Accuracy: 98.99%
------------------------------
Epoch 0/10, Total Loss: 60.2312
Epoch 5/10, Total Loss: 39.1555
Digit Classification Accuracy: 99.01%
Odd/Even Classification Accuracy: 99.18%
------------------------------
Epoch 0/10, Total Loss: 24.4452
Epoch 5/10, Total Loss: 15.0014
Digit Classification Accuracy: 96.57%
Odd/Even Classification Accuracy: 99.83%
------------------------------
Epoch 0/10, Total Loss: 32.4823
Epoch 5/10, Total Loss: 17.2412
Digit Classification Accuracy: 99.36%
Odd/Even Classification Accuracy: 99.79%
------------------------------
