# Experiment E1: Pareto Frontier

**Natural Gradient local efficiency at matched per-step KL divergence**

---

## Generated Files

This notebook generates:
- `E1_results_raw.csv`
- `E1_results_aggregated.csv`
- `E1_gate_validation.csv`
- `E1_summary.json`
- `E1_figure.png` (publication quality)
- `E1_figure.pdf` (publication quality)

**Runtime**: ~5-10 minutes on GPU

---

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
SAVE_DIR = '/content/drive/MyDrive/paper-E-final/E1'
os.makedirs(SAVE_DIR, exist_ok=True)
print(f'Results → {SAVE_DIR}')

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd
import json
from datetime import datetime
import time

torch.manual_seed(42)
np.random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.set_default_dtype(torch.float64)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
def gate_a_fisher_quality(fisher, name="Fisher", threshold=-1e-7):
    fisher_sym = (fisher + fisher.T) / 2
    eigenvalues = torch.linalg.eigvalsh(fisher_sym)
    min_eig = eigenvalues.min().item()
    max_eig = eigenvalues.max().item()
    threshold_value = threshold * abs(max_eig)
    
    details = {
        "min_eigenvalue": min_eig,
        "max_eigenvalue": max_eig,
        "condition_number": max_eig / min_eig if min_eig > 0 else float('inf')
    }
    
    if min_eig < threshold_value:
        return False, f"FAIL: {name} negative eigenvalue {min_eig:.2e}", details
    return True, f"PASS: {name} PSD (min={min_eig:.2e})", details

In [None]:
class SoftmaxRegression(nn.Module):
    def __init__(self, input_dim, num_classes, use_bias=False):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_classes, bias=use_bias)
    def forward(self, x):
        return self.linear(x)

In [None]:
def compute_fisher_analytic(model, X):
    with torch.no_grad():
        logits = model(X)
        probs = F.softmax(logits, dim=1)
        N, C = probs.shape
        D = X.shape[1]
        fisher = torch.zeros(C * D, C * D, dtype=X.dtype, device=X.device)
        for i in range(N):
            p = probs[i]
            x = X[i]
            H = torch.diag(p) - torch.outer(p, p)
            F_sample = torch.kron(H, torch.outer(x, x))
            fisher += F_sample
        fisher /= N
        fisher = (fisher + fisher.T) / 2
        return fisher

