# Phase 1 Reranking Comparison

Applies cross-encoder reranking on top of **Phase 1 data** (memories, database, test cases).
Uses Phase 1 prompts for query generation to ensure a clean apples-to-apples comparison.

This notebook does **not** extract its own memories — it reuses the Phase 1 pipeline output
and only adds the reranking step to measure its isolated impact on retrieval quality.

In [None]:
# Cell 1 — Setup & Imports
import os
from pathlib import Path

from memory_retrieval.search.reranker import Reranker
from memory_retrieval.search.vector import VectorBackend
from memory_retrieval.experiments.runner import run_experiment, run_all_experiments, ExperimentConfig
from memory_retrieval.memories.schema import FIELD_SITUATION, FIELD_DISTANCE, FIELD_RERANK_SCORE
from memory_retrieval.infra.io import load_json
from memory_retrieval.infra.runs import (
    create_run, get_latest_run, get_run, list_runs, update_run_status,
    PHASE1, PHASE2,
)

# Find project root by walking up to pyproject.toml
PROJECT_ROOT = Path.cwd()
while not (PROJECT_ROOT / "pyproject.toml").exists():
    if PROJECT_ROOT == PROJECT_ROOT.parent:
        raise RuntimeError("Could not find project root (pyproject.toml)")
    PROJECT_ROOT = PROJECT_ROOT.parent
os.chdir(PROJECT_ROOT)

# Verify API key
if not os.environ.get("OPENROUTER_API_KEY"):
    print("WARNING: OPENROUTER_API_KEY is not set. Experiments will fail.")
else:
    print("OPENROUTER_API_KEY is set.")

print(f"Project root: {PROJECT_ROOT}")
print("Imports OK.")

In [None]:
# Cell 2 — Configuration

PROMPT_VERSION = "2.0.0"
MODEL_EXPERIMENT = "anthropic/claude-sonnet-4.5"

# Reranking configuration
RERANK_TOP_N = 4          # Final results after reranking
SEARCH_LIMIT = 20         # Vector search candidates per query
DISTANCE_THRESHOLD = 1.1  # For pre-rerank metrics comparison

# Phase 1 run selection (source of DB and test cases)
# To see available runs: print(list_runs(PHASE1))
# To select specific run: PHASE1_RUN = get_run(PHASE1, "run_20260208_143022")
PHASE1_RUN = get_latest_run(PHASE1)

# Derived paths from Phase 1
DB_PATH = str(PHASE1_RUN / "memories" / "memories.db")
TEST_CASES_DIR = str(PHASE1_RUN / "test_cases")

# Initialize backends
vector_backend = VectorBackend()

print("Configuration:")
print(f"  Phase 1 run: {PHASE1_RUN.name}")
print(f"  Prompt version: {PROMPT_VERSION}")
print(f"  Model (experiment): {MODEL_EXPERIMENT}")
print(f"  Rerank top-n: {RERANK_TOP_N}")
print(f"  Search limit: {SEARCH_LIMIT}")
print(f"  Distance threshold: {DISTANCE_THRESHOLD}")
print(f"  DB path: {DB_PATH}")
print(f"  Test cases dir: {TEST_CASES_DIR}")
print(f"  Memory count: {vector_backend.get_memory_count(DB_PATH)}")

## Step 1 — Test Reranker

Verify the cross-encoder reranker works correctly before running full experiments.
Loads `bge-reranker-v2-m3` and tests it on a sample query.

In [None]:
# Cell 4 — Test Reranker on Sample Query
reranker = Reranker()

sample_query = "error handling in async functions"
print(f"Sample query: \"{sample_query}\"\n")

# Vector search
results = vector_backend.search(DB_PATH, sample_query, limit=SEARCH_LIMIT)
print(f"Vector search returned {len(results)} candidates\n")

# Convert SearchResult objects to dicts for reranker
candidates = [
    {
        "id": r.id,
        FIELD_SITUATION: r.situation,
        "lesson": r.lesson,
        FIELD_DISTANCE: r.raw_score,
    }
    for r in results
]

