# Phase 2 — Full Pipeline with Reranking

Independent pipeline that extracts its own memories, builds its own database and test cases,
then runs retrieval experiments with cross-encoder reranking.

Unlike `phase1_reranking_comparison.ipynb` (which reuses Phase 1 data), this notebook
owns its entire pipeline end-to-end, allowing independent prompt/extraction iteration.

In [None]:
# Cell 1 — Setup & Imports
import os
from pathlib import Path

from memory_retrieval.experiments.query_generation import (
    QueryGenerationConfig,
    generate_all_queries,
)
from memory_retrieval.experiments.runner import ExperimentConfig, run_all_experiments
from memory_retrieval.experiments.test_cases import build_test_cases
from memory_retrieval.infra.figures import create_figure_session, save_figure
from memory_retrieval.infra.io import load_json
from memory_retrieval.infra.runs import (
    PHASE2,
    create_run,
    update_run_status,
    get_latest_run,
)
from memory_retrieval.memories.extractor import ExtractionConfig, SituationFormat, extract_memories
from memory_retrieval.search.reranker import Reranker
from memory_retrieval.search.vector import VectorBackend

# Find project root by walking up to pyproject.toml
PROJECT_ROOT = Path.cwd()
while not (PROJECT_ROOT / "pyproject.toml").exists():
    if PROJECT_ROOT.parent == PROJECT_ROOT:
        raise RuntimeError("Could not find project root (pyproject.toml)")
    PROJECT_ROOT = PROJECT_ROOT.parent
os.chdir(PROJECT_ROOT)

# Verify API key
if not os.environ.get("OPENROUTER_API_KEY"):
    print("WARNING: OPENROUTER_API_KEY is not set. Memory building and query generation will fail.")
else:
    print("OPENROUTER_API_KEY is set.")

print(f"Project root: {PROJECT_ROOT}")
print("Imports OK.")

In [None]:
# Cell 2 — Configuration

PROMPT_VERSION = "3.0.0"
MODEL_MEMORIES = "anthropic/claude-haiku-4.5"  # LLM for memory extraction
MODEL_EXPERIMENT = "anthropic/claude-sonnet-4.5"  # LLM for query generation

RAW_DATA_DIR = "data/review_data"

# Search & reranking configuration
SEARCH_LIMIT = 20  # Vector search candidates per query
DISTANCE_THRESHOLD = 1.1  # For pre-rerank metrics comparison

# Rerank text strategies: compare reranking on situation-only vs situation+lesson
RERANK_TEXT_STRATEGIES = {
    "situation_only": lambda c: c["situation"],
    "situation_and_lesson": lambda c: f"situation: {c['situation']}; lesson: {c.get('lesson', '')}",
}

# Run selection: use latest run or select a specific one
# To create a new run: RUN_DIR = None (will be created in Step 1)
# To see available runs: print(list_runs(PHASE2))
# To select specific run: RUN_DIR = get_run(PHASE2, "run_20260209_120000")
# RUN_DIR = get_latest_run(PHASE2)
RUN_DIR = None

# Derived paths (automatic from run directory)
if RUN_DIR is not None:
    MEMORIES_DIR = str(RUN_DIR / "memories")
    DB_PATH = str(RUN_DIR / "memories" / "memories.db")
    TEST_CASES_DIR = str(RUN_DIR / "test_cases")
    QUERIES_DIR = str(RUN_DIR / "queries")
    RESULTS_DIR = str(RUN_DIR / "results")

# Initialize backends
vector_backend = VectorBackend()

print("Configuration:")
print(f"  Using run: {RUN_DIR.name if RUN_DIR else 'None (will create new)'}")
print(f"  Prompt version: {PROMPT_VERSION}")
print(f"  Model (memories): {MODEL_MEMORIES}")
print(f"  Model (experiment): {MODEL_EXPERIMENT}")
print(f"  Search limit: {SEARCH_LIMIT}")
print(f"  Distance threshold: {DISTANCE_THRESHOLD}")
print(f"  Rerank strategies: {list(RERANK_TEXT_STRATEGIES.keys())}")
print(f"  Raw data dir: {RAW_DATA_DIR}")

## Step 1 — Build Memories

Extracts structured memories from raw code review data via LLM.
Each memory contains a **situation description** (25-60 words) and an **actionable lesson** (max 160 chars).

Uses Phase 2 prompts from `data/prompts/phase2`.

Requires `OPENROUTER_API_KEY`.

In [None]:
# Cell 4 — Build Memories: Single File

if RUN_DIR is None:
    run_id, RUN_DIR = create_run(PHASE2)
    MEMORIES_DIR = str(RUN_DIR / "memories")
    DB_PATH = str(RUN_DIR / "memories" / "memories.db")
    TEST_CASES_DIR = str(RUN_DIR / "test_cases")
    QUERIES_DIR = str(RUN_DIR / "queries")
    RESULTS_DIR = str(RUN_DIR / "results")
    print(f"Created new run: {run_id}")

raw_data_path = Path(RAW_DATA_DIR)
raw_files = sorted(raw_data_path.glob("*.json"))

print(f"Found {len(raw_files)} raw data files:")
for i, f in enumerate(raw_files):
    print(f"  [{i}] {f.name}")

if raw_files:
    target_file = raw_files[0]
    print(f"\nProcessing: {target_file.name}")
    extraction_config = ExtractionConfig(
        situation_format=SituationFormat.SINGLE,
        prompts_dir="data/prompts/phase2",
        prompt_version=PROMPT_VERSION,
        model=MODEL_MEMORIES,
    )
    output_path = extract_memories(
        raw_path=str(target_file),
        out_dir=MEMORIES_DIR,
        config=extraction_config,
    )
    print(f"Output saved to: {output_path}")
else:
    print("No raw data files found.")

In [None]:
# Cell 5 — Build Memories: All Files

if RUN_DIR is None:
    run_id, RUN_DIR = create_run(PHASE2)
    MEMORIES_DIR = str(RUN_DIR / "memories")
    DB_PATH = str(RUN_DIR / "memories" / "memories.db")
    TEST_CASES_DIR = str(RUN_DIR / "test_cases")
    QUERIES_DIR = str(RUN_DIR / "queries")
    RESULTS_DIR = str(RUN_DIR / "results")
    print(f"Created new run: {run_id}")

