# Rerank Strategy Comparison

Compare pre-rerank vector search baseline vs post-rerank results for all strategies.

**What this notebook answers:**
- Does reranking actually improve over vector search alone?
- Which reranking strategy performs best?
- For which test cases does reranking help — or hurt?
- What are the optimal threshold and top-N settings per strategy?

When multiple runs are loaded, metrics are averaged to reduce LLM query generation variance.  
Requires runs with reranking data (`phase2.ipynb` or `phase1_reranking_comparison.ipynb`).

In [None]:
# Cell 0 — Setup, Imports & Helpers
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from memory_retrieval.experiments.comparison import (
    group_runs_by_fingerprint,
    load_run_summaries,
    load_subrun_summaries,
    reconstruct_fingerprint_from_run,
)
from memory_retrieval.experiments.metrics_adapter import extract_metric_from_nested
from memory_retrieval.infra.io import load_json
from memory_retrieval.infra.runs import (
    PHASE1,
    PHASE2,
    list_runs,
    update_config_fingerprint,
)

# Find project root
PROJECT_ROOT = Path.cwd()
while not (PROJECT_ROOT / "pyproject.toml").exists():
    if PROJECT_ROOT.parent == PROJECT_ROOT:
        raise RuntimeError("Could not find project root")
    PROJECT_ROOT = PROJECT_ROOT.parent
os.chdir(PROJECT_ROOT)


def is_reranking_summary(summary):
    """Check if a summary contains reranking data."""
    return "baseline" in summary or "rerank_strategies" in summary


def get_all_strategies(summaries):
    """Auto-detect all rerank strategy names present across summaries."""
    strategy_names = set()
    for summary in summaries:
        for strategy_name in summary.get("rerank_strategies", {}).keys():
            strategy_names.add(strategy_name)
        for strategy_name in summary.get("macro_averaged", {}).get("post_rerank", {}).keys():
            strategy_names.add(strategy_name)
    return sorted(strategy_names)


def extract_macro_stage_metrics(summary, stage_key):
    """Extract macro-averaged P/R/F1/MRR for a given stage key.

    Stage keys:
    - pre_rerank_distance: pre-rerank at optimal distance threshold
    - pre_rerank_top_n: pre-rerank at optimal top-N
    - post_rerank_threshold:<strategy>: post-rerank at optimal rerank threshold
    - post_rerank_top_n:<strategy>: post-rerank at optimal top-N
    """
    macro = summary.get("macro_averaged", {})
    if stage_key == "pre_rerank_distance":
        return macro.get("pre_rerank", {}).get("at_optimal_distance_threshold", {})
    elif stage_key == "pre_rerank_top_n":
        return macro.get("pre_rerank", {}).get("at_optimal_top_n", {})
    elif stage_key.startswith("post_rerank_threshold:"):
        strategy_name = stage_key.split(":", 1)[1]
        return macro.get("post_rerank", {}).get(strategy_name, {}).get("at_optimal_threshold", {})
    elif stage_key.startswith("post_rerank_top_n:"):
        strategy_name = stage_key.split(":", 1)[1]
        return macro.get("post_rerank", {}).get(strategy_name, {}).get("at_optimal_top_n", {})
    return {}


def average_stage_metrics_across_runs(summaries, stage_key):
    """Average P/R/F1/MRR across multiple runs for a given stage key."""
    all_metrics = [extract_macro_stage_metrics(summary, stage_key) for summary in summaries]
    valid_metrics = [metrics for metrics in all_metrics if metrics and "f1" in metrics]
    if not valid_metrics:
        return {}
    return {
        key: round(float(np.mean([metrics[key] for metrics in valid_metrics if key in metrics])), 4)
        for key in ["precision", "recall", "f1", "mrr"]
        if any(key in metrics for metrics in valid_metrics)
    }


def extract_per_test_case_stage_f1(summary, stage_key):
    """Extract per-test-case F1 for a given stage key."""
    f1_by_test_case = {}
    for test_case_id, per_tc in summary.get("per_test_case", {}).items():
        if stage_key == "pre_rerank_distance":
            f1 = extract_metric_from_nested(per_tc, "pre_rerank", None, metric_key="f1")
        elif stage_key == "pre_rerank_top_n":
            f1 = per_tc.get("pre_rerank", {}).get("top_n", {}).get("at_optimal", {}).get("f1")
        elif stage_key.startswith("post_rerank_threshold:"):
            strategy_name = stage_key.split(":", 1)[1]
            f1 = extract_metric_from_nested(per_tc, "post_rerank", strategy_name, metric_key="f1")
        elif stage_key.startswith("post_rerank_top_n:"):
            strategy_name = stage_key.split(":", 1)[1]
            f1 = (
                per_tc.get("post_rerank", {})
                .get(strategy_name, {})
                .get("top_n", {})
                .get("at_optimal", {})
                .get("f1")
            )
        else:
            f1 = None
        if f1 is not None:
            f1_by_test_case[test_case_id] = f1
    return f1_by_test_case


def average_per_test_case_stage_f1_across_runs(summaries, stage_key):
    """Average per-test-case F1 across multiple runs for a given stage key."""
    all_test_case_ids = set()
    for summary in summaries:
        all_test_case_ids.update(summary.get("per_test_case", {}).keys())
    averaged_f1 = {}
    for test_case_id in sorted(all_test_case_ids):
        f1_values = []
        for summary in summaries:
            per_tc_f1_map = extract_per_test_case_stage_f1(summary, stage_key)
            if test_case_id in per_tc_f1_map:
                f1_values.append(per_tc_f1_map[test_case_id])
        if f1_values:
            averaged_f1[test_case_id] = round(float(np.mean(f1_values)), 4)
    return averaged_f1


