# Subrun Comparison

Compare parent runs to their subruns — the **deterministic swap** use case:
- Parent runs use Config A (e.g., original reranker or embedder)
- Subruns reuse the same memories, test cases, and queries, but run experiments with Config B
- Because everything except the experiment config is identical, any F1 difference is purely due to the config change

**Typical workflow:**
1. Run a full pipeline batch → parent runs
2. Run a subrun batch pointing at those parent run IDs → subruns
3. Open this notebook, paste the parent run IDs, compare

In [None]:
# Cell 0 — Setup & Imports
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from memory_retrieval.experiments.comparison import (
    compare_configs,
    fingerprint_diff,
    group_runs_by_fingerprint,
    load_run_summaries,
    load_subrun_summaries,
)
from memory_retrieval.infra.figures import create_figure_session, save_figure
from memory_retrieval.infra.runs import PHASE1, PHASE2

# Find project root
PROJECT_ROOT = Path.cwd()
while not (PROJECT_ROOT / "pyproject.toml").exists():
    if PROJECT_ROOT.parent == PROJECT_ROOT:
        raise RuntimeError("Could not find project root")
    PROJECT_ROOT = PROJECT_ROOT.parent
os.chdir(PROJECT_ROOT)


def _extract_f1_from_summary(
    macro: dict, metric_path: str, strategy: str | None = None
) -> float | None:
    """Extract the primary optimal F1 from a macro_averaged dict."""
    if metric_path == "post_rerank" and strategy:
        return (
            macro.get("post_rerank", {}).get(strategy, {}).get("at_optimal_threshold", {}).get("f1")
        )
    elif metric_path == "pre_rerank":
        return macro.get("pre_rerank", {}).get("at_optimal_distance_threshold", {}).get("f1")
    else:
        return macro.get("metrics", {}).get("f1")


def _extract_f1_from_per_test_case(
    per_tc: dict, metric_path: str, strategy: str | None = None
) -> float | None:
    """Extract F1 from a per_test_case entry."""
    if metric_path == "post_rerank" and strategy:
        return (
            per_tc.get("post_rerank", {})
            .get(strategy, {})
            .get("rerank_threshold", {})
            .get("at_optimal", {})
            .get("f1")
        )
    elif metric_path == "pre_rerank":
        return (
            per_tc.get("pre_rerank", {})
            .get("distance_threshold", {})
            .get("at_optimal", {})
            .get("f1")
        )
    else:
        return per_tc.get("metrics", {}).get("f1")


print("Setup complete.")

In [None]:
# Cell 1 — Configuration

# Which phase to analyze
PHASE = PHASE2

# Parent run IDs to compare.
# Copy these from run_batch.py output or data/<phase>/batches/<batch_id>/batch_manifest.json.
# Set to None to load all parent runs that have subruns.
PARENT_RUN_IDS = None  # e.g., ["run_20260218_143022", "run_20260218_145301"]

# Rerank strategy — must match what was used in the subrun's ExperimentConfig.
STRATEGY = "default"

# Metric path: "post_rerank" for reranking experiments, "metrics" for standard vector-only.
METRIC_PATH = "post_rerank"

print(f"Phase: {PHASE}")
print(f"Strategy: {STRATEGY}")
print(f"Metric path: {METRIC_PATH}")

In [None]:
# Cell 2 — Load Parent & Subrun Summaries

parent_summaries = load_run_summaries(PHASE, run_ids=PARENT_RUN_IDS)
subrun_summaries = load_subrun_summaries(PHASE, parent_run_ids=PARENT_RUN_IDS)

loaded_parent_run_ids = sorted(
    [summary.get("run_id", "unknown") for summary in parent_summaries if summary.get("run_id")]
)
if PARENT_RUN_IDS is None:
    parents_part = f"parents-all-n{len(loaded_parent_run_ids)}"
elif loaded_parent_run_ids:
    parents_part = f"parents-{loaded_parent_run_ids[0]}-{loaded_parent_run_ids[-1]}-n{len(loaded_parent_run_ids)}"
else:
    parents_part = "parents-none-none-n0"

