# Cross-Run Comparison

Compare experiment runs across configs to:
1. **Quantify reproducibility** — how much does F1 vary across runs with identical config?
2. **Compare configurations** — when changing prompt/reranker/model, is the improvement real?
3. **Track trends** — are results improving over time?

Works with Phase 1 and Phase 2 runs. Backfills fingerprints for existing runs automatically.

In [None]:
# Cell 0 — Setup, Imports & Backfill
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from memory_retrieval.experiments.comparison import (
    compare_configs,
    compute_variance_report,
    fingerprint_diff,
    group_runs_by_fingerprint,
    load_run_summaries,
    reconstruct_fingerprint_from_run,
)
from memory_retrieval.experiments.metrics_adapter import extract_metric_from_nested
from memory_retrieval.infra.figures import create_figure_session, save_figure
from memory_retrieval.infra.io import load_json
from memory_retrieval.infra.runs import (
    PHASE1,
    PHASE2,
    list_runs,
    update_config_fingerprint,
)

# Find project root
PROJECT_ROOT = Path.cwd()
while not (PROJECT_ROOT / "pyproject.toml").exists():
    if PROJECT_ROOT.parent == PROJECT_ROOT:
        raise RuntimeError("Could not find project root")
    PROJECT_ROOT = PROJECT_ROOT.parent
os.chdir(PROJECT_ROOT)


# --- Helpers for extracting metrics from run_summary.json ---
def extract_f1_from_summary(macro, metric_path, strategy=None):
    """Extract F1 from macro_averaged dict via shared adapter."""
    return extract_metric_from_nested(macro, metric_path, strategy, metric_key="f1")


def extract_f1_from_per_test_case(per_tc, metric_path, strategy=None):
    """Extract F1 from per_test_case entry via shared adapter."""
    return extract_metric_from_nested(per_tc, metric_path, strategy, metric_key="f1")


# --- Backfill fingerprints for existing runs ---
backfilled_count = 0
for phase in [PHASE1, PHASE2]:
    for run in list_runs(phase):
        run_dir = run["run_dir"]
        run_metadata = load_json(run_dir / "run.json")
        if "config_fingerprint" not in run_metadata:
            results_dir = run_dir / "results"
            if results_dir.exists() and list(results_dir.glob("*.json")):
                fingerprint = reconstruct_fingerprint_from_run(run_dir)
                update_config_fingerprint(run_dir, fingerprint)
                backfilled_count += 1
                print(f"  Backfilled: {run['run_id']} -> {fingerprint['fingerprint_hash']}")

if backfilled_count > 0:
    print(f"\nBackfilled {backfilled_count} runs.")
else:
    print("All runs already have fingerprints.")

print("\nSetup complete.")

In [None]:
# Cell 1 — Configuration

# Which phase to analyze
PHASE = PHASE2

# Optional: restrict to specific run IDs (None = all runs)
SELECTED_RUN_IDS = None  # e.g., ["run_20260212_181821", "run_20260213_173243"]

# Rerank strategy to focus on (for reranking runs) "situation_and_lesson" or "situation_only" - for "post_rerank" only
STRATEGY = "situation_only"

# Metric path: "post_rerank" for reranking runs, "pre_rerank" for vector-only baseline
METRIC_PATH = "post_rerank"

# For A/B comparison: select two fingerprint hashes to compare
# Set after viewing Cell 2 output; leave as None for auto-selection
COMPARE_HASH_A = None
COMPARE_HASH_B = None

# Derived label for titles/logs — avoids showing a strategy name when using pre_rerank
METRIC_LABEL = STRATEGY if METRIC_PATH == "post_rerank" else "pre-rerank (distance)"

print(f"Phase: {PHASE}")
print(f"Strategy: {STRATEGY}")
print(f"Metric path: {METRIC_PATH}")
print(f"Metric label: {METRIC_LABEL}")

In [None]:
# Cell 2 — Load Runs & Overview Table

summaries = load_run_summaries(PHASE, run_ids=SELECTED_RUN_IDS)
print(f"Loaded {len(summaries)} run summaries for {PHASE}\n")

groups = group_runs_by_fingerprint(summaries)
print(f"Found {len(groups)} config groups:\n")

