## 1. Setup & Dependencies

In [1]:
import pandas as pd
import json
import re
import string
from collections import Counter
from typing import List, Dict, Tuple, Optional, Literal
from dataclasses import dataclass
from tqdm.auto import tqdm
import itertools
tqdm.pandas()
import warnings
warnings.filterwarnings('ignore')

print("✓ Dependencies imported")

✓ Dependencies imported


  from .autonotebook import tqdm as notebook_tqdm


### Install Required Packages

In [2]:
# !pip install torch torchvision torchaudio
# !pip install pyserini==0.36.0
# !pip install accelerate
# !pip install transformers
# !pip install tqdm
# !pip install python-dotenv

In [3]:
import os

os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-21-openjdk-amd64"
os.environ["PATH"] = os.environ["JAVA_HOME"] + "/bin:" + os.environ["PATH"]

!java -version

openjdk version "21.0.9" 2025-10-21
OpenJDK Runtime Environment (build 21.0.9+10-Ubuntu-122.04)
OpenJDK 64-Bit Server VM (build 21.0.9+10-Ubuntu-122.04, mixed mode, sharing)


In [4]:
# !pip install torch torchvision torchaudio
# !pip install pyserini==0.36.0
# !pip install accelerate
# !pip install transformers
# !pip install tqdm
# !pip install python-dotenv

### Hugging Face Authentication

In [5]:
from huggingface_hub import login
from dotenv import load_dotenv
import os

load_dotenv()

login(os.getenv('HUGGING_FACE_TOKEN'))
print("✓ Logged into Hugging Face")

✓ Logged into Hugging Face


## 2. Data Loading & Preparation

In [6]:
# Load datasets
df_train = pd.read_csv("./data/train.csv", converters={"answers": json.loads})
df_test = pd.read_csv("./data/test.csv")

print(f"Train set: {len(df_train)} questions")
print(f"Test set: {len(df_test)} questions")
print(f"\nSample question: {df_train.iloc[0]['question']}")
print(f"Sample answers: {df_train.iloc[0]['answers']}")

Train set: 3778 questions
Test set: 2032 questions

Sample question: what is the name of justin bieber brother?
Sample answers: ['Jazmyn Bieber', 'Jaxon Bieber']


## 3. Retrieval Functions

In [7]:
from pyserini.search import SimpleSearcher
from pyserini.index.lucene import IndexReader

# Load Pyserini index
print("Loading Pyserini index...")
searcher = SimpleSearcher.from_prebuilt_index('wikipedia-kilt-doc')
index_reader = IndexReader.from_prebuilt_index('wikipedia-kilt-doc')

print(f"✓ Index loaded: {index_reader.stats()['documents']} documents")

