# 1.5 Large Corpus RAG with Query Rewriting

Same setup as 1.4 -- BEIR/NQ dataset, shared 10K+ doc ChromaDB collection, Recall@K -- but with one change: before searching, we rewrite the user's question into a search-optimized query using an LLM.

User questions are conversational ("who sang that song in the movie?"). Embedding search works better with keyword-rich, declarative statements ("1993 film Aerosmith song soundtrack director"). The rewrite model transforms one into the other.

We search with the **rewritten** query but answer with the **original** question. The rewrite is optimized for retrieval, not comprehension. Everything else (corpus, collection, scoring) is identical to 1.4 so the comparison is controlled.


In [15]:
import os
import random
import time
import hashlib
import json
import openai
import chromadb
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
from dotenv import load_dotenv
from datasets import load_dataset
from typing import Any, cast
from concurrent.futures import ThreadPoolExecutor, as_completed

load_dotenv()

True

## Configuration


In [16]:
# Initialize OpenRouter client with the OpenAI SDK
client = openai.OpenAI(
    api_key=os.getenv("OPENROUTER_API_KEY"),
    base_url="https://openrouter.ai/api/v1",
)

# Model being evaluated
EVAL_MODEL = "moonshotai/kimi-k2.5"

# Scoring model
SCORING_MODEL = "google/gemini-3-flash-preview"

# Query rewriting model
REWRITE_MODEL = "google/gemini-3-flash-preview"

# RAG retrieval settings
TOP_K = 5
CORPUS_SIZE = 10000
NUM_EXAMPLES = 50
EMBEDDING_MODEL = "text-embedding-3-small"
CHROMA_DIR = ".chroma_nq"

# Embedding function for ChromaDB (uses OpenAI directly, not OpenRouter)
embedding_fn = OpenAIEmbeddingFunction(
    api_key=os.getenv("OPENAI_API_KEY"),
    model_name=EMBEDDING_MODEL,
)

## Load BEIR/NQ Dataset


In [17]:
def load_data(corpus_size: int, num_examples: int) -> tuple[list[dict], list[dict], dict]:
    """
    Load BEIR Natural Questions: corpus, queries, and relevance judgments.

    Builds a corpus subset that includes all gold-relevant documents for
    the queries we'll evaluate, plus randomly sampled Wikipedia passages
    as distractors up to corpus_size.

    Args:
        corpus_size: Total number of documents in the corpus subset
        num_examples: Number of queries to evaluate

    Returns:
        Tuple of (corpus_subset, eval_queries, qrels)
        - corpus_subset: list of dicts with _id, title, text
        - eval_queries: list of dicts with _id, text
        - qrels: dict mapping query_id -> set of relevant doc_ids
    """
    print("Loading BEIR/NQ queries and relevance judgments...")
    # The BeIR/nq dataset uses a custom loading script that datasets v4+
    # no longer supports. Load directly from the auto-converted parquet files.
    queries_ds = load_dataset(
        "parquet",
        data_files="hf://datasets/BeIR/nq@refs/convert/parquet/queries/queries/0000.parquet",
        split="train",
    )
    qrels_ds = load_dataset("BeIR/nq-qrels", split="test")

    # Build qrels lookup: query_id -> set of relevant corpus doc_ids
    qrels: dict[str, set[str]] = {}
    for row in qrels_ds:
        qrels.setdefault(row["query-id"], set()).add(row["corpus-id"])

    # Select queries that have relevance judgments
    eval_queries = [q for q in queries_ds if q["_id"] in qrels][:num_examples]

    # Collect all gold-relevant doc IDs for the queries we'll evaluate
    gold_doc_ids: set[str] = set()
    for q in eval_queries:
        gold_doc_ids.update(qrels[q["_id"]])

    print(
        f"Selected {len(eval_queries)} queries, {len(gold_doc_ids)} gold documents")

    # Load the full corpus (HuggingFace uses memory-mapped Arrow, so this
    # doesn't load 2.68M docs into RAM -- it's a lazy view)
    print("Loading BEIR/NQ corpus (2.68M Wikipedia passages)...")
    corpus_ds = load_dataset(
        "parquet",
        data_files=[
            "hf://datasets/BeIR/nq@refs/convert/parquet/corpus/corpus/0000.parquet",
            "hf://datasets/BeIR/nq@refs/convert/parquet/corpus/corpus/0001.parquet",
            "hf://datasets/BeIR/nq@refs/convert/parquet/corpus/corpus/0002.parquet",
        ],
        split="train",
    )

    # Find gold documents using HF's optimized batched filter
    print("Locating gold documents in corpus...")
    gold_docs_ds = corpus_ds.filter(
        lambda batch: [did in gold_doc_ids for did in batch["_id"]],
        batched=True,
        batch_size=10000,
    )
    print(f"Found {len(gold_docs_ds)}/{len(gold_doc_ids)} gold documents")

    # Build corpus subset: gold docs + random distractors
    # Skip docs with empty text (OpenAI embeddings API rejects them)
    corpus_subset = [
        {"_id": gold_docs_ds[i]["_id"],
         "text": gold_docs_ds[i]["text"],
         "title": gold_docs_ds[i]["title"]}
        for i in range(len(gold_docs_ds))
        if gold_docs_ds[i]["text"]
    ]

    # Sample random distractor documents from the corpus
    fill_count = max(0, corpus_size - len(corpus_subset))
    if fill_count > 0:
        random.seed(42)
        # Sample candidate indices, then filter out any gold docs
        candidate_indices = random.sample(
            range(len(corpus_ds)), min(fill_count * 2, len(corpus_ds)))
        candidates = corpus_ds.select(candidate_indices)

        for i in range(len(candidates)):
            if candidates[i]["_id"] not in gold_doc_ids and candidates[i]["text"]:
                corpus_subset.append({
                    "_id": candidates[i]["_id"],
                    "text": candidates[i]["text"],
                    "title": candidates[i]["title"],
                })
            if len(corpus_subset) >= corpus_size:
                break

    gold_count = len(gold_docs_ds)
    fill_actual = len(corpus_subset) - gold_count
    print(f"Corpus: {len(corpus_subset)} docs "
          f"({gold_count} gold + {fill_actual} distractors)\n")

    return corpus_subset, eval_queries, qrels

