
# README Figures Helper

Utilities to generate polished figures for the README showing:
1. Linear behavior of $\Delta H$ across learning rates (using E=512 runs).
2. Agreement statistics between true and linear estimates for subsampled batches (from E=2048 synthetic subsets).


In [None]:

import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

plt.style.use('seaborn-v0_8')
plt.rcParams.update({'font.size': 12})

ROOT = Path('..')
ROOT_512 = ROOT / 'E_512'
ROOT_2048 = ROOT / 'E_2048'
ROOT_2048_LIN = ROOT / 'data_for_lin_trend'
ROOT_256 = ROOT / 'E_256'

paths_512 = sorted(ROOT_512.glob('run_*/results.json'))
paths_2048 = sorted(ROOT_2048.glob('run_*/results.json'))
paths_2048_lin = sorted(ROOT_2048_LIN.glob('run_*/results.json'))
paths_256 = sorted(ROOT_256.glob('run_*/results.json'))
print(f"Loaded {len(paths_512)} E=512 runs, {len(paths_2048)} E=2048 runs, {len(paths_2048_lin)} E=2048 (lin-trend) runs, {len(paths_256)} E=256 runs")


In [None]:

def load_runs(paths):
    runs = {}
    for path in paths:
        with path.open() as f:
            runs[path.parent.name] = json.load(f)
    return runs

runs_512 = load_runs(paths_512)
runs_2048 = load_runs(paths_2048)
runs_2048_lin = load_runs(paths_2048_lin)
runs_256 = load_runs(paths_256)
print('E=512 run IDs:', list(runs_512.keys()))
print('E=2048 (lin-trend) run IDs:', list(runs_2048_lin.keys()))


In [None]:

def summary_from_run(data: dict) -> pd.DataFrame:
    approx = data['approx']
    gdotv_linear = float(approx['delta_h_per_lr'])
    quad = approx.get('quadratic') or {}
    gdotv_quad = float(quad.get('gdotv', gdotv_linear))
    vHvv = quad.get('vHvv')

    rows = []
    for entry in sorted(data['true']['entries'], key=lambda e: float(e['eta'])):
        eta = float(entry['eta'])
        delta_true = float(entry['delta_h_true'])
        delta_linear = gdotv_linear * eta
        delta_linquad = None
        if vHvv is not None:
            delta_linquad = gdotv_quad * eta + 0.5 * float(vHvv) * (eta ** 2)
        rows.append({
            'eta': eta,
            'delta_true': delta_true,
            'delta_linear': delta_linear,
            'delta_linquad': delta_linquad,
        })
    return pd.DataFrame(rows)


In [None]:

def per_sequence_contributions(data: dict, eta: float):
    entry = next((e for e in data['true']['entries'] if abs(float(e['eta']) - eta) <= max(1e-12, 1e-6 * abs(eta))), None)
    if entry is None:
        raise ValueError(f"eta={eta} not found in run")
    diag = entry['diagnostics']

    weights_sum = float(diag.get('weights_sum', 1.0))
    norm_w = np.asarray(diag['normalized_weights'], dtype=float)
    weights = norm_w * weights_sum
    token_counts = np.asarray(diag['token_counts'], dtype=float)
    h_new = np.asarray(diag['h_new'], dtype=float)
    base_entropy = float(diag['base_entropy'])

    denom = np.dot(token_counts, weights) if token_counts.size else 1.0
    if denom == 0.0:
        denom = 1.0
    total_tokens = token_counts.sum() if token_counts.size else 1.0

    true_seq = (weights * h_new) / denom - (token_counts / total_tokens) * base_entropy

    approx_seq = data['approx']['per_sequence']
    eta = float(eta)
    linear_seq = np.asarray([rec['gdotv'] * eta for rec in approx_seq], dtype=float)
    return true_seq, linear_seq


In [None]:

def plot_delta_h_vs_eta_grid(include_linquad: bool = True, savepath: Path | None = None) -> None:
    run_items = list(runs_512.items())
    n = len(run_items)
    cols = 2
    rows = (n + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 5, rows * 4), sharex=True, sharey=True)
    axes = np.atleast_1d(axes).flatten()

    for ax, (run_id, data) in zip(axes, run_items):
        df = summary_from_run(data)
        ax.plot(df['eta'], df['delta_true'], marker='o', label='True')
        ax.plot(df['eta'], df['delta_linear'], marker='s', linestyle='--', label='Linear')
        if include_linquad and df['delta_linquad'].notna().any():
            ax.plot(df['eta'], df['delta_linquad'], marker='^', linestyle=':', label='Lin+Quad')
        ax.set_xscale('log')
        ax.set_xlabel('Learning rate (eta)')
        ax.set_ylabel('Delta H')
        ax.set_title(run_id)
        ax.grid(True, ls=':', alpha=0.5)

    for ax in axes[len(run_items):]:
        ax.axis('off')

    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', ncol=3)
    fig.tight_layout(rect=[0, 0, 1, 0.97])
    if savepath is not None:
        savepath.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(savepath, dpi=200)
    plt.show()


In [None]:

def plot_delta_h_vs_eta_aggregate(include_linquad: bool = True, savepath: Path | None = None) -> None:
    summaries = []
    for run_id, data in runs_512.items():
        df = summary_from_run(data)
        df['run'] = run_id
        summaries.append(df)
    combined = pd.concat(summaries, ignore_index=True)

    grouped = combined.groupby('eta').agg({
        'delta_true': ['mean', 'std'],
        'delta_linear': ['mean'],
        'delta_linquad': ['mean']
    })
    grouped.columns = ['_'.join(filter(None, col)).strip('_') for col in grouped.columns]
    grouped = grouped.reset_index()

    fig, ax = plt.subplots(figsize=(6, 4))
    ax.errorbar(grouped['eta'], grouped['delta_true_mean'], yerr=grouped['delta_true_std'], fmt='o-', label='True (mean ± std)')
    ax.plot(grouped['eta'], grouped['delta_linear_mean'], 's--', label='Linear (mean)')
    if include_linquad and 'delta_linquad_mean' in grouped:
        ax.plot(grouped['eta'], grouped['delta_linquad_mean'], '^:', label='Lin+Quad (mean)')
    ax.set_xscale('log')
    ax.set_xlabel('Learning rate (eta)')
    ax.set_ylabel('Delta H')
    ax.set_title('E=512 aggregate Delta H vs eta')
    ax.grid(True, ls=':', alpha=0.5)
    ax.legend()
    fig.tight_layout()
    if savepath is not None:
        savepath.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(savepath, dpi=200)
    plt.show()


In [None]:

def aggregate_summary_table(runs_dict=runs_512, include_linquad: bool = True) -> pd.DataFrame:
    summaries = []
    for run_id, data in runs_dict.items():
        df = summary_from_run(data)
        df['run'] = run_id
        summaries.append(df)
    combined = pd.concat(summaries, ignore_index=True)
    grouped = combined.groupby('eta').agg({
        'delta_true': ['mean', 'std'],
        'delta_linear': ['mean', 'std'],
        'delta_linquad': ['mean', 'std'] if include_linquad else ['mean']
    })
    grouped.columns = ['_'.join(filter(None, col)).strip('_') for col in grouped.columns]
    grouped = grouped.reset_index()
    return grouped


In [None]:

def plot_delta_h_true_vs_eta(run_id: str, runs_dict=runs_2048_lin, savepath: Path | None = None) -> dict:
    df = summary_from_run(runs_dict[run_id])
    eta_vals = df['eta'].to_numpy()
    delta_true = df['delta_true'].to_numpy()

    sign_true = np.sign(delta_true.mean()) or 1.0
    abs_true = np.abs(delta_true)

    log_eta = np.log(eta_vals)
    log_true = np.log(abs_true)
    slope, intercept = np.polyfit(log_eta, log_true, 1)
    fit_eta = np.logspace(np.log10(eta_vals.min()), np.log10(eta_vals.max()), 200)
    fit_abs = np.exp(intercept) * fit_eta ** slope

    ss_tot = np.sum((log_true - log_true.mean()) ** 2)
    ss_res = np.sum((log_true - (slope * log_eta + intercept)) ** 2)
    r_squared = 1.0 - ss_res / ss_tot if ss_tot > 0 else np.nan

    fig, ax = plt.subplots(figsize=(5.5, 4.2))
    ax.loglog(eta_vals, abs_true, marker='o', linestyle='none', label='|DeltaH_true|', color='tab:blue')
    ax.loglog(fit_eta, fit_abs, color='black', linestyle='--', label=f'Fit slope={slope:.2f}')
    ax.set_xlabel('Learning rate (eta)')
    ax.set_ylabel('|Delta H true|')
    ax.set_title(f'{run_id} (sign={int(sign_true)})')
    ax.grid(True, which='both', ls=':', alpha=0.4)
    for spine in ['top', 'right']:
        ax.spines[spine].set_visible(False)
    ax.spines['bottom'].set_linewidth(1.2)
    ax.spines['left'].set_linewidth(1.2)
    ax.tick_params(axis='both', which='both', direction='out', width=1.1)
    ax.legend()
    ax.text(0.05, 0.05, f'R^2={r_squared:.3f}', transform=ax.transAxes, fontsize=11,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.75))
    fig.tight_layout()
    if savepath is not None:
        savepath.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(savepath, dpi=240)
    plt.show()

    return {
        'run': run_id,
        'slope': slope,
        'intercept': intercept,
        'r_squared': r_squared,
        'sign': sign_true,
    }



### Example Usage

```python
# 1) Delta H true vs eta (per run)
fit_stats = plot_delta_h_true_vs_eta('run_01', savepath=Path('figures/delta_h_true_vs_eta_run01.png'))
fit_stats

# All runs (log-log overlay)
aggregate_stats = plot_delta_h_true_aggregate(savepath=Path('figures/delta_h_true_vs_eta_aggregate_loglog.png'))
aggregate_stats.to_csv('figures/delta_h_true_vs_eta_fit_stats_loglog.csv', index=False)

# All runs (linear overlay)
lin_stats = plot_delta_h_true_linear_overlay(savepath=Path('figures/delta_h_true_vs_eta_aggregate_linear.png'))
lin_stats.to_csv('figures/delta_h_true_vs_eta_fit_stats_linear.csv', index=False)

# 2) Synthetic subset error distribution
stats_eta = plot_err_rel_distribution(eta=3.2e-6, subset_size=256, samples_per_run=500, seed=42,
                                      savepath=Path('figures/err_rel_distribution_eta_3.2e-6.png'))
stats_eta.to_csv('figures/err_rel_distribution_eta_3.2e-6.csv', index=False)
```


In [None]:

def plot_delta_h_true_aggregate(runs_dict=runs_2048_lin, savepath: Path | None = None) -> pd.DataFrame:
    fig, ax = plt.subplots(figsize=(6.5, 4.3))
    stats = []
    colors = plt.cm.tab10.colors
    for idx, (run_id, data) in enumerate(runs_dict.items()):
        df = summary_from_run(data)
        eta_vals = df['eta'].to_numpy()
        delta_true = df['delta_true'].to_numpy()
        abs_true = np.abs(delta_true)
        log_eta = np.log(eta_vals)
        log_true = np.log(abs_true)
        slope, intercept = np.polyfit(log_eta, log_true, 1)
        fit_eta = np.logspace(np.log10(eta_vals.min()), np.log10(eta_vals.max()), 200)
        fit_abs = np.exp(intercept) * fit_eta ** slope
        ss_tot = np.sum((log_true - log_true.mean()) ** 2)
        ss_res = np.sum((log_true - (slope * log_eta + intercept)) ** 2)
        r_squared = 1.0 - ss_res / ss_tot if ss_tot > 0 else np.nan
        stats.append({'run': run_id, 'slope': slope, 'intercept': intercept, 'r_squared': r_squared})
        color = colors[idx % len(colors)]
        ax.loglog(eta_vals, abs_true, marker='o', linestyle='none', color=color, alpha=0.6)
        ax.loglog(fit_eta, fit_abs, color=color, alpha=0.4)
    ax.set_xlabel('Learning rate (eta)')
    ax.set_ylabel('|Delta H true|')
    ax.set_title('E=2048 Delta H true vs eta (per run)')
    ax.grid(True, which='both', ls=':', alpha=0.4)
    for spine in ['top', 'right']:
        ax.spines[spine].set_visible(False)
    ax.spines['bottom'].set_linewidth(1.2)
    ax.spines['left'].set_linewidth(1.2)
    ax.tick_params(axis='both', which='both', direction='out', width=1.1)
    fig.tight_layout()
    if savepath is not None:
        savepath.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(savepath, dpi=240)
    plt.show()
    return pd.DataFrame(stats)


