# Text Retrieval and Search Engines - Assignment 3
**Gal Noy** · 209346486

## 1. Setup & Dependencies

In [42]:
import pandas as pd
import json
import re
import string
from collections import Counter
from typing import List, Dict, Tuple, Optional, Literal
from tqdm.auto import tqdm
tqdm.pandas()
import warnings
warnings.filterwarnings('ignore')

print("✓ Dependencies imported")

✓ Dependencies imported


### Install Required Packages

In [43]:
# !apt-get update
# !apt-get install -y openjdk-21-jdk
# !update-alternatives --install /usr/bin/java java /usr/lib/jvm/java-21-openjdk-amd64/bin/java 1
# !update-alternatives --install /usr/bin/javac javac /usr/lib/jvm/java-21-openjdk-amd64/bin/javac 1
# !update-alternatives --set java /usr/lib/jvm/java-21-openjdk-amd64/bin/java
# !update-alternatives --set javac /usr/lib/jvm/java-21-openjdk-amd64/bin/javac

In [44]:
import os

os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-21-openjdk-amd64"
os.environ["PATH"] = os.environ["JAVA_HOME"] + "/bin:" + os.environ["PATH"]

!java -version

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


openjdk version "21.0.9" 2025-10-21
OpenJDK Runtime Environment (build 21.0.9+10-Ubuntu-122.04)
OpenJDK 64-Bit Server VM (build 21.0.9+10-Ubuntu-122.04, mixed mode, sharing)


In [45]:
# !pip install torch torchvision torchaudio
# !pip install pyserini==0.36.0
# !pip install accelerate
# !pip install transformers
# !pip install tqdm
# !pip install python-dotenv

### Hugging Face Authentication

In [46]:
from huggingface_hub import login
from dotenv import load_dotenv
import os

load_dotenv()

login(os.getenv('HUGGING_FACE_TOKEN'))
print("✓ Logged into Hugging Face")

✓ Logged into Hugging Face


## 2. Data Loading

In [47]:
df_train = pd.read_csv("./data/train.csv", converters={"answers": json.loads})
df_test = pd.read_csv("./data/test.csv")

print(f"Train set: {len(df_train)} questions")
print(f"Test set: {len(df_test)} questions")
print(f"\nSample question: {df_train.iloc[0]['question']}")
print(f"Sample answers: {df_train.iloc[0]['answers']}")

Train set: 3778 questions
Test set: 2032 questions

Sample question: what is the name of justin bieber brother?
Sample answers: ['Jazmyn Bieber', 'Jaxon Bieber']


## 3. Retrieval Functions

### Pyserini index

In [48]:
from pyserini.search import SimpleSearcher
from pyserini.index.lucene import IndexReader

print("Loading Pyserini index...")
searcher = SimpleSearcher.from_prebuilt_index('wikipedia-kilt-doc')
index_reader = IndexReader.from_prebuilt_index('wikipedia-kilt-doc')

print(f"✓ Index loaded: {index_reader.stats()['documents']} documents")

Loading Pyserini index...
SimpleSearcher class has been deprecated, please use LuceneSearcher from pyserini.search.lucene instead
✓ Index loaded: 5903530 documents


### Cross-encoder reranker

In [49]:
from sentence_transformers import CrossEncoder
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

print("Loading cross-encoder reranker...")
reranker = CrossEncoder(
    "cross-encoder/ms-marco-MiniLM-L-6-v2",
    model_kwargs={"torch_dtype": torch.float16},
    device=device,
    max_length=512,
)
print("✓ Cross-encoder loaded")

Loading cross-encoder reranker...
✓ Cross-encoder loaded


### Retrieval manager

In [50]:
from dataclasses import dataclass
from typing import List, Dict
from functools import lru_cache
import json

@lru_cache(maxsize=1000)
def get_doc_content(docid: str) -> str:
    """Return cached raw document text."""
    try:
        doc = searcher.doc(docid)
        return json.loads(doc.raw()).get("contents", "").replace("\n", " ")
    except Exception:
        return ""
    

