In [130]:
# ===== CONFIG =====
DOMAIN = "money-heist"     # change only this

TOP_K = 20      # candidates returned from FAISS
TOP_FINAL = 5   # after reranker (top-N to show)
# ==================

from pathlib import Path
import json
import textwrap
import random
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity

pd.set_option("display.max_colwidth", 200)

import faiss

try:
    from sentence_transformers import SentenceTransformer, CrossEncoder
except ImportError:
    SentenceTransformer = None
    CrossEncoder = None
    print("⚠️ sentence-transformers not installed. Free-form queries or reranker may fail.")

In [131]:
REPO_ROOT = Path("..").resolve()

DATA_DIR = REPO_ROOT / "data"
EMB_DIR = DATA_DIR / "embeddings"
PROC_DIR = DATA_DIR / "processed"
INDEX_DIR = DATA_DIR / "indexes"
MODEL_DIR = REPO_ROOT / "models"

EMB_PATH   = EMB_DIR  / f"spans_{DOMAIN}.npy"
ID_PATH    = EMB_DIR  / f"spans_{DOMAIN}.index_ids.npy"
SPANS_PATH = PROC_DIR / f"spans_{DOMAIN}.csv"
FAISS_PATH = INDEX_DIR / f"faiss_flat_{DOMAIN}.index"   # adjust name if needed

RERANKER_ROOT = MODEL_DIR / "reranker" / DOMAIN   # e.g. models/reranker/money-heist

print("Embeddings:", EMB_PATH, EMB_PATH.exists())
print("IDs:",       ID_PATH, ID_PATH.exists())
print("Spans CSV:", SPANS_PATH, SPANS_PATH.exists())
print("FAISS:",     FAISS_PATH, FAISS_PATH.exists())
print("RERANKER_ROOT:", RERANKER_ROOT, RERANKER_ROOT.exists())

Embeddings: /data/sundeep/Fandom_SI/data/embeddings/spans_money-heist.npy True
IDs: /data/sundeep/Fandom_SI/data/embeddings/spans_money-heist.index_ids.npy True
Spans CSV: /data/sundeep/Fandom_SI/data/processed/spans_money-heist.csv True
FAISS: /data/sundeep/Fandom_SI/data/indexes/faiss_flat_money-heist.index True
RERANKER_ROOT: /data/sundeep/Fandom_SI/models/reranker/money-heist True


In [132]:
def find_reranker_dir(root: Path) -> Path:
    """
    Return the first subdirectory under root that looks like a HF model
    (has config.json). Prefer a folder named 'best' if it exists.
    """
    if not root.exists():
        raise ValueError(f"RERANKER_ROOT does not exist: {root}")

    # 1) Prefer '<root>/best'
    best_dir = root / "best"
    if (best_dir / "config.json").exists():
        print("Using reranker dir:", best_dir)
        return best_dir

    # 2) Otherwise search all subdirs for config.json
    candidates = []
    for sub in root.iterdir():
        if sub.is_dir() and (sub / "config.json").exists():
            candidates.append(sub)

    if not candidates:
        raise ValueError(f"No HF model (config.json) found under {root}")

    # Just take the first one for now
    print("Using reranker dir:", candidates[0])
    return candidates[0]


RERANKER_DIR = find_reranker_dir(RERANKER_ROOT)

Using reranker dir: /data/sundeep/Fandom_SI/models/reranker/money-heist/best


In [133]:
# embeddings + ids
embeddings = np.load(EMB_PATH)
index_ids  = np.load(ID_PATH, allow_pickle=True).astype(str)

print("Embeddings shape:", embeddings.shape)
print("Index_ids shape :", index_ids.shape)

# spans
df_spans = pd.read_csv(SPANS_PATH)
df_spans["span_id"] = df_spans["span_id"].astype(str)

print("Spans rows:", len(df_spans))
print("Spans columns:", df_spans.columns.tolist())

assert "text" in df_spans.columns, "Expected span text column 'text' in spans CSV."

# span_id -> row
spanid_to_row = {sid: i for i, sid in enumerate(df_spans["span_id"].tolist())}

missing = [sid for sid in index_ids if sid not in spanid_to_row]
print("IDs in index_ids but not in spans.csv:", len(missing))
print("Example missing:", missing[:5])