raw_data_path = Path(RAW_DATA_DIR)
raw_files = sorted(raw_data_path.glob("*.json"))

print(f"Processing all {len(raw_files)} raw data files...\n")

extraction_config = ExtractionConfig(
    situation_format=SituationFormat.SINGLE,
    prompts_dir="data/prompts/phase2",
    prompt_version=PROMPT_VERSION,
    model=MODEL_MEMORIES,
)

extraction_results = []
for raw_file in raw_files:
    print(f"Processing: {raw_file.name}")
    try:
        output_path = extract_memories(
            raw_path=str(raw_file),
            out_dir=MEMORIES_DIR,
            config=extraction_config,
        )
        extraction_results.append({"file": raw_file.name, "output": output_path, "status": "ok"})
    except Exception as exc:
        extraction_results.append({"file": raw_file.name, "output": None, "status": str(exc)})
        print(f"  ERROR: {exc}")

success_count = sum(1 for result in extraction_results if result["status"] == "ok")
print(f"\nSummary: {success_count}/{len(extraction_results)} files processed successfully.")

# Update run status
update_run_status(
    RUN_DIR,
    "build_memories",
    {
        "count": success_count,
        "failed": len(extraction_results) - success_count,
        "prompt_version": PROMPT_VERSION,
    },
)

## Step 2 — Create Database

Builds a SQLite database with **sqlite-vec** for vector similarity search.
Loads all accepted memories from JSONL files and indexes their situation descriptions
as 1024-dimensional embeddings (via Ollama `mxbai-embed-large`).

Requires Ollama running locally with the `mxbai-embed-large` model.

In [None]:
# Cell 7 — Rebuild Database
print(f"Rebuilding database for run: {RUN_DIR.name}...")
vector_backend.rebuild_database(db_path=DB_PATH, memories_dir=MEMORIES_DIR)

count = vector_backend.get_memory_count(DB_PATH)
print(f"Database rebuilt. Total memories indexed: {count}")

# Update run status
update_run_status(RUN_DIR, "db", {"memory_count": count})

In [None]:
# Cell 8 — Verify Database: Sample Search
sample_query = "error handling in async functions"
print(f'Sample search: "{sample_query}"\n')

results = vector_backend.search(db_path=DB_PATH, query=sample_query, limit=5)

if results:
    for i, result in enumerate(results):
        print(f"--- Result {i + 1} (distance: {result.raw_score:.4f}) ---")
        print(f"  ID: {result.id}")
        print(f"  Situation: {result.situation}")
        print(f"  Lesson: {result.lesson}")
        print()
else:
    print("No results found. Check that the database is populated and Ollama is running.")

## Step 3 — Create Test Cases

Matches raw PR data to extracted memories to build **ground truth** test cases.
Each test case contains the filtered diff, PR context, and the set of memory IDs that should be retrieved.
PRs with no matching memories are skipped.

In [None]:
# Cell 10 — Build Test Cases
print(f"Building test cases for run: {RUN_DIR.name}...\n")
build_test_cases(
    raw_dir=RAW_DATA_DIR,
    memories_dir=MEMORIES_DIR,
    output_dir=TEST_CASES_DIR,
)

test_case_files = sorted(Path(TEST_CASES_DIR).glob("*.json"))
print(f"\nGenerated {len(test_case_files)} test cases:")
for test_case_file in test_case_files:
    test_case = load_json(str(test_case_file))
    ground_truth_count = test_case.get(
        "ground_truth_count", len(test_case.get("ground_truth_memory_ids", []))
    )
    print(f"  {test_case_file.name} — {ground_truth_count} ground truth memories")

# Update run status
update_run_status(RUN_DIR, "test_cases", {"count": len(test_case_files)})

## Step 4 — Generate Queries

Generates search queries from each test case's PR context and diff via LLM.
Queries are saved as separate JSON files in the `queries/` directory so they can be
reused across multiple experiment runs without re-calling the API.

Requires `OPENROUTER_API_KEY`.

In [None]:
# Cell 12 — Generate Queries for All Test Cases
print(f"Generating queries for run: {RUN_DIR.name}...\n")

query_config = QueryGenerationConfig(
    prompts_dir="data/prompts/phase2",
    prompt_version=PROMPT_VERSION,
    model=MODEL_EXPERIMENT,
)
all_query_data = generate_all_queries(
    test_cases_dir=TEST_CASES_DIR,
    queries_dir=QUERIES_DIR,
    config=query_config,
    db_path=DB_PATH,
    search_backend=vector_backend,
)

successful_queries = [data for data in all_query_data if "queries" in data]
total_queries = sum(len(data["queries"]) for data in successful_queries)
print(
    f"\nGenerated queries for {len(successful_queries)} test cases ({total_queries} total queries)"
)

# Update run status
update_run_status(
    RUN_DIR,
    "query_generation",
    {
        "count": len(successful_queries),
        "total_queries": total_queries,
        "model": MODEL_EXPERIMENT,
        "prompt_version": PROMPT_VERSION,
    },
)

## Step 5 — Run Experiments with Reranking

For each test case:
1. Loads pre-generated search queries from the `queries/` directory
2. Vector search for each query (top-20 candidates)
3. Pool and deduplicate results across queries
4. Rerank all candidates with cross-encoder (bge-reranker-v2-m3)
5. Store all reranked results for downstream analysis (top-N sweep, threshold sweep)
6. Compute pre-rerank metrics as baseline

Requires Ollama with `mxbai-embed-large`. Does NOT require `OPENROUTER_API_KEY`.

In [None]:
# Cell 13 — Run All Experiments with Reranking
print(f"Running all experiments for run: {RUN_DIR.name}...\n")

reranker = Reranker()

config = ExperimentConfig(
    search_backend=vector_backend,
    search_limit=SEARCH_LIMIT,
    distance_threshold=DISTANCE_THRESHOLD,
    reranker=reranker,
    rerank_text_strategies=RERANK_TEXT_STRATEGIES,
)
all_results = run_all_experiments(
    test_cases_dir=TEST_CASES_DIR,
    queries_dir=QUERIES_DIR,
    db_path=DB_PATH,
    results_dir=RESULTS_DIR,
    config=config,
)