@dataclass
class RetrievalManager:
    """
    Hybrid retrieval with:
      - BM25 document retrieval
      - QLD document retrieval
      - RRF on documents (recall stage)
      - Passage segmentation
      - Cross-encoder passage reranking (precision stage)
    """

    # Retrieval
    k_docs: int = 100
    k_passages: int = 7
    rrf_k: int = 60

    # BM25 / QLD
    mu: int = 1000
    k1: float = 0.9
    b: float = 0.4

    # Passage extraction
    window: int = 150
    overlap: int = 50
    min_passage_words: int = 30

    def __str__(self):
        return (
            f"RetrievalManager("
            f"k_docs={self.k_docs}, k_passages={self.k_passages}, "
            f"RRF_k={self.rrf_k}, "
            f"window={self.window}, overlap={self.overlap})"
        )

    def extract_passages(self, text: str) -> List[str]:
        """Split document text into overlapping word windows."""
        if not text:
            return []

        words = text.split()
        if len(words) < self.min_passage_words:
            return []

        step = max(1, self.window - self.overlap)
        passages = []

        for i in range(0, len(words), step):
            chunk = words[i:i + self.window]
            if len(chunk) < self.min_passage_words:
                break
            passages.append(" ".join(chunk))

        return passages

    def rerank_passages(self, query: str, passages: List[str]) -> List[str]:
        """Rerank passages using cross-encoder (ordering only)."""
        if not passages:
            return []

        pairs = [(query, p) for p in passages]
        
        scores = reranker.predict(
            pairs,
            batch_size=16,
            show_progress_bar=False,
        )
        
        ranked = sorted(
            zip(passages, scores),
            key=lambda x: x[1],
            reverse=True,
        )

        return [p for p, _ in ranked[:self.k_passages]]
    
    def retrieve_context(self, query: str) -> List[str]:
        """Return top answer-bearing passages."""

        # Lexical retrieval
        searcher.set_bm25(self.k1, self.b)
        bm25_docids = [h.docid for h in searcher.search(query, self.k_docs)]

        searcher.set_qld(self.mu)
        qld_docids = [h.docid for h in searcher.search(query, self.k_docs)]

        # RRF on documents
        doc_scores: Dict[str, float] = {}

        for rank, docid in enumerate(bm25_docids):
            doc_scores[docid] = doc_scores.get(docid, 0.0) + 1.0 / (self.rrf_k + rank + 1)

        for rank, docid in enumerate(qld_docids):
            doc_scores[docid] = doc_scores.get(docid, 0.0) + 1.0 / (self.rrf_k + rank + 1)

        ranked_docids = sorted(
            doc_scores,
            key=doc_scores.get,
            reverse=True,
        )

        # Passage extraction
        passages: List[str] = []

        for docid in ranked_docids:
            content = get_doc_content(docid)
            if not content:
                continue

            passages.extend(self.extract_passages(content))

            # implicit cap: enough passages for reranking
            if len(passages) >= self.k_docs * 5:
                break

        # Cross-encoder reranking
        return self.rerank_passages(query, passages)
    

query = "Who wrote Harry Potter?"

rm = RetrievalManager()
print(rm)

passages = rm.retrieve_context(query)
for i, p in enumerate(passages, 1):
    print(f"{i}. {p[:120]}...")


RetrievalManager(k_docs=100, k_passages=7, RRF_k=60, window=150, overlap=50)
1. Harry Potter Harry Potter is a series of fantasy novels written by British author J. K. Rowling. The novels chronicle th...
2. J. K. Rowling Joanne Rowling ( "rolling"; born 31 July 1965), better known by her pen name J. K. Rowling, is a British n...
3. The Magical Worlds of Harry Potter The Magical Worlds of Harry Potter: A Treasury of Myths, Legends, and Fascinating Fac...
4. English by two major publishers, Bloomsbury in the United Kingdom and Scholastic Press in the United States. A play, "Ha...
5. the lives of the surviving characters and the effects of Voldemort's death on the Wizarding World. In the epilogue, Harr...
6. Fry reading the UK editions and Jim Dale voicing the series for the American editions. Section::::Adaptations.:Stage pro...
7. Harry Potter influences and analogues Writer J. K. Rowling cites several writers as influences in her creation of her be...


## 4. LLM Generation

### Load LLM model

In [51]:
import transformers
import torch
import logging

# Suppress transformers warnings
transformers.logging.set_verbosity_error()
logging.getLogger("transformers").setLevel(logging.ERROR)

print("Loading LLM model...")
model_id = "meta-llama/Llama-3.2-1B-Instruct"

pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.float16},
    device=0
)

terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

# Set pad_token for batch processing
pipeline.tokenizer.pad_token = pipeline.tokenizer.eos_token

print(f"✓ Model loaded on: {'GPU' if torch.cuda.is_available() else 'CPU'}")

Loading LLM model...
✓ Model loaded on: GPU


### Prompt manager

In [52]:
SYSTEM_PROMPT = (
    "You are a strict, grounded Question Answering system.\n"
    "You are given documents and a question.\n"
    "Answer ONLY using information that appears in the documents.\n"
    "Your answer must be ONLY the entity or value that answers the question.\n"
    "Do NOT return sentences, clauses, or descriptions.\n"
    "If the answer cannot be verified in the documents, return: unknown."
)


USER_PROMPT = (
    "Documents:\n"
    "{context}\n\n"
    "Task:\n"
    "- Answer the question using only the documents.\n"
    "- Return ONE single answer only (not a list or multiple items).\n"
    "- The answer must match the question type (person, place, date, number).\n"
    "- Return the shortest complete answer that answers the question.\n"
    "- The answer must appear verbatim or as a clear entity in the documents.\n"
    "- Do NOT return explanations, relations, or full sentences.\n"
    "- If you cannot verify the answer in the documents, output: unknown.\n\n"
    "Question: {question}\n"
    "Answer:"
)


@dataclass
class PromptManager:
    """Manages prompt generation and LLM answer generation."""
    system_prompt: str = SYSTEM_PROMPT
    user_prompt: str = USER_PROMPT
    temperature: float = 0.0
    top_p: float = 1.0
    max_new_tokens: int = 64
    do_sample: bool = False
    
    def __str__(self):
        return f"temp={self.temperature}, top_p={self.top_p}, max_tokens={self.max_new_tokens}"

    @staticmethod
    def clean_answer(answer: str) -> str:
        """Clean and standardize the generated answer."""
        answer = re.sub(r'^(Answer|The answer is|Based on the .*?,):?\s*', '', answer, flags=re.I)
        answer = answer.rstrip('.')
        if any(phrase in answer.lower() for phrase in ["dont know", "don't know", "do not know", "unknown"]):
            return "unknown"
        return answer.strip()

    def create_messages(self, question: str, contexts: List[str]) -> List[Dict]:
        """Create messages for the LLM based on the question and contexts."""
        if not contexts:
            context_str = "No relevant documents found."
        else:
            context_str = '\n\n'.join([f"Document {i+1}: {ctx}" for i, ctx in enumerate(contexts)])
        
        return [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": self.user_prompt.format(context=context_str, question=question)}
        ]

    def generate_answer(self, question: str, contexts: List[str]) -> str:
        """Generate an answer using the LLM based on the question and contexts."""
        messages = self.create_messages(question, contexts)
        
        outputs = pipeline(
            messages,
            max_new_tokens=self.max_new_tokens,
            eos_token_id=terminators,
            do_sample=self.do_sample,
            temperature=self.temperature,
            top_p=self.top_p
        )
        
        answer = outputs[0]["generated_text"][-1].get('content', '')
        return self.clean_answer(answer)