# Show top-5 by distance
print("--- Top 5 by Vector Distance ---")
for i, c in enumerate(candidates[:5], 1):
    print(f"  [{i}] dist={c[FIELD_DISTANCE]:.4f} | {c['id']} | {c[FIELD_SITUATION][:80]}...")

# Rerank
reranked = reranker.rerank(sample_query, candidates, top_n=RERANK_TOP_N)

print(f"\n--- Top {RERANK_TOP_N} after Reranking ---")
for i, r in enumerate(reranked, 1):
    print(f"  [{i}] rerank={r[FIELD_RERANK_SCORE]:.4f} | dist={r[FIELD_DISTANCE]:.4f} | {r['id']}")
    print(f"      {r[FIELD_SITUATION][:100]}")
    print()

## Step 2 — Run All Experiments

For each test case:
1. Generate search queries from PR context via LLM
2. Vector search for each query (top-20)
3. Pool and deduplicate results
4. Rerank with cross-encoder
5. Take top-N results
6. Compute metrics (before and after reranking)

In [None]:
# Cell 6 — Run All Experiments

# Create Phase 2 run for results
run_id, PHASE2_RUN = create_run(
    PHASE2,
    description=f"Reranking experiment (phase1: {PHASE1_RUN.name}, top-{RERANK_TOP_N})",
)
RESULTS_DIR = str(PHASE2_RUN / "results")

# Store config in run metadata
update_run_status(PHASE2_RUN, "config", {
    "phase1_run_id": PHASE1_RUN.name,
    "reranker_model": reranker.model_name,
    "rerank_top_n": RERANK_TOP_N,
    "search_limit": SEARCH_LIMIT,
})

print(f"Phase 2 run: {run_id}")
print(f"Results dir: {RESULTS_DIR}\n")

config = ExperimentConfig(
    search_backend=vector_backend,
    prompts_dir="data/prompts/phase1",
    prompt_version=PROMPT_VERSION,
    model=MODEL_EXPERIMENT,
    search_limit=SEARCH_LIMIT,
    distance_threshold=DISTANCE_THRESHOLD,
    reranker=reranker,
    rerank_top_n=RERANK_TOP_N,
)
all_results = run_all_experiments(
    test_cases_dir=TEST_CASES_DIR,
    db_path=DB_PATH,
    results_dir=RESULTS_DIR,
    config=config,
)

# Update run status
successful = [r for r in all_results if "post_rerank_metrics" in r]
avg_f1 = sum(r["post_rerank_metrics"]["f1"] for r in successful) / len(successful) if successful else 0
update_run_status(PHASE2_RUN, "experiment", {
    "count": len(successful),
    "failed": len(all_results) - len(successful),
    "avg_f1_post_rerank": round(avg_f1, 4),
    "rerank_top_n": RERANK_TOP_N,
    "prompt_version": PROMPT_VERSION,
})

## Step 3 — Results Summary Table

Per-test-case comparison of pre-rerank vs post-rerank metrics.

In [None]:
# Cell 8 — Results Summary Table

successful = [r for r in all_results if "post_rerank_metrics" in r]

print(f"{'Test Case':<30} {'Pre-F1':>8} {'Post-F1':>9} {'Delta':>8} {'Pre-P':>7} {'Post-P':>8} {'Pre-R':>7} {'Post-R':>8}")
print("-" * 95)

for r in successful:
    name = r.get("test_case_id", "?")[:30]
    pre = r["pre_rerank_metrics"]
    post = r["post_rerank_metrics"]
    delta = post["f1"] - pre["f1"]
    marker = "↑" if delta > 0 else "↓" if delta < 0 else "="
    print(f"{name:<30} {pre['f1']:>8.3f} {post['f1']:>9.3f} {delta:>+7.3f}{marker} {pre['precision']:>7.3f} {post['precision']:>8.3f} {pre['recall']:>7.3f} {post['recall']:>8.3f}")

