# Experiment E4: Approximation Invariance Test

**Coordinate invariance degradation in Fisher matrix approximations**

---

## Objective

Quantify how much coordinate invariance degrades when using practical Fisher approximations:
- **Exact Fisher** (ground truth from E3)
- **K-FAC** (Kronecker-factored approximation)
- **Empirical Fisher** (gradient outer product)
- **SGD** (no Fisher, baseline)

## Generated Files

- `E4_results.csv`
- `E4_summary.json`
- `E4_figure.png` (publication quality)
- `E4_figure.pdf` (publication quality)

**Runtime**: ~10 minutes

---

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
SAVE_DIR = '/content/drive/MyDrive/paper-E-final/E4'
os.makedirs(SAVE_DIR, exist_ok=True)
print(f'Results → {SAVE_DIR}')

In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as Fn
import matplotlib.pyplot as plt
import pandas as pd
import json
from datetime import datetime

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

# Deterministic execution
torch.set_default_dtype(torch.float64)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Model Definition

In [None]:
class SoftmaxRegression(nn.Module):
    def __init__(self, input_dim, n_classes, bias=False, dtype=torch.float64, device="cpu"):
        super().__init__()
        self.linear = nn.Linear(input_dim, n_classes, bias=bias, dtype=dtype, device=device)

    def forward(self, x):
        return self.linear(x)

## Utility Functions

In [None]:
def get_param_vector(model: nn.Module) -> torch.Tensor:
    return torch.cat([p.detach().reshape(-1) for p in model.parameters()])

def set_param_vector_(model: nn.Module, vec: torch.Tensor) -> None:
    idx = 0
    with torch.no_grad():
        for p in model.parameters():
            n = p.numel()
            p.copy_(vec[idx:idx+n].view_as(p))
            idx += n
    assert idx == vec.numel()

## Fisher Matrix Computations

Three methods:
1. **Exact Fisher** (analytic, from E3)
2. **K-FAC** (Kronecker-factored approximation)
3. **Empirical Fisher** (gradient outer product)

In [None]:
def compute_exact_fisher(model: SoftmaxRegression, X: torch.Tensor) -> torch.Tensor:
    """
    Exact (true) Fisher for softmax regression.
    F = (1/N) Σ_i [diag(p_i) - p_i p_i^T] ⊗ [x_i x_i^T]
    """
    model.eval()
    with torch.no_grad():
        logits = model(X)
        probs = Fn.softmax(logits, dim=1)

    B, C = probs.shape
    D = X.shape[1]

    # [B, C, C] : Diag(p) - p p^T
    diag_p = torch.diag_embed(probs)
    outer_p = torch.einsum("bc,bd->bcd", probs, probs)
    fisher_class = diag_p - outer_p

    # [B, D, D] : x x^T
    outer_x = torch.einsum("bi,bj->bij", X, X)

    # [B, C, D, C, D] : kron
    fisher = torch.einsum("bik,bjl->bijkl", fisher_class, outer_x)
    fisher = fisher.reshape(B, C*D, C*D)

    # Average over batch
    F_exact = fisher.mean(dim=0)
    # Symmetrize
    F_exact = 0.5 * (F_exact + F_exact.T)
    return F_exact