def get_averaged_sweep_curve(summaries, sweep_type):
    """Compute averaged F1 curve across runs for a given sweep type.

    sweep_type options:
    - distance_threshold: pre-rerank distance threshold sweep
    - pre_top_n: pre-rerank top-N sweep
    - rerank_threshold:<strategy>: post-rerank rerank score threshold sweep
    - post_top_n:<strategy>: post-rerank top-N sweep

    Returns (x_values, mean_f1, min_f1, max_f1, mean_optimal_x).
    """
    all_sweep_dicts = []
    optimal_x_values = []

    for summary in summaries:
        if sweep_type == "distance_threshold":
            sweep_section = summary.get("baseline", {}).get("distance_threshold_sweep", {})
            full_sweep = sweep_section.get("full_sweep", [])
            x_key = "threshold"
            optimal_x_values.append(sweep_section.get("optimal_threshold"))
        elif sweep_type == "pre_top_n":
            sweep_section = summary.get("baseline", {}).get("top_n_sweep", {})
            full_sweep = sweep_section.get("full_sweep", [])
            x_key = "top_n"
            optimal_x_values.append(sweep_section.get("optimal_n"))
        elif sweep_type.startswith("rerank_threshold:"):
            strategy_name = sweep_type.split(":", 1)[1]
            sweep_section = (
                summary.get("rerank_strategies", {})
                .get(strategy_name, {})
                .get("threshold_sweep", {})
            )
            full_sweep = sweep_section.get("full_sweep", [])
            x_key = "threshold"
            optimal_x_values.append(sweep_section.get("optimal_threshold"))
        elif sweep_type.startswith("post_top_n:"):
            strategy_name = sweep_type.split(":", 1)[1]
            sweep_section = (
                summary.get("rerank_strategies", {}).get(strategy_name, {}).get("top_n_sweep", {})
            )
            full_sweep = sweep_section.get("full_sweep", [])
            x_key = "top_n"
            optimal_x_values.append(sweep_section.get("optimal_n"))
        else:
            continue

        if full_sweep:
            sweep_dict = {round(entry[x_key], 4): entry["f1"] for entry in full_sweep}
            all_sweep_dicts.append(sweep_dict)

    if not all_sweep_dicts:
        return [], np.array([]), np.array([]), np.array([]), None

    all_x_values = sorted(set().union(*[set(sweep.keys()) for sweep in all_sweep_dicts]))
    mean_f1_values = np.zeros(len(all_x_values))
    min_f1_values = np.zeros(len(all_x_values))
    max_f1_values = np.zeros(len(all_x_values))

    for x_index, x_value in enumerate(all_x_values):
        f1_at_x = [sweep[x_value] for sweep in all_sweep_dicts if x_value in sweep]
        if f1_at_x:
            mean_f1_values[x_index] = np.mean(f1_at_x)
            min_f1_values[x_index] = min(f1_at_x)
            max_f1_values[x_index] = max(f1_at_x)

    valid_optimal_x = [x for x in optimal_x_values if x is not None]
    mean_optimal_x = float(np.mean(valid_optimal_x)) if valid_optimal_x else None

    return all_x_values, mean_f1_values, min_f1_values, max_f1_values, mean_optimal_x


def load_selected_run_summaries(
    phase,
    run_source,
    selected_parent_run_ids=None,
    selected_subrun_ids=None,
):
    """Load parent/subrun summaries based on RUN_SOURCE selection."""
    valid_sources = {"parent", "subrun", "both"}
    if run_source not in valid_sources:
        allowed = ", ".join(sorted(valid_sources))
        raise ValueError(f"Invalid RUN_SOURCE={run_source!r}. Expected one of: {allowed}")

    parent_summaries = []
    subrun_summaries = []

    if run_source in {"parent", "both"}:
        parent_summaries = [
            {**summary, "source_kind": "parent"}
            for summary in load_run_summaries(phase, run_ids=selected_parent_run_ids)
        ]

    if run_source in {"subrun", "both"}:
        subrun_summaries = [
            {**summary, "source_kind": "subrun"}
            for summary in load_subrun_summaries(
                phase,
                parent_run_ids=selected_parent_run_ids,
                subrun_ids=selected_subrun_ids,
            )
        ]

    return parent_summaries + subrun_summaries


# --- Backfill fingerprints for existing runs (same as cross_run_comparison) ---
backfilled_count = 0
for phase in [PHASE1, PHASE2]:
    for run in list_runs(phase):
        run_dir = run["run_dir"]
        run_metadata = load_json(run_dir / "run.json")
        if "config_fingerprint" not in run_metadata:
            results_dir = run_dir / "results"
            if results_dir.exists() and list(results_dir.glob("*.json")):
                fingerprint = reconstruct_fingerprint_from_run(run_dir)
                update_config_fingerprint(run_dir, fingerprint)
                backfilled_count += 1
                run_id = run["run_id"]
                fingerprint_hash = fingerprint["fingerprint_hash"]
                print(f"  Backfilled: {run_id} -> {fingerprint_hash}")

