# Publication Figures (MNIST/CIFAR-10)
This notebook reads `/kaggle/working/results` and renders ablation-style error-bar plots and shows statistical comparison tables if present.

In [None]:
# Imports
import os, glob
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
results_dir = Path('/kaggle/working/results')
results_dir.mkdir(exist_ok=True, parents=True)
print('Using results dir:', results_dir)

In [None]:
# Helper: aggregate final metrics per optimizer pattern
def collect(pattern, metric_col):
    data = {}
    for f in glob.glob(str(results_dir / pattern)):
        # extract optimizer and seed
        import re
        base = os.path.basename(f)
        mopt = re.search(r'NN_.*?_(\w+)_lr', base)
        mseed = re.search(r'seed(\d+)', base)
        if not mopt or not mseed:
            continue
        opt = mopt.group(1)
        seed = int(mseed.group(1))
        df = pd.read_csv(f)
        val = float(df[metric_col].iloc[-1])
        data.setdefault(opt, {})[seed] = val
    rows = []
    for opt, mp in sorted(data.items()):
        vals = np.array(list(mp.values()), dtype=float)
        rows.append({
            'Optimizer': opt,
            'n': len(vals),
            f'{metric_col}_mean': float(np.mean(vals)) if len(vals) else np.nan,
            f'{metric_col}_std': float(np.std(vals, ddof=1)) if len(vals) > 1 else np.nan,
        })
    return pd.DataFrame(rows)

def plot_errorbars(df, value_col, err_col, title, ylabel, save_name):
    if df.empty:
        print('Nothing to plot for', title)
        return
    order = df.sort_values(value_col, ascending=('loss' in value_col))
    x = np.arange(len(order))
    fig, ax = plt.subplots(figsize=(7,4.5))
    ax.bar(x, order[value_col].values, yerr=order[err_col].values if err_col in order else None, capsize=4)
    ax.set_xticks(x)
    ax.set_xticklabels(order['Optimizer'].values, rotation=30, ha='right')
    ax.set_title(title)
    ax.set_ylabel(ylabel)
    ax.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    out = results_dir.parent / 'plots' / save_name
    out.parent.mkdir(exist_ok=True)
    plt.savefig(out, dpi=300, bbox_inches='tight')
    plt.show()
    print('Saved plot to', out)

In [None]:
# MNIST aggregation
mnist = collect('NN_SimpleMLP_MNIST_*_publication.csv', 'test_acc')
display(mnist)
plot_errorbars(mnist, 'test_acc_mean', 'test_acc_std', 'MNIST Optimizer Ablation (mean ± std)', 'Test Accuracy', 'mnist_ablation_accuracy.png')
mnist_loss = collect('NN_SimpleMLP_MNIST_*_publication.csv', 'test_loss')
plot_errorbars(mnist_loss, 'test_loss_mean', 'test_loss_std', 'MNIST Optimizer Ablation (mean ± std)', 'Test Loss', 'mnist_ablation_loss.png')

In [None]:
# CIFAR-10 aggregation
cifar = collect('NN_SimpleCIFAR10_*_publication.csv', 'test_acc')
display(cifar)
plot_errorbars(cifar, 'test_acc_mean', 'test_acc_std', 'CIFAR-10 Optimizer Ablation (mean ± std)', 'Test Accuracy', 'cifar10_ablation_accuracy.png')
cifar_loss = collect('NN_SimpleCIFAR10_*_publication.csv', 'test_loss')
plot_errorbars(cifar_loss, 'test_loss_mean', 'test_loss_std', 'CIFAR-10 Optimizer Ablation (mean ± std)', 'Test Loss', 'cifar10_ablation_loss.png')

In [None]:
# Show statistical comparisons if present
import pandas as pd
paths = [
    results_dir / 'mnist_statistical_comparisons_publication.csv',
    results_dir / 'cifar10_statistical_comparisons_publication.csv',
    results_dir / 'nn_statistical_comparisons.csv'
]
for p in paths:
    if p.exists():
        print('Showing', p)
        display(pd.read_csv(p).sort_values('p-value'))
    else:
        print('Missing', p)