In [None]:

def plot_delta_h_true_linear_overlay(runs_dict=runs_2048_lin, savepath: Path | None = None) -> pd.DataFrame:
    fig, ax = plt.subplots(figsize=(6.5, 4.3))
    colors = plt.cm.tab10.colors
    stats = []
    for idx, (run_id, data) in enumerate(runs_dict.items()):
        df = summary_from_run(data)
        eta_vals = df['eta'].to_numpy()
        delta_true = df['delta_true'].to_numpy()
        coeffs = np.polyfit(eta_vals, delta_true, 1)
        slope_lin, intercept_lin = coeffs
        fit = np.poly1d(coeffs)
        ss_tot = np.sum((delta_true - delta_true.mean()) ** 2)
        ss_res = np.sum((delta_true - fit(eta_vals)) ** 2)
        r_squared = 1.0 - ss_res / ss_tot if ss_tot > 0 else np.nan
        stats.append({'run': run_id, 'slope_linear': slope_lin, 'intercept_linear': intercept_lin, 'r_squared_linear': r_squared})
        color = colors[idx % len(colors)]
        ax.plot(eta_vals, delta_true, marker='o', linestyle='none', color=color, alpha=0.5)
        ax.plot(eta_vals, fit(eta_vals), linestyle='--', color=color, alpha=0.7)
    ax.set_xlabel('Learning rate (eta)')
    ax.set_ylabel('Delta H true (per token)')
    ax.set_title('E=2048 Delta H true vs eta (linear scale)')
    ax.grid(True, ls=':', alpha=0.4)
    for spine in ['top', 'right']:
        ax.spines[spine].set_visible(False)
    ax.spines['bottom'].set_linewidth(1.2)
    ax.spines['left'].set_linewidth(1.2)
    ax.tick_params(axis='both', which='both', direction='out', width=1.1)
    fig.tight_layout()
    if savepath is not None:
        savepath.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(savepath, dpi=240)
    plt.show()
    return pd.DataFrame(stats)



### Example Usage

```python
plot_delta_h_vs_eta_grid(savepath=Path('figures/delta_h_vs_eta_grid.png'))
plot_delta_h_vs_eta_aggregate(savepath=Path('figures/delta_h_vs_eta_aggregate.png'))
```



### Example Usage

```python
# 1) Delta H vs eta
plot_delta_h_vs_eta_grid(savepath=Path('figures/delta_h_vs_eta_grid.png'))
plot_delta_h_vs_eta_aggregate(savepath=Path('figures/delta_h_vs_eta_aggregate.png'))
aggregate_summary = aggregate_summary_table()
aggregate_summary.to_csv('figures/delta_h_vs_eta_summary.csv', index=False)

# 2) Synthetic subset error distribution
stats_eta = plot_err_rel_distribution(eta=3.2e-6, subset_size=256, samples_per_run=500, seed=42,
                                      savepath=Path('figures/err_rel_distribution_eta_3.2e-6.png'))
stats_eta.to_csv('figures/err_rel_distribution_eta_3.2e-6.csv', index=False)
```