if backfilled_count > 0:
    print(f"\nBackfilled {backfilled_count} runs.")
else:
    print("All runs already have fingerprints.")

print("\nSetup complete.")

In [None]:
# Cell 1 — Configuration

# Which phase to analyze
PHASE = PHASE2

# Which run source to analyze:
# - "parent": only parent runs (default, backward compatible)
# - "subrun": only subruns
# - "both": parents + subruns together
RUN_SOURCE = "parent"

# Optional: restrict to specific parent run IDs (None = all parent runs in phase).
# Used when RUN_SOURCE includes "parent", and as a parent constraint for subrun loading.
SELECTED_PARENT_RUN_IDS = None  # e.g., ['run_20260218_171450', 'run_20260218_172142']

# Optional: restrict to specific subrun IDs (None = all subruns under selected parents).
# Used only when RUN_SOURCE includes "subrun".
SELECTED_SUBRUN_IDS = None  # e.g., ['run_20260219_131817']

# Backward compatibility: map legacy SELECTED_RUN_IDS -> SELECTED_PARENT_RUN_IDS once.
if (
    "SELECTED_RUN_IDS" in globals()
    and SELECTED_RUN_IDS is not None
    and SELECTED_PARENT_RUN_IDS is None
):
    SELECTED_PARENT_RUN_IDS = SELECTED_RUN_IDS
    print("Deprecated: SELECTED_RUN_IDS is now SELECTED_PARENT_RUN_IDS.")

# Optional: restrict to a specific config fingerprint hash (None = all loaded runs).
# Useful when multiple configs exist and you want to isolate one.
# Set this after inspecting Cell 2's output, then re-run from Cell 2 onward.
# SELECTED_FINGERPRINT_HASH = '9c3ca8dc'  # Qwen 0.6B reranker
SELECTED_FINGERPRINT_HASH = None

print(f"Phase: {PHASE}")
print(f"Run source: {RUN_SOURCE}")
print(f"Parent run filter: {SELECTED_PARENT_RUN_IDS or 'all parents'}")
print(f"Subrun filter: {SELECTED_SUBRUN_IDS or 'all subruns (within parent filter)'}")
print(f"Fingerprint filter: {SELECTED_FINGERPRINT_HASH or 'none (all configs)'}")

In [None]:
# Cell 2 — Load Data & Overview Table

all_summaries = load_selected_run_summaries(
    PHASE,
    RUN_SOURCE,
    selected_parent_run_ids=SELECTED_PARENT_RUN_IDS,
    selected_subrun_ids=SELECTED_SUBRUN_IDS,
)

parent_count = sum(1 for summary in all_summaries if summary.get("source_kind") == "parent")
subrun_count = sum(1 for summary in all_summaries if summary.get("source_kind") == "subrun")

print(
    f"Loaded {len(all_summaries)} run summaries for {PHASE} "
    f"(parents={parent_count}, subruns={subrun_count}, source={RUN_SOURCE})"
)

if all_summaries:
    fingerprint_counts = {}
    for summary in all_summaries:
        fingerprint = summary.get("config_fingerprint", {})
        fingerprint_hash = fingerprint.get("fingerprint_hash", "unknown")
        fingerprint_counts[fingerprint_hash] = fingerprint_counts.get(fingerprint_hash, 0) + 1

    print("Fingerprint coverage:")
    for fingerprint_hash, count in sorted(fingerprint_counts.items()):
        print(f"  {fingerprint_hash}: {count} run(s)")
else:
    print("Fingerprint coverage: no runs loaded")

reranking_summaries = [summary for summary in all_summaries if is_reranking_summary(summary)]
non_reranking_count = len(all_summaries) - len(reranking_summaries)
print(
    f"Reranking summaries: {len(reranking_summaries)} (skipped {non_reranking_count} non-reranking)"
)

if not reranking_summaries:
    raise RuntimeError(
        "No reranking summaries found. This notebook requires runs with reranking data.\n"
        "Run phase2.ipynb or phase1_reranking_comparison.ipynb first."
    )

# Optionally filter to a specific fingerprint
if SELECTED_FINGERPRINT_HASH:
    config_groups = group_runs_by_fingerprint(reranking_summaries)
    if SELECTED_FINGERPRINT_HASH not in config_groups:
        available_hashes = list(config_groups.keys())
        print(f"Fingerprint not found. Available: {available_hashes}")
        raise RuntimeError(f"Fingerprint {SELECTED_FINGERPRINT_HASH} not found")
    analysis_summaries = config_groups[SELECTED_FINGERPRINT_HASH]
    print(
        f"Filtered to {len(analysis_summaries)} runs with fingerprint {SELECTED_FINGERPRINT_HASH}"
    )
else:
    config_groups = group_runs_by_fingerprint(reranking_summaries)
    if len(config_groups) > 1:
        print(
            f"\nWarning: {len(config_groups)} different configs found — averaging across ALL configs."
        )
        print("Set SELECTED_FINGERPRINT_HASH to focus on one config. Available hashes:")
        for fingerprint_hash, group_summaries in config_groups.items():
            fingerprint = group_summaries[0].get("config_fingerprint", {})
            extraction_version = fingerprint.get("extraction_prompt_version", "?")
            query_version = fingerprint.get("query_prompt_version", "?")
            print(
                f"  {fingerprint_hash}: {len(group_summaries)} runs, ext={extraction_version}, q={query_version}"
            )
    analysis_summaries = reranking_summaries

