
# Post-Correction ΔH Analysis

This notebook inspects runs produced after adding the SNIS denominator correction. It summarizes aggregate agreement between true and approximate ΔH values and visualizes per-sequence errors.


In [None]:

import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

plt.style.use('seaborn-v0_8')

RESULTS_ROOT_2048 = Path('..') / 'E_2048'
RESULTS_ROOT_256 = Path('..') / 'E_256'
result_paths_2048 = sorted(RESULTS_ROOT_2048.glob('run_*/results.json'))
result_paths_256 = sorted(RESULTS_ROOT_256.glob('run_*/results.json'))
print(f"Found {len(result_paths_2048)} E=2048 run files and {len(result_paths_256)} E=256 run files")


In [None]:

def load_run(path: Path):
    with path.open() as f:
        return json.load(f)

runs_data_2048 = {path.parent.name: load_run(path) for path in result_paths_2048}
runs_data_256 = {path.parent.name: load_run(path) for path in result_paths_256}
print("E=2048 runs:", list(runs_data_2048.keys()))
print("E=256 runs:", list(runs_data_256.keys()))


In [None]:

summary_rows = []
for run_id, data in runs_data.items():
    approx = data['approx']
    gdotv_linear = float(approx['delta_h_per_lr'])
    quad = approx.get('quadratic') or {}
    gdotv_quad = float(quad.get('gdotv', gdotv_linear))
    vHvv = quad.get('vHvv')

    for entry in data['true']['entries']:
        eta = float(entry['eta'])
        delta_true = float(entry['delta_h_true'])
        delta_linear = gdotv_linear * eta
        delta_linquad = None
        if vHvv is not None:
            delta_linquad = gdotv_quad * eta + 0.5 * float(vHvv) * (eta ** 2)
        summary_rows.append({
            'run': run_id,
            'eta': eta,
            'delta_true': delta_true,
            'delta_linear': delta_linear,
            'delta_linquad': delta_linquad,
            'linear_over_true': delta_linear / delta_true if delta_true != 0 else np.nan,
            'linquad_over_true': (delta_linquad / delta_true) if (delta_true != 0 and delta_linquad is not None) else np.nan,
        })

summary_df = pd.DataFrame(summary_rows).sort_values(['run', 'eta']).reset_index(drop=True)
summary_df


In [None]:

print("Linear / True ratios:")
for _, row in summary_df.iterrows():
    print(f"{row['run']} η={row['eta']:.2e}: linear/true={row['linear_over_true']:.6f}")


In [None]:

def _select_eta_entry(data: dict, eta: float) -> dict:
    for entry in data['true']['entries']:
        if abs(float(entry['eta']) - eta) <= max(1e-12, 1e-6 * eta):
            return entry
    raise ValueError(f"η={eta} not found in run")


def get_per_sequence_contributions(run_id: str, eta: float):
    data = runs_data[run_id]
    entry = _select_eta_entry(data, eta)
    diag = entry['diagnostics']

    weights_sum = float(diag['weights_sum']) if diag.get('weights_sum') is not None else 1.0
    norm_w = np.asarray(diag['normalized_weights'], dtype=float)
    weights = norm_w * weights_sum
    token_counts = np.asarray(diag['token_counts'], dtype=float)
    h_new = np.asarray(diag['h_new'], dtype=float)
    base_entropy = float(diag['base_entropy'])

    denom = np.dot(token_counts, weights) if token_counts.size else 1.0
    if denom == 0.0:
        denom = 1.0
    total_tokens = token_counts.sum() if token_counts.size else 1.0

    true_contrib = (weights * h_new) / denom - (token_counts / total_tokens) * base_entropy

    approx_seq = data['approx']['per_sequence']
    eta = float(eta)
    linear_seq = np.asarray([rec['gdotv'] * eta for rec in approx_seq], dtype=float)

    return true_contrib, linear_seq


In [None]:

def summarize_standard_errors() -> pd.DataFrame:
    rows = []
    for run_id, data in runs_data.items():
        approx_var = data['approx'].get('variance') or {}
        quad_var = data['approx'].get('quadratic', {}).get('variance') or {}
        for entry in data['true']['entries']:
            eta = float(entry['eta'])
            delta_true = float(entry['delta_h_true'])
            delta_linear = float(data['approx']['delta_h_per_lr']) * eta
            se_linear_raw = approx_var.get('se_shard')
            se_linquad_raw = quad_var.get('se_gdotv')
            rows.append({
                'run': run_id,
                'eta': eta,
                'delta_true': delta_true,
                'delta_linear': delta_linear,
                'se_linear_raw': se_linear_raw,
                'se_linear': (se_linear_raw * eta) if se_linear_raw is not None else None,
                'se_linear_jackknife': (approx_var.get('se_jackknife') * eta) if approx_var.get('se_jackknife') is not None else None,
                'se_gdotv_raw': se_linquad_raw,
                'se_gdotv': (se_linquad_raw * eta) if se_linquad_raw is not None else None,
                'se_vHvv': quad_var.get('se_vHvv'),
                'ess_true': entry['diagnostics'].get('ess'),
                'num_shards': approx_var.get('num_shards'),
            })
    return pd.DataFrame(rows)

se_df = summarize_standard_errors()
se_df



### Per-Sequence Error Histogram

The helper below plots `|ΔH_true_i - ΔH_linear_i| / |ΔH_true|` for a chosen run and learning rate.


In [None]:

def plot_normalized_abs_error(run_id: str, eta: float, bins: int = 60) -> None:
    true_seq, linear_seq = get_per_sequence_contributions(run_id, eta)
    entry = _select_eta_entry(runs_data[run_id], eta)
    delta_true = float(entry['delta_h_true'])
    norm = abs(delta_true) if delta_true != 0 else 1.0

    norm_abs_error = np.abs(true_seq - linear_seq) / norm

    plt.figure(figsize=(6, 4))
    plt.hist(norm_abs_error, bins=bins, color='tab:blue', alpha=0.8)
    plt.xlabel("|ΔH_true_i - ΔH_linear_i| / |ΔH_true|")
    plt.ylabel("Count")
    plt.title(f"Normalized per-sequence error{run_id}, η={eta:.2e}")
    plt.grid(True, ls=':', alpha=0.5)
    plt.show()



#### Example Usage

```python
plot_normalized_abs_error('run_01', 8e-06)
```



### Agreement Within Error Bars

The table below computes standard errors for both true and linear estimates (using per-sequence variability for the true value) and checks whether the two agree within one combined standard deviation.


In [None]:

def compute_true_se(run_id: str, eta: float) -> float:
    true_seq, _ = get_per_sequence_contributions(run_id, eta)
    if true_seq.size == 0:
        return 0.0
    return float(np.std(true_seq, ddof=1) / np.sqrt(true_seq.size))


def build_agreement_table() -> pd.DataFrame:
    rows = []
    for run_id, data in runs_data.items():
        approx_var = data['approx'].get('variance') or {}
        for entry in data['true']['entries']:
            eta = float(entry['eta'])
            delta_true = float(entry['delta_h_true'])
            delta_linear = float(data['approx']['delta_h_per_lr']) * eta
            se_linear = (approx_var.get('se_shard') or 0.0) * eta
            se_true = compute_true_se(run_id, eta)
            diff = delta_linear - delta_true
            combined = (se_linear ** 2 + se_true ** 2) ** 0.5
            within = abs(diff) <= combined if combined > 0 else np.nan
            rows.append({
                'run': run_id,
                'eta': eta,
                'delta_true': delta_true,
                'delta_linear': delta_linear,
                'se_true': se_true,
                'se_linear': se_linear,
                'diff': diff,
                'combined_se': combined,
                'within_1sigma': within,
            })
    return pd.DataFrame(rows)

agreement_df = build_agreement_table()
agreement_df



