
# Batch Size Comparison Diagnostics

Compare DeltaH estimates for batches of size 256 versus synthetic 256-sample subsets drawn from E=2048 runs. Focus on the relative error `err_rel = 1 - (DeltaH_linear / DeltaH_true)` for selected eta values.


In [None]:

import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

plt.style.use('seaborn-v0_8')

ROOT = Path('..')
ROOT_2048 = ROOT / 'E_2048'
ROOT_256 = ROOT / 'E_256'

paths_2048 = sorted(ROOT_2048.glob('run_*/results.json'))
paths_256 = sorted(ROOT_256.glob('run_*/results.json'))
print(f"Found {len(paths_2048)} E=2048 runs and {len(paths_256)} E=256 runs")


In [None]:

def load_runs(paths):
    runs = {}
    for path in paths:
        with path.open() as f:
            runs[path.parent.name] = json.load(f)
    return runs

runs_2048 = load_runs(paths_2048)
runs_256 = load_runs(paths_256)
print("E=2048 run ids:", list(runs_2048.keys()))
print("E=256 run ids:", list(runs_256.keys()))


In [None]:

# Reference true totals from E=2048 runs for each eta
reference_true_totals = {}
for run_id, data in runs_2048.items():
    for entry in data['true']['entries']:
        eta = float(entry['eta'])
        reference_true_totals.setdefault(eta, []).append(float(entry['delta_h_true']))
reference_true_totals = {eta: float(np.mean(vals)) for eta, vals in reference_true_totals.items()}
print('Reference true totals (E=2048):')
for eta, val in reference_true_totals.items():
    print(f"  eta={eta:.2e}: delta_h_true={val:.6e}")


In [None]:

def select_eta_entry(data: dict, eta: float, *, tol: float = 1e-12) -> dict:
    for entry in data['true']['entries']:
        if abs(float(entry['eta']) - eta) <= max(tol, tol * abs(eta), 1e-6 * abs(eta)):
            return entry
    raise ValueError(f"eta={eta} not found")


def per_sequence_contributions(data: dict, eta: float):
    entry = select_eta_entry(data, eta)
    diag = entry['diagnostics']

    weights_sum = float(diag.get('weights_sum', 1.0))
    norm_w = np.asarray(diag['normalized_weights'], dtype=float)
    weights = norm_w * weights_sum
    token_counts = np.asarray(diag['token_counts'], dtype=float)
    h_new = np.asarray(diag['h_new'], dtype=float)
    base_entropy = float(diag['base_entropy'])

    denom = np.dot(token_counts, weights) if token_counts.size else 1.0
    if denom == 0.0:
        denom = 1.0
    total_tokens = token_counts.sum() if token_counts.size else 1.0

    true_seq = (weights * h_new) / denom - (token_counts / total_tokens) * base_entropy

    approx_seq = data['approx']['per_sequence']
    eta = float(eta)
    linear_seq = np.asarray([rec['gdotv'] * eta for rec in approx_seq], dtype=float)
    return true_seq, linear_seq


def aggregate_err_rel(true_seq_subset: np.ndarray, linear_seq_subset: np.ndarray, true_total_full: float) -> float:
    if true_total_full == 0:
        return np.nan
    linear_total_subset = linear_seq_subset.sum()
    true_total_subset = true_seq_subset.sum()
    return (linear_total_subset - true_total_subset) / true_total_full


def aggregate_err_rel_from_run(data: dict, eta: float, true_total_full: float) -> float:
    true_seq, linear_seq = per_sequence_contributions(data, eta)
    return aggregate_err_rel(true_seq, linear_seq, true_total_full=true_total_full)


In [None]:

def simulate_err_rel_from_2048(run_data: dict, eta: float, sample_size: int = 256, *, num_samples: int = 1000, seed: int = 0) -> np.ndarray:
    rng = np.random.default_rng(seed)
    true_seq, linear_seq = per_sequence_contributions(run_data, eta)
    n = true_seq.size
    if n < sample_size:
        raise ValueError(f"Sample size {sample_size} exceeds available sequences {n}")
    true_total_full = true_seq.sum()

    errs = []
    for _ in range(num_samples):
        idx = rng.choice(n, size=sample_size, replace=False)
        errs.append(aggregate_err_rel(true_seq[idx], linear_seq[idx], true_total_full=true_total_full))
    return np.asarray(errs)


In [None]:

def collect_err_rel_across_runs(runs: dict, eta: float, true_total_ref: float) -> pd.DataFrame:
    rows = []
    for run_id, data in runs.items():
        try:
            true_seq, linear_seq = per_sequence_contributions(data, eta)
        except ValueError:
            continue
        err = aggregate_err_rel(true_seq, linear_seq, true_total_full=true_total_ref)
        rows.append({'run': run_id, 'eta': eta, 'err_rel': err})
    return pd.DataFrame(rows)


In [None]:

def compare_err_rel_distributions(eta: float, *, sample_size: int = 256, num_samples: int = 2000, seed: int = 0) -> None:
    if eta not in reference_true_totals:
        raise ValueError(f"No reference DeltaH_true for eta={eta}")
    true_total_ref = reference_true_totals[eta]

    actual_df = collect_err_rel_across_runs(runs_256, eta, true_total_ref=true_total_ref)
    if actual_df.empty:
        raise ValueError(f"No E=256 runs found for eta={eta}")

    synthetic_errs = []
    for run_id, data in runs_2048.items():
        try:
            errs = simulate_err_rel_from_2048(data, eta, sample_size=sample_size, num_samples=num_samples, seed=seed)
        except ValueError:
            continue
        synthetic_errs.append(errs)
    if not synthetic_errs:
        raise ValueError(f"No E=2048 runs found for eta={eta}")
    synthetic_errs = np.concatenate(synthetic_errs)

    plt.figure(figsize=(7, 4))
    plt.hist(synthetic_errs, bins=60, alpha=0.7, label='Synthetic (from E=2048)', color='tab:blue')
    for i, row in actual_df.iterrows():
        plt.axvline(row['err_rel'], color='tab:orange', linestyle='--', linewidth=1.5,
                    label='Actual E=256' if i == actual_df.index[0] else None)

    plt.xlabel('err_rel = (DeltaH_linear - DeltaH_true_subset) / DeltaH_true_full')
    plt.ylabel('Frequency')
    plt.title(f'err_rel distribution comparison (eta={eta:.2e})')
    plt.grid(True, ls=':', alpha=0.5)
    plt.legend()
    plt.show()

    summary = {
        'eta': eta,
        'synthetic_mean': float(np.nanmean(synthetic_errs)),
        'synthetic_std': float(np.nanstd(synthetic_errs, ddof=1)),
        'synthetic_min': float(np.nanmin(synthetic_errs)),
        'synthetic_max': float(np.nanmax(synthetic_errs)),
        'actual_values': actual_df['err_rel'].tolist(),
    }
    display(pd.DataFrame([summary]))



### Example Usage

```python
compare_err_rel_distributions(eta=1.6e-6, num_samples=2000)
compare_err_rel_distributions(eta=3.2e-6, num_samples=2000)
```



### Subsampling Error Distribution

Sample random subsets (size `subset_size`) across all E=2048 runs, compute `err_rel = 1 - DeltaH_linear / DeltaH_true`, and examine the distribution along with mean/variance summaries.


In [None]:

def sample_err_rel_across_runs(runs: dict, eta_values, *, subset_size: int = 256, samples_per_eta: int = 200, seed: int = 0) -> pd.DataFrame:
    rng = np.random.default_rng(seed)
    rows = []
    for eta in eta_values:
        for run_id, data in runs.items():
            try:
                true_seq, linear_seq = per_sequence_contributions(data, eta)
            except ValueError:
                continue
            n = true_seq.size
            if n < subset_size:
                continue
            true_total_full = true_seq.sum()
            for _ in range(samples_per_eta):
                idx = rng.choice(n, size=subset_size, replace=False)
                err = aggregate_err_rel(true_seq[idx], linear_seq[idx], true_total_full=true_total_full)
                rows.append({'run': run_id, 'eta': eta, 'err_rel': err})
    return pd.DataFrame(rows)


def summarize_err_rel_distribution(df: pd.DataFrame) -> pd.DataFrame:
    rows = []
    for eta, group in df.groupby('eta'):
        vals = group['err_rel'].dropna().to_numpy()
        if vals.size == 0:
            continue
        rows.append({
            'eta': eta,
            'mean_err_rel': float(np.mean(vals)),
            'std_err_rel': float(np.std(vals, ddof=1)),
            'var_err_rel': float(np.var(vals, ddof=1)),
            'min_err_rel': float(np.min(vals)),
            'max_err_rel': float(np.max(vals)),
            'num_samples': vals.size,
        })
    return pd.DataFrame(rows)


def plot_err_rel_histograms(df: pd.DataFrame, eta_values) -> None:
    num = len(eta_values)
    fig, axes = plt.subplots(1, num, figsize=(5 * num, 4), sharey=True)
    if num == 1:
        axes = [axes]
    for ax, eta in zip(axes, eta_values):
        subset = df[df['eta'] == eta]['err_rel'].dropna()
        ax.hist(subset, bins=60, color='tab:blue', alpha=0.75)
        ax.set_title(f'eta={eta:.2e}')
        ax.set_xlabel('err_rel = (DeltaH_linear_subset - DeltaH_true_subset) / DeltaH_true_full')
        ax.grid(True, ls=':', alpha=0.5)
    axes[0].set_ylabel('Frequency')
    plt.tight_layout()
    plt.show()



#### Example Usage

```python
eta_values = [1.6e-6, 3.2e-6]
# Synthetic subsets from E=2048 runs
subset_df = sample_err_rel_across_runs(runs_2048, eta_values, subset_size=256, samples_per_eta=500, seed=42)
plot_err_rel_histograms(subset_df, eta_values)
summarize_err_rel_distribution(subset_df)

# Actual vs synthetic comparison at a single eta
a_compare = compare_err_rel_distributions(eta=1.6e-6, num_samples=2000)
```