In [None]:
def compute_kfac_fisher(model: SoftmaxRegression, X: torch.Tensor) -> torch.Tensor:
    """
    K-FAC approximation: F ≈ S ⊗ A
    
    Where:
    - A = (1/N) Σ_i x_i x_i^T  (input covariance, D×D)
    - S = (1/N) Σ_i [diag(p_i) - p_i p_i^T]  (output Fisher, C×C)
    
    K-FAC assumes independence between input and output statistics,
    which is exact for linear layers with certain distributions.
    
    Reference: Martens & Grosse (2015), Grosse & Martens (2016)
    """
    model.eval()
    with torch.no_grad():
        logits = model(X)
        probs = Fn.softmax(logits, dim=1)
        
        B, C = probs.shape
        D = X.shape[1]
        
        # A: Input covariance (D × D)
        # A = (1/N) Σ x_i x_i^T
        A = (X.T @ X) / B
        A = 0.5 * (A + A.T)  # Symmetrize
        
        # S: Output Fisher (C × C)
        # S = (1/N) Σ [diag(p_i) - p_i p_i^T]
        # This is the expected Fisher over the output distribution
        diag_p = torch.diag(probs.mean(dim=0))  # E[diag(p)]
        outer_p = (probs.T @ probs) / B         # E[p p^T]
        S = diag_p - outer_p
        S = 0.5 * (S + S.T)  # Symmetrize
        
        # F_kfac = S ⊗ A (Kronecker product)
        # Note: torch.kron computes A ⊗ B, we need S ⊗ A for correct layout
        F_kfac = torch.kron(S, A)
        F_kfac = 0.5 * (F_kfac + F_kfac.T)
        
    return F_kfac

