In [None]:
# pip install -U sentence-transformers faiss-cpu datasets tqdm transformers accelerate pyarrow pandas

import os, time, math, sys
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import faiss
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

# ----------------------------
# Config
# ----------------------------
PARQUET_PATH = "exp1_rev.parquet"          # your corpus parquet
FAISS_INDEX_PATH = "exp1_300m_desc_index_g.faiss"        # your saved FAISS index
QUERIES_CSV = "final_benchmark.csv"        # csv with a 'query' or 'prompt' column
OUTPUT_CSV = "exp1_300m_rerank600m_desc_results.csv"

MODEL_NAME = "google/embeddinggemma-300m"  # encoder model you used for the index
BATCH_SIZE = 128
EMB_DIM = 768
USE_FLOAT16_DISK = True

# HNSW params (info only; already baked into your saved index)
HNSW_M = 32
HNSW_EF_CONSTRUCTION = 200
HNSW_EF_SEARCH = 128

K = 100              # retrieve top-K from FAISS
RERANK_TOP = 3       # keep only top-3 after reranking
RERANK_BS = 8       # reranker batch size
RERANKER_NAME = "Qwen/Qwen3-Reranker-0.6B"

# If your parquet has a different text field name, list options here (first match wins)
DOC_FIELD_CANDIDATES = ["combined", "text", "document", "content", "passage", "body"]

# ----------------------------
# Utils
# ----------------------------
def pick_doc_field(hfds):
    cols = set(hfds.column_names)
    for c in DOC_FIELD_CANDIDATES:
        if c in cols:
            return c
    raise ValueError(f"None of {DOC_FIELD_CANDIDATES} found in dataset columns: {sorted(cols)}")

def load_queries(path):
    df = pd.read_csv(path)
    if "query" in df.columns:
        return df["query"].astype(str).tolist(), "query"
    if "prompt" in df.columns:
        return df["prompt"].astype(str).tolist(), "prompt"
    raise ValueError(f"{path} must contain a 'query' or 'prompt' column.")

def ensure_float32(x):
    return x.astype("float32") if x.dtype != np.float32 else x

# ----------------------------
# 1) Load corpus + queries + index + encoder
# ----------------------------
print("[1/5] Loading corpus parquet…")
dataset = load_dataset("parquet", data_files=PARQUET_PATH)["train"]
DOC_FIELD = pick_doc_field(dataset)
print(f"     Using DOC_FIELD='{DOC_FIELD}'")

print("[2/5] Loading queries…")
query_texts, query_col = load_queries(QUERIES_CSV)
print(f"     Num queries: {len(query_texts):,}")

print("[3/5] Loading FAISS index…")
index = faiss.read_index(FAISS_INDEX_PATH)
print(f"     [OK] Loaded FAISS index with ntotal={index.ntotal:,}")