Embeddings shape: (3189, 384)
Index_ids shape : (3189,)
Spans rows: 3189
Spans columns: ['span_id', 'article_id', 'page_name', 'title', 'section', 'span_index', 'start_char', 'end_char', 'len_chars', 'num_sents', 'text', 'url', 'source_path']
IDs in index_ids but not in spans.csv: 0
Example missing: []


In [134]:
MODEL_INFO_PATH = EMB_DIR / f"model_info_{DOMAIN}.json"

if MODEL_INFO_PATH.exists():
    with MODEL_INFO_PATH.open("r", encoding="utf-8") as f:
        model_info = json.load(f)
    encoder_name = model_info.get("model_name", "sentence-transformers/all-MiniLM-L6-v2")
else:
    encoder_name = "sentence-transformers/all-MiniLM-L6-v2"

print("Query encoder model:", encoder_name)
encoder = SentenceTransformer(encoder_name)
print("Encoder loaded.")

Query encoder model: sentence-transformers/all-MiniLM-L6-v2
Encoder loaded.


In [135]:
index = faiss.read_index(str(FAISS_PATH))
print("FAISS index:", FAISS_PATH)
print("Index dimension:", index.d)

# make sure embeddings are float32 for FAISS
if embeddings.dtype != np.float32:
    embeddings = embeddings.astype("float32")

FAISS index: /data/sundeep/Fandom_SI/data/indexes/faiss_flat_money-heist.index
Index dimension: 384


In [136]:
def get_span_text(span_id: str):
    idx = spanid_to_row.get(span_id)
    if idx is None:
        return None
    return str(df_spans.iloc[idx]["text"])


def encode_query_text(text: str) -> np.ndarray:
    vec = encoder.encode([text], convert_to_numpy=True)
    if vec.dtype != np.float32:
        vec = vec.astype("float32")
    return vec


def faiss_search(query_vec: np.ndarray, top_k: int = TOP_K):
    if query_vec.ndim == 1:
        query_vec = query_vec[None, :]
    if query_vec.dtype != np.float32:
        query_vec = query_vec.astype("float32")

    D, I = index.search(query_vec, top_k)
    D, I = D[0], I[0]

    results = []
    for rank, (dist, idx) in enumerate(zip(D, I)):
        if idx < 0:
            continue
        span_id = index_ids[idx]
        results.append({
            "rank": rank,
            "faiss_score": float(dist),
            "array_idx": int(idx),
            "span_id": span_id,
            "text": get_span_text(span_id),
        })
    return results


def rerank_results(query_text: str, candidates):
    # candidate["text"] already has span text
    pairs = [(query_text, c["text"] or "") for c in candidates]
    scores = reranker.predict(pairs)

    for c, s in zip(candidates, scores):
        c["rerank_score"] = float(s)

    # sort by reranker score descending (higher = more relevant)
    return sorted(candidates, key=lambda x: x["rerank_score"], reverse=True)


def pretty_block(title: str):
    print("\n" + "="*30 + f" {title} " + "="*30 + "\n")


def show_results(results, score_key: str, top_n: int = TOP_FINAL):
    for r in results[:top_n]:
        print(f"[rank {r['rank']}] span_id={r['span_id']} {score_key}={r[score_key]:.4f}")
        print(textwrap.fill((r["text"] or "").replace("\n", " "), width=110))
        print()

In [137]:
def compare(query_text: str):
    pretty_block("QUERY")
    print(query_text)

    # bi-encoder → FAISS
    q_vec = encode_query_text(query_text)
    faiss_results = faiss_search(q_vec, top_k=TOP_K)

    pretty_block("BEFORE RERANKER (FAISS top-k)")
    show_results(faiss_results, score_key="faiss_score", top_n=TOP_FINAL)

    # rerank same candidates
    reranked = rerank_results(query_text, faiss_results)

    pretty_block("AFTER RERANKER (CrossEncoder top-k)")
    show_results(reranked, score_key="rerank_score", top_n=TOP_FINAL)

In [138]:
compare("What was the Professor's main plan in the first heist?")



What was the Professor's main plan in the first heist?


[rank 0] span_id=money-heist_span_0002949 faiss_score=0.6567
The Professor teaches the robbers everything they need to know about the heist in that room.

[rank 1] span_id=money-heist_span_0002904 faiss_score=0.6567
The Professor teaches the robbers everything they need to know about the heist in that room.