In [None]:
def compute_empirical_fisher(model: SoftmaxRegression, X: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """
    Empirical Fisher: F_emp = (1/N) Σ_i g_i g_i^T
    
    Where g_i = ∇_θ log p(y_i | x_i, θ) is the gradient of the log-likelihood
    for the *observed* label y_i (not the model's prediction).
    
    This is NOT the true Fisher, which uses the model's distribution.
    The empirical Fisher is essentially free to compute but may not
    preserve coordinate invariance.
    
    Reference: Kunstner et al. (2019) "Limitations of the Empirical Fisher"
    """
    model.eval()
    B = X.shape[0]
    param_dim = sum(p.numel() for p in model.parameters())
    
    # Compute per-sample gradients
    grads = []
    for i in range(B):
        model.zero_grad()
        logits = model(X[i:i+1])
        log_prob = Fn.log_softmax(logits, dim=1)
        loss = -log_prob[0, y[i]]  # Negative log-likelihood for sample i
        loss.backward()
        
        grad_i = torch.cat([p.grad.reshape(-1) for p in model.parameters()])
        grads.append(grad_i)
    
    grads = torch.stack(grads)  # [B, param_dim]
    
    # F_emp = (1/N) Σ g_i g_i^T
    F_emp = (grads.T @ grads) / B
    F_emp = 0.5 * (F_emp + F_emp.T)
    
    return F_emp

## Gate Validation System

In [None]:
def gate_a_fisher_psd(F_mat: torch.Tensor, name="Fisher"):
    """Gate A: Verify Fisher is positive semi-definite"""
    Fsym = 0.5 * (F_mat + F_mat.T)
    evals = torch.linalg.eigvalsh(Fsym)
    min_eig = evals.min().item()
    max_eig = evals.max().item()
    thresh = -1e-8 * abs(max_eig)
    passed = (min_eig >= thresh)
    msg = f"{name}: min_eig={min_eig:.3e}, max_eig={max_eig:.3e}"
    return passed, msg, min_eig, max_eig

def gate_c_equivalence(models_dict, X, kappa_values, threshold=1e-5):
    """Gate C: Verify all κ produce identical initial logits"""
    base = kappa_values[0]
    with torch.no_grad():
        logits0 = models_dict[base](X)
    max_diff = 0.0
    fails = []
    for k in kappa_values[1:]:
        with torch.no_grad():
            logitsk = models_dict[k](X)
        diff = (logitsk - logits0).abs().max().item()
        max_diff = max(max_diff, diff)
        if diff > threshold:
            fails.append((k, diff))
    if len(fails) == 0:
        return True, f"GateC PASS: max|Δlogits|={max_diff:.2e}", max_diff
    else:
        return False, "GateC FAIL: " + ", ".join([f"κ={k}: {d:.2e}" for k,d in fails]), max_diff

## Pseudo-Inverse and Coordinate Transform

In [None]:
def pinv_psd(F_mat: torch.Tensor, rcond=1e-8):
    """Symmetric pseudo-inverse for PSD matrices"""
    Fsym = 0.5 * (F_mat + F_mat.T)
    evals, evecs = torch.linalg.eigh(Fsym)
    max_eig = evals.max()
    cut = rcond * max_eig
    inv = torch.zeros_like(evals)
    mask = evals > cut
    inv[mask] = 1.0 / evals[mask]
    return (evecs * inv.unsqueeze(0)) @ evecs.T

def make_S_vec(kappa: float, n_classes: int, n_features: int, device, dtype):
    """Build diagonal scaling vector S(κ) for coordinate transform"""
    s_hi = math.sqrt(kappa)
    s_lo = 1.0 / math.sqrt(kappa)
    s_feat = torch.ones(n_features, device=device, dtype=dtype)
    half = n_features // 2
    s_feat[:half] = s_hi
    s_feat[half:] = s_lo
    # Repeat per class
    S = s_feat.repeat(n_classes)
    return S

## Core: One-Step Update with Different Fisher Approximations

In [None]:
def one_step_update_in_theta_space(g_theta, F_theta, S_vec, method, eta, rcond=1e-8):
    """
    Compute parameter update Δθ in original θ-space,
    after transforming to φ-coordinates and back.
    
    For coordinate invariance, the result should be independent of S_vec
    when using exact Fisher. Approximations may show degradation.
    
    Methods:
    - 'sgd': Δθ = -η * g (no Fisher)
    - 'exact': Δθ = -η * F⁻¹ g (exact Natural Gradient)
    - 'kfac': Δθ = -η * F_kfac⁻¹ g (K-FAC Natural Gradient)
    - 'empirical': Δθ = -η * F_emp⁻¹ g (Empirical Fisher Natural Gradient)
    """
    # Transform gradient to φ-coordinates: g_φ = S * g_θ
    g_phi = S_vec * g_theta

    # Transform Fisher to φ-coordinates: F_φ = diag(S) @ F_θ @ diag(S)
    F_phi = (S_vec[:, None] * F_theta) * S_vec[None, :]
    F_phi = 0.5 * (F_phi + F_phi.T)

    if method == "sgd":
        # SGD in φ-space: Δφ = -η * g_φ
        delta_phi = -eta * g_phi
    else:
        # Natural gradient variants: Δφ = -η * F_φ⁻¹ g_φ
        F_pinv = pinv_psd(F_phi, rcond=rcond)
        delta_phi = -eta * F_pinv @ g_phi

    # Transform back to θ-coordinates: Δθ = S * Δφ
    delta_theta = S_vec * delta_phi
    
    return delta_theta

## Main Experiment Runner

In [None]:
def run_e4_experiment(
    X, y, seed,
    kappa_values=(1, 10, 100, 1000),
    eta=0.001,
    rcond=1e-8,
    n_features=16,
    n_classes=10,
    device="cpu"
):
    """
    Run E4 experiment for one seed.
    
    Compares coordinate invariance across:
    - SGD (baseline)
    - Exact Natural Gradient
    - K-FAC Natural Gradient  
    - Empirical Fisher Natural Gradient
    """
    torch.manual_seed(seed)
    np.random.seed(seed)
    dtype = torch.float64
    
    # Create models for each κ (same initialization, different coordinates)
    models = {}
    for k in kappa_values:
        models[k] = SoftmaxRegression(n_features, n_classes, bias=False, dtype=dtype, device=device)
    
    # Set identical initial parameters
    theta_init = get_param_vector(models[kappa_values[0]]).clone()
    for k in kappa_values:
        set_param_vector_(models[k], theta_init.clone())
    
    # Gate C: Verify initial equivalence
    passedC, msgC, _ = gate_c_equivalence(models, X, kappa_values)
    if not passedC:
        print(f"⚠️ {msgC}")
    
    # Use κ=1 model as reference
    model_ref = models[kappa_values[0]]
    
    # Compute gradient (same for all methods)
    model_ref.zero_grad()
    logits = model_ref(X)
    loss = Fn.cross_entropy(logits, y)
    loss.backward()
    g_theta = torch.cat([p.grad.reshape(-1) for p in model_ref.parameters()]).detach()
    
    # Compute different Fisher matrices
    F_exact = compute_exact_fisher(model_ref, X)
    F_kfac = compute_kfac_fisher(model_ref, X)
    F_emp = compute_empirical_fisher(model_ref, X, y)
    
    # Gate A for all Fisher matrices
    gate_results = {}
    for name, F_mat in [('exact', F_exact), ('kfac', F_kfac), ('empirical', F_emp)]:
        passed, msg, min_eig, max_eig = gate_a_fisher_psd(F_mat, name=name)
        gate_results[name] = {'passed': passed, 'msg': msg, 'min_eig': min_eig, 'max_eig': max_eig}
    
    # Define methods and their Fisher matrices
    methods_config = {
        'sgd': F_exact,      # SGD doesn't use Fisher, but we pass it for consistent API
        'exact': F_exact,
        'kfac': F_kfac,
        'empirical': F_emp
    }
    
    # Compute Δθ for each method and each κ
    results = {method: {} for method in methods_config}
    
    for method, F_theta in methods_config.items():
        deltas = {}
        for k in kappa_values:
            S_vec = make_S_vec(k, n_classes, n_features, device, dtype)
            delta_theta = one_step_update_in_theta_space(
                g_theta, F_theta, S_vec, 
                method=method if method != 'sgd' else 'sgd',
                eta=eta, rcond=rcond
            )
            deltas[k] = delta_theta
        
        # Compute consistency: max|Δθ(κ) - Δθ(1)|
        ref_delta = deltas[kappa_values[0]]
        for k in kappa_values:
            diff = (deltas[k] - ref_delta).abs().max().item()
            results[method][k] = diff
    
    return {
        'seed': seed,
        'gate_c': msgC,
        'gate_a': gate_results,
        'consistency': results
    }

## Data Generation

In [None]:
def generate_data(n_samples=5000, n_features=16, n_classes=10, noise_std=1.0, seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    centers = torch.randn(n_classes, n_features, dtype=torch.float64)
    Q, _ = torch.linalg.qr(centers.T)
    centers = (Q[:, :n_classes].T * 2.5).to(device)
    samples_per_class = n_samples // n_classes
    X_list, y_list = [], []
    for c in range(n_classes):
        X_class = centers[c] + torch.randn(samples_per_class, n_features, dtype=torch.float64, device=device) * noise_std
        y_class = torch.full((samples_per_class,), c, dtype=torch.long, device=device)
        X_list.append(X_class)
        y_list.append(y_class)
    X = torch.cat(X_list)
    y = torch.cat(y_list)
    perm = torch.randperm(X.shape[0], device=device)
    return X[perm], y[perm]

X_train, y_train = generate_data(seed=42)
print(f'Data: X={X_train.shape}, y={y_train.shape}')

## Run Experiments

In [None]:
CONFIG = {
    'kappa_values': [1, 10, 100, 1000],
    'n_seeds': 3,
    'eta': 0.001,
    'rcond': 1e-8,
    'n_features': 16,
    'n_classes': 10
}

print('='*70)
print('E4: Approximation Invariance Test')
print('='*70)
print(f"Methods: SGD, Exact NG, K-FAC NG, Empirical Fisher NG")
print(f"Condition numbers: {CONFIG['kappa_values']}")
print(f"Seeds: {CONFIG['n_seeds']}")
print('='*70 + '\n')

all_results = []

for seed in range(CONFIG['n_seeds']):
    print(f"Seed {seed}:")
    result = run_e4_experiment(
        X_train, y_train,
        seed=seed,
        kappa_values=tuple(CONFIG['kappa_values']),
        eta=CONFIG['eta'],
        rcond=CONFIG['rcond'],
        n_features=CONFIG['n_features'],
        n_classes=CONFIG['n_classes'],
        device=device
    )
    
    print(f"  Gate C: {result['gate_c']}")
    for name, ga in result['gate_a'].items():
        status = '✓' if ga['passed'] else '✗'
        print(f"  Gate A ({name}): {status} min_eig={ga['min_eig']:.2e}")
    
    print(f"  At κ=1000:")
    for method in ['sgd', 'exact', 'kfac', 'empirical']:
        val = result['consistency'][method][1000]
        print(f"    {method:12s}: {val:.2e}")
    
    # Store results
    for method in ['sgd', 'exact', 'kfac', 'empirical']:
        for k in CONFIG['kappa_values'][1:]:  # Skip κ=1
            all_results.append({
                'seed': seed,
                'kappa': k,
                'method': method,
                'max_diff': result['consistency'][method][k]
            })
    print()

print('✓ All experiments completed')

## Results Analysis

In [None]:
df = pd.DataFrame(all_results)

print('\n' + '='*70)
print('RESULTS SUMMARY')
print('='*70)

# Summary table
summary_table = df.groupby(['method', 'kappa'])['max_diff'].agg(['mean', 'std']).reset_index()
summary_pivot = summary_table.pivot(index='kappa', columns='method', values='mean')
summary_pivot = summary_pivot[['sgd', 'empirical', 'kfac', 'exact']]  # Order by expected invariance
print('\nMean deviation by method and κ:')
print(summary_pivot.to_string())

# Ratios at κ=1000
print('\n' + '-'*70)
print('Invariance ratios at κ=1000 (relative to Exact NG):')
print('-'*70)

exact_1000 = df[(df['method']=='exact') & (df['kappa']==1000)]['max_diff'].mean()
kfac_1000 = df[(df['method']=='kfac') & (df['kappa']==1000)]['max_diff'].mean()
emp_1000 = df[(df['method']=='empirical') & (df['kappa']==1000)]['max_diff'].mean()
sgd_1000 = df[(df['method']=='sgd') & (df['kappa']==1000)]['max_diff'].mean()

print(f'  Exact NG:      {exact_1000:.2e} (reference)')
print(f'  K-FAC NG:      {kfac_1000:.2e} ({kfac_1000/exact_1000:.1f}× worse)')
print(f'  Empirical NG:  {emp_1000:.2e} ({emp_1000/exact_1000:.1f}× worse)')
print(f'  SGD:           {sgd_1000:.2e} ({sgd_1000/exact_1000:.1f}× worse)')

print('\n' + '-'*70)
print('Invariance spectrum (best to worst):')
print('-'*70)
methods_sorted = sorted(
    [('exact', exact_1000), ('kfac', kfac_1000), ('empirical', emp_1000), ('sgd', sgd_1000)],
    key=lambda x: x[1]
)
for i, (method, val) in enumerate(methods_sorted, 1):
    print(f'  {i}. {method:12s}: {val:.2e}')

## Publication-Quality Figure

In [None]:
# Color scheme
colors = {
    'exact': '#2ecc71',      # Green (best)
    'kfac': '#f39c12',       # Orange
    'empirical': '#e74c3c',  # Red
    'sgd': '#3498db'         # Blue (baseline)
}

labels = {
    'exact': 'Exact NG',
    'kfac': 'K-FAC NG',
    'empirical': 'Empirical Fisher NG',
    'sgd': 'SGD'
}

markers = {
    'exact': 'o',
    'kfac': 's',
    'empirical': '^',
    'sgd': 'd'
}

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.5))