test_prompt_manager = PromptManager()
print(f"Testing: {test_prompt_manager}")
test_answer = test_prompt_manager.generate_answer(query, passages)
print(f"✓ Generated answer: '{test_answer}'")

Testing: temp=0.0, top_p=1.0, max_tokens=64
✓ Generated answer: 'J. K. Rowling'


## 5. Evaluation Metrics

In [53]:
def normalize_answer(s: str) -> str:
    """Normalize answer for comparison"""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)
    
    def white_space_fix(text):
        return ' '.join(text.split())
    
    def remove_punc(text):
        return ''.join(ch for ch in text if ch not in set(string.punctuation))
    
    def lower(text):
        return text.lower()
    
    return white_space_fix(remove_articles(remove_punc(lower(s))))


def compute_token_metrics(prediction: str, ground_truth: str) -> Tuple[float, float, float]:
    """
    Compute precision, recall, and F1 score for token-level comparison.
    Returns: (precision, recall, f1)
    """
    pred_tokens = normalize_answer(prediction).split()
    gt_tokens = normalize_answer(ground_truth).split()
    
    # Handle empty cases
    if len(pred_tokens) == 0 or len(gt_tokens) == 0:
        match = int(pred_tokens == gt_tokens)
        return match, match, match
    
    # Compute overlap
    common = Counter(pred_tokens) & Counter(gt_tokens)
    num_same = sum(common.values())
    
    if num_same == 0:
        return 0.0, 0.0, 0.0
    
    precision = num_same / len(pred_tokens)
    recall = num_same / len(gt_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    
    return precision, recall, f1


def evaluate_predictions(df_gold: pd.DataFrame, predictions: Dict[int, str]) -> Dict:
    """Evaluate predictions against ground truth."""
    f1_scores = []
    precision_scores = []
    recall_scores = []
    exact_matches = []
    
    for _, row in df_gold.iterrows():
        qid = row['id']
        
        # Handle missing predictions
        if qid not in predictions:
            f1_scores.append(0.0)
            precision_scores.append(0.0)
            recall_scores.append(0.0)
            exact_matches.append(0)
            continue
        
        prediction = predictions[qid]
        ground_truths = row['answers']
        
        # Normalize once
        norm_prediction = normalize_answer(prediction)
        
        # Find best match across all ground truths
        best_f1 = 0.0
        best_precision = 0.0
        best_recall = 0.0
        is_exact = 0
        
        for gt in ground_truths:
            norm_gt = normalize_answer(gt)
            
            # Compute metrics
            prec, rec, f1 = compute_token_metrics(prediction, gt)
            
            # Track best scores
            if f1 > best_f1:
                best_f1 = f1
                best_precision = prec
                best_recall = rec
            
            # Check exact match
            if norm_prediction == norm_gt:
                is_exact = 1
        
        f1_scores.append(best_f1)
        precision_scores.append(best_precision)
        recall_scores.append(best_recall)
        exact_matches.append(is_exact)
    
    return {
        'f1': 100.0 * sum(f1_scores) / len(f1_scores) if f1_scores else 0.0,
        'precision': 100.0 * sum(precision_scores) / len(precision_scores) if precision_scores else 0.0,
        'recall': 100.0 * sum(recall_scores) / len(recall_scores) if recall_scores else 0.0,
        'exact_match': 100.0 * sum(exact_matches) / len(exact_matches) if exact_matches else 0.0,
        'f1_scores': f1_scores,
        'precision_scores': precision_scores,
        'recall_scores': recall_scores,
        'exact_matches': exact_matches
    }


# Test evaluation
test_predictions = {1: "J.K. Rowling", 2: "Paris", 3: "Shakespeare"}
test_gold = pd.DataFrame({
    'id': [1, 2, 3],
    'answers': [["J.K. Rowling", "Rowling"], ["Earth"], ["William Shakespeare", "Shakespeare"]]
})

test_metrics = evaluate_predictions(test_gold, test_predictions)
print(f"✓ Evaluation test: F1={test_metrics['f1']:.2f}, P={test_metrics['precision']:.2f}, R={test_metrics['recall']:.2f}, EM={test_metrics['exact_match']:.2f}")

✓ Evaluation test: F1=66.67, P=66.67, R=66.67, EM=66.67


## 6. Experiment Framework

In [54]:
def run_experiment(
    name: str,
    df_data: pd.DataFrame,
    retrieval_manager: RetrievalManager,
    prompt_manager: PromptManager,
    max_questions: Optional[int] = None,
    verbose: bool = True
) -> Dict:
    """
    Run a full experiment: retrieval + prompting + evaluation.
    Args:
        name: Name of the experiment.
        df_data: DataFrame with questions and answers.
        retrieval_manager: RetrievalManager instance.
        prompt_manager: PromptManager instance.
        max_questions: Optional limit on number of questions to process.
        verbose: Whether to print progress and results.
    """
    if max_questions:
        df_data = df_data.head(max_questions)

    predictions = {}

    iterator = tqdm(df_data.iterrows(), total=len(df_data), desc=name) if verbose else df_data.iterrows()

    for _, row in iterator:
        question = row['question']
        qid = row['id']

        contexts = retrieval_manager.retrieve_context(question)
        answer = prompt_manager.generate_answer(question, contexts)

        predictions[qid] = answer

    metrics = evaluate_predictions(df_data, predictions)

    result = {
        'name': name,
        'retrieval': retrieval_manager,
        'prompt': prompt_manager,
        'f1_score': metrics['f1'],
        'precision': metrics['precision'],
        'recall': metrics['recall'],
        'exact_match': metrics['exact_match'],
        'num_questions': len(df_data),
        'predictions': predictions,
        'f1_scores': metrics['f1_scores'],
        'precision_scores': metrics['precision_scores'],
        'recall_scores': metrics['recall_scores'],
        'exact_matches': metrics['exact_matches']
    }

    if verbose:
        print(f"\n{name}")
        print(f"   Retrieval: {retrieval_manager}")
        print(f"   Prompt: {prompt_manager}")
        print(
            f"   F1={metrics['f1']:.2f} | "
            f"P={metrics['precision']:.2f} | "
            f"R={metrics['recall']:.2f} | "
            f"EM={metrics['exact_match']:.2f}"
        )
        print(f"   Questions: {len(df_data)}\n")

    return result


test_retrieval = RetrievalManager()
test_prompt = PromptManager()
print(f"Testing experiment with:")
print(f"  Retrieval: {test_retrieval}")
print(f"  Prompt: {test_prompt}")

test_exp = run_experiment(
    "Quick Test",
    df_train.head(25),
    test_retrieval,
    test_prompt,
    verbose=True
)

print(f"✓ Experiment framework ready")

Testing experiment with:
  Retrieval: RetrievalManager(k_docs=100, k_passages=7, RRF_k=60, window=150, overlap=50)
  Prompt: temp=0.0, top_p=1.0, max_tokens=64


Quick Test: 100%|██████████| 25/25 [01:57<00:00,  4.70s/it]


Quick Test
   Retrieval: RetrievalManager(k_docs=100, k_passages=7, RRF_k=60, window=150, overlap=50)
   Prompt: temp=0.0, top_p=1.0, max_tokens=64
   F1=30.33 | P=36.00 | R=27.73 | EM=20.00
   Questions: 25

✓ Experiment framework ready





## 7. Experiments

### Experiments global config

In [55]:
EXPERIMENT_SEED = 42
EXPERIMENTS_NUM_QUESTIONS = 100
DEFAULT_PROMPT_MANAGER = PromptManager()

### Experiments utils

In [56]:
def get_experiment_log_path(num_questions: int) -> str:
    """Get path for experiment log CSV based on number of questions."""
    return f"./results/grid_search_results_q{num_questions}.csv"


def generate_config_key(
    retrieval_mgr: RetrievalManager,
    prompt_mgr: PromptManager,
) -> str:
    """Generate unique config key for RRF-based retrieval."""
    return (
        f"RRF_k{retrieval_mgr.rrf_k}_"
        f"mu{retrieval_mgr.mu}_"
        f"k1{retrieval_mgr.k1}_b{retrieval_mgr.b}_"
        f"kdocs{retrieval_mgr.k_docs}_"
        f"kpass{retrieval_mgr.k_passages}_"
        f"win{retrieval_mgr.window}_ovl{retrieval_mgr.overlap}"
    )
    
    
def build_retrieval_manager(base: dict, override: dict) -> RetrievalManager:
    """Build RetrievalManager safely."""
    return RetrievalManager(**{**base, **override})


def save_results_to_csv(result: dict, key: str, path: str):
    """Save experiment results to a CSV file."""
    os.makedirs(os.path.dirname(path), exist_ok=True)

    row = {
        "config_key": key,
        "f1": result["f1_score"],
        "precision": result["precision"],
        "recall": result["recall"],
        "exact_match": result["exact_match"],
        "num_questions": result["num_questions"],
    }

    df = pd.DataFrame([row])
    if not os.path.exists(path):
        df.to_csv(path, index=False)
    else:
        df.to_csv(path, mode="a", header=False, index=False)


def load_completed_configs(path: str) -> set[str]:
    """Load set of completed experiment config keys from CSV."""
    if not os.path.exists(path):
        return set()
    return set(pd.read_csv(path)["config_key"])


### Print utils

In [57]:
def print_grid_results_table(grid: list[dict], *, num_questions: int):
    """Print results table for a given experiment grid."""
    pd.set_option("display.max_colwidth", None)

    path = get_experiment_log_path(num_questions)
    if not os.path.exists(path):
        print("No results file found.")
        return

    df = pd.read_csv(path)

    keys = [
        generate_config_key(g["retrieval_mgr"], g["prompt_mgr"])
        for g in grid
    ]

    grid_df = (
        df[df["config_key"].isin(keys)]
        .sort_values("f1", ascending=False)
        .reset_index(drop=True)
    )

    if grid_df.empty:
        print("No completed configs found for this grid.")
        return

    display_cols = [
        "config_key",
        "f1",
        "precision",
        "recall",
        "exact_match",
    ]

    print("\nGrid results (sorted by F1):")
    display(grid_df[display_cols])


### Best-config selector

In [58]:
def select_top_k_configs(
    retrieval_managers: list[RetrievalManager],
    prompt_managers: list[PromptManager],
    *,
    num_questions: int,
    top_k: int = 5,
):
    """Select top-k configurations from experiment results."""
    path = get_experiment_log_path(num_questions)
    df = pd.read_csv(path)

    scored_entries = []

    for r_mgr, p_mgr in zip(retrieval_managers, prompt_managers):
        key = generate_config_key(r_mgr, p_mgr)
        row = df[df["config_key"] == key]
        if row.empty:
            continue

        scored_entries.append({
            "retrieval_mgr": r_mgr,
            "prompt_mgr": p_mgr,
            "f1": float(row.iloc[0]["f1"]),
            "config_key": key,
        })

    scored_entries.sort(
        key=lambda x: (x["f1"], x["config_key"]),
        reverse=True,
    )

    return scored_entries[:top_k]


### Phase runner

In [59]:
def run_phase(
    *,
    phase_name: str,
    grid: list[dict],
    df_train: pd.DataFrame,
    num_questions: int,
    top_k: int | None = None,
):
    """
    Run a phase of experiments over a grid of configurations.
    Args:
        phase_name: Name of the experiment phase.
        grid: List of configurations (dicts with 'retrieval_mgr' and 'prompt_mgr').
        df_train: DataFrame with training questions and answers.
        num_questions: Number of questions to use per configuration.
        top_k: If specified, return only the top-k configurations after running.
    """
    print("\n" + "=" * 80)
    print(phase_name)
    print(f"Questions per config: {num_questions}")
    print("=" * 80)

    path = get_experiment_log_path(num_questions)

    validation_data = df_train.sample(
        n=num_questions,
        random_state=EXPERIMENT_SEED,
    ).reset_index(drop=True)

    completed = load_completed_configs(path)

    pending = [
        g for g in grid
        if generate_config_key(g["retrieval_mgr"], g["prompt_mgr"]) not in completed
    ]

    print(f"Total configs: {len(grid)}")
    print(f"Completed: {len(grid) - len(pending)}")
    print(f"Pending: {len(pending)}")
    print("-" * 80)

    for i, entry in enumerate(pending, start=1):
        r_mgr = entry["retrieval_mgr"]
        p_mgr = entry["prompt_mgr"]
        key = generate_config_key(r_mgr, p_mgr)

        print(f"[{i}/{len(pending)}] Running: {key}")

        result = run_experiment(
            name=key,
            df_data=validation_data,
            retrieval_manager=r_mgr,
            prompt_manager=p_mgr,
            verbose=True,
        )

        save_results_to_csv(result, key, path)
        print(f"✓ F1={result['f1_score']:.4f}")

    print_grid_results_table(grid, num_questions=num_questions)

    if top_k is None:
        return grid

    return select_top_k_configs(
        [g["retrieval_mgr"] for g in grid],
        [g["prompt_mgr"] for g in grid],
        num_questions=num_questions,
        top_k=top_k,
    )


### Phase 1 - Capacity (k_docs / k_passages)

In [60]:
PHASE_1_GRID = []

BASE_RETRIEVAL_PARAMS = {
    "window": 150,
    "overlap": 50,
    "mu": 1000,
    "k1": 0.9,
    "b": 0.4,
}

K_DOCS = [10, 50, 100, 250, 500]
K_PASSAGES = [5, 7, 10]

CAPACITY_PAIRS = [
    (k_docs, k_passages)
    for k_docs in K_DOCS
    for k_passages in K_PASSAGES
]

for k_docs, k_passages in CAPACITY_PAIRS:
    PHASE_1_GRID.append({
        "retrieval_mgr": RetrievalManager(
            k_docs=k_docs,
            k_passages=k_passages,
            **BASE_RETRIEVAL_PARAMS,
        ),
        "prompt_mgr": DEFAULT_PROMPT_MANAGER,
    })

PHASE_1_TOP_CONFIGS = run_phase(
    phase_name="PHASE 1 — Retrieval Capacity",
    grid=PHASE_1_GRID,
    df_train=df_train,
    num_questions=EXPERIMENTS_NUM_QUESTIONS,
    top_k=3,
)



PHASE 1 — Retrieval Capacity
Questions per config: 100
Total configs: 15
Completed: 15
Pending: 0
--------------------------------------------------------------------------------

Grid results (sorted by F1):


Unnamed: 0,config_key,f1,precision,recall,exact_match
0,RRF_k60_mu1000_k10.9_b0.4_kdocs250_kpass5_win150_ovl50,34.382749,38.434524,34.1,22.0
1,RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass10_win150_ovl50,33.289076,36.833333,32.933333,21.0
2,RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass5_win150_ovl50,32.163701,34.758766,32.933333,20.0
3,RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass7_win150_ovl50,31.639891,35.047727,32.6,17.0
4,RRF_k60_mu1000_k10.9_b0.4_kdocs50_kpass10_win150_ovl50,30.414461,34.841782,30.433333,17.0
5,RRF_k60_mu1000_k10.9_b0.4_kdocs250_kpass7_win150_ovl50,28.314412,32.123214,29.516667,15.0
6,RRF_k60_mu1000_k10.9_b0.4_kdocs100_kpass10_win150_ovl50,27.426702,31.623462,28.85,14.0
7,RRF_k60_mu1000_k10.9_b0.4_kdocs250_kpass10_win150_ovl50,26.462324,30.593452,26.742857,14.0
8,RRF_k60_mu1000_k10.9_b0.4_kdocs100_kpass7_win150_ovl50,26.452391,28.940591,28.35,14.0
9,RRF_k60_mu1000_k10.9_b0.4_kdocs100_kpass5_win150_ovl50,26.335129,29.825,26.85,14.0


### Phase 2 - Passage segmentation (window / overlap)

In [61]:
PHASE_2_GRID = []

WINDOW_OVERLAP_PAIRS = [
    (100, 30),
    (150, 50),
    (200, 50),
    (250, 60),
]

for entry in PHASE_1_TOP_CONFIGS:
    base = entry["retrieval_mgr"]
    for w, o in WINDOW_OVERLAP_PAIRS:
        PHASE_2_GRID.append({
            "retrieval_mgr": RetrievalManager(
                k_docs=base.k_docs,
                k_passages=base.k_passages,
                window=w,
                overlap=o,
                mu=base.mu,
                k1=base.k1,
                b=base.b,
            ),
            "prompt_mgr": DEFAULT_PROMPT_MANAGER,
        })

PHASE_2_TOP_CONFIGS = run_phase(
    phase_name="PHASE 2 — Passage Segmentation",
    grid=PHASE_2_GRID,
    df_train=df_train,
    num_questions=EXPERIMENTS_NUM_QUESTIONS,
    top_k=3,
)



PHASE 2 — Passage Segmentation
Questions per config: 100
Total configs: 12
Completed: 12
Pending: 0
--------------------------------------------------------------------------------

Grid results (sorted by F1):


Unnamed: 0,config_key,f1,precision,recall,exact_match
0,RRF_k60_mu1000_k10.9_b0.4_kdocs250_kpass5_win150_ovl50,34.382749,38.434524,34.1,22.0
1,RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass5_win250_ovl60,33.762698,37.403846,33.35,22.0
2,RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass5_win200_ovl50,33.656606,35.1625,34.766667,23.0
3,RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass10_win200_ovl50,33.610317,36.775,33.266667,24.0
4,RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass10_win150_ovl50,33.289076,36.833333,32.933333,21.0
5,RRF_k60_mu1000_k10.9_b0.4_kdocs250_kpass5_win200_ovl50,32.956606,34.829167,33.766667,22.0
6,RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass5_win150_ovl50,32.163701,34.758766,32.933333,20.0
7,RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass10_win250_ovl60,31.002525,33.622222,31.433333,19.0
8,RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass10_win100_ovl30,30.872619,33.45,32.266667,20.0
9,RRF_k60_mu1000_k10.9_b0.4_kdocs250_kpass5_win250_ovl60,27.988889,30.034799,28.766667,18.0


### Phase 3 - BM25 / QLD parameters

In [62]:
PHASE_3_GRID = []

BM25_PARAMS = [
    {"k1": 0.6, "b": 0.3},
    {"k1": 0.9, "b": 0.4},
    {"k1": 1.2, "b": 0.6},
]

QLD_PARAMS = [
    {"mu": 1000},
    {"mu": 2000},
]

for entry in PHASE_2_TOP_CONFIGS:
    base = entry["retrieval_mgr"]
    for bm25 in BM25_PARAMS:
        for qld in QLD_PARAMS:
            PHASE_3_GRID.append({
                "retrieval_mgr": RetrievalManager(
                    k_docs=base.k_docs,
                    k_passages=base.k_passages,
                    window=base.window,
                    overlap=base.overlap,
                    k1=bm25["k1"],
                    b=bm25["b"],
                    mu=qld["mu"],
                ),
                "prompt_mgr": DEFAULT_PROMPT_MANAGER,
            })

PHASE_3_TOP_CONFIGS = run_phase(
    phase_name="PHASE 3 — Lexical Hyperparameters",
    grid=PHASE_3_GRID,
    df_train=df_train,
    num_questions=EXPERIMENTS_NUM_QUESTIONS,
    top_k=None,
)



PHASE 3 — Lexical Hyperparameters
Questions per config: 100
Total configs: 18
Completed: 18
Pending: 0
--------------------------------------------------------------------------------

Grid results (sorted by F1):


Unnamed: 0,config_key,f1,precision,recall,exact_match
0,RRF_k60_mu1000_k11.2_b0.6_kdocs250_kpass5_win150_ovl50,36.616082,40.791667,36.766667,22.0
1,RRF_k60_mu2000_k11.2_b0.6_kdocs250_kpass5_win150_ovl50,34.949415,39.291667,34.766667,21.0
2,RRF_k60_mu2000_k11.2_b0.6_kdocs500_kpass5_win200_ovl50,34.847082,37.245833,34.766667,24.0
3,RRF_k60_mu1000_k10.9_b0.4_kdocs250_kpass5_win150_ovl50,34.382749,38.434524,34.1,22.0
4,RRF_k60_mu1000_k10.6_b0.3_kdocs500_kpass5_win200_ovl50,33.78994,35.829167,34.266667,23.0
5,RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass5_win250_ovl60,33.762698,37.403846,33.35,22.0
6,RRF_k60_mu1000_k11.2_b0.6_kdocs500_kpass5_win200_ovl50,33.680416,35.579167,34.266667,23.0
7,RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass5_win200_ovl50,33.656606,35.1625,34.766667,23.0
8,RRF_k60_mu2000_k10.6_b0.3_kdocs500_kpass5_win200_ovl50,33.48994,35.329167,34.1,24.0
9,RRF_k60_mu1000_k11.2_b0.6_kdocs500_kpass5_win250_ovl60,33.096032,36.103846,33.6,22.0


### Phase 4 - Best configs comparison over more questions

In [63]:
PHASE_4_F1_RATIO = 0.90          # promote configs within 90% of best
PHASE_4_MAX = 10                 # safety cap
PHASE_4_NUM_QUESTIONS = 300

experiments_path = get_experiment_log_path(
    num_questions=EXPERIMENTS_NUM_QUESTIONS
)
experiments_df = pd.read_csv(experiments_path)

assert not experiments_df.empty, "No results found for experiments"

# Sort globally by F1
experiments_df = experiments_df.sort_values(
    "f1", ascending=False
).reset_index(drop=True)

best_f1 = experiments_df.iloc[0]["f1"]
threshold_f1 = best_f1 * PHASE_4_F1_RATIO

print(f"Best F1 ({EXPERIMENTS_NUM_QUESTIONS}q): {best_f1:.4f}")
print(f"Promotion threshold: {threshold_f1:.4f} "
      f"({PHASE_4_F1_RATIO:.0%} of best)")

# Promote configs close by ratio
promoted_df = (
    experiments_df[experiments_df["f1"] >= threshold_f1]
    .head(PHASE_4_MAX)
    .reset_index(drop=True)
)

print(f"\nPromoted configs: {len(promoted_df)}")
for i, row in promoted_df.iterrows():
    print(
        f"{i+1}. {row['config_key']} | "
        f"F1={row['f1']:.4f}"
    )


Best F1 (100q): 36.6161
Promotion threshold: 32.9545 (90% of best)

Promoted configs: 10
1. RRF_k60_mu1000_k11.2_b0.6_kdocs250_kpass5_win150_ovl50 | F1=36.6161
2. RRF_k60_mu2000_k11.2_b0.6_kdocs250_kpass5_win150_ovl50 | F1=34.9494
3. RRF_k60_mu2000_k11.2_b0.6_kdocs500_kpass5_win200_ovl50 | F1=34.8471
4. RRF_k60_mu1000_k10.9_b0.4_kdocs250_kpass5_win150_ovl50 | F1=34.3827
5. RRF_k60_mu1000_k10.6_b0.3_kdocs500_kpass5_win200_ovl50 | F1=33.7899
6. RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass5_win250_ovl60 | F1=33.7627
7. RRF_k60_mu1000_k11.2_b0.6_kdocs500_kpass5_win200_ovl50 | F1=33.6804
8. RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass5_win200_ovl50 | F1=33.6566
9. RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass10_win200_ovl50 | F1=33.6103
10. RRF_k60_mu2000_k10.6_b0.3_kdocs500_kpass5_win200_ovl50 | F1=33.4899


In [64]:
# Collect all unique configs by config_key
ALL_CONFIGS = {}

for entry in PHASE_1_GRID + PHASE_2_GRID + PHASE_3_GRID:
    key = generate_config_key(
        entry["retrieval_mgr"],
        entry["prompt_mgr"],
    )
    ALL_CONFIGS[key] = entry   # dedupe by key

# Build Phase 4 grid directly from promoted config_keys
PHASE_4_GRID = [
    {
        "retrieval_mgr": ALL_CONFIGS[key]["retrieval_mgr"],
        "prompt_mgr": DEFAULT_PROMPT_MANAGER,
    }
    for key in promoted_df["config_key"].values
    if key in ALL_CONFIGS
]

print(f"\n✓ Phase 4 grid size: {len(PHASE_4_GRID)}")

PHASE_4_RESULTS = run_phase(
    phase_name="PHASE 4 — Stabilized Comparison",
    grid=PHASE_4_GRID,
    df_train=df_train,
    num_questions=PHASE_4_NUM_QUESTIONS,
    top_k=None,
)



✓ Phase 4 grid size: 10

PHASE 4 — Stabilized Comparison
Questions per config: 300
Total configs: 10
Completed: 10
Pending: 0
--------------------------------------------------------------------------------

Grid results (sorted by F1):


Unnamed: 0,config_key,f1,precision,recall,exact_match
0,RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass10_win200_ovl50,32.857235,35.653101,33.657828,21.666667
1,RRF_k60_mu1000_k10.6_b0.3_kdocs500_kpass5_win200_ovl50,30.748316,32.406665,32.057828,20.0
2,RRF_k60_mu2000_k10.6_b0.3_kdocs500_kpass5_win200_ovl50,30.641472,32.831548,32.168939,19.666667
3,RRF_k60_mu2000_k11.2_b0.6_kdocs500_kpass5_win200_ovl50,30.493208,32.539998,31.653066,19.333333
4,RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass5_win200_ovl50,30.266834,32.039998,31.668939,19.333333
5,RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass5_win250_ovl60,30.024503,32.795068,30.991162,19.0
6,RRF_k60_mu1000_k11.2_b0.6_kdocs500_kpass5_win200_ovl50,29.979533,32.095554,31.224495,18.666667
7,RRF_k60_mu1000_k11.2_b0.6_kdocs250_kpass5_win150_ovl50,28.79106,31.315219,29.886111,17.666667
8,RRF_k60_mu2000_k11.2_b0.6_kdocs250_kpass5_win150_ovl50,28.410799,31.099875,29.497222,16.333333
9,RRF_k60_mu1000_k10.9_b0.4_kdocs250_kpass5_win150_ovl50,26.948318,29.322684,28.163889,16.0


### Phase 5 - Final comparison between 2 best configs, over 1000 questions

In [65]:
PHASE_5_NUM_QUESTIONS = 1000
PHASE_5_TOP_K = 3

phase_4_path = get_experiment_log_path(num_questions=PHASE_4_NUM_QUESTIONS)
df_phase_4 = pd.read_csv(phase_4_path)

assert not df_phase_4.empty, "No Phase 4 results found"

df_phase_4 = df_phase_4.sort_values("f1", ascending=False).reset_index(drop=True)

top_3_keys = df_phase_4.head(PHASE_5_TOP_K)["config_key"].tolist()

print("Top 3 configs from Phase 4:")
for i, row in df_phase_4.head(PHASE_5_TOP_K).iterrows():
    print(f"{i+1}. {row['config_key']} | F1={row['f1']:.4f}")


PHASE_5_GRID = [
    {
        "retrieval_mgr": e["retrieval_mgr"],
        "prompt_mgr": DEFAULT_PROMPT_MANAGER,
    }
    for e in PHASE_4_GRID
    if generate_config_key(e["retrieval_mgr"], e["prompt_mgr"]) in top_3_keys
]

print(f"\n✓ Phase 5 grid size: {len(PHASE_5_GRID)}")

PHASE_5_RESULTS = run_phase(
    phase_name="PHASE 5 — Final Tie-Breaker",
    grid=PHASE_5_GRID,
    df_train=df_train,
    num_questions=PHASE_5_NUM_QUESTIONS,
    top_k=None,
)


Top 3 configs from Phase 4:
1. RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass10_win200_ovl50 | F1=32.8572
2. RRF_k60_mu1000_k10.6_b0.3_kdocs500_kpass5_win200_ovl50 | F1=30.7483
3. RRF_k60_mu2000_k10.6_b0.3_kdocs500_kpass5_win200_ovl50 | F1=30.6415

✓ Phase 5 grid size: 3

PHASE 5 — Final Tie-Breaker
Questions per config: 1000
Total configs: 3
Completed: 3
Pending: 0
--------------------------------------------------------------------------------

Grid results (sorted by F1):


Unnamed: 0,config_key,f1,precision,recall,exact_match
0,RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass10_win200_ovl50,32.784387,35.456917,33.412422,21.9
1,RRF_k60_mu1000_k10.6_b0.3_kdocs500_kpass5_win200_ovl50,31.558194,33.554958,32.770563,20.1
2,RRF_k60_mu2000_k10.6_b0.3_kdocs500_kpass5_win200_ovl50,31.431037,33.717853,32.668896,19.9


## Kaggle Submission

In [None]:
phase_5_path = get_experiment_log_path(num_questions=PHASE_5_NUM_QUESTIONS)
df_phase_5 = pd.read_csv(phase_5_path)

assert not df_phase_5.empty, "No Phase 5 results found"

df_phase_5 = df_phase_5.sort_values("f1", ascending=False).reset_index(drop=True)

best_row = df_phase_5.iloc[0]
BEST_CONFIG_KEY = best_row["config_key"]

print("Selected best config from Phase 5:")
print(f"{BEST_CONFIG_KEY} | F1={best_row['f1']:.4f}")

best_entry = next(
    e for e in PHASE_5_GRID
    if generate_config_key(e["retrieval_mgr"], e["prompt_mgr"]) == BEST_CONFIG_KEY
)

BEST_RETRIEVAL_MGR = best_entry["retrieval_mgr"]
BEST_PROMPT_MGR = best_entry["prompt_mgr"]

print("\nGenerating answers for test set...")

test_questions = df_test["question"].tolist()
test_ids = df_test["id"].tolist()

predictions = []

for question in tqdm(test_questions, total=len(test_questions), desc="Test Questions"):
    contexts = BEST_RETRIEVAL_MGR.retrieve_context(question)
    answer = BEST_PROMPT_MGR.generate_answer(question, contexts)
    predictions.append(answer)

submission_df = pd.DataFrame({
    "id": test_ids,
    "prediction": predictions,
})

# Required serialization for Kaggle evaluation
submission_df["prediction"] = submission_df["prediction"].apply(
    lambda x: json.dumps([x], ensure_ascii=False)
)

SUBMISSION_PATH = "./results/kaggle_submission.csv"
submission_df.to_csv(SUBMISSION_PATH, index=False)

print("=" * 80)
print(f"✓ Kaggle submission saved to: {SUBMISSION_PATH}")
print(f"✓ Total rows: {len(submission_df)}")
print("=" * 80)


Selected best config from Phase 5:
RRF_k60_mu1000_k10.9_b0.4_kdocs500_kpass10_win200_ovl50 | F1=32.7844

Generating answers for test set...


Test Questions:   100%|██████████| 2032/2032 [2:07:40<00:00,  3.77s/it]
✓ Kaggle submission saved to: ./results/kaggle_submission.csv
✓ Total rows: 2032