### Per-Sequence Contribution Histograms

Visualize the distribution of contributions for the true estimate, the linear approximation, and the lin+quad approximation (if curvature diagnostics were captured).


In [None]:

def plot_contribution_histograms(run_id: str, eta: float, bins: int = 60) -> None:
    data = runs_data[run_id]
    entry = _select_eta_entry(data, eta)
    diag = entry['diagnostics']

    weights_sum = float(diag['weights_sum']) if diag.get('weights_sum') is not None else 1.0
    norm_w = np.asarray(diag['normalized_weights'], dtype=float)
    weights = norm_w * weights_sum
    token_counts = np.asarray(diag['token_counts'], dtype=float)
    h_new = np.asarray(diag['h_new'], dtype=float)
    base_entropy = float(diag['base_entropy'])

    denom = np.dot(token_counts, weights) if token_counts.size else 1.0
    if denom == 0.0:
        denom = 1.0
    total_tokens = token_counts.sum() if token_counts.size else 1.0

    true_contrib = (weights * h_new) / denom - (token_counts / total_tokens) * base_entropy

    approx_seq = data['approx']['per_sequence']
    eta = float(eta)
    linear_seq = np.asarray([rec['gdotv'] * eta for rec in approx_seq], dtype=float)

    linquad_seq = None
    if data['approx'].get('quadratic') and 'per_sequence_vhvv' in data['approx']['quadratic']:
        vhvv = np.asarray(data['approx']['quadratic']['per_sequence_vhvv'], dtype=float)
        linquad_seq = linear_seq + 0.5 * (eta ** 2) * vhvv

    datasets = [("ΔH true contrib", true_contrib), ("ΔH linear", linear_seq)]
    if linquad_seq is not None:
        datasets.append(("ΔH lin+quad", linquad_seq))

    num = len(datasets)
    fig, axes = plt.subplots(1, num, figsize=(5 * num, 4), sharey=True)
    if num == 1:
        axes = [axes]

    for ax, (label, values) in zip(axes, datasets):
        ax.hist(values, bins=bins, color='tab:blue', alpha=0.8)
        ax.axvline(0.0, color='black', linestyle='--', linewidth=1)
        ax.set_title(f"{label} (η={eta:.2e})")
        ax.set_xlabel("Contribution per sequence")
        ax.grid(True, ls=':', alpha=0.5)
        max_abs = np.max(np.abs(values))
        if max_abs == 0:
            max_abs = 1e-12
        padding = 0.05 * max_abs
        ax.set_xlim(-max_abs - padding, max_abs + padding)

    axes[0].set_ylabel("Count")
    plt.suptitle(f"Per-sequence contributions for {run_id}")
    plt.tight_layout()
    plt.show()



### Truncation and Clipping Experiments

The helpers below explore how removing or clipping the largest-magnitude approximate contributions affects agreement between the true and approximate ΔH estimates.


In [None]:

def summarize_truncation(run_id: str, eta: float, keep_quantile: float = 0.99) -> dict:
    true_seq, linear_seq = get_per_sequence_contributions(run_id, eta)
    linquad_seq = None
    data = runs_data[run_id]
    if data['approx'].get('quadratic') and 'per_sequence_vhvv' in data['approx']['quadratic']:
        vhvv = np.asarray(data['approx']['quadratic']['per_sequence_vhvv'], dtype=float)
        linquad_seq = linear_seq + 0.5 * (eta ** 2) * vhvv

    threshold = np.quantile(np.abs(linear_seq), keep_quantile)
    mask = np.abs(linear_seq) <= threshold
    kept_pct = mask.mean() * 100

    summary = {
        'run': run_id,
        'eta': eta,
        'keep_quantile': keep_quantile,
        'threshold': threshold,
        'kept_percent': kept_pct,
        'delta_true_trunc': float(true_seq[mask].sum()),
        'delta_linear_trunc': float(linear_seq[mask].sum()),
    }
    if linquad_seq is not None:
        summary['delta_linquad_trunc'] = float(linquad_seq[mask].sum())
    return summary