[0;93m2025-12-14 20:29:53.657930216 [W:onnxruntime:Default, device_discovery.cc:164 DiscoverDevicesForPlatform] GPU device discovery failed: device_discovery.cc:89 ReadFileContents Failed to open file: "/sys/class/drm/card0/device/vendor"[m


Loading Pyserini index...


Dec 14, 2025 8:29:54 PM org.apache.lucene.store.MemorySegmentIndexInputProvider <init>
INFO: Using MemorySegmentIndexInput with Java 21; to disable start with -Dorg.apache.lucene.store.MMapDirectory.enableMemorySegments=false


SimpleSearcher class has been deprecated, please use LuceneSearcher from pyserini.search.lucene instead
✓ Index loaded: 5903530 documents


In [8]:
from sentence_transformers import SentenceTransformer
import torch

# Load bi-encoder
print("Loading bi-encoder...")
device = "cuda" if torch.cuda.is_available() else "cpu"
bi_encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=device)

print("✓ Bi-encoder loaded")

Loading bi-encoder...
✓ Bi-encoder loaded


In [9]:
from dataclasses import dataclass
from typing import List, Dict
from functools import lru_cache
import json
import torch
from sentence_transformers import util


# --------------------------------------------------
# Cached document access
# --------------------------------------------------

@lru_cache(maxsize=1000)
def get_doc_content(docid: str) -> str:
    """Return cached raw document text."""
    try:
        doc = searcher.doc(docid)
        return json.loads(doc.raw()).get("contents", "").replace("\n", " ")
    except Exception:
        return ""


# --------------------------------------------------
# Retrieval Manager
# --------------------------------------------------

@dataclass
class RetrievalManager:
    """
    Hybrid retrieval with:
      - QLD document retrieval
      - BM25 document retrieval
      - Passage extraction
      - Dense passage reranking
      - 3-way RRF on passages
    """
    k_docs: int = 10
    k_passages: int = 5
    rrf_k: int = 60

    mu: int = 1000
    k1: float = 0.9
    b: float = 0.4

    window: int = 150
    overlap: int = 50
    min_passage_words: int = 30
    max_dense_passages: int = 100

    def __str__(self):
        return (
            f"Retrieval(RRF_k={self.rrf_k}, μ={self.mu}, "
            f"k1={self.k1}, b={self.b}) | "
            f"k_docs={self.k_docs}, k_passages={self.k_passages} | "
            f"window={self.window}, overlap={self.overlap}"
        )

    # --------------------------------------------------
    # Passage extraction
    # --------------------------------------------------

    def extract_passages(self, text: str) -> List[str]:
        """Split document text into overlapping word windows."""
        if not text:
            return []

        words = text.split()
        if len(words) < self.min_passage_words:
            return []

        step = max(1, self.window - self.overlap)
        passages = []

        for i in range(0, len(words), step):
            chunk = words[i:i + self.window]
            if len(chunk) < self.min_passage_words:
                break
            passages.append(" ".join(chunk))

        return passages

    # --------------------------------------------------
    # Dense ranking
    # --------------------------------------------------

    def dense_rank(self, query: str, passages: List[str]) -> List[str]:
        """Rank passages by bi-encoder cosine similarity."""
        if not passages:
            return []

        with torch.no_grad():
            q_emb = bi_encoder.encode(
                query,
                convert_to_tensor=True,
                device=device
            )
            p_embs = bi_encoder.encode(
                passages,
                convert_to_tensor=True,
                device=device
            )

            scores = util.cos_sim(q_emb, p_embs).squeeze(0)
            order = torch.argsort(scores, descending=True).tolist()

        return [passages[i] for i in order]

    # --------------------------------------------------
    # Main retrieval
    # --------------------------------------------------

    def retrieve_context(self, query: str) -> List[str]:
        """Return top passages using 3-way RRF (QLD, BM25, Dense)."""

        # ---------- Lexical retrieval ----------
        searcher.set_qld(self.mu)
        qld_docids = [h.docid for h in searcher.search(query, self.k_docs)]

        searcher.set_bm25(self.k1, self.b)
        bm25_docids = [h.docid for h in searcher.search(query, self.k_docs)]

        # ---------- Extract passages per document ----------
        doc_passages: Dict[str, List[str]] = {}
        all_passages: List[str] = []

        for docid in dict.fromkeys(qld_docids + bm25_docids):
            content = get_doc_content(docid)
            if not content:
                continue

            passages = self.extract_passages(content)
            if passages:
                doc_passages[docid] = passages
                all_passages.extend(passages)

        # Deduplicate passages
        all_passages = list(dict.fromkeys(all_passages))

        # ---------- Dense reranking (limited) ----------
        dense_candidates = all_passages[:self.max_dense_passages]
        dense_ranked = self.dense_rank(query, dense_candidates)

        # ---------- RRF scoring on passages ----------
        scores: Dict[str, float] = {}

        # QLD → passages
        for rank, docid in enumerate(qld_docids):
            for p in doc_passages.get(docid, []):
                scores[p] = scores.get(p, 0.0) + 1.0 / (self.rrf_k + rank + 1)

        # BM25 → passages
        for rank, docid in enumerate(bm25_docids):
            for p in doc_passages.get(docid, []):
                scores[p] = scores.get(p, 0.0) + 1.0 / (self.rrf_k + rank + 1)

        # Dense → passages
        for rank, p in enumerate(dense_ranked):
            scores[p] = scores.get(p, 0.0) + 1.0 / (self.rrf_k + rank + 1)

        # ---------- Final ranking ----------
        ranked_passages = sorted(
            scores,
            key=scores.get,
            reverse=True
        )[:self.k_passages]

        return ranked_passages


# --------------------------------------------------
# Quick sanity check
# --------------------------------------------------

query = "Who wrote Harry Potter?"

rm = RetrievalManager()
print(rm)

passages = rm.retrieve_context(query)
for i, p in enumerate(passages, 1):
    print(f"{i}. {p[:100]}...")


Retrieval(RRF_k=60, μ=1000, k1=0.9, b=0.4) | k_docs=10, k_passages=5 | window=150, overlap=50
1. Harry Potter Harry Potter is a series of fantasy novels written by British author J. K. Rowling. The...
2. nine to eleven. On the eve of publishing, Rowling was asked by her publishers to adopt a more gender...
3. time, justified"), while "The Guardian" called it "a richly textured novel given lift-off by an inve...
4. suddenly "fell into her head". Rowling gives an account of the experience on her website saying: Row...
5. 1999 Whitbread Awards. His overall view of the series was negative – "the Potter saga was essentiall...


## 4. LLM Generation

In [10]:
import transformers
import torch
import logging

# Suppress transformers warnings
transformers.logging.set_verbosity_error()
logging.getLogger("transformers").setLevel(logging.ERROR)

print("Loading LLM model...")
model_id = "meta-llama/Llama-3.2-1B-Instruct"

pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.float16},
    device=0
)

terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

# Set pad_token for batch processing
pipeline.tokenizer.pad_token = pipeline.tokenizer.eos_token

print(f"✓ Model loaded on: {'GPU' if torch.cuda.is_available() else 'CPU'}")

Loading LLM model...
✓ Model loaded on: GPU


In [None]:
SYSTEM_PROMPT = (
    "Answer using the provided documents."
    "Prefer short factual answers."
    "If the answer cannot be reasonably inferred from the documents, return 'unknown'."
    "Do not explain your reasoning."
)

USER_PROMPT = (
    "Answer the question using ONLY a short phrase or named entity found verbatim in the documents.\n"
    "If the answer is not explicitly stated, output: unknown.\n\n"
    "{context}\n\n"
    "Question: {question}\n"
    "Answer:"
)