# Update run status
successful = [result for result in all_results if "pre_rerank_metrics" in result]
update_run_status(
    RUN_DIR,
    "experiment",
    {
        "count": len(successful),
        "failed": len(all_results) - len(successful),
        "rerank_strategies": list(RERANK_TEXT_STRATEGIES.keys()),
    },
)

## Step 6 — Results Analysis

Fair comparison of distance-based vs rerank-based retrieval at the same N,
top-N sweep with distance baseline, score distribution analysis,
and **rerank score threshold sweep** (main deliverable).

In [None]:
# Cell A — Fair Top-N Comparison (same N for distance and rerank)
import numpy as np

from retrieval_metrics.compute import compute_threshold_metrics, compute_top_n_metrics
from retrieval_metrics.sweeps import find_optimal_entry

from memory_retrieval.experiments.metrics import pool_and_deduplicate_by_distance
from memory_retrieval.experiments.metrics_adapter import (
    restriction_evaluation_to_dict,
    threshold_sweep_from_experiments as sweep_threshold,
    top_n_sweep_from_experiments as sweep_top_n,
)


def compute_metrics_at_top_n(ranked_results, ground_truth_ids, top_n, id_field="id"):
    evaluation = compute_top_n_metrics(ranked_results, ground_truth_ids, top_n, id_key=id_field)
    return restriction_evaluation_to_dict(evaluation)


def compute_metrics_at_threshold(
    ranked_results,
    ground_truth_ids,
    threshold,
    score_field,
    higher_is_better,
    id_field="id",
):
    evaluation = compute_threshold_metrics(
        ranked_results,
        ground_truth_ids,
        threshold,
        score_key=score_field,
        higher_is_better=higher_is_better,
        id_key=id_field,
    )
    return restriction_evaluation_to_dict(evaluation, include_accepted_count=True)


def find_optimal_threshold(sweep_results, metric="f1"):
    return find_optimal_entry(sweep_results, metric_key=metric)

# Analysis parameter: how many results to compare at
ANALYSIS_TOP_N = 4


# --- Helper (strategy-specific, stays in notebook) ---
def get_reranked_results(data, strategy_name):
    """Get reranked_results for a strategy, handling single vs multi-strategy format."""
    if "rerank_strategies" in data and strategy_name in data["rerank_strategies"]:
        return data["rerank_strategies"][strategy_name]["reranked_results"]
    return data.get("reranked_results", [])


# --- Load results ---
from pathlib import Path

results_path = Path(RESULTS_DIR)
result_files = sorted(results_path.glob("*.json"))
all_data = [load_json(str(file_path)) for file_path in result_files]

successful = [data for data in all_data if "pre_rerank_metrics" in data]
has_strategies = successful and "rerank_strategies" in successful[0]
strategy_names = list(RERANK_TEXT_STRATEGIES.keys()) if has_strategies else ["default"]

top_n = ANALYSIS_TOP_N

for strategy_name in strategy_names:
    print(f"\n{'=' * 100}")
    print(f"Fair Top-{top_n} Comparison: distance vs rerank — Strategy: {strategy_name}")
    print(f"{'=' * 100}")
    print(
        f"{'Test Case':<25} {'Dist F1':>8} {'Rank F1':>8} {'Delta':>8} {'Dist P':>7} {'Rank P':>7} {'Dist R':>7} {'Rank R':>7} {'GT':>4}"
    )
    print("-" * 90)

    distance_f1_scores, rerank_f1_scores = [], []

    for data in successful:
        test_case_id = data.get("test_case_id", "?")[:25]
        ground_truth_ids = set(data.get("ground_truth", {}).get("memory_ids", []))

        # Distance baseline: pool + dedup + top-N by distance
        pooled_by_dist = pool_and_deduplicate_by_distance(data.get("queries", []))
        distance_metrics = compute_metrics_at_top_n(pooled_by_dist, ground_truth_ids, top_n)

        # Rerank: top-N by rerank score
        reranked = get_reranked_results(data, strategy_name)
        rerank_metrics = compute_metrics_at_top_n(reranked, ground_truth_ids, top_n)

        delta = rerank_metrics["f1"] - distance_metrics["f1"]
        marker = "+" if delta > 0.001 else "-" if delta < -0.001 else "="
        print(
            f"{test_case_id:<25} {distance_metrics['f1']:>8.3f} {rerank_metrics['f1']:>8.3f} {delta:>+7.3f}{marker} {distance_metrics['precision']:>7.3f} {rerank_metrics['precision']:>7.3f} {distance_metrics['recall']:>7.3f} {rerank_metrics['recall']:>7.3f} {len(ground_truth_ids):>4}"
        )

        distance_f1_scores.append(distance_metrics["f1"])
        rerank_f1_scores.append(rerank_metrics["f1"])

    avg_distance_f1 = np.mean(distance_f1_scores)
    avg_rerank_f1 = np.mean(rerank_f1_scores)
    delta = avg_rerank_f1 - avg_distance_f1
    marker = "+" if delta > 0.001 else "-" if delta < -0.001 else "="
    print("-" * 90)
    print(f"{'AVERAGE':<25} {avg_distance_f1:>8.3f} {avg_rerank_f1:>8.3f} {delta:>+7.3f}{marker}")

    improved = sum(
        1
        for dist_f1, rank_f1 in zip(distance_f1_scores, rerank_f1_scores)
        if rank_f1 > dist_f1 + 0.001
    )
    same = sum(
        1
        for dist_f1, rank_f1 in zip(distance_f1_scores, rerank_f1_scores)
        if abs(rank_f1 - dist_f1) <= 0.001
    )
    worse = sum(
        1
        for dist_f1, rank_f1 in zip(distance_f1_scores, rerank_f1_scores)
        if rank_f1 < dist_f1 - 0.001
    )
    print(
        f"\nReranking helped: {improved}/{len(successful)} | Same: {same}/{len(successful)} | Hurt: {worse}/{len(successful)}"
    )