[rank 2] span_id=money-heist_span_0002850 faiss_score=0.6516
The Professor led a second heist in the Bank of Spain , in an effort to pressure the Spanish government who
had arrested Rio . This was a heist that he and his half-brother Berlin had planned years ago.

[rank 3] span_id=money-heist_span_0002507 faiss_score=0.6516
The Professor led a second heist in the Bank of Spain , in an effort to pressure the Spanish government who
had arrested Rio . This was a heist that he and his half-brother Berlin had planned years ago.

[rank 4] span_id=money-heist_span_0002478 faiss_score=0.6516
The Professor led a second heist in the 

In [139]:
def ndcg_at_k_for_ranking(ranked_ids, gold_ids, k=10):
    """
    Simple NDCG@k where gold_ids are relevant (binary relevance).
    """
    gold_set = set(gold_ids)
    topk = ranked_ids[:k]
    rels = np.array([1 if sid in gold_set else 0 for sid in topk], dtype=np.float32)

    if rels.sum() == 0:
        return 0.0

    # DCG
    dcg = np.sum(rels / np.log2(np.arange(len(rels), dtype=np.float32) + 2.0))
    # IDCG: best case is putting a single relevant at rank 1
    idcg = 1.0
    return float(dcg / idcg)


def evaluate_reranker_on_dev(max_queries=200, ks=(1, 5, 10)):
    """
    Evaluate retrieval quality using span_identifier_dev.jsonl

    Returns a DataFrame with:
        - Global MRR and Recall@k for FAISS vs RERANK
        - Global NDCG@10 for FAISS vs RERANK
        - Also prints coverage@TOP_K and conditional metrics given coverage
    """
    eval_subset = eval_queries[:max_queries]

    faiss_rr_sum = 0.0
    rerank_rr_sum = 0.0
    faiss_hits_sum = {k: 0.0 for k in ks}
    rerank_hits_sum = {k: 0.0 for k in ks}
    faiss_ndcg10_sum = 0.0
    rerank_ndcg10_sum = 0.0

    # For coverage and conditional metrics (only queries where FAISS retrieved the gold somewhere in TOP_K)
    covered = 0
    faiss_rr_covered = 0.0
    rerank_rr_covered = 0.0
    faiss_hits_covered = {k: 0.0 for k in ks}
    rerank_hits_covered = {k: 0.0 for k in ks}
    faiss_ndcg10_covered = 0.0
    rerank_ndcg10_covered = 0.0

    used = 0

    for q in eval_subset:
        qid = q["query_id"]
        text = q["text"]

        if qid not in gold_by_qid:
            continue

        gold_ids = gold_by_qid[qid]
        gold_set = set(gold_ids)

        # Encode query
        q_vec = encode_query_text(text)

        # --- FAISS-only ranking ---
        faiss_results = faiss_search(q_vec, top_k=TOP_K)
        faiss_ranked_ids = [r["span_id"] for r in faiss_results]

        # --- FAISS + Reranker ranking ---
        reranked = rerank_results(text, faiss_results)
        rerank_ranked_ids = [r["span_id"] for r in reranked]

        # Global MRR + Recall@k
        faiss_rr, faiss_hits = metrics_for_ranking(faiss_ranked_ids, gold_ids, ks)
        rerank_rr, rerank_hits = metrics_for_ranking(rerank_ranked_ids, gold_ids, ks)

        # Global NDCG@10
        faiss_ndcg10 = ndcg_at_k_for_ranking(faiss_ranked_ids, gold_ids, k=10)
        rerank_ndcg10 = ndcg_at_k_for_ranking(rerank_ranked_ids, gold_ids, k=10)

        faiss_rr_sum += faiss_rr
        rerank_rr_sum += rerank_rr
        faiss_ndcg10_sum += faiss_ndcg10
        rerank_ndcg10_sum += rerank_ndcg10

        for k in ks:
            faiss_hits_sum[k] += faiss_hits[k]
            rerank_hits_sum[k] += rerank_hits[k]

        # Coverage + conditional metrics (only if FAISS retrieved the gold somewhere in TOP_K)
        if any(sid in gold_set for sid in faiss_ranked_ids):
            covered += 1
            faiss_rr_covered += faiss_rr
            rerank_rr_covered += rerank_rr
            faiss_ndcg10_covered += faiss_ndcg10
            rerank_ndcg10_covered += rerank_ndcg10

            for k in ks:
                faiss_hits_covered[k] += faiss_hits[k]
                rerank_hits_covered[k] += rerank_hits[k]

        used += 1

    if used == 0:
        print("No queries with gold labels found.")
        return None

    # ---------- Global metrics ----------
    faiss_mrr = faiss_rr_sum / used
    rerank_mrr = rerank_rr_sum / used

    faiss_recall = {k: faiss_hits_sum[k] / used for k in ks}
    rerank_recall = {k: rerank_hits_sum[k] / used for k in ks}

    faiss_ndcg10_mean = faiss_ndcg10_sum / used
    rerank_ndcg10_mean = rerank_ndcg10_sum / used

    rows = []
    rows.append({
        "metric": "MRR@TOP_K (global)",
        "faiss": faiss_mrr,
        "rerank": rerank_mrr,
    })
    for k in ks:
        rows.append({
            "metric": f"Recall@{k} (global)",
            "faiss": faiss_recall[k],
            "rerank": rerank_recall[k],
        })
    rows.append({
        "metric": "NDCG@10 (global)",
        "faiss": faiss_ndcg10_mean,
        "rerank": rerank_ndcg10_mean,
    })

    # ---------- Coverage & conditional metrics ----------
    coverage = covered / used
    print(f"Evaluated on {used} queries.")
    print(f"FAISS coverage@TOP_K (gold present in top-{TOP_K}): {coverage:.4f}")

    if covered > 0:
        faiss_mrr_cond = faiss_rr_covered / covered
        rerank_mrr_cond = rerank_rr_covered / covered

        faiss_recall_cond = {k: faiss_hits_covered[k] / covered for k in ks}
        rerank_recall_cond = {k: rerank_hits_covered[k] / covered for k in ks}

        faiss_ndcg10_cond = faiss_ndcg10_covered / covered
        rerank_ndcg10_cond = rerank_ndcg10_covered / covered

        print("\nConditional metrics (only queries where FAISS retrieved the gold somewhere in TOP_K):")
        print(f"  FAISS  MRR@TOP_K: {faiss_mrr_cond:.4f}")
        print(f"  RERANK MRR@TOP_K: {rerank_mrr_cond:.4f}")
        for k in ks:
            print(f"  Recall@{k}: FAISS={faiss_recall_cond[k]:.4f}, RERANK={rerank_recall_cond[k]:.4f}")
        print(f"  NDCG@10: FAISS={faiss_ndcg10_cond:.4f}, RERANK={rerank_ndcg10_cond:.4f}")

        rows.append({
            "metric": "MRR@TOP_K (conditional)",
            "faiss": faiss_mrr_cond,
            "rerank": rerank_mrr_cond,
        })
        for k in ks:
            rows.append({
                "metric": f"Recall@{k} (conditional)",
                "faiss": faiss_recall_cond[k],
                "rerank": rerank_recall_cond[k],
            })
        rows.append({
            "metric": "NDCG@10 (conditional)",
            "faiss": faiss_ndcg10_cond,
            "rerank": rerank_ndcg10_cond,
        })

    df_metrics = pd.DataFrame(rows)
    return df_metrics