In [18]:
corpus_subset, eval_queries, qrels = load_data(CORPUS_SIZE, NUM_EXAMPLES)

Loading BEIR/NQ queries and relevance judgments...
Selected 50 queries, 61 gold documents
Loading BEIR/NQ corpus (2.68M Wikipedia passages)...
Locating gold documents in corpus...
Found 61/61 gold documents
Corpus: 10000 docs (61 gold + 9939 distractors)



## Build ChromaDB Collection


In [19]:
def build_collection(
    corpus_subset: list[dict],
    chroma_dir: str,
) -> chromadb.Collection:
    """
    Build or load a persistent ChromaDB collection from the corpus subset.

    Uses a stable collection name based on a hash of the document IDs, so
    re-running with the same corpus skips embedding entirely (even across
    sessions). If the corpus changes, a new collection is created.

    Args:
        corpus_subset: List of document dicts with _id, text, title
        chroma_dir: Path to ChromaDB persistent storage directory

    Returns:
        A ChromaDB Collection ready for querying
    """
    # Build a stable name from the sorted doc IDs so the same corpus
    # always maps to the same collection (and we skip re-embedding)
    id_hash = hashlib.sha256(
        ",".join(sorted(d["_id"] for d in corpus_subset)).encode()
    ).hexdigest()[:12]
    collection_name = f"nq_{len(corpus_subset)}_{id_hash}"

    chroma_client = chromadb.PersistentClient(path=chroma_dir)

    # Try to reuse an existing collection with matching name and size
    try:
        collection = chroma_client.get_collection(
            name=collection_name,
            embedding_function=cast(Any, embedding_fn),
        )
        if collection.count() == len(corpus_subset):
            print(f"Cache hit! Reusing collection '{collection_name}' "
                  f"({collection.count():,} docs, no re-embedding needed)\n")
            return collection
        # Size mismatch -- rebuild
        print(f"Collection size mismatch "
              f"({collection.count()} vs {len(corpus_subset)}), rebuilding...")
        chroma_client.delete_collection(collection_name)
    except Exception:
        pass

    total = len(corpus_subset)
    print(
        f"Embedding {total:,} documents into collection '{collection_name}'...")
    print(f"(This is a one-time cost -- cached on disk for future runs)\n")

    # Embed and add documents in batches with timing stats
    BATCH_SIZE = 500
    total_batches = (total + BATCH_SIZE - 1) // BATCH_SIZE
    collection = chroma_client.create_collection(
        name=collection_name,
        embedding_function=cast(Any, embedding_fn),
    )

    start_time = time.time()
    for batch_num, i in enumerate(range(0, total, BATCH_SIZE), 1):
        batch = corpus_subset[i:i + BATCH_SIZE]
        batch_start = time.time()

        collection.add(
            ids=[d["_id"] for d in batch],
            documents=[d["text"] for d in batch],
            metadatas=[{"title": d["title"]} for d in batch],
        )

        done = min(i + BATCH_SIZE, total)
        batch_time = time.time() - batch_start
        elapsed = time.time() - start_time
        rate = done / elapsed  # docs per second
        remaining = (total - done) / rate if rate > 0 else 0

        print(f"  [{batch_num}/{total_batches}] {done:,}/{total:,} docs | "
              f"batch: {batch_time:.1f}s | "
              f"rate: {rate:.0f} docs/s | "
              f"ETA: {remaining:.0f}s")

    total_time = time.time() - start_time
    print(f"\nCollection ready: {collection.count():,} documents "
          f"(embedded in {total_time:.1f}s)\n")
    return collection