comparison_key = (
    f"{parents_part}__subruns-n{len(subrun_summaries)}__strategy-{STRATEGY}__metric-{METRIC_PATH}"
)
FIGURE_SESSION = create_figure_session(
    root_dir=PROJECT_ROOT / "data" / "comparisons" / "exports",
    notebook_slug=f"subrun_comparison/{PHASE}",
    context_key=comparison_key,
    context={
        "phase": PHASE,
        "strategy": STRATEGY,
        "metric_path": METRIC_PATH,
        "parent_run_ids": PARENT_RUN_IDS,
        "loaded_parent_run_ids": loaded_parent_run_ids,
        "subrun_ids": [summary.get("run_id", "unknown") for summary in subrun_summaries],
    },
)
print(f"Figure export session: {FIGURE_SESSION.session_dir}")

if not subrun_summaries:
    print("No subruns found. Run a subrun batch first:")
    print()
    print(
        "  from memory_retrieval.experiments.batch_runner import BatchSubrunConfig, run_subrun_batch"
    )
    print("  outcome = run_subrun_batch(BatchSubrunConfig(")
    print("      phase=PHASE2,")
    print("      parent_run_ids=[...],")
    print(
        "      experiment_config=ExperimentConfig(search_backend=VectorBackend(), reranker=Reranker(...)),"
    )
    print("  ))")
else:
    print(f"Loaded {len(parent_summaries)} parent runs")
    print(f"Loaded {len(subrun_summaries)} subruns")
    print()

    # Group subruns by parent for the overview
    subruns_by_parent: dict[str, list[dict]] = {}
    for summary in subrun_summaries:
        parent_run_id = summary.get("parent_run_id", "unknown")
        subruns_by_parent.setdefault(parent_run_id, []).append(summary)

    print(f"{'Parent Run ID':<30} {'Parent F1':>10} {'Subruns':>8}  Subrun IDs & F1")
    print("-" * 90)
    for parent_summary in parent_summaries:
        parent_run_id = parent_summary.get("run_id", "?")
        parent_f1 = _extract_f1_from_summary(
            parent_summary.get("macro_averaged", {}), METRIC_PATH, STRATEGY
        )
        f1_display = f"{parent_f1:.4f}" if parent_f1 is not None else "n/a"
        subruns = subruns_by_parent.get(parent_run_id, [])
        print(f"{parent_run_id:<30} {f1_display:>10} {len(subruns):>8}", end="")

        for subrun_summary in subruns:
            subrun_f1 = _extract_f1_from_summary(
                subrun_summary.get("macro_averaged", {}), METRIC_PATH, STRATEGY
            )
            subrun_f1_display = f"{subrun_f1:.4f}" if subrun_f1 is not None else "n/a"
            print(f"  {subrun_summary['run_id']} ({subrun_f1_display})", end="")
        print()

In [None]:
# Cell 3 — Config Diff (What Changed)

if parent_summaries and subrun_summaries:
    parent_fingerprint = parent_summaries[0].get("config_fingerprint", {})
    subrun_fingerprint = subrun_summaries[0].get("config_fingerprint", {})

    if not parent_fingerprint:
        print("Parent runs have no config fingerprint.")
        print("Run update_config_fingerprint() on parent runs, or re-run them via batch_runner.")
    elif not subrun_fingerprint:
        print("Subruns have no config fingerprint.")
        print("Subruns created via run_subrun_batch() always store a fingerprint.")
    else:
        diff = fingerprint_diff(parent_fingerprint, subrun_fingerprint)

        if diff:
            print("Config changes (parent → subrun):")
            print()
            max_field_len = max(len(field) for field in diff)
            for field, values in diff.items():
                print(f"  {field:<{max_field_len}}  {str(values['a']):<40} → {values['b']}")
        else:
            print("No config differences detected.")
            print(
                "Both parent and subrun have the same fingerprint hash:",
                parent_fingerprint.get("fingerprint_hash"),
            )
            print(
                "This means the subrun ran with identical config — results should be very similar."
            )

        print()
        print(f"Parent fingerprint hash: {parent_fingerprint.get('fingerprint_hash', 'n/a')}")
        print(f"Subrun fingerprint hash: {subrun_fingerprint.get('fingerprint_hash', 'n/a')}")
else:
    print("Cannot compute diff: missing parent or subrun summaries.")

In [None]:
# Cell 4 — Statistical Comparison (Parent vs Subrun)
#
# Treats parent summaries as Config A and subrun summaries as Config B.
# Averages each test case's F1 across all parent runs (A) and across all subruns (B),
# then runs a paired Wilcoxon signed-rank test on the per-test-case deltas.

