# Phase 1 Reranking Comparison

Applies cross-encoder reranking on top of **existing Phase 1 experiment results**.
Reuses the exact same queries, vector search candidates, and ground truth from the Phase 1 run.

This notebook does **not** generate new queries or run new searches — it loads the saved
Phase 1 results and only adds the reranking step to measure its isolated impact on retrieval quality.

In [None]:
# Cell 1 — Setup & Imports
import os
from pathlib import Path

from retrieval_metrics.compute import compute_set_metrics

from memory_retrieval.experiments.metrics_adapter import metric_point_to_dict
from memory_retrieval.infra.figures import create_figure_session, save_figure
from memory_retrieval.infra.io import load_json, save_json
from memory_retrieval.infra.runs import (
    PHASE1,
    PHASE2,
    create_run,
    get_latest_run,
    update_run_status,
)
from memory_retrieval.memories.schema import FIELD_DISTANCE, FIELD_RERANK_SCORE, FIELD_SITUATION
from memory_retrieval.search.reranker import Reranker

# Find project root by walking up to pyproject.toml
PROJECT_ROOT = Path.cwd()
while not (PROJECT_ROOT / "pyproject.toml").exists():
    if PROJECT_ROOT.parent == PROJECT_ROOT:
        raise RuntimeError("Could not find project root (pyproject.toml)")
    PROJECT_ROOT = PROJECT_ROOT.parent
os.chdir(PROJECT_ROOT)

def compute_metrics(retrieved_ids, ground_truth_ids):
    point = compute_set_metrics(retrieved_ids, ground_truth_ids)
    rounded = metric_point_to_dict(point, round_digits=4)
    return {
        "precision": rounded["precision"],
        "recall": rounded["recall"],
        "f1": rounded["f1"],
    }

print(f"Project root: {PROJECT_ROOT}")
print("Imports OK.")

In [None]:
# Cell 2 — Configuration

# Analysis parameter for top-N comparison tables
ANALYSIS_TOP_N = 4
DISTANCE_THRESHOLD = 1.1  # For pre-rerank metrics (must match Phase 1)

# Phase 1 run selection (source of results, DB, and test cases)
# To see available runs: print(list_runs(PHASE1))
# To select specific run: PHASE1_RUN = get_run(PHASE1, "run_20260208_143022")
PHASE1_RUN = get_latest_run(PHASE1)

# Derived paths from Phase 1
PHASE1_RESULTS_DIR = PHASE1_RUN / "results"
phase1_result_files = sorted(PHASE1_RESULTS_DIR.glob("*.json"))

print("Configuration:")
print(f"  Phase 1 run: {PHASE1_RUN.name}")
print(f"  Phase 1 results dir: {PHASE1_RESULTS_DIR}")
print(f"  Phase 1 result files: {len(phase1_result_files)}")
print(f"  Analysis top-n: {ANALYSIS_TOP_N}")
print(f"  Distance threshold: {DISTANCE_THRESHOLD}")

if not phase1_result_files:
    print("\nERROR: No Phase 1 result files found. Run phase1.ipynb first.")

## Step 1 — Load Phase 1 Results & Initialize Reranker

Load existing Phase 1 experiment results (queries + vector search candidates).
Initialize the cross-encoder reranker (`bge-reranker-v2-m3`).

In [None]:
# Cell 4 — Load Phase 1 Results & Initialize Reranker

# Load all Phase 1 experiment results
# When multiple results exist per test case, keep the latest experiment
phase1_by_test_case: dict[str, dict] = {}
for result_file in phase1_result_files:
    data = load_json(str(result_file))
    test_case_id = data.get("test_case_id", "?")
    if test_case_id not in phase1_by_test_case or data.get(
        "experiment_id", ""
    ) > phase1_by_test_case[test_case_id].get("experiment_id", ""):
        phase1_by_test_case[test_case_id] = data

print(f"Loaded {len(phase1_by_test_case)} unique test case results from Phase 1\n")
for test_case_id, data in sorted(phase1_by_test_case.items()):
    ground_truth_count = data.get("ground_truth", {}).get("count", 0)
    metrics = data.get("metrics", {})
    num_queries = len(data.get("queries", []))
    print(
        f"  {test_case_id}: {num_queries} queries, {ground_truth_count} GT memories, F1={metrics.get('f1', 0):.3f}"
    )

# Initialize reranker
reranker = Reranker()

# Quick test to load the model
print("\nLoading reranker model (first use triggers download)...")
_ = reranker.score_pairs("test", ["test document"])
print("Reranker ready.")

## Step 2 — Apply Reranking to Phase 1 Results

For each Phase 1 test case result:
1. Pool and deduplicate vector search candidates from all queries
2. Compute **top-N by vector distance** baseline (apples-to-apples with reranking)
3. Rerank candidates with cross-encoder, take top-N
4. Compare both at the same N — the only difference is ranking method

No new LLM calls or vector searches — everything is reused from Phase 1.

In [None]:
# Cell 6 — Apply Reranking to Phase 1 Results

from retrieval_metrics.compute import compute_threshold_metrics, compute_top_n_metrics
from retrieval_metrics.sweeps import find_optimal_entry