# Overview table
header = f"{'Group':>8} {'Hash':>10} {'N Runs':>7} {'Avg F1':>8} {'Std':>8} {'Key Config Differences':>40}"
print(header)
print("-" * len(header))

group_labels = {}  # hash -> label for later cells
sorted_hashes = sorted(groups.keys(), key=lambda h: groups[h][0].get("run_id", ""))

for group_index, fingerprint_hash in enumerate(sorted_hashes):
    group_summaries = groups[fingerprint_hash]
    label = chr(ord("A") + group_index)
    group_labels[fingerprint_hash] = label

    # Extract F1 per run using backward-compatible helper
    f1_values = []
    for summary in group_summaries:
        macro = summary.get("macro_averaged", {})
        f1 = extract_f1_from_summary(macro, METRIC_PATH, STRATEGY)
        if f1 is not None:
            f1_values.append(f1)

    avg_f1 = np.mean(f1_values) if f1_values else 0.0
    std_f1 = np.std(f1_values, ddof=1) if len(f1_values) > 1 else 0.0

    # Key config from fingerprint
    fingerprint = group_summaries[0].get("config_fingerprint", {})
    config_desc = f"ext={fingerprint.get('extraction_prompt_version', '?')}, q={fingerprint.get('query_prompt_version', '?')}"

    print(
        f"{label:>8} {fingerprint_hash:>10} {len(group_summaries):>7} "
        f"{avg_f1:>8.3f} {std_f1:>8.4f} {config_desc:>40}"
    )

    # Print individual runs
    for summary in group_summaries:
        macro = summary.get("macro_averaged", {})
        run_f1 = extract_f1_from_summary(macro, METRIC_PATH, STRATEGY) or 0.0
        print(f"{'':>8} {'':>10} {'':>7} {run_f1:>8.3f} {'':>8} {summary['run_id']}")

print(f"\nGroup labels for A/B comparison: {group_labels}")

loaded_run_ids = sorted(
    [summary.get("run_id", "unknown") for summary in summaries if summary.get("run_id")]
)
if SELECTED_RUN_IDS:
    if loaded_run_ids:
        selected_part = f"selected-{loaded_run_ids[0]}-{loaded_run_ids[-1]}-n{len(loaded_run_ids)}"
    else:
        selected_part = "selected-none-none-n0"
else:
    selected_part = f"selected-all-runs-n{len(loaded_run_ids)}"

if sorted_hashes:
    groups_part = f"groups-{sorted_hashes[0][:8]}-{sorted_hashes[-1][:8]}-n{len(sorted_hashes)}"
else:
    groups_part = "groups-none-none-n0"

comparison_key = f"{selected_part}__{groups_part}__metric-{METRIC_PATH}__label-{METRIC_LABEL}"
FIGURE_SESSION = create_figure_session(
    root_dir=PROJECT_ROOT / "data" / "comparisons" / "exports",
    notebook_slug=f"cross_run_comparison/{PHASE}",
    context_key=comparison_key,
    context={
        "phase": PHASE,
        "strategy": STRATEGY,
        "metric_path": METRIC_PATH,
        "metric_label": METRIC_LABEL,
        "selected_run_ids": SELECTED_RUN_IDS,
        "loaded_run_ids": loaded_run_ids,
        "group_hashes": sorted_hashes,
    },
)
print(f"Figure export session: {FIGURE_SESSION.session_dir}")

In [None]:
# Cell 3 — Run Timeline
# Each run gets its own x position (sorted by time), so batch runs never overlap.
# X tick labels show the actual timestamp. Dashed lines connect consecutive runs
# within the same config group to make trends visible across sequential experiments.

from datetime import datetime

# Color map shared with Cell 6
colors_map = plt.cm.Set2(np.linspace(0, 1, max(len(groups), 3)))
hash_to_color = {
    fingerprint_hash: colors_map[i] for i, fingerprint_hash in enumerate(sorted_hashes)
}