In [None]:
def generate_data(n_samples=5000, n_features=16, n_classes=10, noise_std=1.0, seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    centers = torch.randn(n_classes, n_features, dtype=torch.float64)
    Q, _ = torch.linalg.qr(centers.T)
    centers = (Q[:, :n_classes].T * 2.5).to(device)
    samples_per_class = n_samples // n_classes
    X_list, y_list = [], []
    for c in range(n_classes):
        X_class = centers[c] + torch.randn(samples_per_class, n_features, dtype=torch.float64, device=device) * noise_std
        y_class = torch.full((samples_per_class,), c, dtype=torch.long, device=device)
        X_list.append(X_class)
        y_list.append(y_class)
    X = torch.cat(X_list)
    y = torch.cat(y_list)
    perm = torch.randperm(X.shape[0], device=device)
    return X[perm], y[perm]

X_train, y_train = generate_data(n_samples=5000, seed=42)
print(f'Data: X={X_train.shape}, y={y_train.shape}')

In [None]:
def get_parameters_flat(model):
    return torch.cat([p.flatten() for p in model.parameters()])

def set_parameters_flat(model, params):
    offset = 0
    for p in model.parameters():
        numel = p.numel()
        p.data.copy_(params[offset:offset+numel].view_as(p))
        offset += numel

def compute_loss_and_grad(model, X, y):
    model.zero_grad()
    logits = model(X)
    loss = F.cross_entropy(logits, y)
    loss.backward()
    grad = torch.cat([p.grad.flatten() for p in model.parameters()])
    return loss.item(), grad

def sgd_step(model, X, y, target_kl_step, fisher):
    loss_before, grad = compute_loss_and_grad(model, X, y)
    grad_F_grad = torch.dot(grad, fisher @ grad)
    eta = torch.sqrt(2 * target_kl_step / grad_F_grad)
    theta_before = get_parameters_flat(model)
    theta_after = theta_before - eta * grad
    set_parameters_flat(model, theta_after)
    with torch.no_grad():
        logits_before = model.forward(X)
    set_parameters_flat(model, theta_after)
    with torch.no_grad():
        logits_after = model.forward(X)
        probs_before = F.softmax(logits_before, dim=1)
        probs_after = F.softmax(logits_after, dim=1)
        kl_actual = -(probs_before * torch.log(probs_after / (probs_before + 1e-10) + 1e-10)).sum(1).mean().item()
    loss_after = F.cross_entropy(logits_after, y).item()
    delta_loss = loss_before - loss_after
    return {
        'loss_before': loss_before,
        'loss_after': loss_after,
        'delta_loss': delta_loss,
        'kl_step': kl_actual,
        'kl_ratio': kl_actual / target_kl_step,
        'eta': eta.item(),
        'direction': grad / torch.norm(grad)
    }

def natgrad_step(model, X, y, target_kl_step, fisher):
    loss_before, grad = compute_loss_and_grad(model, X, y)
    eigvals, eigvecs = torch.linalg.eigh(fisher)
    eigvals_inv = torch.where(eigvals > 1e-8, 1.0 / eigvals, torch.zeros_like(eigvals))
    fisher_inv = eigvecs @ torch.diag(eigvals_inv) @ eigvecs.T
    natgrad = fisher_inv @ grad
    natgrad_F_natgrad = torch.dot(natgrad, fisher @ natgrad)
    eta = torch.sqrt(2 * target_kl_step / natgrad_F_natgrad)
    theta_before = get_parameters_flat(model)
    theta_after = theta_before - eta * natgrad
    set_parameters_flat(model, theta_after)
    with torch.no_grad():
        logits_before = model.forward(X)
    set_parameters_flat(model, theta_after)
    with torch.no_grad():
        logits_after = model.forward(X)
        probs_before = F.softmax(logits_before, dim=1)
        probs_after = F.softmax(logits_after, dim=1)
        kl_actual = -(probs_before * torch.log(probs_after / (probs_before + 1e-10) + 1e-10)).sum(1).mean().item()
    loss_after = F.cross_entropy(logits_after, y).item()
    delta_loss = loss_before - loss_after
    return {
        'loss_before': loss_before,
        'loss_after': loss_after,
        'delta_loss': delta_loss,
        'kl_step': kl_actual,
        'kl_ratio': kl_actual / target_kl_step,
        'eta': eta.item(),
        'direction': natgrad / torch.norm(natgrad)
    }

In [None]:
CONFIG = {
    'n_features': 16,
    'n_classes': 10,
    'use_bias': False,
    'epsilon_values': [1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2, 1e-1],
    'n_seeds': 10
}

print(f"Total experiments: {len(CONFIG['epsilon_values']) * CONFIG['n_seeds'] * 2}")

results = []
gate_results = []

for seed in range(CONFIG['n_seeds']):
    torch.manual_seed(seed)
    model = SoftmaxRegression(CONFIG['n_features'], CONFIG['n_classes'], use_bias=CONFIG['use_bias']).to(device)
    fisher = compute_fisher_analytic(model, X_train)
    passed, msg, details = gate_a_fisher_quality(fisher, name=f"Seed {seed}")
    if seed == 0:
        print(f"Gate A: {msg}")
    gate_results.append({'seed': seed, 'gate_a_passed': passed, **details})
    if not passed:
        print(f"⚠️ WARNING: Gate A failed for seed {seed}")
    theta_0 = get_parameters_flat(model).clone()
    for eps in CONFIG['epsilon_values']:
        set_parameters_flat(model, theta_0)
        sgd_result = sgd_step(model, X_train, y_train, eps, fisher)
        set_parameters_flat(model, theta_0)
        ng_result = natgrad_step(model, X_train, y_train, eps, fisher)
        cos_sim = torch.dot(sgd_result['direction'], ng_result['direction']).item()
        results.append({'seed': seed, 'epsilon': eps, 'method': 'sgd', **sgd_result})
        results.append({'seed': seed, 'epsilon': eps, 'method': 'natgrad', **ng_result, 'cos_g_v': cos_sim})

print(f"✓ Completed")

In [None]:
gate_df = pd.DataFrame(gate_results)
pass_rate = gate_df['gate_a_passed'].mean()
print(f"\nGate A pass rate: {pass_rate:.1%}")
if pass_rate == 1.0:
    print("✓✓✓ GATE A: ALL PASSED")
else:
    print("⚠️ GATE A: SOME FAILED")

In [None]:
df = pd.DataFrame(results)
agg_df = df.groupby(['epsilon', 'method']).agg({
    'delta_loss': ['mean', 'std'],
    'kl_step': ['mean', 'std'],
    'kl_ratio': ['mean', 'std']
}).reset_index()
agg_df.columns = ['epsilon', 'method', 'delta_loss_mean', 'delta_loss_std',
                  'kl_step_mean', 'kl_step_std', 'kl_ratio_mean', 'kl_ratio_std']

sgd_df = agg_df[agg_df['method'] == 'sgd'].set_index('epsilon')
ng_df = agg_df[agg_df['method'] == 'natgrad'].set_index('epsilon')
relative_advantage = (ng_df['delta_loss_mean'] - sgd_df['delta_loss_mean']) / sgd_df['delta_loss_mean'] * 100
print(f"\nMean NG advantage: {relative_advantage.mean():+.2f}%")

In [None]:
# Publication-quality figure (no title, no panel labels)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.5))

sgd_data = agg_df[agg_df['method'] == 'sgd']
ng_data = agg_df[agg_df['method'] == 'natgrad']

ax1.errorbar(sgd_data['kl_step_mean'], sgd_data['delta_loss_mean'],
             yerr=sgd_data['delta_loss_std'], xerr=sgd_data['kl_step_std'],
             marker='s', label='SGD', capsize=3, linewidth=2)
ax1.errorbar(ng_data['kl_step_mean'], ng_data['delta_loss_mean'],
             yerr=ng_data['delta_loss_std'], xerr=ng_data['kl_step_std'],
             marker='o', label='Natural Gradient', capsize=3, linewidth=2)
ax1.set_xscale('log')
ax1.set_xlabel('Per-Step KL Divergence', fontsize=11)
ax1.set_ylabel('Loss Reduction', fontsize=11)
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

eps_values = relative_advantage.index.to_numpy()
adv_values = relative_advantage.values
ax2.plot(eps_values, adv_values, marker='o', linewidth=2, markersize=7, color='C2')
ax2.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
ax2.set_xscale('log')
ax2.set_xlabel('Per-Step KL Divergence', fontsize=11)
ax2.set_ylabel('Relative Advantage (%)', fontsize=11)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{SAVE_DIR}/E1_figure.png', dpi=300, bbox_inches='tight')
plt.savefig(f'{SAVE_DIR}/E1_figure.pdf', bbox_inches='tight')
print("Figure saved")

In [None]:
df.to_csv(f'{SAVE_DIR}/E1_results_raw.csv', index=False)
agg_df.to_csv(f'{SAVE_DIR}/E1_results_aggregated.csv', index=False)
gate_df.to_csv(f'{SAVE_DIR}/E1_gate_validation.csv', index=False)

summary = {
    'experiment': 'E1',
    'timestamp': datetime.now().isoformat(),
    'config': CONFIG,
    'gate_a_pass_rate': float(pass_rate),
    'mean_relative_advantage_pct': float(relative_advantage.mean()),
    'min_advantage_pct': float(relative_advantage.min()),
    'max_advantage_pct': float(relative_advantage.max())
}

with open(f'{SAVE_DIR}/E1_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print(f"\n✓ All files saved to {SAVE_DIR}")
print(f"  - E1_results_raw.csv")
print(f"  - E1_results_aggregated.csv")
print(f"  - E1_gate_validation.csv")
print(f"  - E1_summary.json")
print(f"  - E1_figure.png")
print(f"  - E1_figure.pdf")