from memory_retrieval.experiments.metrics import pool_and_deduplicate_by_distance
from memory_retrieval.experiments.metrics_adapter import (
    restriction_evaluation_to_dict,
    threshold_sweep_from_experiments as sweep_threshold,
    top_n_sweep_from_experiments as sweep_top_n,
)


def compute_metrics_at_top_n(ranked_results, ground_truth_ids, top_n, id_field="id"):
    evaluation = compute_top_n_metrics(ranked_results, ground_truth_ids, top_n, id_key=id_field)
    return restriction_evaluation_to_dict(evaluation)


def compute_metrics_at_threshold(
    ranked_results,
    ground_truth_ids,
    threshold,
    score_field,
    higher_is_better,
    id_field="id",
):
    evaluation = compute_threshold_metrics(
        ranked_results,
        ground_truth_ids,
        threshold,
        score_key=score_field,
        higher_is_better=higher_is_better,
        id_key=id_field,
    )
    return restriction_evaluation_to_dict(evaluation, include_accepted_count=True)


def find_optimal_threshold(sweep_results, metric="f1"):
    return find_optimal_entry(sweep_results, metric_key=metric)


# Create Phase 2 run for results
run_id, PHASE2_RUN = create_run(
    PHASE2,
    description=f"Reranking on phase1 results ({PHASE1_RUN.name})",
)
RESULTS_DIR = str(PHASE2_RUN / "results")

update_run_status(
    PHASE2_RUN,
    "config",
    {
        "phase1_run_id": PHASE1_RUN.name,
        "reranker_model": reranker.model_name,
        "analysis_top_n": ANALYSIS_TOP_N,
        "distance_threshold": DISTANCE_THRESHOLD,
        "mode": "reuse_phase1_results",
    },
)

print(f"Phase 2 run: {run_id}")
print(f"Results dir: {RESULTS_DIR}\n")

all_results: list[dict] = []

# Iterate over all results from phase1
for i, (test_case_id, phase1_data) in enumerate(sorted(phase1_by_test_case.items()), 1):
    print(f"[{i}/{len(phase1_by_test_case)}] {test_case_id}")

    ground_truth_ids = set(phase1_data.get("ground_truth", {}).get("memory_ids", []))
    query_results = phase1_data.get("queries", [])

    # Pool and deduplicate across all queries (sorted by distance, best first)
    pooled = pool_and_deduplicate_by_distance(query_results)
    print(f"  Deduplicated memories count: {len(pooled)}")

    # --- Baseline: top-N by vector distance (same N as analysis) ---
    distance_top_n_metrics = compute_metrics_at_top_n(pooled, ground_truth_ids, ANALYSIS_TOP_N)

    # --- Per-query reranking: rerank each query's results independently ---
    all_reranked_per_query = []
    # Iterate over all queries
    for query_result in query_results:
        candidates = [
            {
                "id": result["id"],
                FIELD_SITUATION: result.get("situation", result.get(FIELD_SITUATION, "")),
                FIELD_DISTANCE: result.get("distance", 0),
                "is_ground_truth": result.get("is_ground_truth", result["id"] in ground_truth_ids),
            }
            # Iterate over all results for this query
            for result in query_result.get("results", [])
        ]
        # Simple reranking, just using "situation" description - pass query string of given query result with all its results array
        # mapped to candidates
        reranked = reranker.rerank(query_result["query"], candidates, top_n=None)
        all_reranked_per_query.extend(reranked)

    # Deduplicate by best rerank score - drop information about original query string
    best_by_memory_id: dict[str, dict] = {}
    for result in all_reranked_per_query:
        memory_id = result["id"]
        if (
            memory_id not in best_by_memory_id
            or result[FIELD_RERANK_SCORE] > best_by_memory_id[memory_id][FIELD_RERANK_SCORE]
        ):
            best_by_memory_id[memory_id] = result
    all_reranked = sorted(
        best_by_memory_id.values(), key=lambda x: x[FIELD_RERANK_SCORE], reverse=True
    )

    # Compute top-N metrics for display
    rerank_top_n_metrics = compute_metrics_at_top_n(all_reranked, ground_truth_ids, ANALYSIS_TOP_N)

    f1_delta = rerank_top_n_metrics["f1"] - distance_top_n_metrics["f1"]
    marker = "+" if f1_delta > 0 else ""
    print(
        f"  Top-{ANALYSIS_TOP_N} by distance F1={distance_top_n_metrics['f1']:.3f} | Top-{ANALYSIS_TOP_N} by rerank F1={rerank_top_n_metrics['f1']:.3f} ({marker}{f1_delta:.3f})"
    )

    # Build result
    result = {
        "test_case_id": test_case_id,
        "source_file": phase1_data.get("source_file", "unknown"),
        "phase1_experiment_id": phase1_data.get("experiment_id", "unknown"),
        "model": phase1_data.get("model", "unknown"),
        "prompt_version": phase1_data.get("prompt_version", "unknown"),
        "reranker_model": reranker.model_name,
        "rerank_queries": [query_result["query"] for query_result in query_results],
        "distance_threshold": DISTANCE_THRESHOLD,
        "ground_truth": phase1_data.get("ground_truth", {}),
        "queries": query_results,
        "pooled_candidate_count": len(pooled),
        "distance_top_n_metrics": {
            "precision": distance_top_n_metrics["precision"],
            "recall": distance_top_n_metrics["recall"],
            "f1": distance_top_n_metrics["f1"],
            "n": ANALYSIS_TOP_N,
            "ground_truth_retrieved": len(
                distance_top_n_metrics["retrieved_ids"] & ground_truth_ids
            ),
        },
        "reranked_results": [
            {
                "id": result["id"],
                "rerank_score": result[FIELD_RERANK_SCORE],
                "distance": result.get(FIELD_DISTANCE, 0),
                "situation": result.get(FIELD_SITUATION, ""),
                "is_ground_truth": result["id"] in ground_truth_ids,
            }
            for result in all_reranked  # Store ALL reranked (for sweep analysis)
        ],
        "distance_top_n_results": [
            {
                "id": result["id"],
                "distance": result.get("distance", 0),
                "situation": result.get("situation", result.get(FIELD_SITUATION, "")),
                "is_ground_truth": result.get("is_ground_truth", result["id"] in ground_truth_ids),
            }
            for result in pooled[:ANALYSIS_TOP_N]
        ],
    }

    # Save
    Path(RESULTS_DIR).mkdir(parents=True, exist_ok=True)
    save_json(result, Path(RESULTS_DIR) / f"rerank_{test_case_id}.json")
    all_results.append(result)