# Collect all runs with timestamps and F1 values
run_records = []
for fingerprint_hash, group_summaries in groups.items():
    for summary in group_summaries:
        run_id = summary.get("run_id", "")
        try:
            # Strip "run_" prefix and take only date+time parts (first two underscore-separated
            # segments), so both old format (YYYYMMDD_HHMMSS) and new format
            # (YYYYMMDD_HHMMSS_ffffff) parse correctly.
            run_id_parts = run_id.replace("run_", "").split("_")
            timestamp = datetime.strptime("_".join(run_id_parts[:2]), "%Y%m%d_%H%M%S")
        except (ValueError, AttributeError):
            continue
        macro = summary.get("macro_averaged", {})
        f1 = extract_f1_from_summary(macro, METRIC_PATH, STRATEGY)
        if f1 is not None:
            run_records.append(
                {
                    "timestamp": timestamp,
                    "f1": f1,
                    "fingerprint_hash": fingerprint_hash,
                }
            )

run_records.sort(key=lambda record: record["timestamp"])

if not run_records:
    print("No runs with valid timestamps found.")
else:
    fig, ax = plt.subplots(figsize=(14, 5))

    # One dot per run at its sorted index — no overlap regardless of how close timestamps are
    for run_index, record in enumerate(run_records):
        color = hash_to_color[record["fingerprint_hash"]]
        ax.scatter(
            run_index, record["f1"], color=color, s=80, zorder=5, edgecolors="white", linewidth=0.5
        )

    # Dashed trend line connecting runs of the same config group in time order
    for fingerprint_hash in sorted_hashes:
        group_indices = [
            (i, r["f1"])
            for i, r in enumerate(run_records)
            if r["fingerprint_hash"] == fingerprint_hash
        ]
        if len(group_indices) > 1:
            xs, ys = zip(*group_indices)
            ax.plot(
                xs,
                ys,
                color=hash_to_color[fingerprint_hash],
                linewidth=1,
                alpha=0.35,
                linestyle="--",
            )

    # X tick labels: actual timestamps so temporal context is not lost
    ax.set_xticks(range(len(run_records)))
    ax.set_xticklabels(
        [r["timestamp"].strftime("%m-%d %H:%M") for r in run_records],
        rotation=45,
        ha="right",
        fontsize=7,
    )

    # Legend
    for fingerprint_hash in sorted_hashes:
        label = group_labels[fingerprint_hash]
        fingerprint = groups[fingerprint_hash][0].get("config_fingerprint", {})
        desc = (
            f"Group {label}: ext={fingerprint.get('extraction_prompt_version', '?')}, "
            f"q={fingerprint.get('query_prompt_version', '?')}"
        )
        ax.scatter([], [], color=hash_to_color[fingerprint_hash], s=80, label=desc)

    ax.set_xlabel("Run (sorted by time)")
    ax.set_ylabel(f"Macro F1 ({METRIC_LABEL})")
    ax.set_title(f"F1 Over Time — {PHASE}")
    ax.legend(fontsize=8, loc="best")
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    saved_paths = save_figure(
        fig,
        FIGURE_SESSION,
        "f1_over_time",
        title=f"F1 Over Time — {PHASE}",
    )
    plt.show()
    print(f"Saved: {saved_paths['png']}")

In [None]:
# Cell 4 — Reproducibility (Same-Config Variance)

print("VARIANCE ANALYSIS (same-config runs)\n")

# Find groups with multiple runs
multi_run_groups = {h: g for h, g in groups.items() if len(g) >= 2}

if not multi_run_groups:
    print("No config groups with 2+ runs found. Run more experiments with the same config.")
