# Experiment E2: KL-Controlled Learning

**Validation of per-step KL divergence control protocol**

---

## Generated Files

This notebook generates:
- `E2_results_full.pkl`
- `E2_results_summary.csv`
- `E2_gate_validation.csv`
- `E2_metadata.json`
- `E2_figure.png` (publication quality)
- `E2_figure.pdf` (publication quality)

**Runtime**: ~10-15 minutes on GPU

---

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
SAVE_DIR = '/content/drive/MyDrive/paper-E-final/E2'
os.makedirs(SAVE_DIR, exist_ok=True)
print(f'Results → {SAVE_DIR}')

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd
import json
import pickle
from datetime import datetime
from tqdm.notebook import tqdm

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
torch.set_default_dtype(torch.float64)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

In [None]:
def gate_a_fisher_quality(fisher, name="Fisher", threshold=-1e-7):
    fisher_sym = (fisher + fisher.T) / 2
    eigenvalues = torch.linalg.eigvalsh(fisher_sym)
    min_eig = eigenvalues.min().item()
    max_eig = eigenvalues.max().item()
    threshold_value = threshold * abs(max_eig)
    details = {
        "min_eigenvalue": min_eig,
        "max_eigenvalue": max_eig,
        "condition_number": max_eig / min_eig if min_eig > 0 else float('inf')
    }
    if min_eig < threshold_value:
        return False, f"FAIL: {name} negative {min_eig:.2e}", details
    return True, f"PASS: {name} PSD (min={min_eig:.2e})", details

def gate_b_kl_calibration(kl_measured, kl_target, tolerance=0.2):
    ratio = kl_measured / kl_target
    lower, upper = 1.0 - tolerance, 1.0 + tolerance
    details = {"kl_measured": float(kl_measured), "kl_target": float(kl_target), "ratio": float(ratio)}
    if lower <= ratio <= upper:
        return True, f"PASS: ratio={ratio:.4f}", details
    return False, f"FAIL: ratio={ratio:.4f}", details

In [None]:
class SoftmaxRegression(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_classes, bias=False)
    def forward(self, x):
        return self.linear(x)

In [None]:
def compute_fisher_analytic(model, X):
    with torch.no_grad():
        logits = model(X)
        probs = F.softmax(logits, dim=1)
        N, C = probs.shape
        D = X.shape[1]
        fisher = torch.zeros(C * D, C * D, dtype=X.dtype, device=X.device)
        for i in range(N):
            p = probs[i]
            x = X[i]
            H = torch.diag(p) - torch.outer(p, p)
            F_sample = torch.kron(H, torch.outer(x, x))
            fisher += F_sample
        fisher /= N
        fisher = (fisher + fisher.T) / 2
        return fisher