# Strategy comparison summary
if len(strategy_names) > 1:
    print(f"\n{'=' * 60}")
    print(f"STRATEGY COMPARISON (fair top-{top_n} avg F1)")
    print(f"{'=' * 60}")

    # Distance baseline (same for all strategies)
    distance_experiments = [
        {
            "ground_truth_ids": set(data.get("ground_truth", {}).get("memory_ids", [])),
            "ranked_results": pool_and_deduplicate_by_distance(data.get("queries", [])),
        }
        for data in successful
    ]
    distance_sweep_at_n = sweep_top_n(distance_experiments, [top_n])
    print(f"  Distance top-{top_n}:        {distance_sweep_at_n[0]['f1']:.3f}")

    for strategy_name in strategy_names:
        strategy_experiments = [
            {
                "ground_truth_ids": set(data.get("ground_truth", {}).get("memory_ids", [])),
                "ranked_results": get_reranked_results(data, strategy_name),
            }
            for data in successful
        ]
        strategy_sweep_at_n = sweep_top_n(strategy_experiments, [top_n])
        delta = strategy_sweep_at_n[0]["f1"] - distance_sweep_at_n[0]["f1"]
        print(
            f"  {strategy_name:<22} {strategy_sweep_at_n[0]['f1']:.3f} ({delta:+.3f} vs distance)"
        )

In [None]:
# Cell B — Top-N Sweep: Distance Baseline vs Rerank Strategies
import matplotlib.pyplot as plt
import numpy as np

if "FIGURE_SESSION" not in globals() or FIGURE_SESSION.context.get("run_id") != RUN_DIR.name:
    FIGURE_SESSION = create_figure_session(
        root_dir=RUN_DIR / "figures",
        notebook_slug="phase2",
        context_key=RUN_DIR.name,
        context={"phase": PHASE2, "run_id": RUN_DIR.name},
    )
print(f"Figure export session: {FIGURE_SESSION.session_dir}")

num_test_cases = len(successful)
max_n = 20
n_values = list(range(1, max_n + 1))

# Distance baseline sweep
distance_experiments = [
    {
        "ground_truth_ids": set(data.get("ground_truth", {}).get("memory_ids", [])),
        "ranked_results": pool_and_deduplicate_by_distance(data.get("queries", [])),
    }
    for data in successful
]
distance_sweep = sweep_top_n(distance_experiments, n_values)
distance_f1_scores = [entry["f1"] for entry in distance_sweep]

# Per-strategy rerank sweeps
strategy_sweeps = {}
for key in strategy_names:
    rerank_experiments = [
        {
            "ground_truth_ids": set(data.get("ground_truth", {}).get("memory_ids", [])),
            "ranked_results": get_reranked_results(data, key),
        }
        for data in successful
    ]
    strategy_sweeps[key] = sweep_top_n(rerank_experiments, n_values)

# Print table
print(f"Top-N Sweep: Distance vs Rerank (averaged over {num_test_cases} test cases)\n")
header = f"{'N':>4} {'Dist F1':>8}"
for key in strategy_names:
    header += f" {key[:16] + ' F1':>18} {'Delta':>7}"
print(header)
print("-" * (14 + 27 * len(strategy_names)))
for top_n in [1, 2, 3, 4, 5, 6, 8, 10, 15, 20]:
    if top_n <= max_n:
        index = top_n - 1
        row = f"{top_n:>4} {distance_f1_scores[index]:>8.3f}"
        for key in strategy_names:
            strategy_f1 = strategy_sweeps[key][index]["f1"]
            delta = strategy_f1 - distance_f1_scores[index]
            row += f" {strategy_f1:>18.3f} {delta:>+7.3f}"
        print(row)

# --- Plots ---
colors_rerank = ["#9b59b6", "#e67e22", "#2ecc71", "#e74c3c"]
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Left: F1 comparison — distance baseline + all strategies
ax = axes[0]
ax.plot(
    n_values,
    distance_f1_scores,
    label="Distance (baseline)",
    color="#3498db",
    linewidth=2.5,
    marker="o",
    markersize=5,
    linestyle="--",
)
for i, key in enumerate(strategy_names):
    strategy_f1_values = [entry["f1"] for entry in strategy_sweeps[key]]
    ax.plot(
        n_values,
        strategy_f1_values,
        label=f"Rerank: {key}",
        color=colors_rerank[i % len(colors_rerank)],
        linewidth=2,
        marker="^",
        markersize=4,
    )
ax.axvline(
    x=ANALYSIS_TOP_N, color="gray", linestyle=":", alpha=0.5, label=f"Default N={ANALYSIS_TOP_N}"
)
ax.set_xlabel("Top-N (results kept)")
ax.set_ylabel("F1 Score")
ax.set_title(f"F1 vs Top-N: Distance vs Rerank (avg over {num_test_cases} test cases)")
ax.set_xticks(n_values)
ax.set_ylim(0, 1.05)
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# Right: P/R/F1 detail for best strategy + distance baseline
best_strategy = max(strategy_names, key=lambda k: strategy_sweeps[k][ANALYSIS_TOP_N - 1]["f1"])
best_sweep = strategy_sweeps[best_strategy]
best_precisions = [entry["precision"] for entry in best_sweep]
best_recalls = [entry["recall"] for entry in best_sweep]
best_f1_values = [entry["f1"] for entry in best_sweep]

ax = axes[1]
ax.plot(
    n_values,
    distance_f1_scores,
    label="F1 (distance)",
    color="#3498db",
    linewidth=2,
    linestyle="--",
    alpha=0.7,
)
ax.plot(
    n_values,
    best_precisions,
    label=f"Precision ({best_strategy})",
    color="#e74c3c",
    linewidth=1.5,
    marker="o",
    markersize=3,
)
ax.plot(
    n_values,
    best_recalls,
    label=f"Recall ({best_strategy})",
    color="#2ecc71",
    linewidth=1.5,
    marker="s",
    markersize=3,
)
ax.plot(
    n_values,
    best_f1_values,
    label=f"F1 ({best_strategy})",
    color="#9b59b6",
    linewidth=2.5,
    marker="^",
    markersize=4,
)
ax.axvline(
    x=ANALYSIS_TOP_N, color="gray", linestyle=":", alpha=0.5, label=f"Default N={ANALYSIS_TOP_N}"
)
ax.set_xlabel("Top-N (results kept)")
ax.set_ylabel("Score")
ax.set_title(f"P/R/F1 vs Top-N: {best_strategy} (with distance F1 baseline)")
ax.set_xticks(n_values)
ax.set_ylim(0, 1.05)
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