print("[4/5] Loading encoder model for queries…")
device_enc = torch.device("cuda" if torch.cuda.is_available() else "cpu")
enc_model = SentenceTransformer(
    MODEL_NAME,
    trust_remote_code=True,
    model_kwargs={"torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32},
)
# Note: SentenceTransformer handles device internally; no .to(device) needed.

# ----------------------------
# 2) Encode queries and search FAISS
# ----------------------------
print("[5/5] Encoding & retrieving…")
t0 = time.time()
q_emb = enc_model.encode_query(
    query_texts,
    batch_size=BATCH_SIZE,
    convert_to_numpy=True,
    normalize_embeddings=True,
    show_progress_bar=True,
)
q_emb = ensure_float32(np.array(q_emb))
scores, ids = index.search(q_emb, k=K)
elapsed_all = time.time() - t0
avg_latency = elapsed_all / max(1, len(query_texts))
print(f"     Retrieval time: {elapsed_all:.3f}s total | ~{avg_latency:.4f}s/query")
print("[*] Fetching top-K documents for all queries…")
unique_ids = sorted(set(int(x) for row in ids for x in row if x >= 0))
id_to_pos = {cid: i for i, cid in enumerate(unique_ids)}
subset = dataset.select(unique_ids)  # keeps order of unique_ids
subset_texts = subset[DOC_FIELD]

def get_doc(doc_id: int) -> str:
    return subset_texts[id_to_pos[int(doc_id)]]

# ----------------------------
# 4) Reranker (Qwen/Qwen3-Reranker-0.6B)
# ----------------------------
print("[*] Loading reranker…")
device_rank = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ranker_tokenizer = AutoTokenizer.from_pretrained(RERANKER_NAME, padding_side='left')
ranker_model = AutoModelForCausalLM.from_pretrained(
    RERANKER_NAME,
    # Uncomment if you have Flash-Attn 2 installed and want extra speed:
    # torch_dtype=torch.float16,
    # attn_implementation="flash_attention_2",
).to(device_rank).eval()

token_false_id = ranker_tokenizer.convert_tokens_to_ids("no")
token_true_id  = ranker_tokenizer.convert_tokens_to_ids("yes")
if token_false_id == ranker_tokenizer.unk_token_id or token_true_id == ranker_tokenizer.unk_token_id:
    raise RuntimeError("Reranker tokenizer is missing 'yes'/'no' tokens.")

max_length = 8192
prefix = (
    "<|im_start|>system\n"
    "Judge whether the Document meets the requirements based on the Query and the Instruct provided. "
    "Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
)
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
prefix_tokens = ranker_tokenizer.encode(prefix, add_special_tokens=False)
suffix_tokens = ranker_tokenizer.encode(suffix, add_special_tokens=False)

def format_instruction(instruction, query, doc):
    if instruction is None:
        instruction = 'Given a web search query, retrieve relevant passages that answer the query'
    return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"

def _process_inputs(pairs):
    inputs = ranker_tokenizer(
        pairs,
        padding=False,
        truncation='longest_first',
        return_attention_mask=True,
        max_length=max_length - len(prefix_tokens) - len(suffix_tokens),
        add_special_tokens=True,
    )
    # wrap with prefix/suffix then pad (left-padding already set)
    for i, ids_list in enumerate(inputs['input_ids']):
        inputs['input_ids'][i] = prefix_tokens + ids_list + suffix_tokens
    inputs = ranker_tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length)
    for k in inputs:
        inputs[k] = inputs[k].to(device_rank)
    return inputs

@torch.inference_mode()
def _compute_yes_scores(inputs):
    logits = ranker_model(**inputs).logits[:, -1, :]     # last position
    yes = logits[:, token_true_id]
    no  = logits[:, token_false_id]
    two = torch.stack([no, yes], dim=1)
    return torch.nn.functional.log_softmax(two, dim=1)[:, 1].exp().tolist()

def rerank_query(query: str, docs: list[str], task: str = None, top_n: int = 3) -> tuple[list[str], list[float], list[int]]:
    """Return top_n docs, scores, and original indices (highest score first)."""
    all_scores = []
    for start in range(0, len(docs), RERANK_BS):
        chunk = docs[start:start+RERANK_BS]
        pairs = [format_instruction(task, query, d) for d in chunk]
        inputs = _process_inputs(pairs)
        scores = _compute_yes_scores(inputs)
        all_scores.extend(scores)
    order = np.argsort(all_scores)[::-1][:top_n]
    return [docs[i] for i in order], [all_scores[i] for i in order], order.tolist()

# ----------------------------
# 5) Build results with top-3 per query
# ----------------------------

print("[*] Reranking to top-3 per query…")
task = "Given a user's query, retrieve relevant passages that answer the all the query's requirements"
rows = []

rerank_latencies = []   # per-query rerank latency (seconds)