In [20]:
collection = build_collection(corpus_subset, CHROMA_DIR)

Cache hit! Reusing collection 'nq_10000_b1bf36b34ec3' (10,000 docs, no re-embedding needed)



## Query Rewriting


In [21]:
def rewrite_query(question: str) -> str:
    """
    Rewrite a user question into a search-optimized query.

    Conversational questions don't always match well against document
    embeddings. This function asks an LLM to rewrite the question into
    a keyword-rich, declarative form that's better suited for embedding
    similarity search.

    Args:
        question: The original user question

    Returns:
        A search-optimized query string
    """
    response = client.chat.completions.create(
        model=REWRITE_MODEL,
        messages=[
            {"role": "system", "content": (
                "You are a search query optimizer. Given a user question, "
                "rewrite it as a search query optimized for semantic similarity "
                "search against a Wikipedia corpus. Follow these rules:\n"
                "1. Extract key entities, names, dates, and concepts\n"
                "2. Use declarative, keyword-rich phrasing instead of question form\n"
                "3. Expand abbreviations and add synonyms where helpful\n"
                "4. Remove filler words (who, what, how, please, etc.)\n"
                "5. Keep it concise -- one line, no explanation\n"
                "6. Output ONLY the rewritten query, nothing else"
            )},
            {"role": "user", "content": question},
        ],
        max_tokens=100,
    )
    return response.choices[0].message.content.strip()

## Generate & Score Answers


In [22]:
def generate_answer(question: str, retrieved_docs: list[str]) -> str:
    """
    Generate an answer using the eval model with retrieved documents.

    Args:
        question: The question to answer
        retrieved_docs: List of document texts retrieved from the collection

    Returns:
        The generated answer as a string
    """
    context = "\n\n".join(
        [f"[Document {i+1}]\n{doc}" for i, doc in enumerate(retrieved_docs)])

    system_prompt = """You are a helpful assistant that answers questions based on provided documents.
Your task is to:
1. Carefully read all provided documents
2. Find the information needed to answer the question
3. Provide a clear, concise answer based ONLY on the documents

If the answer cannot be found in the documents, say so explicitly."""

    user_prompt = f"""Documents:
{context}

Question: {question}

Please answer the question based on the provided documents."""

    response = client.chat.completions.create(
        model=EVAL_MODEL,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ],
        max_tokens=4096,
    )

    return response.choices[0].message.content or ""