plt.tight_layout()
saved_paths = save_figure(
    fig,
    FIGURE_SESSION,
    "rerank_topn_sweep",
    title="F1 vs Top-N: Distance vs Rerank",
)
plt.show()
print(f"Saved: {saved_paths['png']}")

In [None]:
# Cell C — Rerank Score Distribution (Per Strategy)
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import gaussian_kde

if "FIGURE_SESSION" not in globals() or FIGURE_SESSION.context.get("run_id") != RUN_DIR.name:
    FIGURE_SESSION = create_figure_session(
        root_dir=RUN_DIR / "figures",
        notebook_slug="phase2",
        context_key=RUN_DIR.name,
        context={"phase": PHASE2, "run_id": RUN_DIR.name},
    )

n_strategies = len(strategy_names)
fig, axes = plt.subplots(n_strategies, 3, figsize=(20, 5 * n_strategies), squeeze=False)

for row, key in enumerate(strategy_names):
    gt_scores, non_gt_scores = [], []
    gt_distances, non_gt_distances = [], []
    per_test_case_gt_scores = {}
    per_test_case_non_gt_scores = {}

    for data in successful:
        test_case_id = data.get("test_case_id", "?")
        test_case_label = test_case_id[:20]
        reranked = get_reranked_results(data, key)
        tc_gt, tc_non_gt = [], []
        for result in reranked:
            if result.get("is_ground_truth"):
                gt_scores.append(result["rerank_score"])
                gt_distances.append(result["distance"])
                tc_gt.append(result["rerank_score"])
            else:
                non_gt_scores.append(result["rerank_score"])
                non_gt_distances.append(result["distance"])
                tc_non_gt.append(result["rerank_score"])
        if tc_gt or tc_non_gt:
            per_test_case_gt_scores[test_case_label] = tc_gt
            per_test_case_non_gt_scores[test_case_label] = tc_non_gt

    gt_arr = np.array(gt_scores)
    non_gt_arr = np.array(non_gt_scores)

    # --- Column 1: Density histogram (non-GT) + KDE overlay + GT strip markers ---
    ax = axes[row][0]
    all_scores = (
        np.concatenate([gt_arr, non_gt_arr])
        if len(gt_arr) and len(non_gt_arr)
        else (gt_arr if len(gt_arr) else non_gt_arr)
    )
    score_min, score_max = all_scores.min(), all_scores.max()
    bins = np.linspace(score_min, score_max, 30)

    if len(non_gt_arr):
        ax.hist(
            non_gt_arr,
            bins=bins,
            alpha=0.5,
            density=True,
            label=f"Non-GT (n={len(non_gt_arr)})",
            color="#e74c3c",
            edgecolor="white",
            linewidth=0.5,
        )
        # KDE overlay for non-GT
        if len(non_gt_arr) > 3:
            kde = gaussian_kde(non_gt_arr)
            x_kde = np.linspace(score_min, score_max, 200)
            ax.plot(x_kde, kde(x_kde), color="#c0392b", linewidth=2, label="Non-GT KDE")
        ax.axvline(
            np.median(non_gt_arr),
            color="#c0392b",
            linestyle="--",
            linewidth=1.5,
            label=f"Non-GT median: {np.median(non_gt_arr):.4f}",
        )

    if len(gt_arr):
        # GT as rug/strip plot — individual vertical lines at bottom
        y_max = ax.get_ylim()[1]
        strip_height = y_max * 0.08
        for score in gt_arr:
            ax.plot(
                [score, score],
                [0, strip_height],
                color="#2ecc71",
                linewidth=2.5,
                alpha=0.8,
                zorder=5,
            )
        # Invisible scatter for legend entry
        ax.scatter(
            [],
            [],
            color="#2ecc71",
            marker="|",
            s=100,
            linewidth=2.5,
            label=f"GT individuals (n={len(gt_arr)})",
        )
        ax.axvline(
            np.median(gt_arr),
            color="#27ae60",
            linestyle="--",
            linewidth=1.5,
            label=f"GT median: {np.median(gt_arr):.4f}",
        )

    ax.set_xlabel("Rerank Score")
    ax.set_ylabel("Density")
    ax.set_title(f"Rerank Score Distribution: {key}")
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)

    # --- Column 2: Rerank score vs distance scatter (unchanged) ---
    ax = axes[row][1]
    if gt_scores:
        ax.scatter(
            gt_distances,
            gt_scores,
            alpha=0.7,
            label="Ground Truth",
            color="#2ecc71",
            s=60,
            zorder=3,
            edgecolors="white",
            linewidth=0.3,
        )
    if non_gt_scores:
        ax.scatter(
            non_gt_distances,
            non_gt_scores,
            alpha=0.4,
            label="Non-GT",
            color="#e74c3c",
            s=30,
            zorder=2,
            edgecolors="white",
            linewidth=0.3,
        )
    ax.set_xlabel("Vector Distance")
    ax.set_ylabel("Rerank Score")
    ax.set_title(f"Rerank Score vs Distance: {key}")
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)

    # --- Column 3: Per-test-case box/strip comparison ---
    ax = axes[row][2]
    test_case_labels = sorted(per_test_case_gt_scores.keys() | per_test_case_non_gt_scores.keys())
    y_positions = list(range(len(test_case_labels)))

    for y_pos, label in enumerate(test_case_labels):
        tc_gt = per_test_case_gt_scores.get(label, [])
        tc_non_gt = per_test_case_non_gt_scores.get(label, [])

        # Non-GT as box plot
        if tc_non_gt:
            box = ax.boxplot(
                [tc_non_gt],
                positions=[y_pos],
                vert=False,
                widths=0.4,
                patch_artist=True,
                boxprops=dict(facecolor="#e74c3c", alpha=0.3, edgecolor="#c0392b"),
                medianprops=dict(color="#c0392b", linewidth=1.5),
                whiskerprops=dict(color="#c0392b", alpha=0.5),
                capprops=dict(color="#c0392b", alpha=0.5),
                flierprops=dict(marker=".", markerfacecolor="#e74c3c", markersize=3, alpha=0.4),
                manage_ticks=False,
            )

        # GT as individual strip markers
        if tc_gt:
            ax.scatter(
                tc_gt,
                [y_pos] * len(tc_gt),
                color="#2ecc71",
                marker="D",
                s=50,
                zorder=5,
                edgecolors="white",
                linewidth=0.5,
                alpha=0.9,
            )

    ax.set_yticks(y_positions)
    ax.set_yticklabels(test_case_labels, fontsize=7)
    ax.set_xlabel("Rerank Score")
    ax.set_title(f"Per-Test-Case Scores: {key}")
    ax.grid(True, alpha=0.3, axis="x")
    # Legend
    ax.scatter([], [], color="#2ecc71", marker="D", s=50, label="GT")
    ax.plot([], [], color="#e74c3c", linewidth=6, alpha=0.3, label="Non-GT (box)")
    ax.legend(fontsize=8, loc="lower right")

    # --- Statistics ---
    if len(gt_arr):
        print(
            f"[{key}] GT rerank scores:     n={len(gt_arr):>3}, min={gt_arr.min():.4f}, max={gt_arr.max():.4f}, mean={gt_arr.mean():.4f}, median={np.median(gt_arr):.4f}"
        )
    if len(non_gt_arr):
        print(
            f"[{key}] Non-GT rerank scores: n={len(non_gt_arr):>3}, min={non_gt_arr.min():.4f}, max={non_gt_arr.max():.4f}, mean={non_gt_arr.mean():.4f}, median={np.median(non_gt_arr):.4f}"
        )
    if len(gt_arr) and len(non_gt_arr):
        separation = gt_arr.mean() - non_gt_arr.mean()
        print(
            f"[{key}] Mean separation: {separation:.4f} ({'good' if separation > 1.0 else 'moderate' if separation > 0.5 else 'weak'})"
        )
    # Per-test-case counts
    gt_per_tc = [len(per_test_case_gt_scores.get(label, [])) for label in test_case_labels]
    non_gt_per_tc = [len(per_test_case_non_gt_scores.get(label, [])) for label in test_case_labels]
    print(f"[{key}] Per-test-case GT counts: {gt_per_tc} (total unique memories: {len(gt_arr)})")
    print(
        f"[{key}] Per-test-case Non-GT counts: {non_gt_per_tc} (total unique memories: {len(non_gt_arr)})"
    )
    print()