In [None]:
def generate_data(n_samples=5000, n_features=16, n_classes=10, noise_std=1.0, seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    centers = torch.randn(n_classes, n_features, dtype=torch.float64)
    Q, _ = torch.linalg.qr(centers.T)
    centers = (Q[:, :n_classes].T * 2.5).to(device)
    samples_per_class = n_samples // n_classes
    X_list, y_list = [], []
    for c in range(n_classes):
        X_class = centers[c] + torch.randn(samples_per_class, n_features, dtype=torch.float64, device=device) * noise_std
        y_class = torch.full((samples_per_class,), c, dtype=torch.long, device=device)
        X_list.append(X_class)
        y_list.append(y_class)
    X = torch.cat(X_list)
    y = torch.cat(y_list)
    perm = torch.randperm(X.shape[0], device=device)
    return X[perm], y[perm]

X_train, y_train = generate_data(n_samples=5000, seed=42)
X_test, y_test = generate_data(n_samples=1000, seed=43)
print(f'Data: train={X_train.shape}, test={X_test.shape}')

In [None]:
def get_parameters_flat(model):
    return torch.cat([p.flatten() for p in model.parameters()])

def set_parameters_flat(model, params):
    offset = 0
    for p in model.parameters():
        numel = p.numel()
        p.data.copy_(params[offset:offset+numel].view_as(p))
        offset += numel

def compute_empirical_kl(model, X, theta_old):
    with torch.no_grad():
        theta_new = get_parameters_flat(model).clone()
        set_parameters_flat(model, theta_old)
        logits_old = model(X)
        probs_old = F.softmax(logits_old, dim=1)
        set_parameters_flat(model, theta_new)
        logits_new = model(X)
        probs_new = F.softmax(logits_new, dim=1)
        kl = -(probs_old * torch.log(probs_new / (probs_old + 1e-10) + 1e-10)).sum(1).mean()
        return kl.item()

def evaluate(model, X, y):
    with torch.no_grad():
        logits = model(X)
        loss = F.cross_entropy(logits, y).item()
        preds = logits.argmax(dim=1)
        acc = (preds == y).float().mean().item()
        return loss, acc

In [None]:
def train_with_kl_control(model, X_train, y_train, X_test, y_test, epsilon_step, n_epochs, method='sgd'):
    history = {'epoch': [], 'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': [],
               'kl_step': [], 'kl_ratio': [], 'eta': [], 'n_corrections': []}
    gate_log = []
    
    fisher = compute_fisher_analytic(model, X_train)
    passed_a, msg_a, details_a = gate_a_fisher_quality(fisher)
    gate_log.append({'epoch': 0, 'gate': 'A', 'passed': passed_a, **details_a})
    if not passed_a:
        raise RuntimeError("Gate A failed")
    
    if method == 'natgrad':
        eigvals, eigvecs = torch.linalg.eigh(fisher)
        eigvals_inv = torch.where(eigvals > 1e-8, 1.0/eigvals, torch.zeros_like(eigvals))
        fisher_inv = eigvecs @ torch.diag(eigvals_inv) @ eigvecs.T
    
    for epoch in range(n_epochs):
        theta_old = get_parameters_flat(model).clone()
        model.zero_grad()
        logits = model(X_train)
        loss = F.cross_entropy(logits, y_train)
        loss.backward()
        grad = torch.cat([p.grad.flatten() for p in model.parameters()])
        
        if method == 'sgd':
            direction = grad
            quad_form = torch.dot(direction, fisher @ direction)
        else:
            direction = fisher_inv @ grad
            quad_form = torch.dot(direction, fisher @ direction)
        
        eta = torch.sqrt(2 * epsilon_step / quad_form)
        n_corrections = 0
        max_corrections = 2
        
        while n_corrections <= max_corrections:
            theta_new = theta_old - eta * direction
            set_parameters_flat(model, theta_new)
            kl_emp = compute_empirical_kl(model, X_train, theta_old)
            kl_ratio = kl_emp / epsilon_step
            if 0.8 <= kl_ratio <= 1.2:
                break
            eta = eta * torch.sqrt(torch.tensor(epsilon_step / kl_emp))
            n_corrections += 1
        
        train_loss, train_acc = evaluate(model, X_train, y_train)
        test_loss, test_acc = evaluate(model, X_test, y_test)
        
        history['epoch'].append(epoch)
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        history['kl_step'].append(kl_emp)
        history['kl_ratio'].append(kl_ratio)
        history['eta'].append(eta.item())
        history['n_corrections'].append(n_corrections)
        
        if epoch % 10 == 0:
            passed_b, msg_b, details_b = gate_b_kl_calibration(kl_emp, epsilon_step)
            gate_log.append({'epoch': epoch, 'gate': 'B', 'passed': passed_b, **details_b})
    
    return history, gate_log

In [None]:
CONFIG = {
    'n_features': 16,
    'n_classes': 10,
    'epsilon_values': [1e-4, 3e-4, 1e-3],
    'n_epochs': 100,
    'n_seeds': 5,
    'methods': ['sgd', 'natgrad']
}

print(f"Total experiments: {len(CONFIG['epsilon_values']) * len(CONFIG['methods']) * CONFIG['n_seeds']}")

all_results = {}
all_gates = []

for eps_idx, epsilon in enumerate(CONFIG['epsilon_values']):
    for method in CONFIG['methods']:
        for seed in range(CONFIG['n_seeds']):
            print(f"ε={epsilon:.0e}, {method}, seed={seed}")
            set_seed(seed)
            model = SoftmaxRegression(CONFIG['n_features'], CONFIG['n_classes']).to(device)
            history, gate_log = train_with_kl_control(
                model, X_train, y_train, X_test, y_test,
                epsilon_step=epsilon, n_epochs=CONFIG['n_epochs'], method=method
            )
            key = f"{method}_eps{eps_idx}_seed{seed}"
            all_results[key] = {'method': method, 'epsilon': epsilon, 'seed': seed, 'history': history}
            for g in gate_log:
                all_gates.append({**g, 'method': method, 'epsilon': epsilon, 'seed': seed})

print('✓ Completed')

In [None]:
gate_df = pd.DataFrame(all_gates)
gate_a = gate_df[gate_df['gate'] == 'A']
gate_b = gate_df[gate_df['gate'] == 'B']
print(f"Gate A: {gate_a['passed'].mean():.1%}")
print(f"Gate B: {gate_b['passed'].mean():.1%}")
print(f"Mean KL ratio: {gate_b['ratio'].mean():.4f}")

In [None]:
# Publication-quality figure (no title)
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

for i, epsilon in enumerate(CONFIG['epsilon_values']):
    ax_curve = axes[0, i]
    ax_kl = axes[1, i]
    
    for method in CONFIG['methods']:
        histories = []
        for seed in range(CONFIG['n_seeds']):
            key = f"{method}_eps{i}_seed{seed}"
            histories.append(all_results[key]['history'])
        
        epochs = histories[0]['epoch']
        test_accs = np.array([h['test_acc'] for h in histories])
        kl_ratios = np.array([h['kl_ratio'] for h in histories])
        
        mean_acc = test_accs.mean(axis=0)
        std_acc = test_accs.std(axis=0)
        mean_kl = kl_ratios.mean(axis=0)
        std_kl = kl_ratios.std(axis=0)
        
        label = 'SGD' if method == 'sgd' else 'NG'
        ax_curve.plot(epochs, mean_acc, label=label, linewidth=2)
        ax_curve.fill_between(epochs, mean_acc - std_acc, mean_acc + std_acc, alpha=0.2)
        ax_kl.plot(epochs, mean_kl, label=label, linewidth=2)
        ax_kl.fill_between(epochs, mean_kl - std_kl, mean_kl + std_kl, alpha=0.2)
    
    ax_curve.set_xlabel('Epoch', fontsize=10)
    if i == 0:
        ax_curve.set_ylabel('Test Accuracy', fontsize=10)
    ax_curve.legend(fontsize=9)
    ax_curve.grid(True, alpha=0.3)
    
    ax_kl.axhline(y=1.0, color='red', linestyle='--', alpha=0.5)
    ax_kl.axhline(y=0.8, color='gray', linestyle=':', alpha=0.3)
    ax_kl.axhline(y=1.2, color='gray', linestyle=':', alpha=0.3)
    ax_kl.set_xlabel('Epoch', fontsize=10)
    if i == 0:
        ax_kl.set_ylabel('KL Ratio', fontsize=10)
    ax_kl.legend(fontsize=9)
    ax_kl.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{SAVE_DIR}/E2_figure.png', dpi=300, bbox_inches='tight')
plt.savefig(f'{SAVE_DIR}/E2_figure.pdf', bbox_inches='tight')
print('Figure saved')

In [None]:
with open(f'{SAVE_DIR}/E2_results_full.pkl', 'wb') as f:
    pickle.dump(all_results, f)

summary_records = []
for key, data in all_results.items():
    hist = data['history']
    summary_records.append({
        'method': data['method'], 'epsilon': data['epsilon'], 'seed': data['seed'],
        'final_loss': hist['train_loss'][-1], 'final_accuracy': hist['test_acc'][-1],
        'mean_kl_ratio': np.mean(hist['kl_ratio']), 'std_kl_ratio': np.std(hist['kl_ratio']),
        'mean_corrections': np.mean(hist['n_corrections'])
    })

summary_df = pd.DataFrame(summary_records)
summary_df.to_csv(f'{SAVE_DIR}/E2_results_summary.csv', index=False)
gate_df.to_csv(f'{SAVE_DIR}/E2_gate_validation.csv', index=False)

metadata = {
    'experiment': 'E2',
    'timestamp': datetime.now().isoformat(),
    'config': CONFIG,
    'gate_summary': {
        'gate_a_pass_rate': float(gate_a['passed'].mean()),
        'gate_b_pass_rate': float(gate_b['passed'].mean()),
        'mean_kl_ratio': float(gate_b['ratio'].mean())
    },
    'performance_summary': {
        'sgd_mean_accuracy': float(summary_df[summary_df['method']=='sgd']['final_accuracy'].mean()),
        'ng_mean_accuracy': float(summary_df[summary_df['method']=='natgrad']['final_accuracy'].mean())
    }
}

with open(f'{SAVE_DIR}/E2_metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)

print(f'\n✓ All files saved to {SAVE_DIR}')
print('  - E2_results_full.pkl')
print('  - E2_results_summary.csv')
print('  - E2_gate_validation.csv')
print('  - E2_metadata.json')
print('  - E2_figure.png')
print('  - E2_figure.pdf')