# Left panel: Deviation vs κ
for method in ['sgd', 'empirical', 'kfac', 'exact']:
    method_data = df[df['method'] == method]
    grouped = method_data.groupby('kappa')['max_diff'].agg(['mean', 'std'])
    
    ax1.plot(grouped.index, grouped['mean'], 
             marker=markers[method], label=labels[method],
             linewidth=2, markersize=8, color=colors[method])
    ax1.fill_between(grouped.index, 
                     grouped['mean'] - grouped['std'],
                     grouped['mean'] + grouped['std'],
                     alpha=0.2, color=colors[method])

ax1.set_xscale('log')
ax1.set_yscale('log')
ax1.axhline(y=1e-5, color='gray', linestyle='--', alpha=0.5, label='Threshold')
ax1.set_xlabel('Condition Number κ', fontsize=11)
ax1.set_ylabel('max |Δθ(κ) - Δθ(1)|', fontsize=11)
ax1.legend(fontsize=9, loc='upper left')
ax1.grid(True, alpha=0.3, which='both')
ax1.set_ylim(1e-17, 1)

# Right panel: Bar chart at κ=1000
methods_order = ['exact', 'kfac', 'empirical', 'sgd']
values_1000 = [df[(df['method']==m) & (df['kappa']==1000)]['max_diff'].mean() for m in methods_order]
stds_1000 = [df[(df['method']==m) & (df['kappa']==1000)]['max_diff'].std() for m in methods_order]

