In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image
import random

# --- Data Augmentation and Width Normalization ---

class WidthNormalize(object):
    """Normalizes the width of an image to a target width."""
    def __init__(self, target_width, preserve_aspect_ratio=False):
        self.target_width = target_width
        self.preserve_aspect_ratio = preserve_aspect_ratio

    def __call__(self, img):
        w, h = img.size
        if self.preserve_aspect_ratio:
            new_h = int(h * (self.target_width / w))
            return img.resize((self.target_width, new_h), Image.BICUBIC)
        else:
            return img.resize((self.target_width, h), Image.BICUBIC)

class RandomDistort(object):
    """Applies random distortions (translation, scaling, rotation)."""
    def __init__(self, max_translation=3, max_scaling=1.15, max_rotation=15, p=0.5):
        self.max_translation = max_translation
        self.max_scaling = max_scaling
        self.max_rotation = max_rotation
        self.p = p

    def __call__(self, img):
        if random.random() > self.p:
            return img
        w, h = img.size
        tx = random.randint(-self.max_translation, self.max_translation)
        ty = random.randint(-self.max_translation, self.max_translation)
        img = transforms.functional.affine(img, angle=0, translate=(tx, ty), scale=1, shear=0)
        scale = random.uniform(1 / self.max_scaling, self.max_scaling)
        img = transforms.functional.affine(img, angle=0, translate=(0, 0), scale=scale, shear=0)
        angle = random.uniform(-self.max_rotation, self.max_rotation)
        img = transforms.functional.rotate(img, angle)
        return img

class PadToSize(object):
    """Pads an image to a specific size."""
    def __init__(self, target_size):
        self.target_size = target_size

    def __call__(self, img):
        w, h = img.size
        pad_width = self.target_size[0] - w
        pad_height = self.target_size[1] - h

        # Calculate padding for left, top, right, bottom
        pad_left = pad_width // 2
        pad_top = pad_height // 2
        pad_right = pad_width - pad_left
        pad_bottom = pad_height - pad_top

        # Use PIL's ImageOps.expand for padding with a border color (0 for black)
        return transforms.functional.pad(img, (pad_left, pad_top, pad_right, pad_bottom), padding_mode='constant', fill=0)


def get_mnist_datasets(root='./data', width_targets=None, train_transform=None, test_transform=None):
    """Gets MNIST datasets, handling width normalization and padding."""

    # 1. Base transforms: Convert to tensor and normalize.
    base_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
    ])

    # 2. Define train_transform: Include PadToSize *after* custom transforms.
    if train_transform is not None:
        train_transform = transforms.Compose([
            train_transform,
            PadToSize((29, 29)),  # Pad to 29x29 *after* distortions/resizing.
            base_transform,  # Convert to tensor and normalize.
        ])
    else:
        train_transform = transforms.Compose([
            PadToSize((29, 29)), # Pad to 29x29
            base_transform
        ])

    # 3. Define test_transform:  Also include padding!
    if test_transform is not None:
        test_transform = transforms.Compose([
            test_transform,
            PadToSize((29, 29)), # Pad to 29x29
            base_transform
        ])
    else:
        test_transform = transforms.Compose([
           PadToSize((29, 29)),  # Pad to 29x29
            base_transform,
        ])


    # 4. Load/download the original MNIST dataset
    original_train_dataset = torchvision.datasets.MNIST(root=root, train=True, download=True, transform=train_transform)
    original_test_dataset = torchvision.datasets.MNIST(root=root, train=False, download=True, transform=test_transform)

    if width_targets is None:
        return original_train_dataset, original_test_dataset

    # 5. Create width-normalized datasets (if requested)
    train_datasets = [original_train_dataset]
    test_datasets = [original_test_dataset]
    for target_width in width_targets:
        width_norm_train_transform = transforms.Compose([
            WidthNormalize(target_width), # Normalise width
            train_transform,  # Apply other augmentation transform
        ])
        width_norm_test_transform = transforms.Compose([
            WidthNormalize(target_width),  # Normalize width
            test_transform,  # Apply other test transforms
        ])
        train_datasets.append(torchvision.datasets.MNIST(root=root, train=True, download=False, transform=width_norm_train_transform))
        test_datasets.append(torchvision.datasets.MNIST(root=root, train=False, download=False, transform=width_norm_test_transform))
    return train_datasets, test_datasets


# --- Model Definition (Single Column DNN) ---