for i, q in enumerate(tqdm(query_texts, desc="Reranking")):
    # gather FAISS top-K docs for query i
    q_ids = [int(d) for d in ids[i] if d >= 0]
    q_docs = [get_doc(d) for d in q_ids]

    # time the rerank for this query
    t0 = time.perf_counter()
    top_docs, top_scores, order = rerank_query(q, q_docs, task=task, top_n=RERANK_TOP)
    dt = time.perf_counter() - t0
    rerank_latencies.append(dt)

    # map back to original corpus ids
    top_doc_ids = [q_ids[j] for j in order]

    # human-readable string
    rec_text = "I have 3 recommendations: " + " ".join([f"{j+1}. {top_docs[j]}" for j in range(len(top_docs))])

    rows.append({
        "prompt": q,
        "recommendations": rec_text,
        "doc_ids_top3": "|".join(str(x) for x in top_doc_ids),
        "scores_top3": "|".join(f"{s:.4f}" for s in top_scores),
    })

# aggregate rerank timing
rerank_total = float(np.sum(rerank_latencies))
rerank_avg   = rerank_total / max(1, len(rerank_latencies))
rerank_p50   = float(np.percentile(rerank_latencies, 50))
rerank_p99   = float(np.percentile(rerank_latencies, 99))

print(f"     Rerank time: {rerank_total:.3f}s total | ~{rerank_avg:.4f}s/query | "
      f"P50={rerank_p50:.4f}s | P99={rerank_p99:.4f}s")

# ----------------------------
# 6) Save CSV (+ P50/P99)
# ----------------------------
results_df = pd.DataFrame(rows)

# If you have per-query retrieval latencies collected as `retrieval_latencies`,
# you can compute their P50/P99 similarly. If not, we’ll keep using your avg_latency.
results_df["retrieval_latency_sec_per_query"] = avg_latency
results_df["rerank_latency_sec_per_query"]    = rerank_avg
results_df["rerank_latency_p50_sec"]          = rerank_p50
results_df["rerank_latency_p99_sec"]          = rerank_p99

print("\nSample rows:")
print(results_df.head(3))

results_df.to_csv(OUTPUT_CSV, index=False)
print(f"\n[OK] Saved -> {OUTPUT_CSV}")



In [None]:
PARQUET_PATH = "exp1_rev.parquet"          # your corpus parquet
FAISS_INDEX_PATH = "300m_rev.faiss"        # your saved FAISS index
QUERIES_CSV = "final_benchmark.csv"        # csv with a 'query' or 'prompt' column
OUTPUT_CSV = "exp1_300m_rerank4B_rev_results.csv"

MODEL_NAME = "google/embeddinggemma-300m"  # encoder model you used for the index
BATCH_SIZE = 128
EMB_DIM = 768
USE_FLOAT16_DISK = True

# HNSW params (info only; already baked into your saved index)
HNSW_M = 32
HNSW_EF_CONSTRUCTION = 200
HNSW_EF_SEARCH = 128

K = 100              # retrieve top-K from FAISS
RERANK_TOP = 3       # keep only top-3 after reranking
RERANK_BS = 4       # reranker batch size
RERANKER_NAME = "Qwen/Qwen3-Reranker-4B"


print("[*] Loading reranker…")
device_rank = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ranker_tokenizer = AutoTokenizer.from_pretrained(RERANKER_NAME, padding_side='left')
ranker_model = AutoModelForCausalLM.from_pretrained(
    RERANKER_NAME,
    # Uncomment if you have Flash-Attn 2 installed and want extra speed:
    # torch_dtype=torch.float16,
    # attn_implementation="flash_attention_2",
).to(device_rank).eval()

token_false_id = ranker_tokenizer.convert_tokens_to_ids("no")
token_true_id  = ranker_tokenizer.convert_tokens_to_ids("yes")
if token_false_id == ranker_tokenizer.unk_token_id or token_true_id == ranker_tokenizer.unk_token_id:
    raise RuntimeError("Reranker tokenizer is missing 'yes'/'no' tokens.")