# Show config info
if analysis_summaries:
    config_fingerprint = analysis_summaries[0].get("config_fingerprint", {})
    if config_fingerprint:
        print(f"\nConfig (first run):")
        print(f"  Extraction prompt: v{config_fingerprint.get('extraction_prompt_version', '?')}")
        print(f"  Query prompt:      v{config_fingerprint.get('query_prompt_version', '?')}")
        print(f"  Query model:       {config_fingerprint.get('query_model', '?')}")
        print(f"  Reranker model:    {config_fingerprint.get('reranker_model', '?')}")

all_strategies = get_all_strategies(analysis_summaries)
print(f"\nRerank strategies: {all_strategies}")
print(f"Runs being averaged: {len(analysis_summaries)}")

# --- Build stage keys for all comparisons ---
# These variables are used in all subsequent cells.
comparison_stage_keys = ["pre_rerank_distance", "pre_rerank_top_n"]
for strategy_name in all_strategies:
    comparison_stage_keys.append(f"post_rerank_threshold:{strategy_name}")
    comparison_stage_keys.append(f"post_rerank_top_n:{strategy_name}")

stage_display_names = {
    "pre_rerank_distance": "Pre-Rerank (Distance Threshold)",
    "pre_rerank_top_n": "Pre-Rerank (Top-N)",
}
for strategy_name in all_strategies:
    stage_display_names[f"post_rerank_threshold:{strategy_name}"] = (
        f"Post-Rerank [{strategy_name}] (Threshold)"
    )
    stage_display_names[f"post_rerank_top_n:{strategy_name}"] = (
        f"Post-Rerank [{strategy_name}] (Top-N)"
    )

# Compute averaged metrics for every stage
stage_averaged_metrics = {
    stage_key: average_stage_metrics_across_runs(analysis_summaries, stage_key)
    for stage_key in comparison_stage_keys
}

# --- Print overview table ---
pre_rerank_baseline_f1 = stage_averaged_metrics.get("pre_rerank_distance", {}).get("f1", 0.0)

print(f"\n{'Stage':<50} {'Precision':>10} {'Recall':>8} {'F1':>8} {'Delta':>9} {'MRR':>8}")
print("-" * 98)

for stage_key in comparison_stage_keys:
    metrics = stage_averaged_metrics[stage_key]
    if not metrics:
        continue
    display_name = stage_display_names[stage_key]
    f1 = metrics.get("f1", 0.0)
    delta_str = f"{f1 - pre_rerank_baseline_f1:+.3f}" if stage_key != "pre_rerank_distance" else "—"
    print(
        f"{display_name:<50} "
        f"{metrics.get('precision', 0.0):>10.4f} "
        f"{metrics.get('recall', 0.0):>8.4f} "
        f"{f1:>8.4f} "
        f"{delta_str:>9} "
        f"{metrics.get('mrr', 0.0):>8.4f}"
    )
    if stage_key == "pre_rerank_top_n":
        print()  # blank line between pre-rerank and post-rerank sections

print(f"\nBaseline F1 (pre-rerank at optimal distance threshold): {pre_rerank_baseline_f1:.4f}")
print("Delta is relative to this baseline.")
print()
print("How to read these F1 values — there are 3 levels of optimism:")
print("  [most optimistic]  subrun_comparison Cell 4: each test case picks its own threshold")
print("  [this table]       each RUN picks one threshold best for all its test cases combined,")
print("                     then we average the 30 best results — not deployable, because in")
print("                     production you pick a threshold before seeing any queries")
print("  [most realistic]   Cell 4 sweep below: one fixed threshold for all runs — this is")
print("                     what you actually set in config and the F1 you can expect to see")

In [None]:
# Cell 3 — Metrics Bar Chart: Pre-Rerank vs All Strategies
# Shows F1, Precision, Recall, MRR side by side for each stage.
# Top-N versions are included to show the full picture.

# Stages to show in bars (threshold-optimized versions = the recommended metrics)
bar_stage_keys = ["pre_rerank_distance"]
for strategy_name in all_strategies:
    bar_stage_keys.append(f"post_rerank_threshold:{strategy_name}")

bar_short_labels = ["Pre-Rerank\n(Dist. Thresh.)"]
for strategy_name in all_strategies:
    bar_short_labels.append(f"{strategy_name}\n(Rerank Thresh.)")

# Color scheme: pre-rerank = gray, each strategy = distinct color
strategy_palette = plt.cm.Set2(np.linspace(0, 0.8, max(len(all_strategies), 3)))
bar_colors = ["#95a5a6"]  # gray for pre-rerank
for strategy_index in range(len(all_strategies)):
    bar_colors.append(strategy_palette[strategy_index])

metric_names = ["f1", "precision", "recall", "mrr"]
metric_display_names = {"f1": "F1", "precision": "Precision", "recall": "Recall", "mrr": "MRR"}

fig, axes = plt.subplots(1, 4, figsize=(18, 5))