if parent_summaries and subrun_summaries:
    comparison = compare_configs(
        parent_summaries,
        subrun_summaries,
        strategy=STRATEGY,
        metric_path=METRIC_PATH,
    )

    if "error" in comparison:
        print(f"Comparison error: {comparison['error']}")
    else:
        print("=" * 60)
        print("PARENT vs SUBRUN — STATISTICAL COMPARISON")
        print("=" * 60)
        print(f"  Common test cases:       {comparison['num_common_test_cases']}")
        print(f"  Parent runs (n):         {comparison['num_runs_a']}")
        print(f"  Subruns (n):             {comparison['num_runs_b']}")
        print()
        print(f"  Mean F1 — parent:        {comparison['mean_f1_a']:.4f}")
        print(f"  Mean F1 — subrun:        {comparison['mean_f1_b']:.4f}")
        print(f"  Mean delta (sub - par):  {comparison['mean_diff']:+.4f}")
        print(f"  Direction:               {comparison['direction']}")
        print()
        ci = comparison["bootstrap_ci"]
        print(f"  Bootstrap 95% CI:        [{ci['lower']:.4f}, {ci['upper']:.4f}]")
        print(f"  Excludes zero:           {ci['excludes_zero']}")
        print(f"  Wilcoxon p-value:        {comparison['wilcoxon']['p_value']:.4f}")
        print(f"  Significance:            {comparison['significance']}")
        print()

        # Plain-language interpretation
        direction = comparison["direction"]
        significance = comparison["significance"]
        winner = "subrun" if direction == "b_better" else "parent"

        if significance == "significant":
            print(f"  ✓ {winner.upper()} is significantly better (|delta| ≥ 0.03, p < 0.05)")
        elif significance == "marginal":
            print(f"  ~ {winner.upper()} shows marginal improvement (|delta| ≥ 0.01, p < 0.10)")
        else:
            print("  ✗ No significant difference — subrun config change has no measurable effect")
else:
    print("Cannot compare: missing parent or subrun summaries.")

In [None]:
# Cell 5 — Per-Test-Case Paired Plot
#
# Left: paired dot plot (parent blue, subrun orange) — shows absolute F1 per test case
# Right: delta bar chart — shows improvement or degradation per test case

if (
    parent_summaries
    and subrun_summaries
    and "comparison" in dir()
    and "paired_differences" in comparison
):
    paired = comparison["paired_differences"]

    fig, axes = plt.subplots(1, 2, figsize=(16, max(5, len(paired) * 0.55 + 2)))

    test_case_ids = [entry["test_case_id"] for entry in paired]
    y_positions = list(range(len(test_case_ids)))

    # Left: paired dot plot
    ax = axes[0]
    for y_pos, entry in enumerate(paired):
        diff_color = (
            "#2ecc71"
            if entry["diff"] > 0.001
            else "#e74c3c"
            if entry["diff"] < -0.001
            else "#95a5a6"
        )
        ax.plot(
            [entry["f1_a"], entry["f1_b"]],
            [y_pos, y_pos],
            color=diff_color,
            linewidth=2.5,
            alpha=0.6,
        )
        ax.scatter(
            entry["f1_a"], y_pos, color="#3498db", s=70, zorder=5, edgecolors="white", linewidth=0.5
        )
        ax.scatter(
            entry["f1_b"], y_pos, color="#e67e22", s=70, zorder=5, edgecolors="white", linewidth=0.5
        )

    ax.set_yticks(y_positions)
    ax.set_yticklabels(test_case_ids, fontsize=8)
    ax.set_xlabel("F1 Score")
    ax.set_title("Per-Test-Case F1: Parent (blue) vs Subrun (orange)")
    ax.scatter([], [], color="#3498db", s=70, label="Parent")
    ax.scatter([], [], color="#e67e22", s=70, label="Subrun")
    ax.plot([], [], color="#2ecc71", linewidth=2.5, label="Improved")
    ax.plot([], [], color="#e74c3c", linewidth=2.5, label="Degraded")
    ax.legend(fontsize=8, loc="lower right")
    ax.set_xlim(-0.05, 1.05)
    ax.grid(True, alpha=0.3, axis="x")

    # Right: delta bar chart
    ax = axes[1]
    diffs = [entry["diff"] for entry in paired]
    bar_colors = ["#2ecc71" if diff > 0 else "#e74c3c" for diff in diffs]
    ax.barh(y_positions, diffs, color=bar_colors, alpha=0.75, edgecolor="white", linewidth=0.3)
    ax.axvline(0, color="black", linewidth=0.8)
    ax.axvline(
        comparison["mean_diff"],
        color="#2c3e50",
        linewidth=1.5,
        linestyle="--",
        label=f"Mean delta: {comparison['mean_diff']:+.3f}",
    )
    ax.set_yticks(y_positions)
    ax.set_yticklabels(test_case_ids, fontsize=8)
    ax.set_xlabel("F1 Delta (Subrun − Parent)")
    ax.set_title("Per-Test-Case Delta")
    ax.legend(fontsize=9)
    ax.grid(True, alpha=0.3, axis="x")

    plt.suptitle(
        f"Subrun Comparison — {PHASE} ({STRATEGY})  |  {comparison['significance'].replace('_', ' ')}",
        fontsize=12,
        fontweight="bold",
    )
    plt.tight_layout()
    saved_paths = save_figure(
        fig,
        FIGURE_SESSION,
        "per_test_case_paired_plot",
        title=f"Subrun Comparison — {PHASE} ({STRATEGY})",
    )
    plt.show()
    print(f"Saved: {saved_paths['png']}")