max_length = 8192
prefix = (
    "<|im_start|>system\n"
    "Judge whether the Document meets the requirements based on the Query and the Instruct provided. "
    "Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
)
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
prefix_tokens = ranker_tokenizer.encode(prefix, add_special_tokens=False)
suffix_tokens = ranker_tokenizer.encode(suffix, add_special_tokens=False)

def format_instruction(instruction, query, doc):
    if instruction is None:
        instruction = 'Given a web search query, retrieve relevant passages that answer the query'
    return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"

def _process_inputs(pairs):
    inputs = ranker_tokenizer(
        pairs,
        padding=False,
        truncation='longest_first',
        return_attention_mask=True,
        max_length=max_length - len(prefix_tokens) - len(suffix_tokens),
        add_special_tokens=True,
    )
    # wrap with prefix/suffix then pad (left-padding already set)
    for i, ids_list in enumerate(inputs['input_ids']):
        inputs['input_ids'][i] = prefix_tokens + ids_list + suffix_tokens
    inputs = ranker_tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length)
    for k in inputs:
        inputs[k] = inputs[k].to(device_rank)
    return inputs

@torch.inference_mode()
def _compute_yes_scores(inputs):
    logits = ranker_model(**inputs).logits[:, -1, :]     # last position
    yes = logits[:, token_true_id]
    no  = logits[:, token_false_id]
    two = torch.stack([no, yes], dim=1)
    return torch.nn.functional.log_softmax(two, dim=1)[:, 1].exp().tolist()

def rerank_query(query: str, docs: list[str], task: str = None, top_n: int = 3) -> tuple[list[str], list[float], list[int]]:
    """Return top_n docs, scores, and original indices (highest score first)."""
    all_scores = []
    for start in range(0, len(docs), RERANK_BS):
        chunk = docs[start:start+RERANK_BS]
        pairs = [format_instruction(task, query, d) for d in chunk]
        inputs = _process_inputs(pairs)
        scores = _compute_yes_scores(inputs)
        all_scores.extend(scores)
    order = np.argsort(all_scores)[::-1][:top_n]
    return [docs[i] for i in order], [all_scores[i] for i in order], order.tolist()

# ----------------------------
# 5) Build results with top-3 per query
# ----------------------------

print("[*] Reranking to top-3 per query…")
task = "Given a user's query, retrieve relevant passages that answer the all the query's requirements"
rows = []

rerank_latencies = []   # per-query rerank latency (seconds)

for i, q in enumerate(tqdm(query_texts, desc="Reranking")):
    # gather FAISS top-K docs for query i
    q_ids = [int(d) for d in ids[i] if d >= 0]
    q_docs = [get_doc(d) for d in q_ids]

    # time the rerank for this query
    t0 = time.perf_counter()
    top_docs, top_scores, order = rerank_query(q, q_docs, task=task, top_n=RERANK_TOP)
    dt = time.perf_counter() - t0
    rerank_latencies.append(dt)

    # map back to original corpus ids
    top_doc_ids = [q_ids[j] for j in order]

    # human-readable string
    rec_text = "I have 3 recommendations: " + " ".join([f"{j+1}. {top_docs[j]}" for j in range(len(top_docs))])

    rows.append({
        "prompt": q,
        "recommendations": rec_text,
        "doc_ids_top3": "|".join(str(x) for x in top_doc_ids),
        "scores_top3": "|".join(f"{s:.4f}" for s in top_scores),
    })

# aggregate rerank timing
rerank_total = float(np.sum(rerank_latencies))
rerank_avg   = rerank_total / max(1, len(rerank_latencies))
rerank_p50   = float(np.percentile(rerank_latencies, 50))
rerank_p99   = float(np.percentile(rerank_latencies, 99))

print(f"     Rerank time: {rerank_total:.3f}s total | ~{rerank_avg:.4f}s/query | "
      f"P50={rerank_p50:.4f}s | P99={rerank_p99:.4f}s")

# ----------------------------
# 6) Save CSV (+ P50/P99)
# ----------------------------
results_df = pd.DataFrame(rows)