plt.tight_layout()
saved_paths = save_figure(
    fig,
    FIGURE_SESSION,
    "rerank_score_distribution",
    title="Rerank Score Distribution",
)
plt.show()
print(f"Saved: {saved_paths['png']}")

In [None]:
# Cell D — Rerank Score Threshold Sweep (MAIN DELIVERABLE)
import matplotlib.pyplot as plt
import numpy as np

if "FIGURE_SESSION" not in globals() or FIGURE_SESSION.context.get("run_id") != RUN_DIR.name:
    FIGURE_SESSION = create_figure_session(
        root_dir=RUN_DIR / "figures",
        notebook_slug="phase2",
        context_key=RUN_DIR.name,
        context={"phase": PHASE2, "run_id": RUN_DIR.name},
    )

# Build per-strategy experiment lists for sweep
experiments_by_strategy = {}
for key in strategy_names:
    experiment_results = []
    for data in successful:
        ground_truth_ids = set(data.get("ground_truth", {}).get("memory_ids", []))
        reranked = get_reranked_results(data, key)
        experiment_results.append(
            {
                "test_case_id": data["test_case_id"],
                "ground_truth_ids": ground_truth_ids,
                "ranked_results": reranked,
                "reranked": reranked,
            }
        )
    experiments_by_strategy[key] = experiment_results

# Determine sweep range from data
all_max_scores = []
for key in strategy_names:
    for experiment_result in experiments_by_strategy[key]:
        if experiment_result["ranked_results"]:
            all_max_scores.append(
                max(result["rerank_score"] for result in experiment_result["ranked_results"])
            )
global_max = max(all_max_scores) if all_max_scores else 1.0
sweep_thresholds = list(np.arange(0.0, global_max + 0.01, 0.005))

# Sweep thresholds per strategy using centralized metrics
strategy_sweep_results = {}
for key in strategy_names:
    sweep_results = sweep_threshold(
        experiments_by_strategy[key],
        sweep_thresholds,
        score_field="rerank_score",
        higher_is_better=True,
    )
    optimal = find_optimal_threshold(sweep_results, metric="f1")
    strategy_sweep_results[key] = {
        "sweep": sweep_results,
        "best_f1_index": optimal["index"],
        "best_threshold": optimal["threshold"],
    }

# --- Print optimal thresholds ---
print("OPTIMAL THRESHOLDS PER STRATEGY")
best_overall_strategy = None
best_overall_f1 = -1
for key in strategy_names:
    sweep_data = strategy_sweep_results[key]
    best_entry = sweep_data["sweep"][sweep_data["best_f1_index"]]
    print(
        f"  {key + ':':30} threshold={sweep_data['best_threshold']:.4f}  F1={best_entry['f1']:.3f}  P={best_entry['precision']:.3f}  R={best_entry['recall']:.3f}  MRR={best_entry['mrr']:.3f}"
    )
    if best_entry["f1"] > best_overall_f1:
        best_overall_f1 = best_entry["f1"]
        best_overall_strategy = key

print(f"\nBest strategy: {best_overall_strategy}")

# --- Plots ---
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Left: P/R/F1/MRR vs threshold for best strategy
best_sweep = strategy_sweep_results[best_overall_strategy]["sweep"]
best_precisions = [entry["precision"] for entry in best_sweep]
best_recalls = [entry["recall"] for entry in best_sweep]
best_f1_values = [entry["f1"] for entry in best_sweep]
best_mrr_values = [entry["mrr"] for entry in best_sweep]
best_threshold = strategy_sweep_results[best_overall_strategy]["best_threshold"]

