# Experiment E3: Reparameterization Invariance Test

**One-step update consistency across parameter coordinate rescaling**

**Implementation**: Based on validated 1-step reference implementation

---

## Generated Files

- `E3_results.csv`
- `E3_summary.json`
- `E3_figure.png` (publication quality)
- `E3_figure.pdf` (publication quality)

**Runtime**: ~5 minutes

---

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
SAVE_DIR = '/content/drive/MyDrive/paper-E-final/E3'
os.makedirs(SAVE_DIR, exist_ok=True)
print(f'Results → {SAVE_DIR}')

In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd
import json
from datetime import datetime

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

torch.set_default_dtype(torch.float64)

In [None]:
# Model
class SoftmaxRegression(nn.Module):
    def __init__(self, input_dim, n_classes, bias=False, dtype=torch.float64, device="cpu"):
        super().__init__()
        self.linear = nn.Linear(input_dim, n_classes, bias=bias, dtype=dtype, device=device)

    def forward(self, x):
        return self.linear(x)

In [None]:
# Utils: flatten / unflatten
def get_param_vector(model: nn.Module) -> torch.Tensor:
    return torch.cat([p.detach().reshape(-1) for p in model.parameters()])

def set_param_vector_(model: nn.Module, vec: torch.Tensor) -> None:
    # IMPORTANT: overwrite params deterministically
    idx = 0
    with torch.no_grad():
        for p in model.parameters():
            n = p.numel()
            p.copy_(vec[idx:idx+n].view_as(p))
            idx += n
    assert idx == vec.numel()

In [None]:
# True Fisher (analytic) for softmax regression (bias=False)
def compute_true_fisher_softmax_regression(model: SoftmaxRegression, X: torch.Tensor) -> torch.Tensor:
    model.eval()
    with torch.no_grad():
        logits = model(X)
        probs = F.softmax(logits, dim=1)

    B, C = probs.shape
    D = X.shape[1]

    # [B, C, C] : Diag(p) - p p^T
    diag_p = torch.diag_embed(probs)
    outer_p = torch.einsum("bc,bd->bcd", probs, probs)
    fisher_class = diag_p - outer_p

    # [B, D, D] : x x^T
    outer_x = torch.einsum("bi,bj->bij", X, X)

    # [B, C, D, C, D] : kron
    fisher = torch.einsum("bik,bjl->bijkl", fisher_class, outer_x)
    fisher = fisher.reshape(B, C*D, C*D)

    # Average over batch
    F_true = fisher.mean(dim=0)
    # Symmetrize
    F_true = 0.5 * (F_true + F_true.T)
    return F_true

In [None]:
# Gate A: Fisher PSD quality
def gate_a_fisher_psd(F: torch.Tensor, name="Fisher"):
    Fsym = 0.5 * (F + F.T)
    evals = torch.linalg.eigvalsh(Fsym)
    min_eig = evals.min().item()
    max_eig = evals.max().item()
    thresh = -1e-8 * abs(max_eig)
    passed = (min_eig >= thresh)
    msg = f"{name}: min_eig={min_eig:.3e}, max_eig={max_eig:.3e}, thresh={thresh:.3e}"
    return passed, msg, min_eig, max_eig

In [None]:
# Gate C: all κ produce identical logits at init
def gate_c_equivalence(models_dict, X, kappa_values, threshold=1e-5):
    base = kappa_values[0]
    with torch.no_grad():
        logits0 = models_dict[base](X)
    max_diff = 0.0
    fails = []
    for k in kappa_values[1:]:
        with torch.no_grad():
            logitsk = models_dict[k](X)
        diff = (logitsk - logits0).abs().max().item()
        max_diff = max(max_diff, diff)
        if diff > threshold:
            fails.append((k, diff))
    if len(fails) == 0:
        return True, f"GateC PASS: max|Δlogits|={max_diff:.2e}", max_diff
    else:
        return False, "GateC FAIL: " + ", ".join([f"κ={k}: {d:.2e}" for k,d in fails]), max_diff