class MNIST_DNN(nn.Module):
    """Single-column DNN for MNIST, as described in the paper."""
    def __init__(self, num_classes=10):
        super(MNIST_DNN, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 20, kernel_size=4, stride=1, padding=0)  # No padding
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(20, 40, kernel_size=5, stride=1, padding=0)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=3)

        # Calculate the size of the input to the fully connected layer.  This calculation is NOW correct:
        # 29x29 input -> conv1(k=4,s=1,p=0) -> 26x26 -> pool1(k=2,s=2) -> 13x13
        # 13x13 -> conv2(k=5,s=1,p=0) -> 9x9 -> pool2(k=3,s=3) -> 3x3
        self.fc_input_size = 40 * 3 * 3  # 40 channels * 3 * 3

        # Fully connected layers
        self.fc1 = nn.Linear(self.fc_input_size, 150)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(150, num_classes)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, self.fc_input_size)  # Flatten
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x

# --- Training Function ---

def train_model(model, train_loader, criterion, optimizer, num_epochs=10, device='cpu'):
    """Trains the given model."""
    model.to(device)
    model.train()

    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            # --- Debug Print (Check input shape) ---
            #print(f"Shape of inputs BEFORE model: {inputs.shape}")

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader):.4f}')
    print('Finished Training')
    return model

# --- Evaluation Function ---

def evaluate_model(model, test_loader, device='cpu'):
    """Evaluates the model on the test set."""
    model.to(device)
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy on the test set: {accuracy:.2f}%')
    return accuracy

# --- Main Script ---

if __name__ == '__main__':
    # Hyperparameters
    batch_size = 64
    learning_rate = 0.001
    num_epochs = 10
    width_targets = [10, 14, 18]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Data Loading:  Crucially, use the corrected get_mnist_datasets
    train_transforms = RandomDistort()
    train_datasets, test_datasets = get_mnist_datasets(width_targets=width_targets, train_transform=train_transforms)


    # Model Training (Single Column - Example)
    trained_models = []
    for i, train_dataset in enumerate(train_datasets):
      print(f"Training model for dataset {i + 1}...")
      train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
      test_dataloader = DataLoader(test_datasets[i], batch_size=batch_size, shuffle=False)

      model = MNIST_DNN().to(device)
      optimizer = optim.Adam(model.parameters(), lr=learning_rate)
      criterion = nn.CrossEntropyLoss()

      lambda_annealing = lambda epoch: 0.993**epoch
      scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_annealing)

      model = train_model(model, train_dataloader, criterion, optimizer, num_epochs, device)
      accuracy = evaluate_model(model, test_dataloader, device)
      trained_models.append(model)
      
    # --- MCDNN Prediction (Averaging) ---
    def predict_mcdnn(models, test_loader, device):
      """Predicts using the MCDNN ensemble."""
      all_predictions = []  # Store predictions from all models
      all_labels_list = []

      for model in models:
          model.to(device)
          model.eval()  # Set models to evaluation mode
          predictions = [] # For storing individual model prediction
          with torch.no_grad():
              for inputs, labels in test_loader:
                  inputs = inputs.to(device)
                  outputs = model(inputs)
                  # Apply softmax to get probabilities (not needed for training, but here we need prob)
                  probabilities = torch.nn.functional.softmax(outputs, dim=1)
                  predictions.append(probabilities.cpu())  # Move to CPU to avoid GPU memory issues.
          all_predictions.append(torch.cat(predictions))  # Concatenate the predictions from all the batches.
          all_labels_list.append(labels.cpu())
          
      # all_predictions is now a LIST of tensors.  Each tensor is the predictions of one model
      # on the entire test set.  We need to average these *across models*.
      all_predictions = torch.stack(all_predictions)  # Stack into a single tensor: [num_models, num_samples, num_classes]
      averaged_predictions = torch.mean(all_predictions, dim=0) # Average across models (dim=0).  Result: [num_samples, num_classes]

      # Get predicted classes from averaged probabilities
      _, predicted_classes = torch.max(averaged_predictions, 1)  # Get the class with highest probability
      return predicted_classes, torch.cat(all_labels_list)
      
    # Get MCDNN predictions on original test loader and calculate final accuracy.
    final_test_loader = DataLoader(test_datasets[0], batch_size=batch_size, shuffle=False) #original test loader
    mcdnn_predictions, final_test_labels = predict_mcdnn(trained_models, final_test_loader, device)

    # Calculate accuracy
    correct = (mcdnn_predictions == final_test_labels).sum().item()
    total = final_test_labels.size(0)
    mcdnn_accuracy = 100 * correct / total
    print(f'MCDNN Accuracy on the original test set: {mcdnn_accuracy:.2f}%')