In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
import pandas as pd
import random

# ----------------------------------
# 0. Reproducibility and device
# ----------------------------------

def set_seed(seed: int = 42) -> None:
    """Set random seeds for reproducible experiments."""
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_device() -> torch.device:
    """Selects MPS (Apple), CUDA, or CPU."""
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

set_seed(42)
DEVICE = get_device()
print("Using device for Set B:", DEVICE)

EPOCHS_B = 100
LR_B = 0.01
BATCH_SIZE_B = 128
NOISE_FRACTION = 0.25

# ----------------------------------
# 1. Complexity measure helpers
# ----------------------------------

def calculate_l2_norm(model: nn.Module) -> float:
    """Computes the Frobenius norm of all weight matrices."""
    l2_norm = 0.0
    for name, param in model.named_parameters():
        if "weight" in name:
            l2_norm += torch.sum(param.detach() ** 2)
    return torch.sqrt(l2_norm).item()

def calculate_spectral_norm(model: nn.Module) -> float:
    """Computes the sum of maximum singular values across weight matrices."""
    spectral_norm_sum = 0.0
    for name, param in model.named_parameters():
        if "weight" in name and param.dim() > 1:
            W = param
            try:
                if W.numel() > 0:
                    _, S, _ = torch.linalg.svd(W, full_matrices=False)
                    spectral_norm_sum += S[0].item()
            except Exception:
                continue
    return spectral_norm_sum

def calculate_sharpness(
    model: nn.Module,
    criterion: nn.Module,
    data_loader: DataLoader,
    rho: float = 0.01,
    device: torch.device = DEVICE,
) -> float:
    """
    Approximates sharpness using a single SAM-style perturbation step.
    S(w*) = (L(w* + Îµ) - L(w*)) / (1 + L(w*)).
    """
    model.eval()

    try:
        data_batch, target_batch = next(iter(data_loader))
    except StopIteration:
        return 0.0

    data_batch, target_batch = data_batch.to(device), target_batch.to(device)

    with torch.no_grad():
        outputs = model(data_batch)
        base_loss = criterion(outputs, target_batch).item()

    model.zero_grad()
    outputs = model(data_batch)
    loss = criterion(outputs, target_batch)
    loss.backward()

    grad_norm_sq = 0.0
    for p in model.parameters():
        if p.grad is not None:
            grad_norm_sq += torch.sum(p.grad ** 2)
    grad_norm = torch.sqrt(grad_norm_sq)
    if grad_norm.item() == 0.0:
        return 0.0

    epsilon_map = {}
    for name, p in model.named_parameters():
        if p.grad is not None:
            eps = (p.grad / grad_norm) * rho
            p.data.add_(eps)
            epsilon_map[name] = eps

    with torch.no_grad():
        outputs_perturbed = model(data_batch)
        pert_loss = criterion(outputs_perturbed, target_batch).item()

    for name, p in model.named_parameters():
        if name in epsilon_map:
            p.data.sub_(epsilon_map[name])

    sharp = (pert_loss - base_loss) / (1.0 + base_loss)
    return max(0.0, sharp)

def evaluate_model(
    model: nn.Module,
    data_loader: DataLoader,
    criterion: nn.Module,
    device: torch.device = DEVICE,
) -> tuple[float, float]:
    """Computes average loss and classification error on a dataset."""
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            loss = criterion(outputs, target)
            total_loss += loss.item() * data.size(0)

            _, preds = torch.max(outputs.data, 1)
            correct += (preds == target).sum().item()
            total += target.size(0)

    avg_loss = total_loss / total
    error = 1.0 - (correct / total)
    return avg_loss, error

# ----------------------------------
# 2. Model definition (Deep FFN)
# ----------------------------------

