In [6]:
import torch
import pandas as pd
from models.mnist_model import MNISTModel
from utils.data_loader import get_mnist_data_loader

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_, _, test_loader = get_mnist_data_loader(batch_size=64)

# Define model info
model_checkpoints = {
    "Cross Entropy": {
        "default": "checkpoints/mnist_model_cross_entropy.pth",
        "data1%": "checkpoints/mnist_model_cross_entropy_data1%.pth",
        "data5%": "checkpoints/mnist_model_cross_entropy_data5%.pth",
        "noise":  "checkpoints/mnist_model_cross_entropy_noisy40%.pth"
    },
    "Margin": {
        "default": "checkpoints/mnist_model_margin.pth",
        "data1%": "checkpoints/mnist_model_margin_data1%.pth",
        "data5%": "checkpoints/mnist_model_margin_data5%.pth",
        "noise":  "checkpoints/mnist_model_margin_noisy40%.pth"
    },
    "MultiLayerMargin": {
        "default": "checkpoints/mnist_model_multi_layer_margin.pth",
        "data1%": "checkpoints/mnist_model_multi_layer_margin_data1%.pth",
        "data5%": "checkpoints/mnist_model_multi_layer_margin_data5%.pth",
        "noise":  "checkpoints/mnist_model_multi_layer_margin_noisy40%.pth"
    },
    "TrueMultiLayerMargin": {
        "default": "checkpoints/mnist_model_true_multi_layer_margin_fp16.pth",
        "data1%": "checkpoints/mnist_model_true_multi_layer_margin_data1%_fp16.pth",
        "data5%": "checkpoints/mnist_model_true_multi_layer_margin_data5%_fp16.pth",
        "noise":  "checkpoints/mnist_model_true_multi_layer_margin_noisy40%_fp16.pth"
    },
}

# Accuracy computation function
def compute_test_accuracy(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            preds = outputs.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return 100.0 * correct / total

# Build the results table
results = []

for model_name, paths in model_checkpoints.items():
    row = {"model": model_name}
    for condition, path in paths.items():
        model = MNISTModel().to(device)
        model.load_state_dict(torch.load(path, map_location=device))
        acc = compute_test_accuracy(model, test_loader, device)
        row[condition + " test acc"] = f"{acc:.2f}%"
    results.append(row)

df = pd.DataFrame(results)
df = df[["model", "default test acc", "data1% test acc", "data5% test acc", "noise test acc"]]  # enforce column order
display(df)


Unnamed: 0,model,default test acc,data1% test acc,data5% test acc,noise test acc
0,Cross Entropy,97.92%,93.81%,97.00%,90.87%
1,Margin,98.59%,91.61%,96.05%,92.66%
2,MultiLayerMargin,96.82%,91.10%,95.43%,87.01%
3,TrueMultiLayerMargin,97.84%,9.82%,10.32%,9.58%