def summarize_clipping(run_id: str, eta: float, clip_value: float) -> dict:
    true_seq, linear_seq = get_per_sequence_contributions(run_id, eta)
    linquad_seq = None
    data = runs_data[run_id]
    if data['approx'].get('quadratic') and 'per_sequence_vhvv' in data['approx']['quadratic']:
        vhvv = np.asarray(data['approx']['quadratic']['per_sequence_vhvv'], dtype=float)
        linquad_seq = linear_seq + 0.5 * (eta ** 2) * vhvv

    clipped_linear = np.clip(linear_seq, -clip_value, clip_value)
    summary = {
        'run': run_id,
        'eta': eta,
        'clip_value': clip_value,
        'delta_true_clip': float(np.clip(true_seq, -clip_value, clip_value).sum()),
        'delta_linear_clip': float(clipped_linear.sum()),
    }
    if linquad_seq is not None:
        clipped_linquad = np.clip(linquad_seq, -clip_value, clip_value)
        summary['delta_linquad_clip'] = float(clipped_linquad.sum())
    return summary


In [None]:

# Example usage:
# trunc_result = summarize_truncation('run_01', 8e-06, keep_quantile=0.95)
# clip_result = summarize_clipping('run_01', 8e-06, clip_value=5e-4)
# trunc_result, clip_result


In [None]:

def batch_truncation_summary(keep_quantiles=(0.99, 0.95, 0.90)) -> pd.DataFrame:
    rows = []
    for run_id, data in runs_data.items():
        for entry in data['true']['entries']:
            eta = float(entry['eta'])
            true_seq, linear_seq = get_per_sequence_contributions(run_id, eta)
            vhvv = None
            if data['approx'].get('quadratic') and 'per_sequence_vhvv' in data['approx']['quadratic']:
                vhvv = np.asarray(data['approx']['quadratic']['per_sequence_vhvv'], dtype=float)
            for q in keep_quantiles:
                thr = np.quantile(np.abs(linear_seq), q)
                mask = np.abs(linear_seq) <= thr
                row = {
                    'run': run_id,
                    'eta': eta,
                    'keep_quantile': q,
                    'threshold': thr,
                    'kept_percent': mask.mean() * 100,
                    'delta_true_trunc': float(true_seq[mask].sum()),
                    'delta_linear_trunc': float(linear_seq[mask].sum()),
                }
                if vhvv is not None:
                    linquad_seq = linear_seq + 0.5 * (eta ** 2) * vhvv
                    row['delta_linquad_trunc'] = float(linquad_seq[mask].sum())
                rows.append(row)
    return pd.DataFrame(rows)

trunc_df = batch_truncation_summary()
trunc_df



### Clip Sweep Visualization

The helper below sweeps clipping thresholds (using quantiles of |ΔH linear|) and shows how the clipped totals compare between the true, linear, and lin+quad estimates.


In [None]:

def plot_clip_sweep(run_id: str, eta: float, base_quantile: float = 0.999, steps: int = 10) -> None:
    true_seq, linear_seq = get_per_sequence_contributions(run_id, eta)
    linquad_seq = None
    data = runs_data[run_id]
    if data['approx'].get('quadratic') and 'per_sequence_vhvv' in data['approx']['quadratic']:
        vhvv = np.asarray(data['approx']['quadratic']['per_sequence_vhvv'], dtype=float)
        linquad_seq = linear_seq + 0.5 * (eta ** 2) * vhvv

    quantiles = [base_quantile ** k for k in range(steps, -1, -1)]
    thresholds = [np.quantile(np.abs(linear_seq), q) for q in quantiles]

    clipped_true = []
    clipped_linear = []
    clipped_linquad = []

    for thr in thresholds:
        clipped_true.append(float(np.clip(true_seq, -thr, thr).sum()))
        clipped_linear.append(float(np.clip(linear_seq, -thr, thr).sum()))
        if linquad_seq is not None:
            clipped_linquad.append(float(np.clip(linquad_seq, -thr, thr).sum()))

    plt.figure(figsize=(7, 4))
    plt.plot(thresholds, clipped_true, label='ΔH true (clipped)', marker='o')
    plt.plot(thresholds, clipped_linear, label='ΔH linear (clipped)', marker='s')
    if clipped_linquad:
        plt.plot(thresholds, clipped_linquad, label='ΔH lin+quad (clipped)', marker='^')

    plt.xscale('log')
    plt.xlabel('Clipping threshold (|ΔH contribution|)')
    plt.ylabel('Clipped sum')
    plt.title(f'Clip sweep for {run_id}, η={eta:.2e}')
    plt.grid(True, which='both', ls=':', alpha=0.6)
    plt.legend()
    plt.show()



