
# Fisher vs True Î”H Diagnostics

This notebook explores how the ground-truth entropy change compares to the linear and lin+quadratic approximations across recent runs. It also provides helpers for drilling into per-sequence contributions once interesting run/Î· pairs are identified.


In [None]:

import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

plt.style.use('seaborn-v0_8')

RESULTS_ROOT = Path('..')
result_paths = sorted(RESULTS_ROOT.glob('run_*/results.json'))
print(f"Found {len(result_paths)} runs")


In [None]:

def load_run(path: Path):
    with path.open() as f:
        return json.load(f)

runs_data = {path.parent.name: load_run(path) for path in result_paths}
print(f"Loaded runs: {list(runs_data.keys())}")


In [None]:

summary_rows = []
for run_id, data in runs_data.items():
    approx = data['approx']
    gdotv_linear = float(approx['delta_h_per_lr'])
    quad = approx.get('quadratic') or {}
    gdotv_quad = float(quad.get('gdotv', gdotv_linear))
    vhvv = quad.get('vHvv')

    for entry in data['true']['entries']:
        eta = float(entry['eta'])
        delta_true = float(entry['delta_h_true'])
        delta_linear = gdotv_linear * eta
        delta_linquad = None
        if vhvv is not None:
            delta_linquad = gdotv_quad * eta + 0.5 * float(vhvv) * (eta ** 2)

        summary_rows.append({
            'run': run_id,
            'eta': eta,
            'delta_true': delta_true,
            'delta_linear': delta_linear,
            'delta_linquad': delta_linquad,
        })

summary_df = pd.DataFrame(summary_rows).sort_values(['run', 'eta']).reset_index(drop=True)
summary_df


In [None]:

def plot_loglog_comparison(df: pd.DataFrame) -> None:
    for run_id, group in df.groupby('run'):
        true_vals = np.abs(group['delta_true'].values.astype(float))
        linear_vals = np.abs(group['delta_linear'].values.astype(float))
        linquad_series = group['delta_linquad'].astype(float)
        linquad_vals = np.abs(linquad_series.values)

        mask_linear = (true_vals > 0) & (linear_vals > 0)
        mask_linquad = (true_vals > 0) & linquad_series.notna().values & (linquad_vals > 0)

        if not (mask_linear.any() or mask_linquad.any()):
            print(f"Run {run_id} has no positive entries to plot.")
            continue

        fig, ax = plt.subplots(figsize=(6, 6))

        if mask_linear.any():
            ax.loglog(true_vals[mask_linear], linear_vals[mask_linear], 'o', label='Linear')
            for x, y, eta in zip(true_vals[mask_linear], linear_vals[mask_linear], group.loc[mask_linear, 'eta']):
                ax.annotate(f"Î·={eta:.2e}", (x, y), textcoords='offset points', xytext=(4, 4), fontsize=8)

        if mask_linquad.any():
            ax.loglog(true_vals[mask_linquad], linquad_vals[mask_linquad], 'x', label='Lin+Quad')
            for x, y, eta in zip(true_vals[mask_linquad], linquad_vals[mask_linquad], group.loc[mask_linquad, 'eta']):
                ax.annotate(f"Î·={eta:.2e}", (x, y), textcoords='offset points', xytext=(4, -10), fontsize=8, color='tab:orange')

        bounds = []
        if mask_linear.any():
            bounds.extend(true_vals[mask_linear])
            bounds.extend(linear_vals[mask_linear])
        if mask_linquad.any():
            bounds.extend(true_vals[mask_linquad])
            bounds.extend(linquad_vals[mask_linquad])

        bounds = np.array(bounds)
        lim_min = bounds.min() * 0.8
        lim_max = bounds.max() * 1.2
        ax.plot([lim_min, lim_max], [lim_min, lim_max], '--', color='gray', label='y = x')
        ax.set_xlim(lim_min, lim_max)
        ax.set_ylim(lim_min, lim_max)

        ax.set_title(f"{run_id}: |Î”H_true| vs approximations")
        ax.set_xlabel("|Î”H_true|")
        ax.set_ylabel("|Î”H_approx|")
        ax.legend()
        ax.grid(True, which='both', ls=':', alpha=0.5)
        plt.show()

plot_loglog_comparison(summary_df)


In [None]:

def _select_eta_entry(data: dict, eta: float) -> dict:
    for entry in data['true']['entries']:
        if abs(float(entry['eta']) - eta) <= max(1e-12, 1e-6 * eta):
            return entry
    raise ValueError(f"η={eta} not found in run")