# Summary
successful = [result for result in all_results if "reranked_results" in result]
if successful:
    avg_distance_f1 = sum(result["distance_top_n_metrics"]["f1"] for result in successful) / len(
        successful
    )

    # Compute rerank top-N metrics from stored reranked_results
    rerank_f1_scores = []
    for result in successful:
        ground_truth_ids = set(result.get("ground_truth", {}).get("memory_ids", []))
        rerank_metrics = compute_metrics_at_top_n(
            result["reranked_results"], ground_truth_ids, ANALYSIS_TOP_N
        )
        rerank_f1_scores.append(rerank_metrics["f1"])
    avg_rerank_f1 = sum(rerank_f1_scores) / len(rerank_f1_scores)

    print(f"\n{'=' * 55}")
    print(f"SUMMARY ({len(successful)} test cases, top-{ANALYSIS_TOP_N})")
    print(f"{'=' * 55}")
    print(f"  Avg distance top-{ANALYSIS_TOP_N} F1: {avg_distance_f1:.3f}")
    print(f"  Avg rerank top-{ANALYSIS_TOP_N} F1:   {avg_rerank_f1:.3f}")
    print(f"  Delta: {avg_rerank_f1 - avg_distance_f1:+.3f}")

# Update run status
update_run_status(
    PHASE2_RUN,
    "experiment",
    {
        "count": len(successful),
        "avg_f1_distance_top_n": round(avg_distance_f1, 4) if successful else 0,
        "avg_f1_rerank_top_n": round(avg_rerank_f1, 4) if successful else 0,
        "analysis_top_n": ANALYSIS_TOP_N,
    },
)

## Step 3 — Results Summary Table

Per-test-case comparison: **top-N by distance** vs **top-N by rerank score** (same N, same candidates).

In [None]:
# Cell 8 — Results Summary Table

successful = [result for result in all_results if "reranked_results" in result]

print(f"Top-{ANALYSIS_TOP_N} comparison: vector distance vs cross-encoder reranking\n")
print(
    f"{'Test Case':<25} {'Dist F1':>8} {'Rank F1':>8} {'Delta':>8} {'Dist P':>7} {'Rank P':>7} {'Dist R':>7} {'Rank R':>7}"
)
print("-" * 85)

for result in successful:
    name = result.get("test_case_id", "?")[:25]
    distance_metrics = result["distance_top_n_metrics"]
    ground_truth_ids = set(result.get("ground_truth", {}).get("memory_ids", []))

    # Compute rerank top-N metrics from reranked_results
    rerank_top_ids = {entry["id"] for entry in result["reranked_results"][:ANALYSIS_TOP_N]}
    rerank_metrics = compute_metrics(rerank_top_ids, ground_truth_ids)

    delta = rerank_metrics["f1"] - distance_metrics["f1"]
    marker = "+" if delta > 0.001 else "-" if delta < -0.001 else "="
    print(
        f"{name:<25} {distance_metrics['f1']:>8.3f} {rerank_metrics['f1']:>8.3f} {delta:>+7.3f}{marker} {distance_metrics['precision']:>7.3f} {rerank_metrics['precision']:>7.3f} {distance_metrics['recall']:>7.3f} {rerank_metrics['recall']:>7.3f}"
    )