@dataclass
class PromptManager:
    """Manages prompt generation and LLM answer generation."""
    system_prompt: str = SYSTEM_PROMPT
    user_prompt: str = USER_PROMPT
    temperature: float = 0.0
    top_p: float = 1.0
    max_new_tokens: int = 256
    do_sample: bool = False
    prompt_id: str = "default"  # For later use in prompt tuning
    
    def __str__(self):
        return f"temp={self.temperature}, top_p={self.top_p}, max_tokens={self.max_new_tokens}"

    @staticmethod
    def clean_answer(answer: str) -> str:
        """Clean and standardize the generated answer."""
        answer = re.sub(r'^(Answer|The answer is|Based on the .*?,):?\s*', '', answer, flags=re.I)
        answer = answer.rstrip('.')
        if any(phrase in answer.lower() for phrase in ["dont know", "don't know", "do not know", "unknown"]):
            return "unknown"
        return answer.strip()

    def create_messages(self, question: str, contexts: List[str]) -> List[Dict]:
        """Create messages for the LLM based on the question and contexts."""
        if not contexts:
            context_str = "No relevant documents found."
        else:
            context_str = '\n\n'.join([f"Document {i+1}: {ctx}" for i, ctx in enumerate(contexts)])
        
        return [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": self.user_prompt.format(context=context_str, question=question)}
        ]

    def generate_answer(self, question: str, contexts: List[str]) -> str:
        """Generate an answer using the LLM based on the question and contexts."""
        messages = self.create_messages(question, contexts)
        
        outputs = pipeline(
            messages,
            max_new_tokens=self.max_new_tokens,
            eos_token_id=terminators,
            do_sample=self.do_sample,
            temperature=self.temperature,
            top_p=self.top_p
        )
        
        answer = outputs[0]["generated_text"][-1].get('content', '')
        return self.clean_answer(answer)

    def batch_generate_answers(self, questions: List[str], contexts_list: List[List[str]]) -> List[str]:
        """Generate answers for multiple questions in batch."""
        # Create messages for all questions
        batch_messages = [self.create_messages(q, ctx) for q, ctx in zip(questions, contexts_list)]
        
        # Process batch through pipeline
        outputs = pipeline(
            batch_messages,
            max_new_tokens=self.max_new_tokens,
            eos_token_id=terminators,
            do_sample=self.do_sample,
            temperature=self.temperature,
            top_p=self.top_p,
            batch_size=len(questions)
        )
        
        # Extract and clean answers
        answers = []
        for output in outputs:
            answer = output[0]["generated_text"][-1].get('content', '')
            answers.append(self.clean_answer(answer))
        
        return answers


# Test the PromptManager
test_prompt_manager = PromptManager()
print(f"Testing: {test_prompt_manager}")
test_answer = test_prompt_manager.generate_answer(query, passages)
print(f"✓ Generated answer: '{test_answer}'")

Testing: temp=0.0, top_p=1.0, max_tokens=256
✓ Generated answer: 'J. K. Rowling'


## 5. Evaluation Metrics

In [12]:
def normalize_answer(s: str) -> str:
    """Normalize answer for comparison"""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)
    
    def white_space_fix(text):
        return ' '.join(text.split())
    
    def remove_punc(text):
        return ''.join(ch for ch in text if ch not in set(string.punctuation))
    
    def lower(text):
        return text.lower()
    
    return white_space_fix(remove_articles(remove_punc(lower(s))))