In [140]:
df_metrics = evaluate_reranker_on_dev(max_queries=200, ks=(1, 5, 10))
df_metrics

Evaluated on 159 queries.
FAISS coverage@TOP_K (gold present in top-20): 0.3082

Conditional metrics (only queries where FAISS retrieved the gold somewhere in TOP_K):
  FAISS  MRR@TOP_K: 0.4074
  RERANK MRR@TOP_K: 0.5683
  Recall@1: FAISS=0.2449, RERANK=0.4082
  Recall@5: FAISS=0.5918, RERANK=0.7755
  Recall@10: FAISS=0.8163, RERANK=0.9184
  NDCG@10: FAISS=0.4935, RERANK=0.6474


Unnamed: 0,metric,faiss,rerank
0,MRR@TOP_K (global),0.125554,0.17514
1,Recall@1 (global),0.075472,0.125786
2,Recall@5 (global),0.18239,0.238994
3,Recall@10 (global),0.251572,0.283019
4,NDCG@10 (global),0.152096,0.199502
5,MRR@TOP_K (conditional),0.407408,0.568312
6,Recall@1 (conditional),0.244898,0.408163
7,Recall@5 (conditional),0.591837,0.77551
8,Recall@10 (conditional),0.816327,0.918367
9,NDCG@10 (conditional),0.493537,0.647365