def score_answer(
    question: str,
    documents: list[str],
    generated_answer: str,
) -> tuple[int, str]:
    """
    Score the generated answer using the scoring model.

    Args:
        question: The original question
        documents: The retrieved documents used to generate the answer
        generated_answer: The answer generated by the eval model

    Returns:
        Tuple of (score out of 100, explanation)
    """
    context = "\n\n".join(
        [f"[Document {i+1}]\n{doc}" for i, doc in enumerate(documents)])

    system_prompt = """You are an expert evaluator assessing the quality of answers to questions.
Evaluate the answer on these criteria:
1. Correctness (0-25 points): Is the answer factually accurate based on the documents?
2. Completeness (0-25 points): Does it fully answer the question? Are important details included?
3. Faithfulness (0-25 points): Does it only use information from the documents? No hallucinations?
4. Clarity (0-25 points): Is the answer clear, well-organized, and easy to understand?

Respond with a JSON object containing:
{
    "score": <integer from 0-100>,
    "reasoning": "<brief explanation of the score>"
}"""

    eval_prompt = f"""Documents:
{context}

Question: {question}

Generated Answer:
{generated_answer}"""

    response = client.chat.completions.create(
        model=SCORING_MODEL,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": eval_prompt},
        ],
        max_tokens=300,
    )

    try:
        score_data = json.loads(response.choices[0].message.content or "{}")
        return score_data.get("score", 0), score_data.get("reasoning", "")
    except json.JSONDecodeError:
        return 0, "Error parsing score response"

## Evaluation Loop


In [23]:
def evaluate_single_example(
    query: dict,
    example_index: int,
    collection: chromadb.Collection,
    qrels: dict[str, set[str]],
    top_k: int,
) -> tuple[int, dict]:
    """
    Evaluate a single query against the shared collection.

    Rewrites the query for better retrieval, retrieves top-k documents,
    computes Recall@K against ground-truth relevance labels, generates
    an answer using the ORIGINAL question, and scores it.
    """
    question = query["text"]
    query_id = query["_id"]
    gold_ids = qrels.get(query_id, set())

    # Rewrite the question into a search-optimized query
    rewritten = rewrite_query(question)

    # Retrieve using the REWRITTEN query (not the original question)
    results = collection.query(query_texts=[rewritten], n_results=top_k)

    retrieved_ids = results["ids"][0] if results["ids"] else []
    retrieved_docs = results["documents"][0] if results["documents"] else []

    # Compute Recall@K: what fraction of gold docs did we find?
    hits = len(gold_ids & set(retrieved_ids))
    recall = hits / len(gold_ids) if gold_ids else 0.0

    # Generate answer using the ORIGINAL question (not rewritten)
    generated_answer = generate_answer(question, retrieved_docs)
    score, reasoning = score_answer(question, retrieved_docs, generated_answer)

    result = {
        "query_id": query_id,
        "question": question,
        "rewritten_query": rewritten,
        "generated_answer": generated_answer,
        "score": score,
        "scoring_reasoning": reasoning,
        "recall_at_k": recall,
        "gold_docs_found": hits,
        "gold_docs_total": len(gold_ids),
    }

    return example_index, result


def run_evaluation(
    eval_queries: list[dict],
    collection: chromadb.Collection,
    qrels: dict[str, set[str]],
    top_k: int,
    max_workers: int = 8,
) -> dict:
    """Run evaluation on all queries using parallel workers."""
    eval_size = len(eval_queries)

    print(f"Running evaluation on {eval_size} queries with "
          f"{max_workers} parallel workers...\n")

    results_by_index: dict[int, dict] = {}
    scores: list[int] = []
    recalls: list[float] = []

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_idx = {
            executor.submit(
                evaluate_single_example,
                eval_queries[i], i, collection, qrels, top_k
            ): i
            for i in range(eval_size)
        }

        completed = 0
        for future in as_completed(future_to_idx):
            idx = future_to_idx[future]
            try:
                example_idx, result = future.result()
                results_by_index[example_idx] = result
                scores.append(result["score"])
                recalls.append(result["recall_at_k"])
                completed += 1

                # Print progress
                print(
                    f"[{completed}/{eval_size}] Example {example_idx + 1}")
                print(f"  Question:  {result['question'][:80]}...")
                print(f"  Rewritten: {result['rewritten_query'][:80]}")
                print(f"  Recall@{top_k}: {result['recall_at_k']:.2f} "
                      f"({result['gold_docs_found']}/{result['gold_docs_total']} gold)")
                print(f"  Score: {result['score']}/100")
                print(f"  Reasoning: {result['scoring_reasoning']}\n")

            except Exception as e:
                print(f"[Error] Example {idx + 1} failed: {e}\n")

    results = [results_by_index[i]
               for i in range(eval_size) if i in results_by_index]

    avg_score = sum(scores) / len(scores) if scores else 0
    avg_recall = sum(recalls) / len(recalls) if recalls else 0

    return {
        "model_evaluated": EVAL_MODEL,
        "scoring_model": SCORING_MODEL,
        "rewrite_model": REWRITE_MODEL,
        "dataset": "BeIR/NQ",
        "corpus_size": collection.count(),
        "top_k": top_k,
        "num_examples_evaluated": len(scores),
        "overall_score": round(avg_score, 2),
        "avg_recall_at_k": round(avg_recall, 4),
        "individual_scores": scores,
        "individual_recalls": recalls,
        "score_distribution": {
            "90-100": sum(1 for s in scores if s >= 90),
            "80-89": sum(1 for s in scores if 80 <= s < 90),
            "70-79": sum(1 for s in scores if 70 <= s < 80),
            "60-69": sum(1 for s in scores if 60 <= s < 70),
            "below-60": sum(1 for s in scores if s < 60),
        },
        "detailed_results": results,
    }