if successful:
    avg_pre_f1 = sum(r["pre_rerank_metrics"]["f1"] for r in successful) / len(successful)
    avg_post_f1 = sum(r["post_rerank_metrics"]["f1"] for r in successful) / len(successful)
    avg_pre_p = sum(r["pre_rerank_metrics"]["precision"] for r in successful) / len(successful)
    avg_post_p = sum(r["post_rerank_metrics"]["precision"] for r in successful) / len(successful)
    avg_pre_r = sum(r["pre_rerank_metrics"]["recall"] for r in successful) / len(successful)
    avg_post_r = sum(r["post_rerank_metrics"]["recall"] for r in successful) / len(successful)
    delta = avg_post_f1 - avg_pre_f1
    marker = "↑" if delta > 0 else "↓" if delta < 0 else "="
    print("-" * 95)
    print(f"{'AVERAGE':<30} {avg_pre_f1:>8.3f} {avg_post_f1:>9.3f} {delta:>+7.3f}{marker} {avg_pre_p:>7.3f} {avg_post_p:>8.3f} {avg_pre_r:>7.3f} {avg_post_r:>8.3f}")
    print(f"\nTarget F1 > 0.75: {'✅ ACHIEVED' if avg_post_f1 > 0.75 else '❌ NOT YET'}")

## Step 4 — Rerank Top-N Sweep

Analyze how different top-N values after reranking affect metrics.
Uses the same experiment data but varies the cutoff point.

In [None]:
# Cell 10 — Rerank Top-N Sweep Analysis
import matplotlib.pyplot as plt
import numpy as np

FIGURES_DIR = Path("notebooks/phase2/figures")
FIGURES_DIR.mkdir(parents=True, exist_ok=True)

# Load result files
results_path = Path(RESULTS_DIR)
result_files = sorted(results_path.glob("*.json"))

if not result_files:
    print("No result files found. Run experiments first.")
else:
    all_data = [load_json(str(f)) for f in result_files]
    n_cases = len(all_data)

    # For each result, we have the full reranked_results list
    # We can simulate different top-N by truncating
    max_n = 20
    n_values = list(range(1, max_n + 1))
    
    sweep_p, sweep_r, sweep_f1 = [], [], []
    
    for n in n_values:
        n_p, n_r, n_f1 = [], [], []
        for data in all_data:
            gt_ids = set(data.get("ground_truth", {}).get("memory_ids", []))
            gt_count = len(gt_ids)
            
            # Get all reranked results (already sorted by score)
            reranked = data.get("reranked_results", [])
            
            # If we need more than available reranked, we need the full candidate pool
            # For now, we can only sweep up to len(reranked_results)
            top_n_ids = {r["id"] for r in reranked[:n]}
            hits = len(top_n_ids & gt_ids)
            actual_n = len(top_n_ids)
            
            p = hits / actual_n if actual_n > 0 else 0.0
            r = hits / gt_count if gt_count > 0 else 0.0
            f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0.0
            
            n_p.append(p)
            n_r.append(r)
            n_f1.append(f1)
        
        sweep_p.append(np.mean(n_p))
        sweep_r.append(np.mean(n_r))
        sweep_f1.append(np.mean(n_f1))
    
    # Print table
    print(f"Rerank Top-N Sweep (averaged over {n_cases} test cases)\n")
    print(f"{'N':>4} {'Precision':>10} {'Recall':>8} {'F1':>8}")
    print("-" * 34)
    for n in [1, 2, 3, 4, 5, 6, 8, 10]:
        if n <= max_n:
            i = n - 1
            print(f"{n:>4} {sweep_p[i]:>10.3f} {sweep_r[i]:>8.3f} {sweep_f1[i]:>8.3f}")
    
    # Plot
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.plot(n_values, sweep_p, label="Precision", color="#3498db", linewidth=2, marker="o", markersize=4)
    ax.plot(n_values, sweep_r, label="Recall", color="#2ecc71", linewidth=2, marker="s", markersize=4)
    ax.plot(n_values, sweep_f1, label="F1", color="#9b59b6", linewidth=2, marker="^", markersize=4)
    ax.axvline(x=RERANK_TOP_N, color="gray", linestyle="--", alpha=0.5, label=f"Default N={RERANK_TOP_N}")
    ax.set_xlabel("Top-N (results kept after reranking)")
    ax.set_ylabel("Score")
    ax.set_title(f"Precision / Recall / F1 vs Rerank Top-N (avg over {n_cases} test cases)")
    ax.set_xticks(n_values)
    ax.set_ylim(0, 1.05)
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    fig.savefig(FIGURES_DIR / "rerank_topn_sweep.png", dpi=200, bbox_inches="tight")
    plt.show()
    print(f"Saved: {FIGURES_DIR / 'rerank_topn_sweep.png'}")