def compute_token_metrics(prediction: str, ground_truth: str) -> Tuple[float, float, float]:
    """
    Compute precision, recall, and F1 score for token-level comparison.
    Returns: (precision, recall, f1)
    """
    pred_tokens = normalize_answer(prediction).split()
    gt_tokens = normalize_answer(ground_truth).split()
    
    # Handle empty cases
    if len(pred_tokens) == 0 or len(gt_tokens) == 0:
        match = int(pred_tokens == gt_tokens)
        return match, match, match
    
    # Compute overlap
    common = Counter(pred_tokens) & Counter(gt_tokens)
    num_same = sum(common.values())
    
    if num_same == 0:
        return 0.0, 0.0, 0.0
    
    precision = num_same / len(pred_tokens)
    recall = num_same / len(gt_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    
    return precision, recall, f1


def evaluate_predictions(df_gold: pd.DataFrame, predictions: Dict[int, str]) -> Dict:
    """Evaluate predictions against ground truth."""
    f1_scores = []
    precision_scores = []
    recall_scores = []
    exact_matches = []
    
    for _, row in df_gold.iterrows():
        qid = row['id']
        
        # Handle missing predictions
        if qid not in predictions:
            f1_scores.append(0.0)
            precision_scores.append(0.0)
            recall_scores.append(0.0)
            exact_matches.append(0)
            continue
        
        prediction = predictions[qid]
        ground_truths = row['answers']
        
        # Normalize once
        norm_prediction = normalize_answer(prediction)
        
        # Find best match across all ground truths
        best_f1 = 0.0
        best_precision = 0.0
        best_recall = 0.0
        is_exact = 0
        
        for gt in ground_truths:
            norm_gt = normalize_answer(gt)
            
            # Compute metrics
            prec, rec, f1 = compute_token_metrics(prediction, gt)
            
            # Track best scores
            if f1 > best_f1:
                best_f1 = f1
                best_precision = prec
                best_recall = rec
            
            # Check exact match
            if norm_prediction == norm_gt:
                is_exact = 1
        
        f1_scores.append(best_f1)
        precision_scores.append(best_precision)
        recall_scores.append(best_recall)
        exact_matches.append(is_exact)
    
    return {
        'f1': 100.0 * sum(f1_scores) / len(f1_scores) if f1_scores else 0.0,
        'precision': 100.0 * sum(precision_scores) / len(precision_scores) if precision_scores else 0.0,
        'recall': 100.0 * sum(recall_scores) / len(recall_scores) if recall_scores else 0.0,
        'exact_match': 100.0 * sum(exact_matches) / len(exact_matches) if exact_matches else 0.0,
        'f1_scores': f1_scores,
        'precision_scores': precision_scores,
        'recall_scores': recall_scores,
        'exact_matches': exact_matches
    }


# Test evaluation
test_predictions = {1: "J.K. Rowling", 2: "Paris", 3: "Shakespeare"}
test_gold = pd.DataFrame({
    'id': [1, 2, 3],
    'answers': [["J.K. Rowling", "Rowling"], ["Earth"], ["William Shakespeare", "Shakespeare"]]
})

test_metrics = evaluate_predictions(test_gold, test_predictions)
print(f"✓ Evaluation test: F1={test_metrics['f1']:.2f}, P={test_metrics['precision']:.2f}, R={test_metrics['recall']:.2f}, EM={test_metrics['exact_match']:.2f}")

✓ Evaluation test: F1=66.67, P=66.67, R=66.67, EM=66.67


## 6. Experiment Framework

In [21]:
def run_experiment(
    name: str,
    df_data: pd.DataFrame,
    retrieval_manager: RetrievalManager,
    prompt_manager: PromptManager,
    max_questions: Optional[int] = None,
    verbose: bool = True
) -> Dict:
    if max_questions:
        df_data = df_data.head(max_questions)

    predictions = {}

    iterator = tqdm(df_data.iterrows(), total=len(df_data), desc=name) if verbose else df_data.iterrows()

    for _, row in iterator:
        question = row['question']
        qid = row['id']

        contexts = retrieval_manager.retrieve_context(question)
        answer = prompt_manager.generate_answer(question, contexts)

        predictions[qid] = answer

    metrics = evaluate_predictions(df_data, predictions)

    result = {
        'name': name,
        'retrieval': retrieval_manager,
        'prompt': prompt_manager,
        'f1_score': metrics['f1'],
        'precision': metrics['precision'],
        'recall': metrics['recall'],
        'exact_match': metrics['exact_match'],
        'num_questions': len(df_data),
        'predictions': predictions,
        'f1_scores': metrics['f1_scores'],
        'precision_scores': metrics['precision_scores'],
        'recall_scores': metrics['recall_scores'],
        'exact_matches': metrics['exact_matches']
    }

    if verbose:
        print(f"\n{name}")
        print(f"   Retrieval: {retrieval_manager}")
        print(f"   Prompt: {prompt_manager}")
        print(
            f"   F1={metrics['f1']:.2f} | "
            f"P={metrics['precision']:.2f} | "
            f"R={metrics['recall']:.2f} | "
            f"EM={metrics['exact_match']:.2f}"
        )
        print(f"   Questions: {len(df_data)}\n")

    return result


# Test experiment
test_retrieval = RetrievalManager(
    k_docs=20,
    k_passages=7
)
test_prompt = PromptManager()
print(f"Testing experiment with:")
print(f"  Retrieval: {test_retrieval}")
print(f"  Prompt: {test_prompt}")

test_exp = run_experiment(
    "Quick Test",
    df_train.head(100),
    test_retrieval,
    test_prompt,
    verbose=True
)

print(f"✓ Experiment framework ready")

Testing experiment with:
  Retrieval: Retrieval(RRF_k=60, μ=1000, k1=0.9, b=0.4) | k_docs=20, k_passages=7 | window=150, overlap=50
  Prompt: temp=0.0, top_p=1.0, max_tokens=256


Quick Test: 100%|██████████| 100/100 [01:07<00:00,  1.48it/s]


Quick Test
   Retrieval: Retrieval(RRF_k=60, μ=1000, k1=0.9, b=0.4) | k_docs=20, k_passages=7 | window=150, overlap=50
   Prompt: temp=0.0, top_p=1.0, max_tokens=256
   F1=25.69 | P=31.19 | R=26.47 | EM=13.00
   Questions: 100

✓ Experiment framework ready





## 7. Experiments

In [14]:
test_exp

{'name': 'Quick Test',
 'retrieval': RetrievalManager(k_docs=20, k_passages=7, rrf_k=60, mu=1000, k1=0.9, b=0.4, window=150, overlap=50, min_passage_words=30, max_dense_passages=100),
 'prompt': PromptManager(system_prompt="Answer using the provided documents.Prefer short factual answers.If the answer cannot be reasonably inferred from the documents, return 'unknown'.Do not explain your reasoning.", user_prompt='Answer the question using ONLY a short phrase or named entity found verbatim in the documents.\nIf the answer is not explicitly stated, output: unknown.\n\n{context}\n\nQuestion: {question}\nAnswer:', temperature=0.0, top_p=1.0, max_new_tokens=256, do_sample=False, prompt_id='default'),
 'f1_score': 34.885714285714286,
 'precision': 40.99999999999999,
 'recall': 35.733333333333334,
 'exact_match': 16.0,
 'num_questions': 25,
 'predictions': {1: 'unknown',
  2: 'Padmé',
  3: 'Texas',
  4: 'unknown',
  5: 'unknown',
  6: 'Gimli',
  7: 'Memphis Grizzlies',
  8: 'Washington',
  9: 

### Experiments global config

In [15]:
EXPERIMENT_SEED = 42
EXPERIMENT_QUESTIONS = 100

EXPERIMENT_LOG_PATH = (
    f"./results/grid_search_results_q{EXPERIMENT_QUESTIONS}.csv"
)

validation_data = df_train.sample(
    n=EXPERIMENT_QUESTIONS,
    random_state=EXPERIMENT_SEED
).reset_index(drop=True)

print("=" * 80)
print("PHASED RETRIEVAL + GENERATION EXPERIMENT FRAMEWORK")
print("=" * 80)
print(f"Validation questions per config: {EXPERIMENT_QUESTIONS}")
print(f"Random seed: {EXPERIMENT_SEED}")
print(f"Results cache: {EXPERIMENT_LOG_PATH}")
print("=" * 80)


PHASED RETRIEVAL + GENERATION EXPERIMENT FRAMEWORK
Validation questions per config: 100
Random seed: 42
Results cache: ./results/grid_search_results_q100.csv


### Experiments utils

In [16]:
def build_retrieval_manager(base: dict, override: dict) -> RetrievalManager:
    """Build RetrievalManager safely."""
    return RetrievalManager(**{**base, **override})


def generate_config_key(
    retrieval_mgr: RetrievalManager,
    prompt_mgr: PromptManager,
) -> str:
    """Generate unique config key for RRF-based retrieval."""
    return (
        f"RRF_k{retrieval_mgr.rrf_k}_"
        f"mu{retrieval_mgr.mu}_"
        f"k1{retrieval_mgr.k1}_b{retrieval_mgr.b}_"
        f"kdocs{retrieval_mgr.k_docs}_"
        f"kpass{retrieval_mgr.k_passages}_"
        f"win{retrieval_mgr.window}_ovl{retrieval_mgr.overlap}_"
        f"prompt{prompt_mgr.prompt_id}"
    )


def save_results_to_csv(result: dict, key: str, path: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)

    row = {
        "config_key": key,
        "f1": result["f1_score"],
        "precision": result["precision"],
        "recall": result["recall"],
        "exact_match": result["exact_match"],
        "num_questions": result["num_questions"],
    }

    df = pd.DataFrame([row])
    if not os.path.exists(path):
        df.to_csv(path, index=False)
    else:
        df.to_csv(path, mode="a", header=False, index=False)


def load_completed_configs(path: str) -> set[str]:
    if not os.path.exists(path):
        return set()
    return set(pd.read_csv(path)["config_key"])


### Best-config selector

In [17]:
def select_top_k_configs(
    retrieval_managers: list[RetrievalManager],
    prompt_managers: list[PromptManager],
    *,
    top_k: int = 5,
):
    """
    Return the top-K configurations by validation F1 score.
    Always sorted by descending F1.
    """
    df = pd.read_csv(EXPERIMENT_LOG_PATH)

    scored_entries = []

    for r_mgr, p_mgr in zip(retrieval_managers, prompt_managers):
        key = generate_config_key(r_mgr, p_mgr)
        row = df[df["config_key"] == key]
        if row.empty:
            continue

        scored_entries.append({
            "retrieval_mgr": r_mgr,
            "prompt_mgr": p_mgr,
            "f1": float(row.iloc[0]["f1"]),
            "config_key": key,
        })

    # Sort before slicing
    scored_entries.sort(
        key=lambda x: (x["f1"], x["config_key"]),
        reverse=True,
    )

    return scored_entries[:top_k]


### Phase runner

In [18]:
def run_phase(
    *,
    phase_name: str,
    grid: list[dict],
    validation_data,
    top_k: int | None = None,
):
    """
    Run a single experiment phase.

    Each grid item must contain:
      - retrieval_mgr: RetrievalManager
      - prompt_mgr: PromptManager

    Returns:
      - top-K configs sorted by F1 (if top_k is provided)
      - otherwise, the full grid sorted by F1
    """
    print("\n" + "=" * 80)
    print(phase_name)
    print("=" * 80)

    completed = load_completed_configs(EXPERIMENT_LOG_PATH)

    pending = [
        g for g in grid
        if generate_config_key(g["retrieval_mgr"], g["prompt_mgr"]) not in completed
    ]

    print(f"Total configs: {len(grid)}")
    print(f"Completed configs: {len(grid) - len(pending)}")
    print(f"Pending configs: {len(pending)}")
    print("-" * 80)

    for i, entry in enumerate(pending, start=1):
        retrieval_mgr = entry["retrieval_mgr"]
        prompt_mgr = entry["prompt_mgr"]

        key = generate_config_key(retrieval_mgr, prompt_mgr)
        print(f"[{i}/{len(pending)}] Running: {key}")

        result = run_experiment(
            name=key,
            df_data=validation_data,
            retrieval_manager=retrieval_mgr,
            prompt_manager=prompt_mgr,
            verbose=True,
        )

        save_results_to_csv(result, key, EXPERIMENT_LOG_PATH)
        print(f"✓ F1={result['f1_score']:.4f}")

    # Load results once for consistent sorting
    df = pd.read_csv(EXPERIMENT_LOG_PATH)

    def get_f1(entry):
        key = generate_config_key(entry["retrieval_mgr"], entry["prompt_mgr"])
        row = df[df["config_key"] == key]
        return float(row.iloc[0]["f1"]) if not row.empty else -1.0

    # Sort full grid by F1
    sorted_grid = sorted(
        grid,
        key=lambda g: (get_f1(g), generate_config_key(g["retrieval_mgr"], g["prompt_mgr"])),
        reverse=True,
    )

    if top_k is None:
        return sorted_grid

    top_configs = select_top_k_configs(
        [g["retrieval_mgr"] for g in sorted_grid],
        [g["prompt_mgr"] for g in sorted_grid],
        top_k=top_k,
    )

    print("\nTop configs selected:")
    for i, entry in enumerate(top_configs, 1):
        print(
            f"{i}. {entry['config_key']} | "
            f"F1={entry['f1']:.4f}"
        )

    return top_configs


In [None]:
# ============================================================
# PHASE 1 — Retrieval Capacity (paired k_docs, k_passages)
# ============================================================

PHASE_1_GRID = []

BASE_RETRIEVAL_PARAMS = {
    "window": 150,
    "overlap": 50,
    "mu": 1000,
    "k1": 0.9,
    "b": 0.4,
}

CAPACITY_PAIRS = [
    (5, 3),
    (5, 10),
    (10, 5),
    (15, 5),
    (15, 7),
    (20, 7),
    (20, 10),
    (30, 10),
    (50, 5),
    (100, 5)
]

for k_docs, k_passages in CAPACITY_PAIRS:
    PHASE_1_GRID.append({
        "retrieval_mgr": RetrievalManager(
            k_docs=k_docs,
            k_passages=k_passages,
            **BASE_RETRIEVAL_PARAMS,
        ),
        "prompt_mgr": PromptManager(),
    })

print(f"✓ Phase 1 grid size: {len(PHASE_1_GRID)}")

PHASE_1_TOP_CONFIGS = run_phase(
    phase_name="PHASE 1 — Retrieval Capacity",
    grid=PHASE_1_GRID,
    validation_data=validation_data,
    top_k=3,
)


✓ Phase 1 grid size: 10

PHASE 1 — Retrieval Capacity
Total configs: 10
Completed configs: 9
Pending configs: 1
--------------------------------------------------------------------------------
[1/1] Running: RRF_k60_mu1000_k10.9_b0.4_kdocs100_kpass5_win150_ovl50_promptdefault


RRF_k60_mu1000_k10.9_b0.4_kdocs100_kpass5_win150_ovl50_promptdefault:  22%|██▏       | 22/100 [00:19<01:01,  1.26it/s]

In [None]:
# ============================================================
# PHASE 2 — Passage Segmentation (window / overlap pairs)
# ============================================================

PHASE_2_GRID = []

# Base + proportional variants
WINDOW_OVERLAP_PAIRS = [
    (100, 30),
    (150, 50),  # Base
    (200, 50),
    (250, 60),
]

for entry in PHASE_1_TOP_CONFIGS:
    base_mgr = entry["retrieval_mgr"]

    for window, overlap in WINDOW_OVERLAP_PAIRS:
        PHASE_2_GRID.append({
            "retrieval_mgr": RetrievalManager(
                k_docs=base_mgr.k_docs,
                k_passages=base_mgr.k_passages,
                window=window,
                overlap=overlap,
                mu=base_mgr.mu,
                k1=base_mgr.k1,
                b=base_mgr.b,
            ),
            "prompt_mgr": PromptManager(),
        })

print(f"✓ Phase 2 grid size: {len(PHASE_2_GRID)}")

PHASE_2_TOP_CONFIGS = run_phase(
    phase_name="PHASE 2 — Passage Segmentation",
    grid=PHASE_2_GRID,
    validation_data=validation_data,
    top_k=2,
)


✓ Phase 2 grid size: 12

PHASE 2 — Passage Segmentation
Total configs: 12
Completed configs: 3
Pending configs: 9
--------------------------------------------------------------------------------
[1/9] Running: RRF_k60_mu1000_k10.9_b0.4_kdocs10_kpass5_win100_ovl30_promptdefault


RRF_k60_mu1000_k10.9_b0.4_kdocs10_kpass5_win100_ovl30_promptdefault: 100%|██████████| 100/100 [00:43<00:00,  2.30it/s]



RRF_k60_mu1000_k10.9_b0.4_kdocs10_kpass5_win100_ovl30_promptdefault
   Retrieval: Retrieval(RRF_k=60, μ=1000, k1=0.9, b=0.4) | k_docs=10, k_passages=5 | window=100, overlap=30
   Prompt: temp=0.0, top_p=1.0, max_tokens=256
   F1=17.60 | P=18.82 | R=18.75 | EM=11.00
   Questions: 100

✓ F1=17.6016
[2/9] Running: RRF_k60_mu1000_k10.9_b0.4_kdocs10_kpass5_win200_ovl50_promptdefault


RRF_k60_mu1000_k10.9_b0.4_kdocs10_kpass5_win200_ovl50_promptdefault: 100%|██████████| 100/100 [01:07<00:00,  1.48it/s]



RRF_k60_mu1000_k10.9_b0.4_kdocs10_kpass5_win200_ovl50_promptdefault
   Retrieval: Retrieval(RRF_k=60, μ=1000, k1=0.9, b=0.4) | k_docs=10, k_passages=5 | window=200, overlap=50
   Prompt: temp=0.0, top_p=1.0, max_tokens=256
   F1=21.26 | P=25.43 | R=21.87 | EM=11.00
   Questions: 100

✓ F1=21.2564
[3/9] Running: RRF_k60_mu1000_k10.9_b0.4_kdocs10_kpass5_win250_ovl60_promptdefault


RRF_k60_mu1000_k10.9_b0.4_kdocs10_kpass5_win250_ovl60_promptdefault: 100%|██████████| 100/100 [01:08<00:00,  1.46it/s]



RRF_k60_mu1000_k10.9_b0.4_kdocs10_kpass5_win250_ovl60_promptdefault
   Retrieval: Retrieval(RRF_k=60, μ=1000, k1=0.9, b=0.4) | k_docs=10, k_passages=5 | window=250, overlap=60
   Prompt: temp=0.0, top_p=1.0, max_tokens=256
   F1=16.37 | P=19.98 | R=17.95 | EM=8.00
   Questions: 100

✓ F1=16.3718
[4/9] Running: RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass5_win100_ovl30_promptdefault


RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass5_win100_ovl30_promptdefault: 100%|██████████| 100/100 [00:44<00:00,  2.22it/s]



RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass5_win100_ovl30_promptdefault
   Retrieval: Retrieval(RRF_k=60, μ=1000, k1=0.9, b=0.4) | k_docs=15, k_passages=5 | window=100, overlap=30
   Prompt: temp=0.0, top_p=1.0, max_tokens=256
   F1=17.60 | P=18.82 | R=18.75 | EM=11.00
   Questions: 100

✓ F1=17.6016
[5/9] Running: RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass5_win200_ovl50_promptdefault


RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass5_win200_ovl50_promptdefault: 100%|██████████| 100/100 [01:08<00:00,  1.45it/s]



RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass5_win200_ovl50_promptdefault
   Retrieval: Retrieval(RRF_k=60, μ=1000, k1=0.9, b=0.4) | k_docs=15, k_passages=5 | window=200, overlap=50
   Prompt: temp=0.0, top_p=1.0, max_tokens=256
   F1=21.92 | P=26.43 | R=22.37 | EM=11.00
   Questions: 100

✓ F1=21.9231
[6/9] Running: RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass5_win250_ovl60_promptdefault


RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass5_win250_ovl60_promptdefault: 100%|██████████| 100/100 [01:10<00:00,  1.42it/s]



RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass5_win250_ovl60_promptdefault
   Retrieval: Retrieval(RRF_k=60, μ=1000, k1=0.9, b=0.4) | k_docs=15, k_passages=5 | window=250, overlap=60
   Prompt: temp=0.0, top_p=1.0, max_tokens=256
   F1=17.21 | P=20.98 | R=18.62 | EM=8.00
   Questions: 100

✓ F1=17.2051
[7/9] Running: RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass7_win100_ovl30_promptdefault


RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass7_win100_ovl30_promptdefault: 100%|██████████| 100/100 [01:04<00:00,  1.54it/s]



RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass7_win100_ovl30_promptdefault
   Retrieval: Retrieval(RRF_k=60, μ=1000, k1=0.9, b=0.4) | k_docs=15, k_passages=7 | window=100, overlap=30
   Prompt: temp=0.0, top_p=1.0, max_tokens=256
   F1=22.66 | P=25.60 | R=25.23 | EM=12.00
   Questions: 100

✓ F1=22.6645
[8/9] Running: RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass7_win200_ovl50_promptdefault


RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass7_win200_ovl50_promptdefault: 100%|██████████| 100/100 [01:21<00:00,  1.23it/s]



RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass7_win200_ovl50_promptdefault
   Retrieval: Retrieval(RRF_k=60, μ=1000, k1=0.9, b=0.4) | k_docs=15, k_passages=7 | window=200, overlap=50
   Prompt: temp=0.0, top_p=1.0, max_tokens=256
   F1=18.06 | P=20.44 | R=21.37 | EM=7.00
   Questions: 100

✓ F1=18.0610
[9/9] Running: RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass7_win250_ovl60_promptdefault


RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass7_win250_ovl60_promptdefault: 100%|██████████| 100/100 [01:25<00:00,  1.17it/s]


RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass7_win250_ovl60_promptdefault
   Retrieval: Retrieval(RRF_k=60, μ=1000, k1=0.9, b=0.4) | k_docs=15, k_passages=7 | window=250, overlap=60
   Prompt: temp=0.0, top_p=1.0, max_tokens=256
   F1=11.23 | P=12.62 | R=12.33 | EM=7.00
   Questions: 100

✓ F1=11.2334

Top configs selected:
1. RRF_k60_mu1000_k10.9_b0.4_kdocs10_kpass5_win150_ovl50_promptdefault | F1=26.4648
2. RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass5_win150_ovl50_promptdefault | F1=24.9648
3. RRF_k60_mu1000_k10.9_b0.4_kdocs15_kpass7_win100_ovl30_promptdefault | F1=22.6645





In [None]:
# ============================================================
# PHASE 3 — Lexical Hyperparameters
# ============================================================

PHASE_3_GRID = []

BM25_PARAMS = [
    {"k1": 0.6, "b": 0.3},
    {"k1": 0.9, "b": 0.4},   # baseline
    {"k1": 1.2, "b": 0.6},
]

QLD_PARAMS = [
    {"mu": 1000},           # baseline
    {"mu": 2000},
]

for entry in PHASE_2_TOP_CONFIGS:
    base_mgr = entry["retrieval_mgr"]

    for bm25 in BM25_PARAMS:
        for qld in QLD_PARAMS:
            PHASE_3_GRID.append({
                "retrieval_mgr": RetrievalManager(
                    k_docs=base_mgr.k_docs,
                    k_passages=base_mgr.k_passages,
                    window=base_mgr.window,
                    overlap=base_mgr.overlap,
                    k1=bm25["k1"],
                    b=bm25["b"],
                    mu=qld["mu"],
                ),
                "prompt_mgr": PromptManager(),
            })

print(f"✓ Phase 3 grid size: {len(PHASE_3_GRID)}")

PHASE_3_TOP_CONFIGS = run_phase(
    phase_name="PHASE 3 — Lexical Hyperparameters",
    grid=PHASE_3_GRID,
    validation_data=validation_data,
    top_k=1,
)


✓ Phase 3 grid size: 27

PHASE 3 — Lexical Hyperparameters
Total configs: 27
Completed configs: 7
Pending configs: 20
--------------------------------------------------------------------------------
[1/20] Running: RRF_k60_mu2000_k10.9_b0.4_kdocs15_kpass5_win150_ovl50_promptdefault


RRF_k60_mu2000_k10.9_b0.4_kdocs15_kpass5_win150_ovl50_promptdefault:  12%|█▏        | 3/25 [00:24<02:57,  8.06s/it]


KeyboardInterrupt: 

In [None]:
# ============================================================
# FINAL SELECTION — Top 3 configs on 1,000 TRAIN questions
# ============================================================

FINAL_SELECTION_SEED = 123
FINAL_SELECTION_SIZE = 1000

# Take top-3 configs from Phase 3 (already sorted by F1)
TOP_3_CONFIGS = PHASE_3_TOP_CONFIGS[:3]

final_validation_data = df_train.sample(
    n=FINAL_SELECTION_SIZE,
    random_state=FINAL_SELECTION_SEED,
).reset_index(drop=True)

print("=" * 80)
print("FINAL MODEL SELECTION ON 1,000 TRAIN QUESTIONS")
print("=" * 80)

FINAL_SELECTION_RESULTS = []

for i, entry in enumerate(TOP_3_CONFIGS, 1):
    retrieval_mgr = entry["retrieval_mgr"]

    prompt_mgr = PromptManager(
        system_prompt=SYSTEM_PROMPT,
        user_prompt=USER_PROMPT,
        temperature=0.0,
        do_sample=False,
        top_p=1.0,
        prompt_id="final",
    )

    config_key = generate_config_key(retrieval_mgr, prompt_mgr)

    print(f"\n[{i}/3] Evaluating config: {config_key}")

    result = run_experiment(
        name=f"{config_key}_final_select",
        df_data=final_validation_data,
        retrieval_manager=retrieval_mgr,
        prompt_manager=prompt_mgr,
        verbose=False,
    )

    print(
        f"✓ F1={result['f1_score']:.4f} | "
        f"EM={result['exact_match']:.4f} | "
        f"P={result['precision']:.4f} | "
        f"R={result['recall']:.4f}"
    )

    FINAL_SELECTION_RESULTS.append({
        "retrieval_mgr": retrieval_mgr,
        "prompt_mgr": prompt_mgr,
        "config_key": config_key,
        **result,
    })

# Select the best config by F1
FINAL_SELECTION_RESULTS.sort(
    key=lambda x: x["f1_score"],
    reverse=True,
)

BEST_FINAL_CONFIG = FINAL_SELECTION_RESULTS[0]

print("\n" + "=" * 80)
print("✓ BEST FINAL CONFIG SELECTED")
print("=" * 80)
print(
    f"{BEST_FINAL_CONFIG['config_key']} | "
    f"F1={BEST_FINAL_CONFIG['f1_score']:.4f}"
)


In [None]:
# ============================================================
# KAGGLE SUBMISSION — Final system on TEST set
# ============================================================

BEST_RETRIEVAL_MGR = BEST_FINAL_CONFIG["retrieval_mgr"]
BEST_PROMPT_MGR = BEST_FINAL_CONFIG["prompt_mgr"]

FINAL_CONFIG_KEY = BEST_FINAL_CONFIG["config_key"]

print("=" * 80)
print("KAGGLE SUBMISSION GENERATION")
print("=" * 80)
print(f"Using final config: {FINAL_CONFIG_KEY}")
print("=" * 80)

# Run inference only (no labels needed)
test_questions = df_test["question"].tolist()

print(f"Generating answers for {len(test_questions)} test questions...")

test_contexts = [
    BEST_RETRIEVAL_MGR.retrieve_context(q)
    for q in test_questions
]

test_answers = BEST_PROMPT_MGR.batch_generate_answers(
    questions=test_questions,
    contexts_list=test_contexts,
)

# Build Kaggle submission file
submission_df = pd.DataFrame({
    "id": df_test["id"],
    "answer": test_answers,
})

SUBMISSION_PATH = "./results/kaggle_submission.csv"
submission_df.to_csv(SUBMISSION_PATH, index=False)

print(f"✓ Kaggle submission file saved to: {SUBMISSION_PATH}")
print(f"✓ Total rows: {len(submission_df)}")