if successful:
    avg_distance_f1 = sum(result["distance_top_n_metrics"]["f1"] for result in successful) / len(
        successful
    )
    avg_distance_precision = sum(
        result["distance_top_n_metrics"]["precision"] for result in successful
    ) / len(successful)
    avg_distance_recall = sum(
        result["distance_top_n_metrics"]["recall"] for result in successful
    ) / len(successful)

    rerank_f1_scores, rerank_precisions, rerank_recalls = [], [], []
    for result in successful:
        ground_truth_ids = set(result.get("ground_truth", {}).get("memory_ids", []))
        rerank_top_ids = {entry["id"] for entry in result["reranked_results"][:ANALYSIS_TOP_N]}
        rerank_metrics = compute_metrics(rerank_top_ids, ground_truth_ids)
        rerank_f1_scores.append(rerank_metrics["f1"])
        rerank_precisions.append(rerank_metrics["precision"])
        rerank_recalls.append(rerank_metrics["recall"])

    avg_rerank_f1 = sum(rerank_f1_scores) / len(rerank_f1_scores)
    avg_rerank_precision = sum(rerank_precisions) / len(rerank_precisions)
    avg_rerank_recall = sum(rerank_recalls) / len(rerank_recalls)

    delta = avg_rerank_f1 - avg_distance_f1
    marker = "+" if delta > 0 else "-" if delta < 0 else "="
    print("-" * 85)
    print(
        f"{'AVERAGE':<25} {avg_distance_f1:>8.3f} {avg_rerank_f1:>8.3f} {delta:>+7.3f}{marker} {avg_distance_precision:>7.3f} {avg_rerank_precision:>7.3f} {avg_distance_recall:>7.3f} {avg_rerank_recall:>7.3f}"
    )

    improved = sum(
        1
        for result, rerank_f1 in zip(successful, rerank_f1_scores)
        if rerank_f1 > result["distance_top_n_metrics"]["f1"] + 0.001
    )
    same = sum(
        1
        for result, rerank_f1 in zip(successful, rerank_f1_scores)
        if abs(rerank_f1 - result["distance_top_n_metrics"]["f1"]) <= 0.001
    )
    worse = sum(
        1
        for result, rerank_f1 in zip(successful, rerank_f1_scores)
        if rerank_f1 < result["distance_top_n_metrics"]["f1"] - 0.001
    )
    print(
        f"\nReranking helped: {improved}/{len(successful)} | Same: {same}/{len(successful)} | Hurt: {worse}/{len(successful)}"
    )

## Step 4 — Top-N Sweep: Distance vs Reranking

Compare top-N by vector distance vs top-N by rerank score across different N values.
Shows whether reranking improves ranking quality at every cutoff point.

In [None]:
# Cell 10 — Top-N Sweep: Distance vs Reranking
import matplotlib.pyplot as plt
import numpy as np

if "FIGURE_SESSION" not in globals() or FIGURE_SESSION.context.get("run_id") != PHASE2_RUN.name:
    FIGURE_SESSION = create_figure_session(
        root_dir=PHASE2_RUN / "figures",
        notebook_slug="phase1_reranking_comparison",
        context_key=PHASE2_RUN.name,
        context={
            "phase": PHASE2,
            "run_id": PHASE2_RUN.name,
            "phase1_run_id": PHASE1_RUN.name,
        },
    )
print(f"Figure export session: {FIGURE_SESSION.session_dir}")

all_data = all_results
num_test_cases = len(all_data)

if not all_data:
    print("No results found. Run Step 2 first.")
else:
    max_n = 20
    n_values = list(range(1, max_n + 1))

    # Build experiment lists for distance and rerank sweeps
    distance_experiments = [
        {
            "ground_truth_ids": set(data.get("ground_truth", {}).get("memory_ids", [])),
            "ranked_results": pool_and_deduplicate_by_distance(data.get("queries", [])),
        }
        for data in all_data
    ]
    rerank_experiments = [
        {
            "ground_truth_ids": set(data.get("ground_truth", {}).get("memory_ids", [])),
            "ranked_results": data.get("reranked_results", []),
        }
        for data in all_data
    ]

    distance_sweep = sweep_top_n(distance_experiments, n_values)
    rerank_sweep = sweep_top_n(rerank_experiments, n_values)

    distance_f1_scores = [entry["f1"] for entry in distance_sweep]
    distance_precisions = [entry["precision"] for entry in distance_sweep]
    distance_recalls = [entry["recall"] for entry in distance_sweep]
    rerank_f1_scores = [entry["f1"] for entry in rerank_sweep]
    rerank_precisions = [entry["precision"] for entry in rerank_sweep]
    rerank_recalls = [entry["recall"] for entry in rerank_sweep]

    # Print table
    print(f"Top-N Sweep: Distance vs Reranking (averaged over {num_test_cases} test cases)\n")
    print(
        f"{'N':>4} {'Dist P':>8} {'Rank P':>8} {'Dist R':>8} {'Rank R':>8} {'Dist F1':>8} {'Rank F1':>8} {'F1 Delta':>9}"
    )
    print("-" * 70)
    for top_n in [1, 2, 3, 4, 5, 6, 8, 10, 15, 20]:
        if top_n <= max_n:
            index = top_n - 1
            delta = rerank_f1_scores[index] - distance_f1_scores[index]
            print(
                f"{top_n:>4} {distance_precisions[index]:>8.3f} {rerank_precisions[index]:>8.3f} {distance_recalls[index]:>8.3f} {rerank_recalls[index]:>8.3f} {distance_f1_scores[index]:>8.3f} {rerank_f1_scores[index]:>8.3f} {delta:>+9.3f}"
            )

    # Plot: F1 comparison
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    ax = axes[0]
    ax.plot(
        n_values,
        distance_f1_scores,
        label="F1 (distance)",
        color="#3498db",
        linewidth=2,
        marker="o",
        markersize=4,
    )
    ax.plot(
        n_values,
        rerank_f1_scores,
        label="F1 (reranked)",
        color="#e74c3c",
        linewidth=2,
        marker="^",
        markersize=4,
    )
    ax.axvline(
        x=ANALYSIS_TOP_N,
        color="gray",
        linestyle="--",
        alpha=0.5,
        label=f"Default N={ANALYSIS_TOP_N}",
    )
    ax.set_xlabel("Top-N")
    ax.set_ylabel("F1 Score")
    ax.set_title(f"F1: Distance vs Reranking (avg over {num_test_cases} test cases)")
    ax.set_xticks(n_values)
    ax.set_ylim(0, 1.05)
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Plot: F1 delta
    ax = axes[1]
    f1_deltas = [rerank_f1_scores[i] - distance_f1_scores[i] for i in range(max_n)]
    colors = ["#2ecc71" if delta > 0 else "#e74c3c" for delta in f1_deltas]
    ax.bar(n_values, f1_deltas, color=colors, alpha=0.7)
    ax.axhline(y=0, color="black", linewidth=0.5)
    ax.set_xlabel("Top-N")
    ax.set_ylabel("F1 Delta (rerank - distance)")
    ax.set_title("Reranking F1 Improvement by N")
    ax.set_xticks(n_values)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    saved_paths = save_figure(
        fig,
        FIGURE_SESSION,
        "rerank_topn_sweep",
        title="Top-N Sweep: Distance vs Reranking",
    )
    plt.show()
    print(f"Saved: {saved_paths['png']}")