x_pos = np.arange(len(methods_order))
bars = ax2.bar(x_pos, values_1000, yerr=stds_1000, capsize=5,
               color=[colors[m] for m in methods_order], alpha=0.8)
ax2.set_yscale('log')
ax2.set_xticks(x_pos)
ax2.set_xticklabels([labels[m] for m in methods_order], fontsize=10)
ax2.set_ylabel('max |Δθ(κ) - Δθ(1)| at κ=1000', fontsize=11)
ax2.axhline(y=1e-5, color='gray', linestyle='--', alpha=0.5)
ax2.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for i, (v, bar) in enumerate(zip(values_1000, bars)):
    ax2.text(bar.get_x() + bar.get_width()/2, v * 2, f'{v:.1e}',
             ha='center', va='bottom', fontsize=8, rotation=0)

plt.tight_layout()
plt.savefig(f'{SAVE_DIR}/E4_figure.png', dpi=300, bbox_inches='tight')
plt.savefig(f'{SAVE_DIR}/E4_figure.pdf', bbox_inches='tight')
print('Figure saved')
plt.show()

## Save Results

In [None]:
# Save CSV
df.to_csv(f'{SAVE_DIR}/E4_results.csv', index=False)

# Save summary JSON
summary = {
    'experiment': 'E4',
    'title': 'Approximation Invariance Test',
    'timestamp': datetime.now().isoformat(),
    'config': CONFIG,
    'methods': ['sgd', 'exact', 'kfac', 'empirical'],
    'results_at_kappa_1000': {
        'exact_ng': float(exact_1000),
        'kfac_ng': float(kfac_1000),
        'empirical_ng': float(emp_1000),
        'sgd': float(sgd_1000)
    },
    'ratios_vs_exact': {
        'kfac_vs_exact': float(kfac_1000 / exact_1000),
        'empirical_vs_exact': float(emp_1000 / exact_1000),
        'sgd_vs_exact': float(sgd_1000 / exact_1000)
    },
    'invariance_ranking': [m for m, _ in methods_sorted]
}

with open(f'{SAVE_DIR}/E4_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print(f'\n✓ All files saved to {SAVE_DIR}')
print('  - E4_results.csv')
print('  - E4_summary.json')
print('  - E4_figure.png')
print('  - E4_figure.pdf')

## Summary

### Key Findings

**Invariance Spectrum** (expected order, best to worst):
1. **Exact NG**: ~10⁻¹¹ (machine precision)
2. **K-FAC NG**: ~10⁻? (partial degradation)
3. **Empirical Fisher NG**: ~10⁻? (significant degradation)
4. **SGD**: ~10⁻¹ (severe degradation)

### Interpretation

- **K-FAC** uses Kronecker factorization (F ≈ S ⊗ A), which is exact for independent input/output distributions but approximate in practice
- **Empirical Fisher** uses observed gradients instead of expected gradients, fundamentally different from true Fisher
- The degradation spectrum quantifies the "cost" of using approximations in terms of coordinate invariance

### Implications for Practice

- K-FAC may preserve most of Natural Gradient's invariance properties
- Empirical Fisher should not be expected to preserve coordinate invariance
- This benchmark can evaluate new approximation methods