else:
    # Box plots of per-run macro F1 per config group
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Left: Box plot of run-level F1
    ax = axes[0]
    box_data = []
    box_labels = []
    for fingerprint_hash, group_summaries in multi_run_groups.items():
        label = group_labels.get(fingerprint_hash, fingerprint_hash[:8])
        f1_values = []
        for summary in group_summaries:
            macro = summary.get("macro_averaged", {})
            f1 = extract_f1_from_summary(macro, METRIC_PATH, STRATEGY)
            if f1 is not None:
                f1_values.append(f1)
        if f1_values:
            box_data.append(f1_values)
            box_labels.append(f"Group {label}\n(n={len(f1_values)})")

    if box_data:
        bp = ax.boxplot(box_data, tick_labels=box_labels, patch_artist=True)
        for patch in bp["boxes"]:
            patch.set_facecolor("#3498db")
            patch.set_alpha(0.4)
        # Overlay individual points
        for group_index, values in enumerate(box_data):
            x_positions = np.random.normal(group_index + 1, 0.04, len(values))
            ax.scatter(x_positions, values, color="#2c3e50", s=40, zorder=5, alpha=0.7)
        ax.set_ylabel("Macro F1")
        ax.set_title("Run-Level F1 Variance (Same Config)")
        ax.grid(True, alpha=0.3, axis="y")

    # Right: Per-test-case variance table as text
    ax = axes[1]
    ax.axis("off")

    variance_text = ""
    for fingerprint_hash, group_summaries in multi_run_groups.items():
        label = group_labels.get(fingerprint_hash, fingerprint_hash[:8])
        report = compute_variance_report(
            group_summaries, strategy=STRATEGY, metric_path=METRIC_PATH
        )

        run_level = report.get("run_level", {})
        variance_text += f"Group {label} ({report['num_runs']} runs):\n"
        variance_text += f"  F1: {run_level.get('mean', 0):.3f} +/- {run_level.get('std', 0):.4f}\n"
        variance_text += f"  95% CI: [{run_level.get('ci_95_lower', 0):.3f}, {run_level.get('ci_95_upper', 0):.3f}]\n"
        variance_text += f"  CV: {run_level.get('cv', 0):.3f}\n\n"

        # Per-test-case with highest variance
        per_tc = report.get("per_test_case", {})
        if per_tc:
            sorted_by_cv = sorted(
                per_tc.items(), key=lambda item: item[1].get("cv", 0), reverse=True
            )
            variance_text += "  Most variable test cases:\n"
            for test_case_id, stats in sorted_by_cv[:5]:
                variance_text += f"    {test_case_id}: F1={stats['mean']:.3f} +/- {stats['std']:.4f} (CV={stats['cv']:.3f})\n"
            variance_text += "\n"

    ax.text(
        0,
        1,
        variance_text,
        transform=ax.transAxes,
        fontsize=8,
        verticalalignment="top",
        fontfamily="monospace",
    )
    ax.set_title("Variance Details")

    plt.tight_layout()
    saved_paths = save_figure(
        fig,
        FIGURE_SESSION,
        "same_config_variance",
        title="Run-Level F1 Variance (Same Config)",
    )
    plt.show()
    print(f"Saved: {saved_paths['png']}")

In [None]:
# Cell 4b — Pre-Rerank vs Post-Rerank Effect Analysis (Within Config Group)
#
# For each config group that has reranking data, compares pre-rerank (vector-only)
# F1 against post-rerank F1 using per-run paired analysis via
# retrieval_metrics.compute_effect_size.
#
# Each run contributes one pre-rerank F1 and one post-rerank F1 — a natural
# paired design since both metrics come from the same run.
#
# Note: These are Level 1 F1 values (per-test-case optimal threshold).
# Use for: "does reranking help?" — NOT for estimating production F1.

from retrieval_metrics import compute_effect_size

if not multi_run_groups:
    print("No config groups with 2+ runs found.")