In [24]:
results = run_evaluation(eval_queries, collection, qrels, TOP_K)

Running evaluation on 50 queries with 8 parallel workers...

[1/50] Example 3
  Question:  who sings love will keep us alive by the eagles...
  Rewritten: Timothy B. Schmit lead vocals Love Will Keep Us Alive Eagles Hell Freezes Over b
  Recall@5: 1.00 (1/1 gold)
  Score: 100/100
  Reasoning: The answer is perfectly accurate, complete, and faithful to the provided documents. Document 1 explicitly states that the song features lead vocals by bassist Timothy B. Schmit. The answer is also clear and concise.

[2/50] Example 4
  Question:  who is the leader of the ontario pc party...
  Rewritten: Doug Ford leader Ontario Progressive Conservative Party PC Party of Ontario MPP 
  Recall@5: 1.00 (2/2 gold)
  Score: 100/100
  Reasoning: The answer is entirely accurate based on the provided documents. Documents 1 and 2 explicitly state that Patrick Brown is the leader of the Progressive Conservative Party of Ontario. The answer is clear, faithful to the text, and complete.

[3/50] Example 2
  Qu

## Results Summary


In [12]:
print("=" * 80)
print("EVALUATION SUMMARY")
print("=" * 80)
print(f"Model evaluated:  {results['model_evaluated']}")
print(f"Scoring model:    {results['scoring_model']}")
print(f"Rewrite model:    {results['rewrite_model']}")
print(f"Dataset:          {results['dataset']}")
print(f"Corpus size:      {results['corpus_size']:,} documents")
print(f"Top-K:            {results['top_k']}")
print(f"Queries evaluated: {results['num_examples_evaluated']}")
print(f"\nAvg Recall@{TOP_K}:    {results['avg_recall_at_k']:.4f}")
print(f"Overall Score:    {results['overall_score']}/100")
print(f"\nScore Distribution:")
for range_label, count in results['score_distribution'].items():
    print(f"  {range_label}: {count} examples")

EVALUATION SUMMARY
Model evaluated:  moonshotai/kimi-k2.5
Scoring model:    google/gemini-3-flash-preview
Rewrite model:    google/gemini-3-flash-preview
Dataset:          BeIR/NQ
Corpus size:      10,000 documents
Top-K:            5
Queries evaluated: 50

Avg Recall@5:    0.9800
Overall Score:    96.0/100

Score Distribution:
  90-100: 48 examples
  80-89: 0 examples
  70-79: 0 examples
  60-69: 0 examples
  below-60: 2 examples


## Save Results


In [13]:
# Save results to ./evals
os.makedirs("evals", exist_ok=True)
eval_path = "evals/1.5_large_corpus_rag_with_query_rewriting.json"
with open(eval_path, "w") as f:
    json.dump(results, f, indent=2)
print(f"Results saved to {eval_path}")

Results saved to evals/1.5_large_corpus_rag_with_query_rewriting.json
