# MNIST Noise Sensitivity
Load aggregated results and visualize test accuracy vs sigma for each activation.

In [7]:
import pandas as pd
import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
summary_path = 'results/results_summary.csv'
df = pd.read_csv(summary_path)
df['stderr_test_accuracy'] = df['std_test_accuracy'] / (df['repeats'] ** 0.5)
df

In [None]:
model_types = sorted(df['model_type'].unique())
for model_type in model_types:
    sub = df[df['model_type'] == model_type]
    activations = sorted(sub['activation'].unique())
    n = len(activations)
    fig, axes = plt.subplots(1, n, figsize=(4 * n, 3), sharey=True)
    if n == 1:
        axes = [axes]
    for ax, act in zip(axes, activations):
        d = sub[sub['activation'] == act].sort_values('sigma')
        ax.errorbar(d['sigma'], d['mean_test_accuracy'], yerr=d['stderr_test_accuracy'], marker='o')
        ax.set_title(f'{model_type} / {act}')
        ax.set_xlabel('sigma')
        ax.grid(True, alpha=0.3)
    axes[0].set_ylabel('mean test accuracy')
    plt.tight_layout()
    plt.show()

In [None]:
import torch
from torchvision import datasets, transforms

mnist = datasets.MNIST(root='data', train=True, download=True, transform=transforms.ToTensor())
means = []
maxes = []
medians = []
modes = []
for x, _ in mnist:
    flat = x.flatten()
    means.append(flat.mean().item())
    maxes.append(flat.max().item())
    nonzero = flat[flat > 0]
    if nonzero.numel() == 0:
        medians.append(0.0)
        modes.append(0.0)
    else:
        medians.append(nonzero.median().item())
        modes.append(torch.mode(nonzero).values.item())

plt.figure(figsize=(10, 8))
plt.subplot(2, 2, 1)
plt.hist(means, bins=50, color='steelblue', alpha=0.8)
plt.title('MNIST image mean (per image)')
plt.xlabel('mean pixel value')
plt.ylabel('count')

plt.subplot(2, 2, 2)
plt.hist(maxes, bins=50, color='salmon', alpha=0.8)
plt.title('MNIST image max (per image)')
plt.xlabel('max pixel value')
plt.ylabel('count')

plt.subplot(2, 2, 3)
plt.hist(medians, bins=50, color='seagreen', alpha=0.8)
plt.title('MNIST image median (per image)')
plt.xlabel('median pixel value')
plt.ylabel('count')

plt.subplot(2, 2, 4)
plt.hist(modes, bins=50, color='slateblue', alpha=0.8)
plt.title('MNIST image mode (per image)')
plt.xlabel('mode pixel value')
plt.ylabel('count')
plt.tight_layout()
plt.show()

In [None]:
import random

mnist = datasets.MNIST(root='data', train=True, download=True, transform=transforms.ToTensor())
idx = random.randrange(len(mnist))
img, label = mnist[idx]
plt.figure(figsize=(3, 3))
plt.imshow(img.squeeze(0), cmap='gray')
plt.title(f'Random MNIST sample (label={label})')
plt.axis('off')
plt.show()

In [8]:
from matplotlib import animation
from IPython.display import HTML

# Pick a single MNIST "4" and animate corruption with p in [0, 1].
mnist = datasets.MNIST(root='data', train=True, download=True, transform=transforms.ToTensor())
idx = next(i for i, (_, y) in enumerate(mnist) if y == 4)
img, label = mnist[idx]

# New global seed per trajectory.
seed = random.randrange(1, 1_000_000_000)
random.seed(seed)
torch.manual_seed(seed)
print(f'Animation seed: {seed}')

ps = torch.linspace(0.0, 1.0, 100)
frames = []
for p in ps:
    mask = torch.rand_like(img) < p
    replacement = torch.rand_like(img)
    corrupted = torch.where(mask, replacement, img)
    frames.append(corrupted.squeeze(0).numpy())

fig, ax = plt.subplots(figsize=(3, 3))
im = ax.imshow(frames[0], cmap='gray', vmin=0, vmax=1)
title = ax.set_title(f'label={label}, p={ps[0].item():.2f}')
ax.axis('off')

def update(frame_idx: int):
    im.set_data(frames[frame_idx])
    title.set_text(f'label={label}, p={ps[frame_idx].item():.2f}')
    return im, title

anim = animation.FuncAnimation(
    fig,
    update,
    frames=len(frames),
    interval=50,
    blit=False,
)
plt.close(fig)
HTML(anim.to_jshtml())

Animation seed: 616839595


In [10]:
from matplotlib import animation
from IPython.display import HTML

# Average over multiple trajectories (each with its own global seed).
mnist = datasets.MNIST(root='data', train=True, download=True, transform=transforms.ToTensor())
idx = next(i for i, (_, y) in enumerate(mnist) if y == 4)
img, label = mnist[idx]

num_trajectories = 1000
ps = torch.linspace(0.0, 1.0, 100)
sum_frames = torch.zeros((len(ps), 1, 28, 28))
seeds = [random.randrange(1, 1_000_000_000) for _ in range(num_trajectories)]

with torch.no_grad():
    for seed in seeds:
        torch.manual_seed(seed)
        for i, p in enumerate(ps):
            mask = torch.rand_like(img) < p
            replacement = torch.rand_like(img)
            corrupted = torch.where(mask, replacement, img)
            sum_frames[i] += corrupted

avg_frames = (sum_frames / num_trajectories).squeeze(1).numpy()

fig, ax = plt.subplots(figsize=(3, 3))
im = ax.imshow(avg_frames[0], cmap='gray', vmin=0, vmax=1)
title = ax.set_title(f'label={label}, p={ps[0].item():.2f} (avg)')
ax.axis('off')

def update(frame_idx: int):
    im.set_data(avg_frames[frame_idx])
    title.set_text(f'label={label}, p={ps[frame_idx].item():.2f} (avg)')
    return im, title

anim = animation.FuncAnimation(
    fig,
    update,
    frames=len(avg_frames),
    interval=50,
    blit=False,
)
plt.close(fig)
HTML(anim.to_jshtml())