for metric_index, metric_name in enumerate(metric_names):
    ax = axes[metric_index]

    bar_values = [
        stage_averaged_metrics.get(stage_key, {}).get(metric_name, 0.0)
        for stage_key in bar_stage_keys
    ]

    bars = ax.bar(
        range(len(bar_stage_keys)),
        bar_values,
        color=bar_colors[: len(bar_stage_keys)],
        alpha=0.85,
        edgecolor="white",
        linewidth=0.8,
    )

    # Baseline reference line at pre-rerank level
    baseline_value = stage_averaged_metrics.get("pre_rerank_distance", {}).get(metric_name, 0.0)
    if baseline_value > 0:
        ax.axhline(baseline_value, color="#7f8c8d", linestyle="--", linewidth=1, alpha=0.6)

    # Value annotations on bars
    for bar in bars:
        bar_height = bar.get_height()
        ax.annotate(
            f"{bar_height:.3f}",
            xy=(bar.get_x() + bar.get_width() / 2, bar_height),
            xytext=(0, 4),
            textcoords="offset points",
            ha="center",
            va="bottom",
            fontsize=8,
        )

    ax.set_xticks(range(len(bar_stage_keys)))
    ax.set_xticklabels(bar_short_labels, fontsize=7.5)
    ax.set_title(metric_display_names[metric_name], fontsize=11)
    ax.set_ylim(0, 1.12)
    ax.grid(True, axis="y", alpha=0.3)
    ax.set_ylabel(metric_display_names[metric_name] if metric_index == 0 else "")

fig.suptitle(
    f"Pre-Rerank vs Rerank Strategies — {PHASE}\n"
    f"(averaged across {len(analysis_summaries)} run(s), dashed line = pre-rerank baseline)",
    fontsize=11,
)
plt.tight_layout()
plt.show()

In [None]:
# Cell 4 — Threshold Sweep Comparison
# Left: pre-rerank distance threshold sweep (how F1 changes with distance cutoff)
# Right: per-strategy rerank score threshold sweep
# The axes are intentionally separate — distance and rerank score are different scales.

fig, axes = plt.subplots(1, 2, figsize=(16, 6))
strategy_palette = plt.cm.Set2(np.linspace(0, 0.8, max(len(all_strategies), 3)))

# --- Left: Pre-rerank distance threshold sweep ---
ax_left = axes[0]

distance_x_values, distance_mean_f1, distance_min_f1, distance_max_f1, optimal_distance = (
    get_averaged_sweep_curve(analysis_summaries, "distance_threshold")
)

if distance_x_values:
    num_runs_with_data = len(analysis_summaries)
    ax_left.plot(
        distance_x_values,
        distance_mean_f1,
        color="#2c3e50",
        linewidth=2.5,
        label=f"Pre-Rerank (n={num_runs_with_data})",
    )
    if len(analysis_summaries) > 1:
        ax_left.fill_between(
            distance_x_values,
            distance_min_f1,
            distance_max_f1,
            color="#2c3e50",
            alpha=0.12,
            label="Min-max range",
        )
    if optimal_distance is not None:
        optimal_index = int(np.argmin(np.abs(np.array(distance_x_values) - optimal_distance)))
        optimal_f1_value = distance_mean_f1[optimal_index]
        ax_left.axvline(optimal_distance, color="#2c3e50", linestyle=":", alpha=0.7, linewidth=1.5)
        ax_left.annotate(
            f"d={optimal_distance:.3f}\nF1={optimal_f1_value:.3f}",
            xy=(optimal_distance, optimal_f1_value),
            xytext=(8, -30),
            textcoords="offset points",
            fontsize=8,
            color="#2c3e50",
        )
else:
    ax_left.text(0.5, 0.5, "No distance sweep data", transform=ax_left.transAxes, ha="center")

ax_left.set_xlabel("Distance Threshold (lower = stricter filter, fewer results)")
ax_left.set_ylabel("Macro F1")
ax_left.set_title("Pre-Rerank: Distance Threshold Sweep")
ax_left.legend(fontsize=8)
ax_left.set_ylim(0, 1.05)
ax_left.grid(True, alpha=0.3)

# --- Right: Per-strategy rerank score threshold sweep ---
ax_right = axes[1]
line_styles = ["-", "--", "-.", ":"]

for strategy_index, strategy_name in enumerate(all_strategies):
    rerank_x_values, rerank_mean_f1, rerank_min_f1, rerank_max_f1, optimal_rerank = (
        get_averaged_sweep_curve(analysis_summaries, f"rerank_threshold:{strategy_name}")
    )
    if not rerank_x_values:
        continue

    color = strategy_palette[strategy_index]
    line_style = line_styles[strategy_index % len(line_styles)]

    ax_right.plot(
        rerank_x_values,
        rerank_mean_f1,
        color=color,
        linewidth=2.5,
        linestyle=line_style,
        label=strategy_name,
    )
    if len(analysis_summaries) > 1:
        ax_right.fill_between(
            rerank_x_values, rerank_min_f1, rerank_max_f1, color=color, alpha=0.12
        )
    if optimal_rerank is not None:
        optimal_index = int(np.argmin(np.abs(np.array(rerank_x_values) - optimal_rerank)))
        optimal_f1_value = rerank_mean_f1[optimal_index]
        ax_right.axvline(optimal_rerank, color=color, linestyle=":", alpha=0.6, linewidth=1.5)
        offset_y = -30 - strategy_index * 28
        ax_right.annotate(
            f"t={optimal_rerank:.3f}\nF1={optimal_f1_value:.3f}",
            xy=(optimal_rerank, optimal_f1_value),
            xytext=(8, offset_y),
            textcoords="offset points",
            fontsize=8,
            color=color,
        )