## Step 5 — Rerank Score Distribution

Analyze how rerank scores separate ground truth from non-ground-truth results.
This helps determine if the reranker provides a cleaner separation than vector distance.

In [None]:
# Cell 12 — Rerank Score Distribution
import matplotlib.pyplot as plt
import numpy as np

if not result_files:
    print("No result files found.")
else:
    gt_scores = []
    non_gt_scores = []
    gt_distances = []
    non_gt_distances = []
    
    for data in all_data:
        for r in data.get("reranked_results", []):
            if r.get("is_ground_truth"):
                gt_scores.append(r["rerank_score"])
                gt_distances.append(r["distance"])
            else:
                non_gt_scores.append(r["rerank_score"])
                non_gt_distances.append(r["distance"])
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Rerank score distribution
    ax = axes[0]
    if gt_scores:
        ax.hist(gt_scores, bins=20, alpha=0.7, label=f"Ground Truth (n={len(gt_scores)})", color="#2ecc71")
    if non_gt_scores:
        ax.hist(non_gt_scores, bins=20, alpha=0.7, label=f"Non-GT (n={len(non_gt_scores)})", color="#e74c3c")
    ax.set_xlabel("Rerank Score")
    ax.set_ylabel("Count")
    ax.set_title("Rerank Score Distribution")
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Rerank score vs distance scatter
    ax = axes[1]
    if gt_scores:
        ax.scatter(gt_distances, gt_scores, alpha=0.7, label="Ground Truth", color="#2ecc71", s=60, zorder=3)
    if non_gt_scores:
        ax.scatter(non_gt_distances, non_gt_scores, alpha=0.5, label="Non-GT", color="#e74c3c", s=40, zorder=2)
    ax.set_xlabel("Vector Distance")
    ax.set_ylabel("Rerank Score")
    ax.set_title("Rerank Score vs Vector Distance")
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    fig.savefig(FIGURES_DIR / "rerank_score_distribution.png", dpi=200, bbox_inches="tight")
    plt.show()
    
    # Statistics
    if gt_scores:
        print(f"GT rerank scores:     min={min(gt_scores):.4f}, max={max(gt_scores):.4f}, mean={np.mean(gt_scores):.4f}")
    if non_gt_scores:
        print(f"Non-GT rerank scores: min={min(non_gt_scores):.4f}, max={max(non_gt_scores):.4f}, mean={np.mean(non_gt_scores):.4f}")
    if gt_scores and non_gt_scores:
        separation = np.mean(gt_scores) - np.mean(non_gt_scores)
        print(f"Mean separation: {separation:.4f} ({'good' if separation > 1.0 else 'moderate' if separation > 0.5 else 'weak'})")

## Step 6 — Phase 1 vs Phase 2 Comparison

Side-by-side comparison with Phase 1 results (if available).

In [None]:
# Cell 14 — Phase 1 vs Phase 2 Comparison

# Load Phase 1 results
phase1_results_dir = PHASE1_RUN / "results"
phase1_result_files = sorted(phase1_results_dir.glob("*.json"))

if not phase1_result_files:
    print("No Phase 1 results found for comparison.")