# If you have per-query retrieval latencies collected as `retrieval_latencies`,
# you can compute their P50/P99 similarly. If not, we’ll keep using your avg_latency.
results_df["retrieval_latency_sec_per_query"] = avg_latency
results_df["rerank_latency_sec_per_query"]    = rerank_avg
results_df["rerank_latency_p50_sec"]          = rerank_p50
results_df["rerank_latency_p99_sec"]          = rerank_p99

print("\nSample rows:")
print(results_df.head(3))

results_df.to_csv(OUTPUT_CSV, index=False)
print(f"\n[OK] Saved -> {OUTPUT_CSV}")

In [None]:
# pip install -U sentence-transformers faiss-cpu datasets tqdm transformers accelerate pyarrow pandas

import os, time, math, sys
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import faiss
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer

# ----------------------------
# 0) Load reviews corpus from Parquet (HF Dataset)
# ----------------------------
PARQUET_PATH = "exp1_rev.parquet"          # <-- your reviews parquet
FAISS_INDEX_PATH = "300m_rev.faiss"        # <-- FAISS built over the SAME row order as this parquet
QUERIES_CSV      = "final_benchmark.csv"   # csv with 'query' or 'prompt'
OUTPUT_CSV       = "exp1_300m_rerank4b_reviews_results.csv"

MODEL_NAME    = "google/embeddinggemma-300m"
BATCH_SIZE    = 128
EMB_DIM       = 768
K             = 100
RERANK_TOP    = 3
RERANK_BS     = 8
RERANKER_NAME = "Qwen/Qwen3-Reranker-4B"

# If your parquet has a different text field name, list options here (first match wins)
DOC_FIELD_CANDIDATES = [
    "combined", "text", "document", "content", "passage", "body",
    # common review fields
    "review_text", "review", "review_body", "reviewContent"
]

def load_reviews_dataset(path: str):
    ds = load_dataset("parquet", data_files=path)["train"]
    return ds

def pick_doc_field(hfds) -> str:
    cols = set(hfds.column_names)
    for c in DOC_FIELD_CANDIDATES:
        if c in cols:
            return c
    # Fallback: auto-compose a combined text from common review columns if present
    review_parts = [c for c in [
        "review_title", "review_positive", "review_negative", "text",
        "review_text", "review_body", "pros", "cons"
    ] if c in cols]
    if review_parts:
        return "__compose_reviews__"
    raise ValueError(f"None of {DOC_FIELD_CANDIDATES} found and no typical review fields "
                     f"in dataset columns: {sorted(cols)}")

def _is_nan(x):
    return isinstance(x, float) and math.isnan(x)

def _val(x):
    return "" if x is None or _is_nan(x) else str(x)

dataset = load_reviews_dataset(PARQUET_PATH)
DOC_FIELD = pick_doc_field(dataset)
N_CORPUS = len(dataset)
print(f"[corpus] Using HF Dataset from '{PARQUET_PATH}', n_docs={N_CORPUS:,}")
print(f"[corpus] Using DOC_FIELD='{DOC_FIELD}'")

# ----------------------------
# 1) Queries
# ----------------------------
def load_queries(path):
    dfq = pd.read_csv(path)
    if "query" in dfq.columns:
        return dfq["query"].astype(str).tolist(), "query"
    if "prompt" in dfq.columns:
        return dfq["prompt"].astype(str).tolist(), "prompt"
    raise ValueError(f"{path} must contain a 'query' or 'prompt' column.")

print("[1/4] Loading queries…")
query_texts, query_col = load_queries(QUERIES_CSV)
print(f"     Num queries: {len(query_texts):,}")

# ----------------------------
# 2) FAISS index + encoder
# ----------------------------
print("[2/4] Loading FAISS index…")
index = faiss.read_index(FAISS_INDEX_PATH)
print(f"     [OK] Loaded FAISS index with ntotal={index.ntotal:,}")
if index.ntotal != N_CORPUS:
    print(f"[WARN] FAISS index ntotal ({index.ntotal}) != dataset rows ({N_CORPUS}). "
          f"Make sure the index and Dataset row order match!", file=sys.stderr)