if not all_strategies:
    ax_right.text(0.5, 0.5, "No rerank strategy data", transform=ax_right.transAxes, ha="center")

ax_right.set_xlabel("Rerank Score Threshold (higher = stricter filter)")
ax_right.set_ylabel("Macro F1")
ax_right.set_title("Post-Rerank: Rerank Score Threshold Sweep")
if all_strategies:
    ax_right.legend(fontsize=8)
ax_right.set_ylim(0, 1.05)
ax_right.grid(True, alpha=0.3)

num_runs = len(analysis_summaries)
fig.suptitle(
    f"Threshold Sweep Comparison — {PHASE} (averaged across {num_runs} run(s))",
    fontsize=11,
)
plt.tight_layout()
plt.show()
print("Annotations (d=.../t=..., F1=...) = the most realistic deployment estimate [level 3/3].")
print("One fixed threshold is chosen (average of per-run optima), F1 is read off the averaged")
print("curve at that point. Lower than the table above because every run is forced to share the")
print("same threshold — nobody gets to pick their own. What you see here is what you will get.")

In [None]:
# Cell 5 — Top-N Sweep Comparison
# All stages on the same chart (same x-axis: number of returned results).
# Shows how F1 evolves as you return more results, for each stage.

fig, ax = plt.subplots(figsize=(14, 6))
strategy_palette = plt.cm.Set2(np.linspace(0, 0.8, max(len(all_strategies), 3)))
line_styles = ["-", "--", "-.", ":"]

# Pre-rerank top-N sweep
pre_top_n_x, pre_top_n_mean_f1, pre_top_n_min_f1, pre_top_n_max_f1, optimal_pre_top_n = (
    get_averaged_sweep_curve(analysis_summaries, "pre_top_n")
)

if pre_top_n_x:
    ax.plot(
        pre_top_n_x,
        pre_top_n_mean_f1,
        color="#2c3e50",
        linewidth=2.5,
        linestyle="-",
        label="Pre-Rerank (distance-sorted)",
        zorder=10,
    )
    if len(analysis_summaries) > 1:
        ax.fill_between(pre_top_n_x, pre_top_n_min_f1, pre_top_n_max_f1, color="#2c3e50", alpha=0.1)
    if optimal_pre_top_n is not None:
        optimal_index = int(np.argmin(np.abs(np.array(pre_top_n_x) - optimal_pre_top_n)))
        ax.axvline(optimal_pre_top_n, color="#2c3e50", linestyle=":", alpha=0.5, linewidth=1)
        ax.annotate(
            f"N={int(round(optimal_pre_top_n))}",
            xy=(optimal_pre_top_n, pre_top_n_mean_f1[optimal_index]),
            xytext=(5, 6),
            textcoords="offset points",
            fontsize=8,
            color="#2c3e50",
        )

# Per-strategy top-N sweeps
for strategy_index, strategy_name in enumerate(all_strategies):
    post_top_n_x, post_top_n_mean_f1, post_top_n_min_f1, post_top_n_max_f1, optimal_post_top_n = (
        get_averaged_sweep_curve(analysis_summaries, f"post_top_n:{strategy_name}")
    )
    if not post_top_n_x:
        continue

    color = strategy_palette[strategy_index]
    line_style = line_styles[strategy_index % len(line_styles)]

    ax.plot(
        post_top_n_x,
        post_top_n_mean_f1,
        color=color,
        linewidth=2.5,
        linestyle=line_style,
        label=strategy_name,
    )
    if len(analysis_summaries) > 1:
        ax.fill_between(post_top_n_x, post_top_n_min_f1, post_top_n_max_f1, color=color, alpha=0.1)
    if optimal_post_top_n is not None:
        optimal_index = int(np.argmin(np.abs(np.array(post_top_n_x) - optimal_post_top_n)))
        ax.axvline(optimal_post_top_n, color=color, linestyle=":", alpha=0.5, linewidth=1)
        ax.annotate(
            f"N={int(round(optimal_post_top_n))}",
            xy=(optimal_post_top_n, post_top_n_mean_f1[optimal_index]),
            xytext=(5, 6 + strategy_index * 14),
            textcoords="offset points",
            fontsize=8,
            color=color,
        )

ax.set_xlabel("Top-N (number of results returned)")
ax.set_ylabel("Macro F1")
num_runs = len(analysis_summaries)
ax.set_title(
    f"Top-N Sweep: Pre-Rerank vs All Strategies — {PHASE}\n"
    f"(averaged across {num_runs} run(s), dotted vertical = optimal N per stage)"
)
ax.legend(fontsize=9, loc="best")
ax.set_ylim(0, 1.05)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Cell 6 — Per-Test-Case Heatmap
# Rows: test cases | Columns: pre-rerank + each strategy (at optimal threshold)
# Color: red (low F1) → green (high F1). Gray = no data for that test case.

# Collect all test case IDs across all summaries
all_test_case_ids = set()
for summary in analysis_summaries:
    all_test_case_ids.update(summary.get("per_test_case", {}).keys())
all_test_case_ids = sorted(all_test_case_ids)

# Stage columns for heatmap: pre-rerank baseline + each strategy at optimal threshold
heatmap_stage_keys = ["pre_rerank_distance"]
for strategy_name in all_strategies:
    heatmap_stage_keys.append(f"post_rerank_threshold:{strategy_name}")