In [None]:
# Symmetric pseudo-inverse (coordinate-invariant)
def pinv_psd(F: torch.Tensor, rcond=1e-8):
    Fsym = 0.5 * (F + F.T)
    evals, evecs = torch.linalg.eigh(Fsym)
    max_eig = evals.max()
    cut = rcond * max_eig
    inv = torch.zeros_like(evals)
    mask = evals > cut
    inv[mask] = 1.0 / evals[mask]
    return (evecs * inv.unsqueeze(0)) @ evecs.T

In [None]:
# Build parameter scaling vector S(κ)
def make_S_vec(kappa: float, n_classes: int, n_features: int, device, dtype):
    s_hi = math.sqrt(kappa)
    s_lo = 1.0 / math.sqrt(kappa)
    s_feat = torch.ones(n_features, device=device, dtype=dtype)
    half = n_features // 2
    s_feat[:half] = s_hi
    s_feat[half:] = s_lo
    # Repeat per class
    S = s_feat.repeat(n_classes)
    return S

In [None]:
# Core: compute Δθ in θ-space via φ-coordinate update
def one_step_update_in_theta_space(g_theta, F_theta, S_vec, method, eta, rcond=1e-8):
    # Transform grad and Fisher into φ coords
    g_phi = S_vec * g_theta

    # CRITICAL: F_φ = diag(S) @ F_θ @ diag(S)
    F_phi = (S_vec[:, None] * F_theta) * S_vec[None, :]
    F_phi = 0.5 * (F_phi + F_phi.T)

    if method == "sgd":
        delta_phi = -eta * g_phi
        delta_theta = S_vec * delta_phi
        return delta_theta, F_phi

    elif method == "natgrad":
        Fphi_pinv = pinv_psd(F_phi, rcond=rcond)
        v_phi = Fphi_pinv @ g_phi
        delta_phi = -eta * v_phi
        delta_theta = S_vec * delta_phi
        return delta_theta, F_phi

    else:
        raise ValueError("method must be 'sgd' or 'natgrad'")

In [None]:
# Main E3: 1-step invariance test
def run_e3_one_step(X, y, seed=0, kappa_values=(1,10,100,1000), eta=1e-3, rcond=1e-8, device="cpu"):
    torch.manual_seed(seed)
    dtype = torch.float64

    B, D = X.shape
    C = int(y.max().item()) + 1

    # Build base model and θ0 (FIXED for all κ)
    model = SoftmaxRegression(D, C, bias=False, dtype=dtype, device=device)
    model.train()

    theta0 = get_param_vector(model)

    # Compute gθ and Fθ at θ0
    model.zero_grad(set_to_none=True)
    logits = model(X)
    loss = F.cross_entropy(logits, y)
    loss.backward()
    g_theta = torch.cat([p.grad.detach().reshape(-1) for p in model.parameters()])

    F_theta = compute_true_fisher_softmax_regression(model, X)

    # Gate A on F_theta
    passA, msgA, _, _ = gate_a_fisher_psd(F_theta, "F_theta")
    if not passA:
        raise RuntimeError("Gate A failed: " + msgA)

    # Gate C: all κ have identical logits at init
    models = {}
    for k in kappa_values:
        m = SoftmaxRegression(D, C, bias=False, dtype=dtype, device=device)
        set_param_vector_(m, theta0)
        models[k] = m
    passC, msgC, _ = gate_c_equivalence(models, X, list(kappa_values), threshold=1e-5)
    if not passC:
        raise RuntimeError("Gate C failed: " + msgC)

    # Compute Δθ for each method and κ
    results = {}
    for method in ["sgd", "natgrad"]:
        deltas = {}
        for k in kappa_values:
            S_vec = make_S_vec(k, C, D, device=device, dtype=dtype)
            delta_theta_k, F_phi_k = one_step_update_in_theta_space(
                g_theta, F_theta, S_vec, method=method, eta=eta, rcond=rcond
            )

            # Gate A on F_phi_k
            passA2, msgA2, _, _ = gate_a_fisher_psd(F_phi_k, f"F_phi(κ={k})")
            if not passA2:
                raise RuntimeError("Gate A failed on transformed Fisher: " + msgA2)

            deltas[k] = delta_theta_k

        # Consistency vs κ=1
        ref = deltas[kappa_values[0]]
        cons = {}
        for k in kappa_values[1:]:
            cons[k] = (deltas[k] - ref).abs().max().item()
        results[method] = cons

    return {
        "seed": seed,
        "eta": eta,
        "rcond": rcond,
        "gateA_msg": msgA,
        "gateC_msg": msgC,
        "consistency_natgrad": results["natgrad"],
        "consistency_sgd": results["sgd"],
    }