else:
    print("=" * 92)
    print(f"PRE-RERANK vs POST-RERANK EFFECT ANALYSIS (per-run paired, strategy={STRATEGY})")
    print("=" * 92)
    print()
    print("Note: F1 values are Level 1 (per-test-case optimal threshold, most optimistic).")
    print("Use for comparing pre vs post rerank, NOT for estimating production F1.")

    for fingerprint_hash, group_summaries in multi_run_groups.items():
        label = group_labels.get(fingerprint_hash, fingerprint_hash[:8])

        # Check if this group has reranking data
        has_rerank = any(
            extract_f1_from_summary(
                summary.get("macro_averaged", {}), "post_rerank", STRATEGY
            ) is not None
            for summary in group_summaries
        )

        if not has_rerank:
            print(f"\nGroup {label} ({fingerprint_hash}): No post-rerank data found. Skipping.")
            continue

        # Collect per-run paired differences: post_rerank - pre_rerank
        pre_rerank_f1_values = []
        post_rerank_f1_values = []
        paired_differences = []

        for summary in group_summaries:
            macro = summary.get("macro_averaged", {})
            pre_f1 = extract_f1_from_summary(macro, "pre_rerank", None)
            post_f1 = extract_f1_from_summary(macro, "post_rerank", STRATEGY)

            if pre_f1 is not None and post_f1 is not None:
                pre_rerank_f1_values.append(pre_f1)
                post_rerank_f1_values.append(post_f1)
                paired_differences.append(post_f1 - pre_f1)

        num_pairs = len(paired_differences)
        if num_pairs < 2:
            print(
                f"\nGroup {label}: Too few runs with both pre/post data ({num_pairs}). "
                f"Need at least 2."
            )
            continue

        mean_pre = sum(pre_rerank_f1_values) / num_pairs
        mean_post = sum(post_rerank_f1_values) / num_pairs

        # Use the library for all statistical calculations
        effect = compute_effect_size(paired_differences)

        print(f"\n{'─' * 92}")
        print(f"  Group {label} ({fingerprint_hash}) — {num_pairs} runs")
        print(f"  Pre-rerank (vector only) vs Post-rerank ({STRATEGY})")
        print(f"{'─' * 92}")
        print(f"  Pre-rerank mean F1:     {mean_pre:.4f}")
        print(f"  Post-rerank mean F1:    {mean_post:.4f}")
        print(f"  Mean diff (post-pre):   {effect.mean_difference:+.4f}")
        print(f"  Std of paired diffs:    {effect.std_of_differences:.4f}")
        print(f"  Standard error:         {effect.standard_error:.4f}")
        print(f"  t-statistic:            {effect.t_statistic:.3f}")
        print(f"  95% CI:                 [{effect.ci_lower:+.4f}, {effect.ci_upper:+.4f}]")
        print(f"  CI excludes zero:       {effect.ci_excludes_zero}")
        print()
        print(f"  Cohen's d:              {effect.cohens_d:+.3f}  ({effect.effect_category})")
        print()
        print(f"  Runs post > pre:        {effect.count_positive}/{effect.n}")
        print(f"  Runs post < pre:        {effect.count_negative}/{effect.n}")
        print(f"  Runs tied:              {effect.count_zero}/{effect.n}")
        print()
        print(f"  MDE at n={effect.n}:           {effect.minimum_detectable_effect:.4f}")
        print(f"  Observed |effect|:      {abs(effect.mean_difference):.4f}")

        if effect.power_sufficient:
            print("  -> Effect > MDE: DETECTABLE at 80% power")
        else:
            print("  -> Effect < MDE: BELOW detection threshold (power < 80%)")

        if effect.required_n is not None:
            print(f"  Required n for 80% power: {effect.required_n}")

        if effect.ci_excludes_zero:
            direction = "IMPROVEMENT" if effect.mean_difference > 0 else "DEGRADATION"
            print(f"\n  VERDICT: Statistically significant {direction}")
        else:
            print(f"\n  VERDICT: NOT significant (CI includes zero)")

In [None]:
# Cell 5 — A/B Comparison (Select Two Groups)

# Auto-select: first two groups, or use COMPARE_HASH_A / COMPARE_HASH_B
if COMPARE_HASH_A and COMPARE_HASH_B:
    hash_a, hash_b = COMPARE_HASH_A, COMPARE_HASH_B
elif len(sorted_hashes) >= 2:
    hash_a, hash_b = sorted_hashes[0], sorted_hashes[-1]
else:
    print("Need at least 2 config groups for A/B comparison.")
    hash_a, hash_b = None, None