### Reweighted Estimates by Filtering Strategy

Compare raw ΔH estimates with versions produced by truncating or clipping the largest |ΔH linear| contributions.


In [None]:

def compare_filtered_estimates(run_id: str, eta: float, keep_quantiles=(0.99, 0.95), clip_thresholds=(5e-4, 1e-4)) -> pd.DataFrame:
    rows = []
    data = runs_data[run_id]
    entry = _select_eta_entry(data, eta)
    delta_true = float(entry['delta_h_true'])
    delta_linear = float(data['approx']['delta_h_per_lr']) * eta
    vhvv = None
    if data['approx'].get('quadratic') and 'per_sequence_vhvv' in data['approx']['quadratic']:
        vhvv = np.asarray(data['approx']['quadratic']['per_sequence_vhvv'], dtype=float)
        delta_linquad = delta_linear + 0.5 * float(data['approx']['quadratic']['vHvv']) * (eta ** 2)
    else:
        delta_linquad = None

    rows.append({'strategy': 'raw', 'delta_true': delta_true, 'delta_linear': delta_linear, 'delta_linquad': delta_linquad})

    for q in keep_quantiles:
        trunc = summarize_truncation(run_id, eta, keep_quantile=q)
        rows.append({'strategy': f'truncate@{q}', 'delta_true': trunc['delta_true_trunc'], 'delta_linear': trunc['delta_linear_trunc'], 'delta_linquad': trunc.get('delta_linquad_trunc')})

    for clip in clip_thresholds:
        clipped = summarize_clipping(run_id, eta, clip)
        rows.append({'strategy': f'clip@{clip}', 'delta_true': clipped['delta_true_clip'], 'delta_linear': clipped['delta_linear_clip'], 'delta_linquad': clipped.get('delta_linquad_clip')})

    return pd.DataFrame(rows)

# Example usage:
# compare_filtered_estimates('run_01', 8e-06)



### Truncation Sweep Plot

Plot the ratio of truncated ΔH estimates (linear and lin+quad) relative to the true truncated value across a sweep of keep-quantiles.


In [None]:

def plot_truncation_ratio(run_id: str, eta: float, keep_quantiles=None) -> None:
    if keep_quantiles is None:
        keep_quantiles = [1.0] + [0.999 ** k for k in range(10, -1, -1)]

    summaries = []
    for q in keep_quantiles:
        summary = summarize_truncation(run_id, eta, keep_quantile=q)
        summaries.append(summary)

    df = pd.DataFrame(summaries)
    delta_true = df['delta_true_trunc'].to_numpy()
    ratios_linear = df['delta_linear_trunc'] / delta_true

    plt.figure(figsize=(7, 4))
    plt.plot(df['keep_quantile'], ratios_linear, marker='o', label='linear / true')

    plt.xscale('log')
    plt.xlabel('keep_quantile (log scale)')
    plt.ylabel('ratio to true')
    plt.title(f'Truncation sweep for {run_id}, η={eta:.2e}')
    plt.grid(True, which='both', ls=':', alpha=0.5)
    plt.legend()
    plt.show()