In [None]:
# Generate data
def generate_data(n_samples=5000, n_features=16, n_classes=10, noise_std=1.0, seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    centers = torch.randn(n_classes, n_features, dtype=torch.float64)
    Q, _ = torch.linalg.qr(centers.T)
    centers = (Q[:, :n_classes].T * 2.5).to(device)
    samples_per_class = n_samples // n_classes
    X_list, y_list = [], []
    for c in range(n_classes):
        X_class = centers[c] + torch.randn(samples_per_class, n_features, dtype=torch.float64, device=device) * noise_std
        y_class = torch.full((samples_per_class,), c, dtype=torch.long, device=device)
        X_list.append(X_class)
        y_list.append(y_class)
    X = torch.cat(X_list)
    y = torch.cat(y_list)
    perm = torch.randperm(X.shape[0], device=device)
    return X[perm], y[perm]

X_train, y_train = generate_data(seed=42)
print(f'Data: {X_train.shape}')

In [None]:
# Run experiments
CONFIG = {
    'kappa_values': [1, 10, 100, 1000],
    'n_seeds': 3,
    'eta': 0.001,
    'rcond': 1e-8
}

print('Running E3 invariance tests...')
print(f"Total seeds: {CONFIG['n_seeds']}\n")

all_results = []

for seed in range(CONFIG['n_seeds']):
    print(f"Seed {seed}:")
    result = run_e3_one_step(
        X_train, y_train, 
        seed=seed, 
        kappa_values=tuple(CONFIG['kappa_values']),
        eta=CONFIG['eta'],
        rcond=CONFIG['rcond'],
        device=device
    )
    print(f"  Gate A: {result['gateA_msg']}")
    print(f"  Gate C: {result['gateC_msg']}")
    print(f"  NG κ=1000: {result['consistency_natgrad'][1000]:.2e}")
    print(f"  SGD κ=1000: {result['consistency_sgd'][1000]:.2e}")
    
    for k in CONFIG['kappa_values'][1:]:
        all_results.append({
            'seed': seed,
            'kappa': k,
            'method': 'natgrad',
            'max_diff': result['consistency_natgrad'][k]
        })
        all_results.append({
            'seed': seed,
            'kappa': k,
            'method': 'sgd',
            'max_diff': result['consistency_sgd'][k]
        })

print('\n✓ Completed')

In [None]:
df = pd.DataFrame(all_results)

print('\n' + '='*70)
print('RESULTS SUMMARY')
print('='*70)
summary_table = df.groupby(['method', 'kappa'])['max_diff'].agg(['mean', 'std'])
print(summary_table)

ng_1000 = df[(df['method']=='natgrad') & (df['kappa']==1000)]['max_diff'].mean()
sgd_1000 = df[(df['method']=='sgd') & (df['kappa']==1000)]['max_diff'].mean()

print(f'\nAt κ=1000:')
print(f'  Natural Gradient: {ng_1000:.2e}')
print(f'  SGD: {sgd_1000:.2e}')
print(f'  Ratio (SGD/NG): {sgd_1000/ng_1000:.1f}×')

if ng_1000 < sgd_1000:
    print('\n✓✓✓ SUCCESS: NG is more invariant than SGD')
else:
    print('\n✗ WARNING: Unexpected result')

In [None]:
# Publication-quality figure (no title)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.5))

for method in ['sgd', 'natgrad']:
    method_data = df[df['method'] == method]
    grouped = method_data.groupby('kappa')['max_diff'].agg(['mean', 'std'])
    
    label = 'Natural Gradient' if method == 'natgrad' else 'SGD'
    color = 'C1' if method == 'natgrad' else 'C0'
    
    ax1.plot(grouped.index, grouped['mean'], marker='o', label=label, 
             linewidth=2, markersize=8, color=color)
    ax1.fill_between(grouped.index, grouped['mean'] - grouped['std'], 
                      grouped['mean'] + grouped['std'], alpha=0.2, color=color)

ax1.set_xscale('log')
ax1.set_yscale('log')
ax1.axhline(y=1e-5, color='gray', linestyle='--', alpha=0.5, label='Threshold')
ax1.set_xlabel('Condition Number κ', fontsize=11)
ax1.set_ylabel('max |Δθ(κ) - Δθ(1)|', fontsize=11)
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3, which='both')