def get_per_sequence_contributions(run_id: str, eta: float):
    data = runs_data[run_id]
    eta_entry = _select_eta_entry(data, eta)
    diag = eta_entry['diagnostics']

    # True contribution per sequence, aligned with SNIS per-token averaging
    weights_sum = float(diag['weights_sum']) if diag.get('weights_sum') is not None else 1.0
    norm_w = np.asarray(diag['normalized_weights'], dtype=float)
    w = norm_w * weights_sum
    token_counts = np.asarray(diag['token_counts'], dtype=float)
    h_new = np.asarray(diag['h_new'], dtype=float)
    base_entropy = float(diag['base_entropy'])
    total_tokens = token_counts.sum() if token_counts.size else 1.0
    denom = np.dot(token_counts, w) if token_counts.size else 1.0
    if denom == 0.0:
        denom = 1.0

    true_contrib = (w * h_new) / denom - (token_counts / total_tokens) * base_entropy

    approx_seq = data['approx']['per_sequence']
    eta = float(eta)
    linear_seq = np.asarray([rec['gdotv'] * eta for rec in approx_seq], dtype=float)
    vhvv_vals = np.asarray([rec.get('vhvv', 0.0) for rec in approx_seq], dtype=float)
    if np.allclose(vhvv_vals, 0.0):
        linquad_seq = None
    else:
        linquad_seq = linear_seq + 0.5 * (eta ** 2) * vhvv_vals

    return true_contrib, linear_seq, linquad_seq


def plot_delta_histograms(run_id: str, eta: float, bins: int = 60) -> None:
    true_seq, linear_seq, linquad_seq = get_per_sequence_contributions(run_id, eta)

    datasets = [("ΔH true contrib", true_seq), ("ΔH linear", linear_seq)]
    if linquad_seq is not None:
        datasets.append(("ΔH lin+quad", linquad_seq))

    num_plots = len(datasets)
    fig, axes = plt.subplots(1, num_plots, figsize=(5 * num_plots, 4), sharey=True)
    if num_plots == 1:
        axes = [axes]

    for ax, (label, values) in zip(axes, datasets):
        ax.hist(values, bins=bins, color='tab:blue', alpha=0.8)
        ax.axvline(0.0, color='black', linestyle='--', linewidth=1)
        ax.set_title(f"{label} (η={eta:.2e})")
        ax.set_xlabel("Contribution per sequence")
        ax.grid(True, ls=':', alpha=0.5)
        max_abs = np.max(np.abs(values))
        if max_abs == 0:
            max_abs = 1e-12
        padding = 0.05 * max_abs
        ax.set_xlim(-max_abs - padding, max_abs + padding)

    axes[0].set_ylabel("Count")
    plt.suptitle(f"Per-sequence contributions for {run_id}")
    plt.tight_layout()
    plt.show()


In [None]:

def plot_quadratic_effects(run_id: str, eta: float, bins: int = 60) -> None:
    data = runs_data[run_id]
    approx_seq = data['approx']['per_sequence']
    eta = float(eta)

    linear = np.asarray([rec['gdotv'] * eta for rec in approx_seq], dtype=float)
    quad_term = np.asarray([0.5 * (eta ** 2) * rec.get('vhvv', 0.0) for rec in approx_seq], dtype=float)
    linquad = linear + quad_term

    fig, axes = plt.subplots(1, 2, figsize=(10, 4))

    axes[0].hist(quad_term, bins=bins, color='tab:purple', alpha=0.85)
    axes[0].axvline(0.0, color='black', linestyle='--', linewidth=1)
    axes[0].set_title(f"Quadratic term (η={eta:.2e})")
    axes[0].set_xlabel("0.5 η² · vHvv per sequence")
    axes[0].set_ylabel("Count")
    axes[0].grid(True, ls=':', alpha=0.5)

    abs_linear = np.abs(linear)
    ratio = np.zeros_like(quad_term)
    mask = np.abs(linear) > 0
    ratio[mask] = quad_term[mask] / linear[mask]

    axes[1].scatter(abs_linear[mask], ratio[mask], s=12, alpha=0.6, color='tab:orange')
    axes[1].set_xscale('log')
    axes[1].set_title("Quadratic / Linear vs |Linear|")
    axes[1].set_xlabel("|Linear contribution|")
    axes[1].set_ylabel("Quadratic / Linear")
    axes[1].grid(True, which='both', ls=':', alpha=0.5)

    plt.suptitle(f"Quadratic corrections for {run_id}")
    plt.tight_layout()
    plt.show()