heatmap_column_labels = ["Pre-Rerank\n(Dist. Thresh.)"]
for strategy_name in all_strategies:
    heatmap_column_labels.append(f"{strategy_name}\n(Rerank Thresh.)")

# Compute averaged per-test-case F1 for each stage
per_stage_per_tc_f1 = {}  # stage_key -> {test_case_id: averaged_f1}
for stage_key in heatmap_stage_keys:
    per_stage_per_tc_f1[stage_key] = average_per_test_case_stage_f1_across_runs(
        analysis_summaries, stage_key
    )

# Build heatmap matrix
heatmap_values = np.full((len(all_test_case_ids), len(heatmap_stage_keys)), np.nan)
for column_index, stage_key in enumerate(heatmap_stage_keys):
    averaged_f1_map = per_stage_per_tc_f1[stage_key]
    for row_index, test_case_id in enumerate(all_test_case_ids):
        if test_case_id in averaged_f1_map:
            heatmap_values[row_index, column_index] = averaged_f1_map[test_case_id]

fig, ax = plt.subplots(
    figsize=(
        max(9, len(heatmap_stage_keys) * 3),
        max(6, len(all_test_case_ids) * 0.65),
    )
)
sns.heatmap(
    heatmap_values,
    annot=True,
    fmt=".2f",
    xticklabels=heatmap_column_labels,
    yticklabels=all_test_case_ids,
    cmap="RdYlGn",
    vmin=0,
    vmax=1,
    linewidths=0.5,
    ax=ax,
)
num_runs = len(analysis_summaries)
ax.set_title(
    f"Per-Test-Case F1: Pre-Rerank vs Strategies — {PHASE}\n(averaged across {num_runs} run(s))"
)
ax.set_ylabel("Test Case")
ax.set_xlabel("Stage")
plt.tight_layout()
plt.show()

# Print macro column averages
print("Column averages (macro F1):")
for column_index, (stage_key, column_label) in enumerate(
    zip(heatmap_stage_keys, heatmap_column_labels)
):
    column_avg = np.nanmean(heatmap_values[:, column_index])
    label_oneliner = column_label.replace("\n", " ")
    print(f"  {label_oneliner:<45} F1 = {column_avg:.4f}")

In [None]:
# Cell 7 — Improvement Analysis: Per-Test-Case Delta vs Baseline
# Horizontal bar chart showing the F1 delta (post-rerank minus pre-rerank baseline).
# Green = improved, red = degraded, gray = equivalent (within ±0.01).

baseline_per_tc_f1 = per_stage_per_tc_f1.get("pre_rerank_distance", {})

bar_height_per_strategy = 0.75 / max(len(all_strategies), 1)
test_case_y_positions = np.arange(len(all_test_case_ids))
strategy_palette = plt.cm.Set2(np.linspace(0, 0.8, max(len(all_strategies), 3)))

fig_height = max(6, len(all_test_case_ids) * 0.65 * max(len(all_strategies), 1))
fig, ax = plt.subplots(figsize=(13, fig_height))

for strategy_index, strategy_name in enumerate(all_strategies):
    post_rerank_stage_key = f"post_rerank_threshold:{strategy_name}"
    post_rerank_per_tc_f1 = per_stage_per_tc_f1.get(post_rerank_stage_key, {})

    delta_f1_values = []
    bar_colors = []
    for test_case_id in all_test_case_ids:
        baseline_f1 = baseline_per_tc_f1.get(test_case_id, 0.0)
        post_rerank_f1 = post_rerank_per_tc_f1.get(test_case_id, 0.0)
        delta = post_rerank_f1 - baseline_f1
        delta_f1_values.append(delta)
        if delta > 0.01:
            bar_colors.append("#2ecc71")
        elif delta < -0.01:
            bar_colors.append("#e74c3c")
        else:
            bar_colors.append("#95a5a6")

    y_offset = (strategy_index - len(all_strategies) / 2 + 0.5) * bar_height_per_strategy
    ax.barh(
        test_case_y_positions + y_offset,
        delta_f1_values,
        height=bar_height_per_strategy * 0.85,
        color=bar_colors,
        alpha=0.8,
        label=strategy_name,
        edgecolor="white",
        linewidth=0.5,
    )

ax.axvline(0, color="black", linewidth=1, zorder=5)
ax.set_yticks(test_case_y_positions)
ax.set_yticklabels(all_test_case_ids, fontsize=9)
ax.set_xlabel("ΔF1  (post-rerank minus pre-rerank baseline)")
num_runs = len(analysis_summaries)
ax.set_title(
    f"F1 Improvement Over Pre-Rerank Baseline — {PHASE}\n"
    f"Green = improved, Red = degraded (n={num_runs} runs averaged)"
)
if len(all_strategies) > 1:
    ax.legend(fontsize=9)
ax.grid(True, alpha=0.3, axis="x")
plt.tight_layout()
plt.show()

# Print detailed delta table
strategy_header = "".join(f"  {name:<28}" for name in all_strategies)
print(f"{'Test Case':<32}{strategy_header}")
print("-" * (32 + 30 * len(all_strategies)))

for test_case_id in all_test_case_ids:
    baseline_f1 = baseline_per_tc_f1.get(test_case_id, 0.0)
    row_parts = [f"{test_case_id:<32}"]
    for strategy_name in all_strategies:
        post_rerank_stage_key = f"post_rerank_threshold:{strategy_name}"
        post_rerank_f1 = per_stage_per_tc_f1.get(post_rerank_stage_key, {}).get(test_case_id, 0.0)
        delta = post_rerank_f1 - baseline_f1
        row_parts.append(f"  {baseline_f1:.3f} -> {post_rerank_f1:.3f} ({delta:+.3f})")
    print("".join(row_parts))