class DeepFFN(nn.Module):
    """Five-layer fully-connected network for the memorisation regime."""
    def __init__(
        self,
        input_dim: int = 784,
        hidden_dim: int = 512,
        num_hidden_layers: int = 5,
        num_classes: int = 10,
    ):
        super().__init__()
        layers = []

        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.ReLU())

        for _ in range(num_hidden_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())

        layers.append(nn.Linear(hidden_dim, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(x.size(0), -1)
        return self.net(x)

# ----------------------------------
# 3. Data loading with label noise
# ----------------------------------

def load_mnist_with_label_noise(
    batch_size: int = BATCH_SIZE_B,
    noise_fraction: float = NOISE_FRACTION,
    seed: int = 42,
):
    """
    Loads the MNIST dataset and applies label corruption to a fraction of the
    training labels. The test set remains unchanged.
    """
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]
    )

    full_train = datasets.MNIST("./data", train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST("./data", train=False, download=True, transform=transform)

    train_dataset = full_train

    targets = train_dataset.targets.clone()
    n_train = len(train_dataset)
    n_noisy = int(noise_fraction * n_train)

    rng = torch.Generator().manual_seed(seed)
    noisy_indices = torch.randperm(n_train, generator=rng)[:n_noisy]

    for idx in noisy_indices:
        original_label = targets[idx].item()
        new_label = random.randint(0, 9)
        while new_label == original_label:
            new_label = random.randint(0, 9)
        targets[idx] = new_label

    train_dataset.targets = targets

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

# ----------------------------------
# 4. Training and experiment driver
# ----------------------------------

def train_and_evaluate_set_b(
    epochs: int = EPOCHS_B,
    lr: float = LR_B,
    batch_size: int = BATCH_SIZE_B,
) -> pd.DataFrame:
    """
    Trains a deep FFN on MNIST with 25% randomised labels and reports
    generalisation and complexity metrics.
    """
    print("\n--- Set B: MNIST with 25% Label Noise (Memorisation Regime) ---")
    print(f"Epochs: {epochs}, Batch size: {batch_size}, Device: {DEVICE}")

    train_loader, test_loader = load_mnist_with_label_noise(
        batch_size=batch_size,
        noise_fraction=NOISE_FRACTION,
        seed=42,
    )

    model = DeepFFN().to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total trainable parameters (Set B model): {total_params}")

    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for data, target in train_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * data.size(0)

        scheduler.step()
        avg_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch [{epoch+1}/{epochs}] - Train Loss: {avg_loss:.4f}")

    train_loss, train_error = evaluate_model(model, train_loader, criterion, DEVICE)
    test_loss, test_error = evaluate_model(model, test_loader, criterion, DEVICE)
    gen_gap = test_error - train_error

    l2 = calculate_l2_norm(model)
    spec = calculate_spectral_norm(model)
    sharp = calculate_sharpness(model, criterion, train_loader, device=DEVICE)

    result = {
        "id": "B_FFN_512x5",
        "params": total_params,
        "train_error": train_error,
        "test_error": test_error,
        "gen_gap": gen_gap,
        "l2_norm": l2,
        "spectral_norm": spec,
        "sharpness": sharp,
    }

    df_b = pd.DataFrame([result])
    out_name = "dissertation_results_set_b.csv"
    df_b.to_csv(out_name, index=False)
    print(f"\nSet B results saved to '{out_name}'")
    print(df_b)

    return df_b

if __name__ == "__main__":
    train_and_evaluate_set_b()


Using device for Set B: mps

--- Set B: MNIST with 25% Label Noise (Memorisation Regime) ---
Epochs: 100, Batch size: 128, Device: mps
Total trainable parameters (Set B model): 1457674
Epoch [1/100] - Train Loss: 1.8620
Epoch [2/100] - Train Loss: 1.3144
Epoch [3/100] - Train Loss: 1.2510
Epoch [4/100] - Train Loss: 1.2184
Epoch [5/100] - Train Loss: 1.1984
Epoch [6/100] - Train Loss: 1.1848
Epoch [7/100] - Train Loss: 1.1733
Epoch [8/100] - Train Loss: 1.1641
Epoch [9/100] - Train Loss: 1.1567
Epoch [10/100] - Train Loss: 1.1495
Epoch [11/100] - Train Loss: 1.1433
Epoch [12/100] - Train Loss: 1.1358
Epoch [13/100] - Train Loss: 1.1289
Epoch [14/100] - Train Loss: 1.1236
Epoch [15/100] - Train Loss: 1.1163
Epoch [16/100] - Train Loss: 1.1112
Epoch [17/100] - Train Loss: 1.1056
Epoch [18/100] - Train Loss: 1.0994
Epoch [19/100] - Train Loss: 1.0928
Epoch [20/100] - Train Loss: 1.0847
Epoch [21/100] - Train Loss: 1.0766
Epoch [22/100] - Train Loss: 1.0703
Epoch [23/100] - Train Loss: 1.0

  _, S, _ = torch.linalg.svd(W, full_matrices=False)



Set B results saved to 'dissertation_results_set_b.csv'
            id   params  train_error  test_error   gen_gap    l2_norm  \
0  B_FFN_512x5  1457674     0.017217      0.1612  0.143983  47.206608   

   spectral_norm  sharpness  
0      21.608106   0.022762  