ax = axes[0]
ax.plot(sweep_thresholds, best_precisions, label="Precision", color="#3498db", linewidth=2)
ax.plot(sweep_thresholds, best_recalls, label="Recall", color="#2ecc71", linewidth=2)
ax.plot(sweep_thresholds, best_f1_values, label="F1", color="#9b59b6", linewidth=2.5)
ax.plot(sweep_thresholds, best_mrr_values, label="MRR", color="#e67e22", linewidth=1.5, alpha=0.7)
ax.axvline(
    best_threshold,
    color="#e74c3c",
    linestyle="--",
    linewidth=1.5,
    label=f"Optimal @ {best_threshold:.4f}",
)
ax.set_xlabel("Rerank Score Threshold (accept >= threshold)")
ax.set_ylabel("Score")
ax.set_title(f"P/R/F1/MRR vs Threshold: {best_overall_strategy}")
ax.legend(fontsize=8)
ax.set_ylim(0, 1.05)
ax.grid(True, alpha=0.3)

# Right: F1 vs threshold for ALL strategies
ax = axes[1]
colors_strategy = ["#9b59b6", "#e67e22", "#2ecc71", "#e74c3c"]
for i, key in enumerate(strategy_names):
    strategy_f1_values = [entry["f1"] for entry in strategy_sweep_results[key]["sweep"]]
    ax.plot(
        sweep_thresholds,
        strategy_f1_values,
        label=key,
        color=colors_strategy[i % len(colors_strategy)],
        linewidth=2,
    )
    ax.axvline(
        strategy_sweep_results[key]["best_threshold"],
        color=colors_strategy[i % len(colors_strategy)],
        linestyle=":",
        alpha=0.5,
        linewidth=1,
    )
ax.set_xlabel("Rerank Score Threshold")
ax.set_ylabel("F1 Score")
ax.set_title("F1 vs Threshold: All Strategies")
ax.legend(fontsize=8)
ax.set_ylim(0, 1.05)
ax.grid(True, alpha=0.3)

plt.tight_layout()
saved_paths = save_figure(
    fig,
    FIGURE_SESSION,
    "rerank_threshold_sweep",
    title="Rerank Score Threshold Sweep",
)
plt.show()
print(f"Saved: {saved_paths['png']}")
# --- Threshold table for best strategy ---
print(
    f"\nThreshold table: {best_overall_strategy} (macro-averaged over {len(successful)} test cases)"
)
print(
    f"{'Threshold':>10} {'Precision':>10} {'Recall':>8} {'F1':>8} {'MRR':>8} {'Avg Accepted':>13} {'Avg GT Kept':>12}"
)
print("-" * 75)

table_thresholds = sorted(
    set(
        [
            0.001,
            0.005,
            0.01,
            0.02,
            0.03,
            0.05,
            0.08,
            0.10,
            0.15,
            0.20,
            0.30,
            0.50,
            round(best_threshold, 4),
        ]
    )
)
experiment_results = experiments_by_strategy[best_overall_strategy]
for threshold in table_thresholds:
    if threshold > global_max + 0.01:
        continue
    # Find closest sweep index
    index = min(
        range(len(sweep_thresholds)), key=lambda idx: abs(sweep_thresholds[idx] - threshold)
    )
    sweep_entry = strategy_sweep_results[best_overall_strategy]["sweep"][index]
    avg_accepted = np.mean(
        [
            len(
                [
                    result
                    for result in experiment_result["reranked"]
                    if result["rerank_score"] >= threshold
                ]
            )
            for experiment_result in experiment_results
        ]
    )
    avg_ground_truth_kept = np.mean(
        [
            len(
                {
                    result["id"]
                    for result in experiment_result["reranked"]
                    if result["rerank_score"] >= threshold
                }
                & experiment_result["ground_truth_ids"]
            )
            for experiment_result in experiment_results
        ]
    )
    marker = " <-- optimal" if abs(threshold - best_threshold) < 0.003 else ""
    print(
        f"{threshold:>10.4f} {sweep_entry['precision']:>10.3f} {sweep_entry['recall']:>8.3f} {sweep_entry['f1']:>8.3f} {sweep_entry['mrr']:>8.3f} {avg_accepted:>13.1f} {avg_ground_truth_kept:>12.1f}{marker}"
    )

In [None]:
# Cell E — Per-Test-Case Impact at Optimal Threshold
import numpy as np

top_n = ANALYSIS_TOP_N
key = best_overall_strategy
best_threshold = strategy_sweep_results[key]["best_threshold"]
experiment_results = experiments_by_strategy[key]

print("Per-test-case comparison at best strategy's optimal threshold")
print(f"  Strategy: {key}")
print(f"  Threshold: score >= {best_threshold:.4f}")
print(f"  Distance baseline: top-{top_n} by distance")
print()

print(
    f"{'Test Case':<20} {'Dist F1':>8} {'Thr F1':>8} {'Delta':>8} {'Dist P':>7} {'Thr P':>7} {'Dist R':>7} {'Thr R':>7} {'D Acc':>6} {'T Acc':>6} {'GT':>4}"
)
print("-" * 105)

distance_f1_scores, threshold_f1_scores = [], []

for data, experiment_result in zip(successful, experiment_results):
    test_case_id = data["test_case_id"]
    ground_truth_ids = experiment_result["ground_truth_ids"]
    ground_truth_count = len(ground_truth_ids)

    # Distance baseline: top-N
    pooled = pool_and_deduplicate_by_distance(data.get("queries", []))
    distance_metrics = compute_metrics_at_top_n(pooled, ground_truth_ids, top_n)

    # Threshold-based
    threshold_metrics = compute_metrics_at_threshold(
        experiment_result["reranked"],
        ground_truth_ids,
        best_threshold,
        score_field="rerank_score",
        higher_is_better=True,
    )

    delta = threshold_metrics["f1"] - distance_metrics["f1"]
    marker = "+" if delta > 0.001 else "-" if delta < -0.001 else "="
    print(
        f"{test_case_id:<20} {distance_metrics['f1']:>8.3f} {threshold_metrics['f1']:>8.3f} {delta:>+7.3f}{marker} {distance_metrics['precision']:>7.1%} {threshold_metrics['precision']:>7.1%} {distance_metrics['recall']:>7.1%} {threshold_metrics['recall']:>7.1%} {len(distance_metrics['retrieved_ids']):>6} {threshold_metrics['accepted_count']:>6} {ground_truth_count:>4}"
    )

    # Show missed GT with their rerank scores
    missed = ground_truth_ids - threshold_metrics["retrieved_ids"]
    if missed:
        for memory_id in sorted(missed):
            score = next(
                (
                    result["rerank_score"]
                    for result in experiment_result["reranked"]
                    if result["id"] == memory_id
                ),
                None,
            )
            if score is not None:
                print(
                    f"  {'':20} Missed: {memory_id}  score={score:.4f} (below threshold {best_threshold:.4f})"
                )
            else:
                print(f"  {'':20} Missed: {memory_id}  NOT IN CANDIDATE POOL (query miss)")

    distance_f1_scores.append(distance_metrics["f1"])
    threshold_f1_scores.append(threshold_metrics["f1"])