def ensure_float32(x):
    return x.astype("float32") if x.dtype != np.float32 else x

print("[3/4] Loading encoder model for queries…")
enc_model = SentenceTransformer(
    MODEL_NAME,
    trust_remote_code=True,
    model_kwargs={"torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32},
)

# ----------------------------
# 3) Encode queries & search FAISS
# ----------------------------
print("[4/4] Encoding & retrieving…")
t0 = time.time()
q_emb = enc_model.encode_query(
    query_texts,
    batch_size=BATCH_SIZE,
    convert_to_numpy=True,
    normalize_embeddings=True,
    show_progress_bar=True,
)
q_emb = ensure_float32(np.array(q_emb))
scores, ids = index.search(q_emb, k=K)

elapsed_all = time.time() - t0
avg_latency = elapsed_all / max(1, len(query_texts))
print(f"     Retrieval time: {elapsed_all:.3f}s total | ~{avg_latency:.4f}s/query")

# ----------------------------
# 4) Doc accessor (HF Dataset)
# ----------------------------
# If DOC_FIELD == "__compose_reviews__", we build a 'combined' string on the fly
compose_fields = [c for c in ["review_title", "review_positive", "review_negative", "text",
                              "review_text", "review_body", "pros", "cons"]
                  if c in set(dataset.column_names)]

def get_doc(doc_id: int) -> str:
    if doc_id < 0 or doc_id >= N_CORPUS:
        return ""
    row = dataset[int(doc_id)]
    if DOC_FIELD == "__compose_reviews__":
        parts = []
        if "review_title" in compose_fields and row.get("review_title", None):
            parts.append(f"Title: {_val(row.get('review_title'))}")
        if "text" in compose_fields and row.get("text", None):
            parts.append(_val(row.get("text")))
        if "review_text" in compose_fields and row.get("review_text", None):
            parts.append(_val(row.get("review_text")))
        if "review_body" in compose_fields and row.get("review_body", None):
            parts.append(_val(row.get("review_body")))
        if "review_positive" in compose_fields and row.get("review_positive", None):
            parts.append(f"Pros: {_val(row.get('review_positive'))}")
        if "review_negative" in compose_fields and row.get("review_negative", None):
            parts.append(f"Cons: {_val(row.get('review_negative'))}")
        if "pros" in compose_fields and row.get("pros", None):
            parts.append(f"Pros: {_val(row.get('pros'))}")
        if "cons" in compose_fields and row.get("cons", None):
            parts.append(f"Cons: {_val(row.get('cons'))}")
        return " ".join(p for p in parts if p)
    # simple single-field case
    return _val(row.get(DOC_FIELD, ""))

# ----------------------------
# 5) Reranker (Qwen3-Reranker-4B) setup + helpers
# ----------------------------
_device = "cuda" if torch.cuda.is_available() else "cpu"
_rerank_model_kwargs = {}

tokenizer = AutoTokenizer.from_pretrained(RERANKER_NAME, padding_side="left")
model = AutoModelForCausalLM.from_pretrained(RERANKER_NAME, **_rerank_model_kwargs).to(_device).eval()

token_false_id = tokenizer.convert_tokens_to_ids("no")
token_true_id  = tokenizer.convert_tokens_to_ids("yes")

max_length = getattr(model.config, "max_position_embeddings", 8192)
prefix = (
    "<|im_start|>system\n"
    "Judge whether the Document meets the requirements based on the Query and the Instruct provided. "
    'Note that the answer can only be "yes" or "no".'
    "<|im_end|>\n<|im_start|>user\n"
)
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)

def format_instruction(instruction: str, query: str, doc: str) -> str:
    if instruction is None:
        instruction = "Given a web search query, retrieve relevant passages that answer the query"
    return "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
        instruction=instruction, query=query, doc=doc
    )