# Cosine similarity (compute from stored deltas if needed, or use normalized difference)
# For simplicity, show relative consistency
for method in ['sgd', 'natgrad']:
    method_data = df[df['method'] == method]
    grouped = method_data.groupby('kappa')['max_diff'].agg(['mean', 'std'])
    
    # Normalize to [0,1] scale for visualization
    normalized = 1.0 - np.minimum(grouped['mean'] / grouped['mean'].max(), 1.0)
    
    label = 'Natural Gradient' if method == 'natgrad' else 'SGD'
    color = 'C1' if method == 'natgrad' else 'C0'
    
    ax2.plot(grouped.index, normalized, marker='o', label=label,
             linewidth=2, markersize=8, color=color)

ax2.set_xscale('log')
ax2.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)
ax2.set_xlabel('Condition Number κ', fontsize=11)
ax2.set_ylabel('Relative Consistency', fontsize=11)
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{SAVE_DIR}/E3_figure.png', dpi=300, bbox_inches='tight')
plt.savefig(f'{SAVE_DIR}/E3_figure.pdf', bbox_inches='tight')
print('Figure saved')

In [None]:
df.to_csv(f'{SAVE_DIR}/E3_results.csv', index=False)

ng_data = df[df['method'] == 'natgrad']
sgd_data = df[df['method'] == 'sgd']

summary = {
    'experiment': 'E3',
    'timestamp': datetime.now().isoformat(),
    'config': CONFIG,
    'ng_consistency': {
        'kappa_10': float(ng_data[ng_data['kappa']==10]['max_diff'].mean()),
        'kappa_100': float(ng_data[ng_data['kappa']==100]['max_diff'].mean()),
        'kappa_1000': float(ng_data[ng_data['kappa']==1000]['max_diff'].mean())
    },
    'sgd_degradation': {
        'kappa_10': float(sgd_data[sgd_data['kappa']==10]['max_diff'].mean()),
        'kappa_100': float(sgd_data[sgd_data['kappa']==100]['max_diff'].mean()),
        'kappa_1000': float(sgd_data[sgd_data['kappa']==1000]['max_diff'].mean())
    },
    'ratio_at_1000': float(
        sgd_data[sgd_data['kappa']==1000]['max_diff'].mean() / 
        ng_data[ng_data['kappa']==1000]['max_diff'].mean()
    )
}

with open(f'{SAVE_DIR}/E3_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print(f'\n✓ All files saved to {SAVE_DIR}')
print('  - E3_results.csv')
print('  - E3_summary.json')
print('  - E3_figure.png')
print('  - E3_figure.pdf')
print(f'\nSGD/NG invariance ratio at κ=1000: {summary["ratio_at_1000"]:.1f}×')
print('(Higher = NG more invariant)')