avg_distance_f1 = np.mean(distance_f1_scores)
avg_threshold_f1 = np.mean(threshold_f1_scores)
delta = avg_threshold_f1 - avg_distance_f1
print("-" * 105)
print(
    f"{'AVERAGE':<20} {avg_distance_f1:>8.3f} {avg_threshold_f1:>8.3f} {delta:>+7.3f}{'+' if delta > 0.001 else '-' if delta < -0.001 else '='}"
)

improved = sum(
    1
    for dist_f1, thr_f1 in zip(distance_f1_scores, threshold_f1_scores)
    if thr_f1 > dist_f1 + 0.001
)
same = sum(
    1
    for dist_f1, thr_f1 in zip(distance_f1_scores, threshold_f1_scores)
    if abs(thr_f1 - dist_f1) <= 0.001
)
worse = sum(
    1
    for dist_f1, thr_f1 in zip(distance_f1_scores, threshold_f1_scores)
    if thr_f1 < dist_f1 - 0.001
)
print(
    f"\nThreshold helped: {improved}/{len(successful)} | Same: {same}/{len(successful)} | Hurt: {worse}/{len(successful)}"
)

## Step 7 — Store Config Fingerprint & Run Summary

Stores a config fingerprint in `run.json` and generates a `run_summary.json` for
cross-run comparison. The fingerprint captures all pipeline parameters that affect results.
The summary precomputes metrics for fast comparison without loading full result files.

See `notebooks/comparison/cross_run_comparison.ipynb` for cross-run analysis.

In [None]:
# Cell F — Store Config Fingerprint & Generate Run Summary
from memory_retrieval.experiments.comparison import (
    build_config_fingerprint,
    generate_run_summary,
)
from memory_retrieval.infra.runs import update_config_fingerprint

# Build fingerprint from this run's configuration
fingerprint = build_config_fingerprint(
    extraction_prompt_version=PROMPT_VERSION,
    embedding_model="mxbai-embed-large",
    search_backend="vector",
    search_limit=SEARCH_LIMIT,
    distance_threshold=DISTANCE_THRESHOLD,
    query_model=MODEL_EXPERIMENT,
    query_prompt_version=PROMPT_VERSION,
    reranker_model="BAAI/bge-reranker-v2-m3",
    rerank_text_strategies=list(RERANK_TEXT_STRATEGIES.keys()),
)

# Store in run.json
update_config_fingerprint(RUN_DIR, fingerprint)
print(f"Config fingerprint stored: {fingerprint['fingerprint_hash']}")
for key, value in fingerprint.items():
    if key != "fingerprint_hash":
        print(f"  {key}: {value}")

# Generate run summary
summary = generate_run_summary(RUN_DIR, strategies=list(RERANK_TEXT_STRATEGIES.keys()))
macro = summary.get("macro_averaged", {})
print(f"\nRun summary generated: {RUN_DIR / 'run_summary.json'}")
print(f"  Test cases: {summary['num_test_cases']}")

# Pre-rerank: overfetched (raw) metrics
pre_rerank = macro.get("pre_rerank", {})
overfetched = pre_rerank.get("overfetched", {})
print(
    f"\n  Pre-rerank (overfetched, raw): F1={overfetched.get('f1', 0):.3f} "
    f"P={overfetched.get('precision', 0):.3f} R={overfetched.get('recall', 0):.3f}"
)

# Pre-rerank: at optimal distance threshold
at_optimal_distance = pre_rerank.get("at_optimal_distance_threshold", {})
if at_optimal_distance:
    print(
        f"  Pre-rerank (optimal distance threshold): F1={at_optimal_distance.get('f1', 0):.3f} "
        f"P={at_optimal_distance.get('precision', 0):.3f} R={at_optimal_distance.get('recall', 0):.3f} "
        f"@ threshold={at_optimal_distance.get('optimal_threshold', 0):.4f}"
    )

# Pre-rerank: at optimal top-N
at_optimal_top_n = pre_rerank.get("at_optimal_top_n", {})
if at_optimal_top_n:
    print(
        f"  Pre-rerank (optimal top-N): F1={at_optimal_top_n.get('f1', 0):.3f} "
        f"P={at_optimal_top_n.get('precision', 0):.3f} R={at_optimal_top_n.get('recall', 0):.3f} "
        f"@ N={at_optimal_top_n.get('optimal_n', 0)}"
    )

# Post-rerank per strategy
post_rerank = macro.get("post_rerank", {})
for strategy_name, strategy_data in post_rerank.items():
    at_threshold = strategy_data.get("at_optimal_threshold", {})
    at_top_n = strategy_data.get("at_optimal_top_n", {})
    print(f"\n  Post-rerank ({strategy_name}):")
    if at_threshold:
        print(
            f"    At optimal threshold: F1={at_threshold.get('f1', 0):.3f} "
            f"P={at_threshold.get('precision', 0):.3f} R={at_threshold.get('recall', 0):.3f} "
            f"@ threshold={at_threshold.get('optimal_threshold', 0):.4f}"
        )
    if at_top_n:
        print(
            f"    At optimal top-N:     F1={at_top_n.get('f1', 0):.3f} "
            f"P={at_top_n.get('precision', 0):.3f} R={at_top_n.get('recall', 0):.3f} "
            f"@ N={at_top_n.get('optimal_n', 0)}"
        )