if hash_a and hash_b:
    label_a = group_labels.get(hash_a, hash_a[:8])
    label_b = group_labels.get(hash_b, hash_b[:8])
    summaries_a = groups[hash_a]
    summaries_b = groups[hash_b]

    print(f"Comparing Group {label_a} ({hash_a}) vs Group {label_b} ({hash_b})")

    # Show config diff
    fp_a = summaries_a[0].get("config_fingerprint", {})
    fp_b = summaries_b[0].get("config_fingerprint", {})
    diff = fingerprint_diff(fp_a, fp_b)
    if diff:
        print("\nConfig differences:")
        for field, values in diff.items():
            print(f"  {field}: {values['a']} -> {values['b']}")

    # Run comparison
    comparison = compare_configs(
        summaries_a,
        summaries_b,
        strategy=STRATEGY,
        metric_path=METRIC_PATH,
    )

    if "error" in comparison:
        print(f"\nError: {comparison['error']}")
    else:
        print(f"\n{'=' * 60}")
        print("STATISTICAL COMPARISON")
        print(f"{'=' * 60}")
        print(f"  Mean F1 (A): {comparison['mean_f1_a']:.3f}")
        print(f"  Mean F1 (B): {comparison['mean_f1_b']:.3f}")
        print(f"  Mean diff (B-A): {comparison['mean_diff']:+.4f}")
        print(f"  Direction: {comparison['direction']}")
        print(
            f"  Bootstrap 95% CI: [{comparison['bootstrap_ci']['lower']:.4f}, {comparison['bootstrap_ci']['upper']:.4f}]"
        )
        print(f"  Excludes zero: {comparison['bootstrap_ci']['excludes_zero']}")
        print(f"  Wilcoxon p-value: {comparison['wilcoxon']['p_value']:.4f}")
        print(f"  Significance: {comparison['significance']}")

        # Paired dot plot
        fig, ax = plt.subplots(figsize=(10, 6))
        paired = comparison["paired_differences"]
        test_case_ids = [entry["test_case_id"] for entry in paired]
        y_positions = list(range(len(test_case_ids)))

        for y_pos, entry in enumerate(paired):
            color = (
                "#2ecc71"
                if entry["diff"] > 0.001
                else "#e74c3c"
                if entry["diff"] < -0.001
                else "#95a5a6"
            )
            ax.plot(
                [entry["f1_a"], entry["f1_b"]],
                [y_pos, y_pos],
                color=color,
                linewidth=2,
                alpha=0.6,
            )
            ax.scatter(entry["f1_a"], y_pos, color="#3498db", s=60, zorder=5, edgecolors="white")
            ax.scatter(entry["f1_b"], y_pos, color="#e67e22", s=60, zorder=5, edgecolors="white")

        ax.set_yticks(y_positions)
        ax.set_yticklabels(test_case_ids, fontsize=8)
        ax.set_xlabel("F1 Score")
        ax.set_title(f"Paired Comparison: Group {label_a} (blue) vs Group {label_b} (orange)")
        ax.scatter([], [], color="#3498db", s=60, label=f"Group {label_a}")
        ax.scatter([], [], color="#e67e22", s=60, label=f"Group {label_b}")
        ax.plot([], [], color="#2ecc71", linewidth=2, label="B improved")
        ax.plot([], [], color="#e74c3c", linewidth=2, label="B degraded")
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3, axis="x")
        plt.tight_layout()
        saved_paths = save_figure(
            fig,
            FIGURE_SESSION,
            "ab_paired_comparison",
            title=f"Paired Comparison: Group {label_a} vs Group {label_b}",
        )
        plt.show()
        print(f"Saved: {saved_paths['png']}")

In [None]:
# Cell 6 — Threshold Sweep Overlay
# Supports both metric paths:
# - post_rerank: rerank score threshold sweep from rerank_strategies[STRATEGY].threshold_sweep
# - pre_rerank: distance threshold sweep from baseline.distance_threshold_sweep

if not sorted_hashes:
    print("Cannot plot: no config groups loaded.")
