In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms

In [2]:
%%capture

def load_data():
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Lambda(lambda x: x.view(-1))])

    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    x_train_raw = train_dataset.data.numpy()
    x_test_raw = test_dataset.data.numpy()

    # Flatten to a 2D array where each row is a pixel (28*28 = 784 pixels) for a given image
    x_train_reshaped = x_train_raw.reshape(-1, 784)
    x_test_reshaped = x_test_raw.reshape(-1, 784) 

    # Normalize data between 0 and 1 (For Sigmoid & PCA later on)
    x_train_normalized = x_train_reshaped / 255.0
    x_test_normalized = x_test_reshaped / 255.0 

    # Convert to tensor
    x_train_tensor = torch.tensor(x_train_normalized, dtype=torch.float32), 
    x_test_tensor = torch.tensor(x_test_normalized, dtype=torch.float32)

    return x_train_tensor[0], x_test_tensor[0]

X_train, _ = load_data()
X_train = X_train

# Create DataLoaders for batch processing
train_loader = DataLoader(TensorDataset(X_train), batch_size=64, shuffle=True)

In [3]:
class MultiheadMLP(nn.Module): 
    def __init__(self): 
        super().__init__()

        # Fully connected layers learning a shared representation for both tasks
        self.shared_encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )

        # Classification head for the 10 digits
        self.digits_head = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 10),
            nn.Softmax()
        )

        # Classification head to determine parity
        self.parity_head = nn.Sequential(
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, 2),
            nn.Softmax()
        )

    def forward(self, x):
        # Shared base forward pass
        shared_output = self.shared_encoder(x)
        
        # Task-specific heads
        digits_output = self.digits_head(shared_output)
        parity_output = self.parity_head(shared_output)
        
        return digits_output, parity_output
        

In [4]:
model = MultiheadMLP()
loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)