## Step 5 — Rerank Score Threshold Analysis

Mirrors the Phase 1 distance threshold analysis, but uses **rerank scores** as the filtering signal.
Instead of a fixed top-N cutoff, we sweep rerank score thresholds to find the optimal cutoff.

All metrics are **macro-averaged**: computed per test case, then averaged.
Within each test case, candidates are deduplicated by memory ID (best distance), then reranked.

In [None]:
# Cell 12 — Rerank Score Distribution & Threshold Sweep
import matplotlib.pyplot as plt
import numpy as np

if "FIGURE_SESSION" not in globals() or FIGURE_SESSION.context.get("run_id") != PHASE2_RUN.name:
    FIGURE_SESSION = create_figure_session(
        root_dir=PHASE2_RUN / "figures",
        notebook_slug="phase1_reranking_comparison",
        context_key=PHASE2_RUN.name,
        context={
            "phase": PHASE2,
            "run_id": PHASE2_RUN.name,
            "phase1_run_id": PHASE1_RUN.name,
        },
    )

# Collect per-observation rerank scores (one entry per test_case x memory pair)
ground_truth_scores_all = []  # (score, test_case_id, memory_id)
non_ground_truth_scores_all = []  # (score, test_case_id, memory_id)

# Build per-experiment deduped rerank data (mirrors phase1 threshold analysis structure)
experiments_reranked = []
for data in all_data:
    test_case_id = data["test_case_id"]
    ground_truth_ids = set(data.get("ground_truth", {}).get("memory_ids", []))
    # reranked_results contains ALL candidates sorted by rerank score
    reranked = data.get("reranked_results", [])
    scores_by_id = {result["id"]: result for result in reranked}

    experiments_reranked.append(
        {
            "test_case_id": test_case_id,
            "ground_truth_ids": ground_truth_ids,
            "scores_by_id": scores_by_id,
            "reranked": reranked,
        }
    )

    for result in reranked:
        entry = (result["rerank_score"], test_case_id, result["id"])
        if result.get("is_ground_truth"):
            ground_truth_scores_all.append(entry)
        else:
            non_ground_truth_scores_all.append(entry)

ground_truth_scores = np.array([score for score, _, _ in ground_truth_scores_all])
non_ground_truth_scores = np.array([score for score, _, _ in non_ground_truth_scores_all])

print(f"Experiments: {len(experiments_reranked)} test cases")
print(
    f"GT observations:     {len(ground_truth_scores)} (unique — each memory is GT in exactly 1 test case)"
)
print(f"Non-GT observations: {len(non_ground_truth_scores)}")
print()
print(
    f"GT rerank score range:     [{ground_truth_scores.min():.4f}, {ground_truth_scores.max():.4f}]"
)
print(
    f"Non-GT rerank score range: [{non_ground_truth_scores.min():.4f}, {non_ground_truth_scores.max():.4f}]"
)
print(f"GT mean: {ground_truth_scores.mean():.4f}, median: {np.median(ground_truth_scores):.4f}")
print(
    f"Non-GT mean: {non_ground_truth_scores.mean():.4f}, median: {np.median(non_ground_truth_scores):.4f}"
)
print(f"Mean separation: {ground_truth_scores.mean() - non_ground_truth_scores.mean():.4f}")