else:
    phase1_data = [load_json(str(f)) for f in phase1_result_files]
    
    # Build lookup by test_case_id
    p1_by_tc = {}
    for d in phase1_data:
        tc_id = d.get("test_case_id", "?")
        if tc_id not in p1_by_tc or d.get("experiment_id", "") > p1_by_tc[tc_id].get("experiment_id", ""):
            p1_by_tc[tc_id] = d
    
    p2_by_tc = {}
    for d in all_data:
        tc_id = d.get("test_case_id", "?")
        p2_by_tc[tc_id] = d
    
    # Compare
    common_tcs = sorted(set(p1_by_tc.keys()) & set(p2_by_tc.keys()))
    
    print(f"Comparing {len(common_tcs)} test cases\n")
    print(f"{'Test Case':<25} {'P1 F1':>7} {'P2 F1':>7} {'Delta':>8} {'P1 Prec':>8} {'P2 Prec':>8} {'P1 Rec':>7} {'P2 Rec':>7}")
    print("-" * 85)
    
    p1_f1s, p2_f1s = [], []
    for tc_id in common_tcs:
        p1 = p1_by_tc[tc_id].get("metrics", {})
        p2 = p2_by_tc[tc_id].get("post_rerank_metrics", {})
        
        p1_f1 = p1.get("f1", 0)
        p2_f1 = p2.get("f1", 0)
        delta = p2_f1 - p1_f1
        marker = "↑" if delta > 0.01 else "↓" if delta < -0.01 else "="
        
        p1_f1s.append(p1_f1)
        p2_f1s.append(p2_f1)
        
        print(f"{tc_id[:25]:<25} {p1_f1:>7.3f} {p2_f1:>7.3f} {delta:>+7.3f}{marker} {p1.get('precision', 0):>8.3f} {p2.get('precision', 0):>8.3f} {p1.get('recall', 0):>7.3f} {p2.get('recall', 0):>7.3f}")
    
    if p1_f1s and p2_f1s:
        avg_p1 = np.mean(p1_f1s)
        avg_p2 = np.mean(p2_f1s)
        delta = avg_p2 - avg_p1
        print("-" * 85)
        print(f"{'AVERAGE':<25} {avg_p1:>7.3f} {avg_p2:>7.3f} {delta:>+7.3f}{'↑' if delta > 0 else '↓'}")
        print(f"\nPhase 1 avg F1: {avg_p1:.3f}")
        print(f"Phase 2 avg F1: {avg_p2:.3f}")
        print(f"Improvement: {delta:+.3f} ({delta/avg_p1*100:+.1f}%)")
        print(f"\nTarget F1 > 0.75: {'✅ ACHIEVED' if avg_p2 > 0.75 else '❌ NOT YET'}")

## Step 7 — Worst Cases Deep Dive

Focus on review_2 and review_5 which were problematic in Phase 1.

In [None]:
# Cell 16 — Worst Cases Analysis

worst_cases = ["tc_review_2", "tc_review_5"]

for tc_id in worst_cases:
    data = p2_by_tc.get(tc_id)
    if not data:
        print(f"\n{tc_id}: not found in results")
        continue
    
    gt_ids = set(data.get("ground_truth", {}).get("memory_ids", []))
    pre = data.get("pre_rerank_metrics", {})
    post = data.get("post_rerank_metrics", {})
    
    print(f"\n{'='*60}")
    print(f"{tc_id} — Ground truth: {len(gt_ids)} memories")
    print(f"{'='*60}")
    print(f"Pre-rerank:  F1={pre.get('f1', 0):.3f}  P={pre.get('precision', 0):.3f}  R={pre.get('recall', 0):.3f}")
    print(f"Post-rerank: F1={post.get('f1', 0):.3f}  P={post.get('precision', 0):.3f}  R={post.get('recall', 0):.3f}")
    
    print(f"\nReranked results (top-{RERANK_TOP_N}):")
    for i, r in enumerate(data.get("reranked_results", []), 1):
        gt_marker = "✅ GT" if r.get("is_ground_truth") else "  --"
        print(f"  [{i}] {gt_marker} | rerank={r['rerank_score']:.4f} dist={r['distance']:.4f} | {r['id']}")
        print(f"       {r.get('situation', '')[:100]}")
    
    missed = data.get("missed_ground_truth_ids", [])
    if missed:
        print(f"\nMissed ground truth ({len(missed)}):")
        for mid in missed:
            print(f"  - {mid}")