else:
    fig, ax = plt.subplots(figsize=(14, 6))

    line_styles = ["-", "--", "-.", ":"]
    is_pre_rerank_mode = METRIC_PATH == "pre_rerank"
    threshold_label = (
        "Distance Threshold (cosine distance)" if is_pre_rerank_mode else "Rerank Score Threshold"
    )
    annotation_prefix = "d" if is_pre_rerank_mode else "t"
    plot_mode_label = "Distance" if is_pre_rerank_mode else "Rerank Score"

    plotted_groups = 0

    for group_index, fingerprint_hash in enumerate(sorted_hashes):
        group_summaries = groups[fingerprint_hash]
        label = group_labels.get(fingerprint_hash, fingerprint_hash[:8])
        color = hash_to_color[fingerprint_hash]

        # Collect threshold sweeps from all runs in this group.
        group_sweeps = []  # list of {threshold: f1} dicts
        for summary in group_summaries:
            if is_pre_rerank_mode:
                sweep_data = summary.get("baseline", {}).get("distance_threshold_sweep", {})
            else:
                sweep_data = (
                    summary.get("rerank_strategies", {})
                    .get(STRATEGY, {})
                    .get("threshold_sweep", {})
                )

            full_sweep = sweep_data.get("full_sweep", [])
            if full_sweep:
                sweep_dict = {entry["threshold"]: entry["f1"] for entry in full_sweep}
                group_sweeps.append(sweep_dict)

        if not group_sweeps:
            continue

        # Get common thresholds (all runs share the same fixed grid).
        all_thresholds = sorted(set().union(*[set(sweep.keys()) for sweep in group_sweeps]))

        # Compute mean and min/max bands as numpy arrays.
        mean_f1_values = np.zeros(len(all_thresholds))
        min_f1_values = np.zeros(len(all_thresholds))
        max_f1_values = np.zeros(len(all_thresholds))
        for threshold_index, threshold in enumerate(all_thresholds):
            f1_at_threshold = [sweep[threshold] for sweep in group_sweeps if threshold in sweep]
            if f1_at_threshold:
                mean_f1_values[threshold_index] = np.mean(f1_at_threshold)
                min_f1_values[threshold_index] = min(f1_at_threshold)
                max_f1_values[threshold_index] = max(f1_at_threshold)

        ax.plot(
            all_thresholds,
            mean_f1_values,
            label=f"Group {label} (n={len(group_sweeps)})",
            color=color,
            linewidth=2,
            linestyle=line_styles[group_index % len(line_styles)],
        )
        plotted_groups += 1

        # Shaded band for min-max (if multiple runs).
        if len(group_sweeps) > 1:
            ax.fill_between(all_thresholds, min_f1_values, max_f1_values, color=color, alpha=0.15)

        # Vertical line + annotation at optimal threshold (mean curve peak).
        optimal_index = np.argmax(mean_f1_values)
        optimal_threshold = all_thresholds[optimal_index]
        optimal_f1 = mean_f1_values[optimal_index]
        ax.axvline(optimal_threshold, color=color, linestyle=":", alpha=0.5, linewidth=1)
        ax.annotate(
            f"{annotation_prefix}={optimal_threshold:.3f}\nF1={optimal_f1:.3f}",
            xy=(optimal_threshold, optimal_f1),
            xytext=(10, 10),
            textcoords="offset points",
            fontsize=7,
            color=color,
            ha="left",
        )

    if plotted_groups == 0:
        if is_pre_rerank_mode:
            print(
                "No distance threshold sweep data found for METRIC_PATH='pre_rerank'. "
                "Expected baseline.distance_threshold_sweep.full_sweep in run summaries."
            )
        else:
            print(
                f"No rerank threshold sweep data found for METRIC_PATH='post_rerank' "
                f"and STRATEGY='{STRATEGY}'."
            )
        plt.close(fig)
    else:
        ax.set_xlabel(threshold_label)
        ax.set_ylabel("Macro F1")
        ax.set_title(f"{plot_mode_label} Threshold Sweep Overlay — {PHASE} ({METRIC_LABEL})")
        ax.legend(fontsize=8)
        ax.set_ylim(0, 1.05)
        ax.grid(True, alpha=0.3)
        plt.tight_layout()
        saved_paths = save_figure(
            fig,
            FIGURE_SESSION,
            "threshold_sweep_overlay",
            title=f"{plot_mode_label} Threshold Sweep Overlay — {PHASE} ({METRIC_LABEL})",
        )
        plt.show()
        print(f"Saved: {saved_paths['png']}")

In [None]:
# Cell 7 — Per-Test-Case Heatmap

# Collect F1 per test case per group (averaged across runs in each group)
all_test_case_ids = set()
for group_summaries in groups.values():
    for summary in group_summaries:
        all_test_case_ids.update(summary.get("per_test_case", {}).keys())
all_test_case_ids = sorted(all_test_case_ids)

heatmap_data = np.full((len(all_test_case_ids), len(sorted_hashes)), np.nan)
column_labels = []