else:
    print("Run Cell 4 first to generate the comparison result.")

In [None]:
# Cell 6 — Threshold Sweep Overlay (Parent vs Subrun)
#
# Shows how F1 varies with rerank score threshold for parent runs vs subruns.
# Shaded band = min/max across multiple runs in each group.

if parent_summaries and subrun_summaries:
    fig, ax = plt.subplots(figsize=(12, 5))

    groups_to_plot = [
        (parent_summaries, "#3498db", "Parent"),
        (subrun_summaries, "#e67e22", "Subrun"),
    ]

    for group_summaries, color, group_label in groups_to_plot:
        # Collect threshold sweeps from all summaries in this group
        all_sweep_dicts: list[dict[float, float]] = []
        for summary in group_summaries:
            sweep_data = (
                summary.get("rerank_strategies", {}).get(STRATEGY, {}).get("threshold_sweep", {})
            )
            full_sweep = sweep_data.get("full_sweep", [])
            if full_sweep:
                all_sweep_dicts.append({entry["threshold"]: entry["f1"] for entry in full_sweep})

        if not all_sweep_dicts:
            print(f"No threshold sweep data found for {group_label}. Is STRATEGY correct?")
            continue

        all_thresholds = sorted(set().union(*[set(sweep.keys()) for sweep in all_sweep_dicts]))

        mean_f1_values = []
        min_f1_values = []
        max_f1_values = []
        for threshold in all_thresholds:
            f1_at_threshold = [sweep[threshold] for sweep in all_sweep_dicts if threshold in sweep]
            if f1_at_threshold:
                mean_f1_values.append(np.mean(f1_at_threshold))
                min_f1_values.append(min(f1_at_threshold))
                max_f1_values.append(max(f1_at_threshold))

        ax.plot(
            all_thresholds,
            mean_f1_values,
            label=f"{group_label} (n={len(all_sweep_dicts)})",
            color=color,
            linewidth=2,
        )

        if len(all_sweep_dicts) > 1:
            ax.fill_between(all_thresholds, min_f1_values, max_f1_values, color=color, alpha=0.15)

        optimal_index = int(np.argmax(mean_f1_values))
        if optimal_index < len(all_thresholds):
            ax.axvline(
                all_thresholds[optimal_index],
                color=color,
                linestyle=":",
                alpha=0.6,
                linewidth=1.2,
            )

    ax.set_xlabel("Rerank Score Threshold")
    ax.set_ylabel("Macro F1")
    ax.set_title(f"Threshold Sweep — {PHASE} ({STRATEGY})")
    ax.legend(fontsize=9)
    ax.set_ylim(0, 1.05)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    saved_paths = save_figure(
        fig,
        FIGURE_SESSION,
        "threshold_sweep_overlay",
        title=f"Threshold Sweep — {PHASE} ({STRATEGY})",
    )
    plt.show()
    print(f"Saved: {saved_paths['png']}")
else:
    print("Cannot plot: missing parent or subrun summaries.")

In [None]:
# Cell 7 — Per-Test-Case Heatmap (Parent vs Subrun Side-by-Side)