def _process_inputs(pairs: list[str]):
    avail_len = max_length - len(prefix_tokens) - len(suffix_tokens)
    avail_len = int(max(16, avail_len))
    inputs = tokenizer(
        pairs,
        padding=False,
        truncation="longest_first",
        return_attention_mask=False,
        max_length=avail_len,
    )
    for i, ids_ in enumerate(inputs["input_ids"]):
        inputs["input_ids"][i] = prefix_tokens + ids_ + suffix_tokens
    inputs = tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length)
    for k in inputs:
        inputs[k] = inputs[k].to(model.device)
    return inputs

@torch.no_grad()
def _compute_yes_scores(inputs) -> list[float]:
    batch_scores = model(**inputs).logits[:, -1, :]
    false_vector = batch_scores[:, token_false_id]
    true_vector  = batch_scores[:, token_true_id]
    two_class = torch.stack([false_vector, true_vector], dim=1)
    two_class = torch.nn.functional.log_softmax(two_class, dim=1)
    return two_class[:, 1].exp().tolist()  # P("yes")

def rerank_query(query: str, docs: list[str], task: str = None, top_n: int = 3):
    if not docs:
        return [], [], []
    all_scores = []
    for start in range(0, len(docs), RERANK_BS):
        chunk = docs[start:start+RERANK_BS]
        pairs = [format_instruction(task, query, d) for d in chunk]
        inputs = _process_inputs(pairs)
        scores = _compute_yes_scores(inputs)
        all_scores.extend(scores)
    order = np.argsort(all_scores)[::-1][:top_n]
    return [docs[i] for i in order], [all_scores[i] for i in order], order.tolist()

# ----------------------------
# 6) Rerank top-K per query -> top-3
# ----------------------------
print("[*] Reranking to top-3 per query…")
task = "Given a user's query, retrieve relevant review passages that answer all the query's requirements"
rows = []
rerank_latencies = []

for i, q in enumerate(tqdm(query_texts, desc="Reranking")):
    q_ids = [int(d) for d in ids[i] if d >= 0 and d < N_CORPUS]
    q_docs = [get_doc(d) for d in q_ids]

    t0 = time.perf_counter()
    top_docs, top_scores, order = rerank_query(q, q_docs, task=task, top_n=RERANK_TOP)
    dt = time.perf_counter() - t0
    rerank_latencies.append(dt)

    top_doc_ids = [q_ids[j] for j in order]

    rec_text = "Top-3 relevant review snippets: " + " ".join(
        [f"{j+1}. {top_docs[j]}" for j in range(len(top_docs))]
    )

    rows.append({
        "prompt": q,
        "recommendations": rec_text,
        "doc_ids_top3": "|".join(str(x) for x in top_doc_ids),
        "scores_top3": "|".join(f"{s:.4f}" for s in top_scores),
    })

# ----------------------------
# 7) Latency stats + save
# ----------------------------
rerank_total = float(np.sum(rerank_latencies))
rerank_avg   = rerank_total / max(1, len(rerank_latencies))
rerank_p50   = float(np.percentile(rerank_latencies, 50))
rerank_p99   = float(np.percentile(rerank_latencies, 99))

print(f"     Rerank time: {rerank_total:.3f}s total | ~{rerank_avg:.4f}s/query | "
      f"P50={rerank_p50:.4f}s | P99={rerank_p99:.4f}s")

results_df = pd.DataFrame(rows)
results_df["retrieval_latency_sec_per_query"] = avg_latency
results_df["rerank_latency_sec_per_query"]    = rerank_avg
results_df["rerank_latency_p50_sec"]          = rerank_p50
results_df["rerank_latency_p99_sec"]          = rerank_p99

print("\nSample rows:")
print(results_df.head(3))

results_df.to_csv(OUTPUT_CSV, index=False)
print(f"\n[OK] Saved -> {OUTPUT_CSV}")