# --- Figure 1: Score distribution histogram ---
fig1, ax1 = plt.subplots(figsize=(10, 5))
all_scores_combined = np.concatenate([ground_truth_scores, non_ground_truth_scores])
bins = np.linspace(all_scores_combined.min(), all_scores_combined.max(), 40)
ax1.hist(
    ground_truth_scores,
    bins=bins,
    alpha=0.6,
    density=True,
    label=f"GT (n={len(ground_truth_scores)})",
    color="#2ecc71",
    edgecolor="white",
    linewidth=0.5,
)
ax1.hist(
    non_ground_truth_scores,
    bins=bins,
    alpha=0.6,
    density=True,
    label=f"Non-GT (n={len(non_ground_truth_scores)})",
    color="#e74c3c",
    edgecolor="white",
    linewidth=0.5,
)
ax1.axvline(
    np.median(ground_truth_scores),
    color="#27ae60",
    linestyle="--",
    linewidth=1.5,
    label=f"GT median: {np.median(ground_truth_scores):.4f}",
)
ax1.axvline(
    np.median(non_ground_truth_scores),
    color="#c0392b",
    linestyle="--",
    linewidth=1.5,
    label=f"Non-GT median: {np.median(non_ground_truth_scores):.4f}",
)
ax1.set_xlabel("Rerank Score")
ax1.set_ylabel("Density")
ax1.set_title("Rerank Score Distribution (normalized)")
ax1.legend(fontsize=8)
fig1.tight_layout()
saved_paths = save_figure(
    fig1,
    FIGURE_SESSION,
    "rerank_score_distribution",
    title="Rerank Score Distribution (normalized)",
)
plt.show()
print(f"Saved: {saved_paths['png']}")

# --- Rerank score vs distance scatter ---
fig2, ax2 = plt.subplots(figsize=(10, 5))
for data in all_data:
    for result in data.get("reranked_results", []):
        color = "#2ecc71" if result.get("is_ground_truth") else "#e74c3c"
        alpha = 0.7 if result.get("is_ground_truth") else 0.3
        size = 60 if result.get("is_ground_truth") else 25
        ax2.scatter(
            result["distance"],
            result["rerank_score"],
            c=color,
            alpha=alpha,
            s=size,
            edgecolors="white",
            linewidth=0.3,
        )

from matplotlib.lines import Line2D

legend_elements = [
    Line2D(
        [0],
        [0],
        marker="o",
        color="w",
        markerfacecolor="#2ecc71",
        markersize=8,
        label="Ground Truth",
    ),
    Line2D(
        [0], [0], marker="o", color="w", markerfacecolor="#e74c3c", markersize=8, label="Non-GT"
    ),
]
ax2.legend(handles=legend_elements)
ax2.set_xlabel("Vector Distance")
ax2.set_ylabel("Rerank Score")
ax2.set_title("Rerank Score vs Vector Distance")
ax2.grid(True, alpha=0.3)
fig2.tight_layout()
saved_paths = save_figure(
    fig2,
    FIGURE_SESSION,
    "rerank_vs_distance_scatter",
    title="Rerank Score vs Vector Distance",
)
plt.show()
print(f"Saved: {saved_paths['png']}")
print("\nPer GT memory details (sorted by rerank score, descending):")
for score, test_case_id, memory_id in sorted(ground_truth_scores_all, key=lambda x: -x[0]):
    # Find the distance for this memory
    distance = 0
    for data in all_data:
        if data["test_case_id"] == test_case_id:
            for result in data.get("reranked_results", []):
                if result["id"] == memory_id:
                    distance = result["distance"]
                    break
    print(f"  score={score:.4f}  dist={distance:.4f}  {memory_id}  ({test_case_id})")

## Step 5b — Rerank Score Threshold Sweep

Sweep rerank score thresholds: accept all candidates with score >= threshold.
Higher threshold = stricter filtering (higher precision, lower recall).
Compare with Phase 1 distance threshold sweep.

In [None]:
# Cell 14 — Rerank Score Threshold Sweep
import matplotlib.pyplot as plt
import numpy as np

if "FIGURE_SESSION" not in globals() or FIGURE_SESSION.context.get("run_id") != PHASE2_RUN.name:
    FIGURE_SESSION = create_figure_session(
        root_dir=PHASE2_RUN / "figures",
        notebook_slug="phase1_reranking_comparison",
        context_key=PHASE2_RUN.name,
        context={
            "phase": PHASE2,
            "run_id": PHASE2_RUN.name,
            "phase1_run_id": PHASE1_RUN.name,
        },
    )

# Build experiment list for threshold sweep
rerank_experiments = []
for data in all_data:
    ground_truth_ids = set(data.get("ground_truth", {}).get("memory_ids", []))
    reranked = data.get("reranked_results", [])
    rerank_experiments.append(
        {
            "test_case_id": data["test_case_id"],
            "ground_truth_ids": ground_truth_ids,
            "ranked_results": reranked,
            "reranked": reranked,
        }
    )

# Sweep rerank score thresholds (higher = stricter, opposite direction from distance)
all_scores_flat = np.concatenate([ground_truth_scores, non_ground_truth_scores])
sweep_thresholds_list = list(np.arange(0.0, max(all_scores_flat) + 0.01, 0.005))

rerank_sweep = sweep_threshold(
    rerank_experiments,
    sweep_thresholds_list,
    score_field="rerank_score",
    higher_is_better=True,
)

sweep_precisions = np.array([entry["precision"] for entry in rerank_sweep])
sweep_recalls = np.array([entry["recall"] for entry in rerank_sweep])
sweep_f1_scores = np.array([entry["f1"] for entry in rerank_sweep])
sweep_mrrs = np.array([entry["mrr"] for entry in rerank_sweep])

optimal = find_optimal_threshold(rerank_sweep, metric="f1")
best_f1_index = optimal["index"]
best_threshold = optimal["threshold"]