if parent_summaries and subrun_summaries:
    all_test_case_ids = set()
    for summary in parent_summaries + subrun_summaries:
        all_test_case_ids.update(summary.get("per_test_case", {}).keys())
    all_test_case_ids = sorted(all_test_case_ids)

    columns = [
        (parent_summaries, "Parent\n(avg)"),
        (subrun_summaries, "Subrun\n(avg)"),
    ]

    heatmap_data = np.full((len(all_test_case_ids), 3), np.nan)  # [parent, subrun, delta]

    for col_index, (group_summaries, _) in enumerate(columns):
        for row_index, test_case_id in enumerate(all_test_case_ids):
            f1_values = []
            for summary in group_summaries:
                per_tc = summary.get("per_test_case", {}).get(test_case_id, {})
                f1 = _extract_f1_from_per_test_case(per_tc, METRIC_PATH, STRATEGY)
                if f1 is not None:
                    f1_values.append(f1)
            if f1_values:
                heatmap_data[row_index, col_index] = np.mean(f1_values)

    # Delta column
    for row_index in range(len(all_test_case_ids)):
        parent_f1 = heatmap_data[row_index, 0]
        subrun_f1 = heatmap_data[row_index, 1]
        if not np.isnan(parent_f1) and not np.isnan(subrun_f1):
            heatmap_data[row_index, 2] = subrun_f1 - parent_f1

    column_labels = ["Parent\n(avg F1)", "Subrun\n(avg F1)", "Delta\n(sub−par)"]

    fig, ax = plt.subplots(figsize=(max(8, 5), max(6, len(all_test_case_ids) * 0.55)))
    sns.heatmap(
        heatmap_data[:, :2],
        annot=True,
        fmt=".2f",
        xticklabels=["Parent (avg F1)", "Subrun (avg F1)"],
        yticklabels=all_test_case_ids,
        cmap="RdYlGn",
        vmin=0,
        vmax=1,
        linewidths=0.5,
        ax=ax,
    )
    ax.set_title(f"Per-Test-Case F1 — {PHASE} ({STRATEGY})")
    plt.tight_layout()
    saved_paths = save_figure(
        fig,
        FIGURE_SESSION,
        "per_test_case_f1_heatmap",
        title=f"Per-Test-Case F1 — {PHASE} ({STRATEGY})",
    )
    plt.show()
    print(f"Saved: {saved_paths['png']}")

    # Separate delta heatmap (diverging colormap centered at 0)
    fig, ax = plt.subplots(figsize=(3, max(6, len(all_test_case_ids) * 0.55)))
    delta_max = max(np.nanmax(np.abs(heatmap_data[:, 2])), 0.01)
    sns.heatmap(
        heatmap_data[:, 2:3],
        annot=True,
        fmt="+.2f",
        xticklabels=["Delta (sub−par)"],
        yticklabels=all_test_case_ids,
        cmap="RdYlGn",
        vmin=-delta_max,
        vmax=delta_max,
        center=0,
        linewidths=0.5,
        ax=ax,
    )
    ax.set_title("F1 Delta (Subrun − Parent)")
    plt.tight_layout()
    saved_paths = save_figure(
        fig,
        FIGURE_SESSION,
        "per_test_case_delta_heatmap",
        title="F1 Delta (Subrun - Parent)",
    )
    plt.show()
    print(f"Saved: {saved_paths['png']}")

    # Summary: which test cases improved most / degraded most
    delta_column = heatmap_data[:, 2]
    valid_indices = [i for i in range(len(all_test_case_ids)) if not np.isnan(delta_column[i])]
    sorted_by_delta = sorted(valid_indices, key=lambda i: delta_column[i], reverse=True)

    print("\nTest case breakdown:")
    print(f"  {'Test Case':<35} {'Parent':>8} {'Subrun':>8} {'Delta':>8}")
    print(f"  {'-' * 35} {'-' * 8} {'-' * 8} {'-' * 8}")
    for i in sorted_by_delta:
        test_case_id = all_test_case_ids[i]
        parent_f1 = heatmap_data[i, 0]
        subrun_f1 = heatmap_data[i, 1]
        delta = delta_column[i]
        print(f"  {test_case_id:<35} {parent_f1:>8.3f} {subrun_f1:>8.3f} {delta:>+8.3f}")
else:
    print("Cannot plot: missing parent or subrun summaries.")