for col_index, fingerprint_hash in enumerate(sorted_hashes):
    label = group_labels.get(fingerprint_hash, fingerprint_hash[:8])
    column_labels.append(f"Group {label}")
    group_summaries = groups[fingerprint_hash]

    for row_index, test_case_id in enumerate(all_test_case_ids):
        f1_values = []
        for summary in group_summaries:
            per_tc = summary.get("per_test_case", {}).get(test_case_id, {})
            f1 = extract_f1_from_per_test_case(per_tc, METRIC_PATH, STRATEGY)
            if f1 is not None:
                f1_values.append(f1)
        if f1_values:
            heatmap_data[row_index, col_index] = np.mean(f1_values)

fig, ax = plt.subplots(
    figsize=(max(8, len(sorted_hashes) * 2), max(6, len(all_test_case_ids) * 0.5))
)
sns.heatmap(
    heatmap_data,
    annot=True,
    fmt=".2f",
    xticklabels=column_labels,
    yticklabels=all_test_case_ids,
    cmap="RdYlGn",
    vmin=0,
    vmax=1,
    linewidths=0.5,
    ax=ax,
)
ax.set_title(f"Per-Test-Case F1 by Config Group — {PHASE} ({METRIC_LABEL})")
ax.set_ylabel("Test Case")
ax.set_xlabel("Config Group")
plt.tight_layout()
saved_paths = save_figure(
    fig,
    FIGURE_SESSION,
    "per_test_case_heatmap",
    title=f"Per-Test-Case F1 by Config Group — {PHASE} ({METRIC_LABEL})",
)
plt.show()
print(f"Saved: {saved_paths['png']}")
# Hardest test cases (lowest average F1 across groups)
avg_f1_per_tc = np.nanmean(heatmap_data, axis=1)
sorted_indices = np.argsort(avg_f1_per_tc)
print("\nHardest test cases (lowest avg F1 across all groups):")
for rank, index in enumerate(sorted_indices[:5]):
    print(f"  {rank + 1}. {all_test_case_ids[index]}: avg F1 = {avg_f1_per_tc[index]:.3f}")

In [None]:
# Cell 8 — Summary Dashboard

print("=" * 60)
print("SUMMARY DASHBOARD")
print("=" * 60)

# Best config group
best_hash = None
best_f1 = -1.0
for fingerprint_hash, group_summaries in groups.items():
    f1_values = []
    for summary in group_summaries:
        macro = summary.get("macro_averaged", {})
        f1 = extract_f1_from_summary(macro, METRIC_PATH, STRATEGY)
        if f1 is not None:
            f1_values.append(f1)
    avg_f1 = np.mean(f1_values) if f1_values else 0.0
    if avg_f1 > best_f1:
        best_f1 = avg_f1
        best_hash = fingerprint_hash

if best_hash:
    best_label = group_labels.get(best_hash, best_hash[:8])
    best_fp = groups[best_hash][0].get("config_fingerprint", {})
    print(f"\nBest config: Group {best_label} (hash: {best_hash})")
    print(f"  Avg F1: {best_f1:.3f}")
    print(f"  Extraction prompt: v{best_fp.get('extraction_prompt_version', '?')}")
    print(f"  Query prompt: v{best_fp.get('query_prompt_version', '?')}")
    print(f"  Query model: {best_fp.get('query_model', '?')}")
    print(f"  Reranker: {best_fp.get('reranker_model', 'None')}")
    print(f"  N runs: {len(groups[best_hash])}")

# Hardest test cases
print("\nHardest test cases (avg F1 across all groups):")
for rank, index in enumerate(sorted_indices[:3]):
    print(f"  {rank + 1}. {all_test_case_ids[index]}: {avg_f1_per_tc[index]:.3f}")

# Recommended next experiments
print("\nRecommended next steps:")
multi_run_count = sum(1 for g in groups.values() if len(g) >= 2)
if multi_run_count == 0:
    print("  - Run same config again to measure LLM variance (need 2+ runs per config)")
if len(groups) < 3:
    print("  - Try a new prompt version to expand comparison space")
print("  - Investigate hardest test cases to understand failure modes")
print(f"  - Total runs analyzed: {len(summaries)} across {len(groups)} config groups")