# --- Figure: P/R/F1/MRR vs rerank score threshold ---
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(sweep_thresholds_list, sweep_precisions, label="Precision", color="#3498db", linewidth=2)
ax.plot(sweep_thresholds_list, sweep_recalls, label="Recall", color="#2ecc71", linewidth=2)
ax.plot(sweep_thresholds_list, sweep_f1_scores, label="F1", color="#9b59b6", linewidth=2)
ax.plot(sweep_thresholds_list, sweep_mrrs, label="MRR", color="#e67e22", linewidth=2)
ax.axvline(
    best_threshold,
    color="#e74c3c",
    linestyle="--",
    linewidth=1.5,
    label=f"Best F1 @ {best_threshold:.3f}",
)
ax.set_xlabel("Rerank Score Threshold (accept >= threshold)")
ax.set_ylabel("Score")
ax.set_title("P/R/F1/MRR vs Rerank Score Threshold (macro-averaged)")
ax.legend(fontsize=8)
ax.set_ylim(0, 1.05)
ax.grid(True, alpha=0.3)
fig.tight_layout()
saved_paths = save_figure(
    fig,
    FIGURE_SESSION,
    "rerank_threshold_sweep",
    title="P/R/F1/MRR vs Rerank Score Threshold",
)
plt.show()
print(f"Saved: {saved_paths['png']}")
print(f"\nOptimal F1 threshold: {best_threshold:.4f}")
print(f"  F1:        {sweep_f1_scores[best_f1_index]:.3f}")
print(f"  Precision: {sweep_precisions[best_f1_index]:.3f}")
print(f"  Recall:    {sweep_recalls[best_f1_index]:.3f}")
print(f"  MRR:       {sweep_mrrs[best_f1_index]:.3f}")

# Threshold table
print("\nThreshold table (macro-averaged):")
print(
    f"{'Threshold':>10} {'Precision':>10} {'Recall':>8} {'F1':>8} {'MRR':>8} {'Avg Accepted':>13} {'Avg GT Kept':>12}"
)
print("-" * 75)

# Pick representative thresholds around the interesting range
table_thresholds = sorted(
    set(
        [
            0.001,
            0.005,
            0.01,
            0.02,
            0.03,
            0.05,
            0.08,
            0.10,
            0.15,
            0.20,
            0.30,
            0.50,
            round(best_threshold, 4),
        ]
    )
)

for threshold in table_thresholds:
    index = min(
        range(len(sweep_thresholds_list)),
        key=lambda idx: abs(sweep_thresholds_list[idx] - threshold),
    )
    sweep_entry = rerank_sweep[index]
    avg_accepted = np.mean(
        [
            len(
                [
                    result
                    for result in experiment_result["reranked"]
                    if result["rerank_score"] >= threshold
                ]
            )
            for experiment_result in rerank_experiments
        ]
    )
    avg_ground_truth_kept = np.mean(
        [
            len(
                {
                    result["id"]
                    for result in experiment_result["reranked"]
                    if result["rerank_score"] >= threshold
                }
                & experiment_result["ground_truth_ids"]
            )
            for experiment_result in rerank_experiments
        ]
    )
    marker = " <--" if abs(threshold - best_threshold) < 0.003 else ""
    print(
        f"{threshold:>10.4f} {sweep_entry['precision']:>10.3f} {sweep_entry['recall']:>8.3f} {sweep_entry['f1']:>8.3f} {sweep_entry['mrr']:>8.3f} {avg_accepted:>13.1f} {avg_ground_truth_kept:>12.1f}{marker}"
    )

## Step 5c — Per-Experiment Impact at Optimal Rerank Threshold

Show how the optimal rerank score threshold performs on each test case.
Compare with Phase 1 distance threshold results.

In [None]:
# Cell 16 — Per-Experiment Impact at Optimal Rerank Threshold
import numpy as np

# Phase 1 optimal threshold from threshold analysis notebook
PHASE1_OPTIMAL_THRESHOLD = 0.76  # distance threshold

print("Per-experiment comparison at optimal thresholds:")
print(f"  Phase 1: distance <= {PHASE1_OPTIMAL_THRESHOLD}")
print(f"  Rerank:  score >= {best_threshold:.4f}")
print()

print(
    f"{'Test Case':<20} {'P1 F1':>7} {'Rk F1':>7} {'Delta':>7} {'P1 P':>6} {'Rk P':>6} {'P1 R':>6} {'Rk R':>6} {'P1 Acc':>7} {'Rk Acc':>7} {'GT':>4}"
)
print("-" * 100)