def plot_per_sequence_errors(run_id: str, eta: float, bins: int = 60) -> None:
    true_seq, linear_seq, linquad_seq = get_per_sequence_contributions(run_id, eta)

    err_linear = true_seq - linear_seq
    datasets = [("True - Linear", err_linear)]

    if linquad_seq is not None:
        err_linquad = true_seq - linquad_seq
        datasets.append(("True - Lin+Quad", err_linquad))

    num_plots = len(datasets)
    fig, axes = plt.subplots(1, num_plots, figsize=(5 * num_plots, 4), sharey=True)
    if num_plots == 1:
        axes = [axes]

    for ax, (label, values) in zip(axes, datasets):
        ax.hist(values, bins=bins, color='tab:green', alpha=0.8)
        ax.axvline(0.0, color='black', linestyle='--', linewidth=1)
        ax.set_title(f"{label} (η={float(eta):.2e})")
        ax.set_xlabel("Error per sequence")
        ax.grid(True, ls=':', alpha=0.5)
        max_abs = np.max(np.abs(values))
        if max_abs == 0:
            max_abs = 1e-12
        padding = 0.05 * max_abs
        ax.set_xlim(-max_abs - padding, max_abs + padding)

    axes[0].set_ylabel("Count")
    plt.suptitle(f"Approximation errors for {run_id}")
    plt.tight_layout()
    plt.show()


In [None]:

def plot_error_correlations(run_id: str, eta: float) -> None:
    data = runs_data[run_id]
    true_seq, linear_seq, linquad_seq = get_per_sequence_contributions(run_id, eta)
    eta_val = float(eta)

    quad_term = np.asarray([
        0.5 * (eta_val ** 2) * rec.get('vhvv', 0.0)
        for rec in data['approx']['per_sequence']
    ], dtype=float)

    ratio_linear = np.zeros_like(linear_seq)
    mask_nonzero = np.abs(true_seq) > 0
    ratio_linear[mask_nonzero] = linear_seq[mask_nonzero] / true_seq[mask_nonzero]
    rel_error = np.zeros_like(linear_seq)
    rel_error[mask_nonzero] = 1.0 - ratio_linear[mask_nonzero]

    fig, axes = plt.subplots(1, 5, figsize=(22, 4))

    axes[0].scatter(true_seq[mask_nonzero], rel_error[mask_nonzero], s=12, alpha=0.6, color='tab:blue')
    axes[0].axhline(0.0, color='black', linestyle='--', linewidth=1)
    axes[0].axvline(0.0, color='black', linestyle='--', linewidth=1)
    axes[0].set_title(f"True vs 1 - (linear/true) (η={eta_val:.2e})")
    axes[0].set_xlabel("ΔH true contrib")
    axes[0].set_ylabel("1 - (linear / true)")
    axes[0].grid(True, ls=':', alpha=0.5)

    axes[1].hist(ratio_linear[mask_nonzero], bins=80, color='tab:orange', alpha=0.8)
    axes[1].axvline(1.0, color='black', linestyle='--', linewidth=1)
    axes[1].set_title("linear / true (full range)")
    axes[1].set_xlabel("linear / true")
    axes[1].set_ylabel("Count")
    axes[1].grid(True, ls=':', alpha=0.5)

    clipped_mask = mask_nonzero & (np.abs(ratio_linear) < 5.0)
    axes[2].hist(ratio_linear[clipped_mask], bins=80, color='tab:red', alpha=0.8)
    axes[2].axvline(1.0, color='black', linestyle='--', linewidth=1)
    axes[2].set_title("linear / true (|ratio| < 5)")
    axes[2].set_xlabel("linear / true")
    axes[2].set_ylabel("Count")
    axes[2].grid(True, ls=':', alpha=0.5)

    axes[3].scatter(np.abs(true_seq[mask_nonzero]), ratio_linear[mask_nonzero], s=12, alpha=0.6, color='tab:green')
    axes[3].set_xscale('log')
    axes[3].axhline(1.0, color='black', linestyle='--', linewidth=1)
    axes[3].set_title("linear / true vs |true|")
    axes[3].set_xlabel("|ΔH true contrib|")
    axes[3].set_ylabel("linear / true")
    axes[3].grid(True, which='both', ls=':', alpha=0.5)

    mask_lin_nonzero = np.abs(linear_seq) > 0
    ratio_quad = np.zeros_like(quad_term)
    ratio_quad[mask_lin_nonzero] = quad_term[mask_lin_nonzero] / linear_seq[mask_lin_nonzero]
    axes[4].scatter(ratio_linear[mask_lin_nonzero], ratio_quad[mask_lin_nonzero], s=12, alpha=0.6, color='tab:purple')
    axes[4].axhline(0.0, color='black', linestyle='--', linewidth=1)
    axes[4].axvline(1.0, color='black', linestyle='--', linewidth=1)
    axes[4].set_title("(linear/true) vs (quad/linear)")
    axes[4].set_xlabel("linear / true")
    axes[4].set_ylabel("quad / linear")
    axes[4].grid(True, ls=':', alpha=0.5)

    plt.suptitle(f"Linear approximation diagnostics for {run_id}")
    plt.tight_layout()
    plt.show()



### Usage

After identifying interesting run/η combos:

```python
plot_delta_histograms('run_01', 3.2e-06)
plot_quadratic_effects('run_01', 3.2e-06)
plot_per_sequence_errors('run_01', 3.2e-06)
plot_error_correlations('run_01', 3.2e-06)
```