In [None]:
# Cell 8 — Summary Dashboard

print("=" * 65)
print("SUMMARY DASHBOARD")
print("=" * 65)

# Config info
config_fingerprint = analysis_summaries[0].get("config_fingerprint", {})
if config_fingerprint:
    fingerprint_hash = config_fingerprint.get("fingerprint_hash", "?")
    print(f"\nConfig fingerprint: {fingerprint_hash}")
    print(f"  Extraction prompt:  v{config_fingerprint.get('extraction_prompt_version', '?')}")
    print(f"  Query prompt:       v{config_fingerprint.get('query_prompt_version', '?')}")
    print(f"  Query model:        {config_fingerprint.get('query_model', '?')}")
    print(f"  Reranker model:     {config_fingerprint.get('reranker_model', '?')}")
    print(f"  Runs averaged:      {len(analysis_summaries)}")

# Pre-rerank baseline
baseline_metrics = stage_averaged_metrics.get("pre_rerank_distance", {})
baseline_f1 = baseline_metrics.get("f1", 0.0)
baseline_precision = baseline_metrics.get("precision", 0.0)
baseline_recall = baseline_metrics.get("recall", 0.0)
baseline_mrr = baseline_metrics.get("mrr", 0.0)

print(f"\nBaseline — Pre-Rerank (at optimal distance threshold):")
print(f"  F1:        {baseline_f1:.4f}")
print(f"  Precision: {baseline_precision:.4f}")
print(f"  Recall:    {baseline_recall:.4f}")
print(f"  MRR:       {baseline_mrr:.4f}")

# Per-strategy results
print(f"\nPost-Rerank Strategies:")
best_strategy_name = None
best_strategy_f1 = -1.0

for strategy_name in all_strategies:
    threshold_stage_key = f"post_rerank_threshold:{strategy_name}"
    top_n_stage_key = f"post_rerank_top_n:{strategy_name}"

    threshold_metrics = stage_averaged_metrics.get(threshold_stage_key, {})
    top_n_metrics = stage_averaged_metrics.get(top_n_stage_key, {})

    threshold_f1 = threshold_metrics.get("f1", 0.0)
    top_n_f1 = top_n_metrics.get("f1", 0.0)
    best_f1_for_strategy = max(threshold_f1, top_n_f1)

    delta_threshold = threshold_f1 - baseline_f1
    delta_top_n = top_n_f1 - baseline_f1

    print(f"\n  [{strategy_name}]")
    print(f"    At optimal threshold: F1 = {threshold_f1:.4f}  (delta {delta_threshold:+.4f})")
    print(f"    At optimal top-N:     F1 = {top_n_f1:.4f}  (delta {delta_top_n:+.4f})")

    if best_f1_for_strategy > best_strategy_f1:
        best_strategy_f1 = best_f1_for_strategy
        best_strategy_name = strategy_name

# Winner
print(f"\n{'=' * 65}")
if best_strategy_name and best_strategy_f1 > baseline_f1:
    overall_improvement = best_strategy_f1 - baseline_f1
    relative_improvement = overall_improvement / baseline_f1 * 100 if baseline_f1 > 0 else 0
    print(f"Best strategy: {best_strategy_name}")
    print(f"  Best F1:        {best_strategy_f1:.4f}")
    print(f"  Improvement:    {overall_improvement:+.4f}  ({relative_improvement:+.1f}% relative)")
else:
    print("No rerank strategy outperformed the pre-rerank baseline.")

# Hardest test cases (lowest pre-rerank F1)
print(f"\nHardest test cases (lowest pre-rerank baseline F1):")
sorted_by_baseline_f1 = sorted(
    [
        (test_case_id, baseline_per_tc_f1.get(test_case_id, 0.0))
        for test_case_id in all_test_case_ids
    ],
    key=lambda item: item[1],
)
for rank, (test_case_id, f1_value) in enumerate(sorted_by_baseline_f1[:3]):
    print(f"  {rank + 1}. {test_case_id}: pre-rerank F1 = {f1_value:.3f}")

# Biggest regressions
if all_strategies:
    print(f"\nTest cases with largest degradation after reranking:")
    all_deltas = []
    for test_case_id in all_test_case_ids:
        baseline_f1_value = baseline_per_tc_f1.get(test_case_id, 0.0)
        for strategy_name in all_strategies:
            post_rerank_stage_key = f"post_rerank_threshold:{strategy_name}"
            post_rerank_f1 = per_stage_per_tc_f1.get(post_rerank_stage_key, {}).get(
                test_case_id, 0.0
            )
            delta = post_rerank_f1 - baseline_f1_value
            all_deltas.append((test_case_id, strategy_name, delta))
    worst_regressions = sorted(all_deltas, key=lambda item: item[2])[:3]
    for test_case_id, strategy_name, delta in worst_regressions:
        if delta < -0.001:
            print(f"  {test_case_id} [{strategy_name}]: delta = {delta:+.3f}")

print(f"\nTotal test cases: {len(all_test_case_ids)}")
print(f"Runs analyzed: {len(analysis_summaries)} (phase: {PHASE})")