phase1_f1_scores, rerank_f1_list = [], []
for experiment_result in rerank_experiments:
    test_case_id = experiment_result["test_case_id"]
    ground_truth_ids = experiment_result["ground_truth_ids"]
    ground_truth_count = len(ground_truth_ids)

    # Phase 1: distance threshold (pool + dedup from phase1 data)
    phase1_data = phase1_by_test_case.get(test_case_id, {})
    phase1_pooled = pool_and_deduplicate_by_distance(phase1_data.get("queries", []))
    phase1_metrics = compute_metrics_at_threshold(
        phase1_pooled,
        ground_truth_ids,
        PHASE1_OPTIMAL_THRESHOLD,
        score_field="distance",
        higher_is_better=False,
    )

    # Rerank: score threshold
    rerank_metrics = compute_metrics_at_threshold(
        experiment_result["reranked"],
        ground_truth_ids,
        best_threshold,
        score_field="rerank_score",
        higher_is_better=True,
    )

    delta = rerank_metrics["f1"] - phase1_metrics["f1"]
    marker = "+" if delta > 0.001 else "-" if delta < -0.001 else "="

    phase1_f1_scores.append(phase1_metrics["f1"])
    rerank_f1_list.append(rerank_metrics["f1"])

    print(
        f"{test_case_id:<20} {phase1_metrics['f1']:>7.3f} {rerank_metrics['f1']:>7.3f} {delta:>+6.3f}{marker} {phase1_metrics['precision']:>6.1%} {rerank_metrics['precision']:>6.1%} {phase1_metrics['recall']:>6.1%} {rerank_metrics['recall']:>6.1%} {phase1_metrics['accepted_count']:>7} {rerank_metrics['accepted_count']:>7} {ground_truth_count:>4}"
    )

    # Show missed GT for rerank
    rerank_missed = ground_truth_ids - rerank_metrics["retrieved_ids"]
    if rerank_missed:
        for memory_id in sorted(rerank_missed):
            # Find rerank score for this missed memory
            score = next(
                (
                    result["rerank_score"]
                    for result in experiment_result["reranked"]
                    if result["id"] == memory_id
                ),
                None,
            )
            score_str = f"score={score:.4f}" if score is not None else "NOT IN POOL"
            print(f"  {'':20} Missed: {memory_id} ({score_str})")

print("-" * 100)
avg_phase1 = np.mean(phase1_f1_scores)
avg_rerank = np.mean(rerank_f1_list)
delta = avg_rerank - avg_phase1
print(
    f"{'AVERAGE':<20} {avg_phase1:>7.3f} {avg_rerank:>7.3f} {delta:>+6.3f}{'+' if delta > 0 else '-' if delta < 0 else '='}"
)

improved = sum(
    1
    for phase1_f1, rerank_f1 in zip(phase1_f1_scores, rerank_f1_list)
    if rerank_f1 > phase1_f1 + 0.001
)
same = sum(
    1
    for phase1_f1, rerank_f1 in zip(phase1_f1_scores, rerank_f1_list)
    if abs(rerank_f1 - phase1_f1) <= 0.001
)
worse = sum(
    1
    for phase1_f1, rerank_f1 in zip(phase1_f1_scores, rerank_f1_list)
    if rerank_f1 < phase1_f1 - 0.001
)
print(
    f"\nReranking helped: {improved}/{len(phase1_f1_scores)} | Same: {same}/{len(phase1_f1_scores)} | Hurt: {worse}/{len(phase1_f1_scores)}"
)

In [None]:
# Cell 17 — Final Summary
import numpy as np

print("=" * 70)
print("RERANKING COMPARISON SUMMARY")
print("=" * 70)

print(f"\nPhase 1 run: {PHASE1_RUN.name}")
print(f"Test cases: {len(experiments_reranked)}")
print(f"Reranker: {reranker.model_name}")

# Top-N comparison
print(f"\n--- Top-{ANALYSIS_TOP_N} Comparison (distance vs rerank) ---")
avg_distance_f1 = sum(result["distance_top_n_metrics"]["f1"] for result in all_results) / len(
    all_results
)
rerank_f1_scores = []
for result in all_results:
    ground_truth_ids = set(result.get("ground_truth", {}).get("memory_ids", []))
    rerank_top_ids = {entry["id"] for entry in result["reranked_results"][:ANALYSIS_TOP_N]}
    rerank_metrics = compute_metrics(rerank_top_ids, ground_truth_ids)
    rerank_f1_scores.append(rerank_metrics["f1"])
avg_rerank_f1 = sum(rerank_f1_scores) / len(rerank_f1_scores)

print(f"  Distance top-{ANALYSIS_TOP_N} avg F1: {avg_distance_f1:.3f}")
print(f"  Rerank top-{ANALYSIS_TOP_N} avg F1:   {avg_rerank_f1:.3f}")
print(f"  Delta: {avg_rerank_f1 - avg_distance_f1:+.3f}")

# Threshold comparison
print("\n--- Threshold Comparison (Phase 1 distance vs rerank score) ---")
print(f"  Phase 1 distance threshold {PHASE1_OPTIMAL_THRESHOLD}: avg F1 = {avg_phase1:.3f}")
print(f"  Rerank score threshold {best_threshold:.4f}: avg F1 = {avg_rerank:.3f}")
print(
    f"  Delta: {avg_rerank - avg_phase1:+.3f} ({(avg_rerank - avg_phase1) / avg_phase1 * 100:+.1f}%)"
)

# Optimal rerank threshold details
print("\n--- Optimal Rerank Score Threshold ---")
print(f"  Threshold: {best_threshold:.4f}")
print(f"  F1:        {sweep_f1_scores[best_f1_index]:.3f}")
print(f"  Precision: {sweep_precisions[best_f1_index]:.3f}")
print(f"  Recall:    {sweep_recalls[best_f1_index]:.3f}")
print(f"  MRR:       {sweep_mrrs[best_f1_index]:.3f}")

print("=" * 70)