## 1. Setup & Dependencies

In [12]:
import pandas as pd
import json
import re
import string
from collections import Counter
from typing import List, Dict, Tuple, Optional, Literal
from dataclasses import dataclass
from tqdm.auto import tqdm
import itertools
tqdm.pandas()
import warnings
warnings.filterwarnings('ignore')

print("✓ Dependencies imported")

✓ Dependencies imported


### Install Required Packages

In [13]:
# !pip install torch torchvision torchaudio
# !pip install pyserini==0.36.0
# !pip install accelerate
# !pip install transformers
# !pip install tqdm
# !pip install python-dotenv

In [14]:
import os

os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-21-openjdk-amd64"
os.environ["PATH"] = os.environ["JAVA_HOME"] + "/bin:" + os.environ["PATH"]

!java -version

openjdk version "21.0.9" 2025-10-21
OpenJDK Runtime Environment (build 21.0.9+10-Ubuntu-122.04)
OpenJDK 64-Bit Server VM (build 21.0.9+10-Ubuntu-122.04, mixed mode, sharing)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [15]:
# !pip install torch torchvision torchaudio
# !pip install pyserini==0.36.0
# !pip install accelerate
# !pip install transformers
# !pip install tqdm
# !pip install python-dotenv

### Hugging Face Authentication

In [16]:
from huggingface_hub import login
from dotenv import load_dotenv
import os

load_dotenv()

login(os.getenv('HUGGING_FACE_TOKEN'))
print("✓ Logged into Hugging Face")

✓ Logged into Hugging Face


## 2. Data Loading & Preparation

In [17]:
# Load datasets
df_train = pd.read_csv("./data/train.csv", converters={"answers": json.loads})
df_test = pd.read_csv("./data/test.csv")

print(f"Train set: {len(df_train)} questions")
print(f"Test set: {len(df_test)} questions")
print(f"\nSample question: {df_train.iloc[0]['question']}")
print(f"Sample answers: {df_train.iloc[0]['answers']}")

Train set: 3778 questions
Test set: 2032 questions

Sample question: what is the name of justin bieber brother?
Sample answers: ['Jazmyn Bieber', 'Jaxon Bieber']


## 3. Retrieval Functions

In [19]:
from pyserini.search import SimpleSearcher
from pyserini.index.lucene import IndexReader

# Load Pyserini index
print("Loading Pyserini index...")
searcher = SimpleSearcher.from_prebuilt_index('wikipedia-kilt-doc')
index_reader = IndexReader.from_prebuilt_index('wikipedia-kilt-doc')

print(f"✓ Index loaded: {index_reader.stats()['documents']} documents")

Loading Pyserini index...
SimpleSearcher class has been deprecated, please use LuceneSearcher from pyserini.search.lucene instead
✓ Index loaded: 5903530 documents


In [20]:
from sentence_transformers import SentenceTransformer
import torch

# Load bi-encoder
print("Loading bi-encoder...")
device = "cuda" if torch.cuda.is_available() else "cpu"
bi_encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=device)

print("✓ Bi-encoder loaded")

Loading bi-encoder...
✓ Bi-encoder loaded


In [None]:
from dataclasses import dataclass
from typing import List, Literal
import json
import torch
from sentence_transformers import util


@dataclass
class RetrievalManager:
    """
    Manages passage-based retrieval with optional bi-encoder reranking.
    """
    # Configuration parameters
    k_docs: int = 5
    k_passages: int = 3
    method: Literal["bm25", "qld"] = "qld"
    use_rerank: bool = True

    mu: int = 1000          # QLD smoothing
    k1: float = 0.9         # BM25
    b: float = 0.4          # BM25

    window: int = 150
    overlap: int = 50
    min_passage_words: int = 30

    def __str__(self):
        method_str = (
            f"QLD(mu={self.mu})"
            if self.method == "qld"
            else f"BM25(k1={self.k1}, b={self.b})"
        )
        rerank_str = "BiEncoder" if self.use_rerank else "NoRerank"
        return (
            f"{method_str} → {rerank_str} | "
            f"k_docs={self.k_docs}, k_passages={self.k_passages}"
        )

    def extract_passages(
        self,
        text: str,
    ) -> List[str]:
        """
        Split text into overlapping word windows.
        """
        if not text:
            return []

        words = text.split()
        if len(words) < self.min_passage_words:
            return []

        step = max(1, self.window - self.overlap)
        passages = []

        for i in range(0, len(words), step):
            chunk = words[i:i + self.window]
            if len(chunk) < self.min_passage_words:
                break
            passages.append(" ".join(chunk))

        return passages

    def rerank(
        self,
        query: str,
        passages: List[str],
    ) -> List[str]:
        """
        Rerank passages using bi-encoder cosine similarity.
        """
        if not passages:
            return []

        q_emb = bi_encoder.encode(query, convert_to_tensor=True, device=device)
        p_embs = bi_encoder.encode(passages, convert_to_tensor=True, device=device)

        scores = util.cos_sim(q_emb, p_embs).squeeze(0)
        top_k = min(self.k_passages, len(passages))
        idx = torch.topk(scores, k=top_k).indices.tolist()

        return [passages[i] for i in idx]

    def retrieve_context(self, query: str) -> List[str]:
        """
        Retrieve passages using BM25 or QLD, optionally followed by bi-encoder reranking.
        """
        if self.method == "bm25":
            searcher.set_bm25(self.k1, self.b)
        else:
            searcher.set_qld(self.mu)

        hits = searcher.search(query, self.k_docs)
        passages: List[str] = []

        for hit in hits:
            try:
                doc = searcher.doc(hit.docid)
                content = json.loads(doc.raw()).get("contents", "").replace("\n", " ")
                passages.extend(self.extract_passages(content))
            except Exception:
                continue

        if not self.use_rerank:
            return passages[:self.k_passages]

        return self.rerank(query, passages)


# Test the RetrievalManager
query = "Who wrote Harry Potter?"

configs = [
    RetrievalManager(method="qld", use_rerank=False),
    RetrievalManager(method="qld", use_rerank=True),
    RetrievalManager(method="bm25", use_rerank=False),
    RetrievalManager(method="bm25", use_rerank=True),
]

for cfg in configs:
    print(cfg)
    test_passages = cfg.retrieve_context(query)
    for i, p in enumerate(test_passages, 1):
        print(f"{i}. {p[:100]}...")
    print()


QLD(mu=1000) → NoRerank | k_docs=5, k_passages=3
1. Harry Potter Harry Potter is a series of fantasy novels written by British author J. K. Rowling. The...
2. June 1997, the books have found immense popularity, critical acclaim and commercial success worldwid...
3. English by two major publishers, Bloomsbury in the United Kingdom and Scholastic Press in the United...

QLD(mu=1000) → BiEncoder | k_docs=5, k_passages=3
1. the film, has denied that Rowling ever saw it before writing her book. Rowling has said on record mu...
2. Harry Potter Harry Potter is a series of fantasy novels written by British author J. K. Rowling. The...
3. by Emily Brontë, "Charlie and the Chocolate Factory" by Roald Dahl, "Robinson Crusoe" by Daniel Defo...

BM25(k1=0.9, b=0.4) → NoRerank | k_docs=5, k_passages=3
1. Bonnie Wright Bonnie Francesca Wright (born 17 February 1991) is an English actress, film director, ...
2. the Deathly Hallows – Part 1" and "Part 2", she began attending London's University of the 

## 4. LLM Generation

In [22]:
import transformers
import torch
import logging

# Suppress transformers warnings
transformers.logging.set_verbosity_error()
logging.getLogger("transformers").setLevel(logging.ERROR)

print("Loading LLM model...")
model_id = "meta-llama/Llama-3.2-1B-Instruct"

pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.float16},
    device_map="auto"
)

terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

# Set pad_token for batch processing
pipeline.tokenizer.pad_token = pipeline.tokenizer.eos_token

print(f"✓ Model loaded on: {'GPU' if torch.cuda.is_available() else 'CPU'}")

Loading LLM model...





✓ Model loaded on: GPU


In [23]:
DEFAULT_SYSTEM_PROMPT = (
    "You must respond based strictly on the information in provided passages."
    "Do not incorporate any external knowledge or infer any details beyond what is given."
    "If the answer is not in the context, return 'I dont know'."
    "Do not include explanations, only the final answer!"
)

DEFAULT_USER_PROMPT = (
    "Based on the following documents, provide a concise answer to the question.\n\n"
    "{context}\n\n"
    "Question: {question}\n\n"
    "Answer:"
)

@dataclass
class PromptManager:
    """Manages prompt generation and LLM answer generation."""
    system_prompt: str = DEFAULT_SYSTEM_PROMPT
    user_prompt: str = DEFAULT_USER_PROMPT
    temperature: float = 0.1
    top_p: float = 0.9
    max_new_tokens: int = 256
    do_sample: bool = True
    prompt_id: str = "default"  # For later use in prompt tuning
    
    def __str__(self):
        return f"temp={self.temperature}, top_p={self.top_p}, max_tokens={self.max_new_tokens}"

    @staticmethod
    def clean_answer(answer: str) -> str:
        """Clean and standardize the generated answer."""
        answer = re.sub(r'^(Answer|The answer is|Based on the .*?,):?\s*', '', answer, flags=re.I)
        answer = answer.rstrip('.')
        if any(phrase in answer.lower() for phrase in ["dont know", "don't know", "do not know", "unknown"]):
            return "unknown"
        return answer.strip()

    def create_messages(self, question: str, contexts: List[str]) -> List[Dict]:
        """Create messages for the LLM based on the question and contexts."""
        if not contexts:
            context_str = "No relevant documents found."
        else:
            context_str = '\n\n'.join([f"Document {i+1}: {ctx}" for i, ctx in enumerate(contexts)])
        
        return [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": self.user_prompt.format(context=context_str, question=question)}
        ]

    def generate_answer(self, question: str, contexts: List[str]) -> str:
        """Generate an answer using the LLM based on the question and contexts."""
        messages = self.create_messages(question, contexts)
        
        outputs = pipeline(
            messages,
            max_new_tokens=self.max_new_tokens,
            eos_token_id=terminators,
            do_sample=self.do_sample,
            temperature=self.temperature,
            top_p=self.top_p,
        )
        
        answer = outputs[0]["generated_text"][-1].get('content', '')
        return self.clean_answer(answer)

    def batch_generate_answers(self, questions: List[str], contexts_list: List[List[str]]) -> List[str]:
        """Generate answers for multiple questions in batch."""
        # Create messages for all questions
        batch_messages = [self.create_messages(q, ctx) for q, ctx in zip(questions, contexts_list)]
        
        # Process batch through pipeline
        outputs = pipeline(
            batch_messages,
            max_new_tokens=self.max_new_tokens,
            eos_token_id=terminators,
            do_sample=self.do_sample,
            temperature=self.temperature,
            top_p=self.top_p
        )
        
        # Extract and clean answers
        answers = []
        for output in outputs:
            answer = output[0]["generated_text"][-1].get('content', '')
            answers.append(self.clean_answer(answer))
        
        return answers


# Test the PromptManager
test_prompt_manager = PromptManager()
print(f"Testing: {test_prompt_manager}")
test_answer = test_prompt_manager.generate_answer(query, test_passages)
print(f"✓ Generated answer: '{test_answer}'")

Testing: temp=0.1, top_p=0.9, max_tokens=256
✓ Generated answer: 'J. K. Rowling'


## 5. Evaluation Metrics

In [24]:
def normalize_answer(s: str) -> str:
    """Normalize answer for comparison"""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)
    
    def white_space_fix(text):
        return ' '.join(text.split())
    
    def remove_punc(text):
        return ''.join(ch for ch in text if ch not in set(string.punctuation))
    
    def lower(text):
        return text.lower()
    
    return white_space_fix(remove_articles(remove_punc(lower(s))))


def compute_token_metrics(prediction: str, ground_truth: str) -> Tuple[float, float, float]:
    """
    Compute precision, recall, and F1 score for token-level comparison.
    Returns: (precision, recall, f1)
    """
    pred_tokens = normalize_answer(prediction).split()
    gt_tokens = normalize_answer(ground_truth).split()
    
    # Handle empty cases
    if len(pred_tokens) == 0 or len(gt_tokens) == 0:
        match = int(pred_tokens == gt_tokens)
        return match, match, match
    
    # Compute overlap
    common = Counter(pred_tokens) & Counter(gt_tokens)
    num_same = sum(common.values())
    
    if num_same == 0:
        return 0.0, 0.0, 0.0
    
    precision = num_same / len(pred_tokens)
    recall = num_same / len(gt_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    
    return precision, recall, f1


def evaluate_predictions(df_gold: pd.DataFrame, predictions: Dict[int, str]) -> Dict:
    """Evaluate predictions against ground truth."""
    f1_scores = []
    precision_scores = []
    recall_scores = []
    exact_matches = []
    
    for _, row in df_gold.iterrows():
        qid = row['id']
        
        # Handle missing predictions
        if qid not in predictions:
            f1_scores.append(0.0)
            precision_scores.append(0.0)
            recall_scores.append(0.0)
            exact_matches.append(0)
            continue
        
        prediction = predictions[qid]
        ground_truths = row['answers']
        
        # Normalize once
        norm_prediction = normalize_answer(prediction)
        
        # Find best match across all ground truths
        best_f1 = 0.0
        best_precision = 0.0
        best_recall = 0.0
        is_exact = 0
        
        for gt in ground_truths:
            norm_gt = normalize_answer(gt)
            
            # Compute metrics
            prec, rec, f1 = compute_token_metrics(prediction, gt)
            
            # Track best scores
            if f1 > best_f1:
                best_f1 = f1
                best_precision = prec
                best_recall = rec
            
            # Check exact match
            if norm_prediction == norm_gt:
                is_exact = 1
        
        f1_scores.append(best_f1)
        precision_scores.append(best_precision)
        recall_scores.append(best_recall)
        exact_matches.append(is_exact)
    
    return {
        'f1': 100.0 * sum(f1_scores) / len(f1_scores) if f1_scores else 0.0,
        'precision': 100.0 * sum(precision_scores) / len(precision_scores) if precision_scores else 0.0,
        'recall': 100.0 * sum(recall_scores) / len(recall_scores) if recall_scores else 0.0,
        'exact_match': 100.0 * sum(exact_matches) / len(exact_matches) if exact_matches else 0.0,
        'f1_scores': f1_scores,
        'precision_scores': precision_scores,
        'recall_scores': recall_scores,
        'exact_matches': exact_matches
    }


# Test evaluation
test_predictions = {1: "J.K. Rowling", 2: "Paris", 3: "Shakespeare"}
test_gold = pd.DataFrame({
    'id': [1, 2, 3],
    'answers': [["J.K. Rowling", "Rowling"], ["Earth"], ["William Shakespeare", "Shakespeare"]]
})

test_metrics = evaluate_predictions(test_gold, test_predictions)
print(f"✓ Evaluation test: F1={test_metrics['f1']:.2f}, P={test_metrics['precision']:.2f}, R={test_metrics['recall']:.2f}, EM={test_metrics['exact_match']:.2f}")

✓ Evaluation test: F1=66.67, P=66.67, R=66.67, EM=66.67


## 6. Experiment Framework

In [25]:
def run_experiment(
    name: str,
    df_data: pd.DataFrame,
    retrieval_manager: RetrievalManager,
    prompt_manager: PromptManager,
    max_questions: Optional[int] = None,
    batch_size: int = 4,
    verbose: bool = True
) -> Dict:
    if max_questions:
        df_data = df_data.head(max_questions)
    
    predictions = {}
    
    # Process in batches
    num_batches = (len(df_data) + batch_size - 1) // batch_size
    iterator = tqdm(range(num_batches), desc=name) if verbose else range(num_batches)
    
    for batch_idx in iterator:
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, len(df_data))
        batch_df = df_data.iloc[start_idx:end_idx]
        
        # Retrieve contexts for all questions in batch
        batch_questions = []
        batch_qids = []
        batch_contexts = []
        
        for _, row in batch_df.iterrows():
            question = row['question']
            qid = row['id']
            contexts = retrieval_manager.retrieve_context(question)
            
            batch_questions.append(question)
            batch_qids.append(qid)
            batch_contexts.append(contexts)
        
        # Generate answers in batch
        batch_answers = prompt_manager.batch_generate_answers(batch_questions, batch_contexts)
        
        # Store predictions
        for qid, answer in zip(batch_qids, batch_answers):
            predictions[qid] = answer
    
    metrics = evaluate_predictions(df_data, predictions)
    
    result = {
        'name': name,
        'retrieval': retrieval_manager,
        'prompt': prompt_manager,
        'f1_score': metrics['f1'],
        'precision': metrics['precision'],
        'recall': metrics['recall'],
        'exact_match': metrics['exact_match'],
        'num_questions': len(df_data),
        'predictions': predictions,
        'f1_scores': metrics['f1_scores'],
        'precision_scores': metrics['precision_scores'],
        'recall_scores': metrics['recall_scores'],
        'exact_matches': metrics['exact_matches']
    }
    
    if verbose:
        print(f"\n{name}")
        print(f"   Retrieval: {retrieval_manager}")
        print(f"   Prompt: {prompt_manager}")
        print(f"   F1={metrics['f1']:.2f} | P={metrics['precision']:.2f} | R={metrics['recall']:.2f} | EM={metrics['exact_match']:.2f}")
        print(f"   Questions: {len(df_data)}\n")
    
    return result

# Test experiment
test_retrieval = RetrievalManager()
test_prompt = PromptManager()
print(f"Testing experiment with:")
print(f"  Retrieval: {test_retrieval}")
print(f"  Prompt: {test_prompt}")

test_exp = run_experiment(
    "Quick Test",
    df_val.head(5),
    test_retrieval,
    test_prompt,
    verbose=True
)

print(f"✓ Experiment framework ready")

Testing experiment with:
  Retrieval: QLD(mu=1000) → BiEncoder | k_docs=5, k_passages=3
  Prompt: temp=0.1, top_p=0.9, max_tokens=256


Quick Test: 100%|██████████| 2/2 [00:08<00:00,  4.48s/it]


Quick Test
   Retrieval: QLD(mu=1000) → BiEncoder | k_docs=5, k_passages=3
   Prompt: temp=0.1, top_p=0.9, max_tokens=256
   F1=20.00 | P=24.00 | R=30.00 | EM=0.00
   Questions: 5

✓ Experiment framework ready





## 7. Experiments

### Experiments global config

In [26]:
EXPERIMENT_SEED = 42
EXPERIMENT_QUESTIONS = 100

EXPERIMENT_LOG_PATH = (
    f"./results/grid_search_results_q{EXPERIMENT_QUESTIONS}.csv"
)

validation_data = df_val.sample(
    n=EXPERIMENT_QUESTIONS,
    random_state=EXPERIMENT_SEED
).reset_index(drop=True)

print("=" * 80)
print("PHASED RETRIEVAL + GENERATION EXPERIMENT FRAMEWORK")
print("=" * 80)
print(f"Validation questions per config: {EXPERIMENT_QUESTIONS}")
print(f"Random seed: {EXPERIMENT_SEED}")
print(f"Results cache: {EXPERIMENT_LOG_PATH}")
print("=" * 80)


PHASED RETRIEVAL + GENERATION EXPERIMENT FRAMEWORK
Validation questions per config: 100
Random seed: 42
Results cache: ./results/grid_search_results_q100.csv


### Experiments utils

In [None]:
def build_retrieval_manager(base: dict, override: dict) -> RetrievalManager:
    """Build RetrievalManager safely."""
    return RetrievalManager(**{**base, **override})


def generate_config_key(
    retrieval_mgr: RetrievalManager,
    prompt_mgr: PromptManager,
) -> str:
    """Unique experiment key: retrieval + prompt + decoding + dataset size."""
    method_part = f"alpha{retrieval_mgr.alpha}_mu{retrieval_mgr.mu}_k1{retrieval_mgr.k1}_b{retrieval_mgr.b}"

    return (
        f"HYBRID_"
        f"kdocs{retrieval_mgr.k_docs}_"
        f"kpass{retrieval_mgr.k_passages}_"
        f"{method_part}_"
        f"win{retrieval_mgr.window}_ovl{retrieval_mgr.overlap}_"
        f"prompt{prompt_mgr.prompt_id}_"
        f"sample{int(prompt_mgr.do_sample)}_"
        f"temp{prompt_mgr.temperature}_"
        f"q{EXPERIMENT_QUESTIONS}"
    )


def save_results_to_csv(result: dict, key: str, path: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)

    row = {
        "config_key": key,
        "f1": result["f1_score"],
        "precision": result["precision"],
        "recall": result["recall"],
        "exact_match": result["exact_match"],
        "num_questions": result["num_questions"],
    }

    df = pd.DataFrame([row])
    if not os.path.exists(path):
        df.to_csv(path, index=False)
    else:
        df.to_csv(path, mode="a", header=False, index=False)


def load_completed_configs(path: str) -> set[str]:
    if not os.path.exists(path):
        return set()
    return set(pd.read_csv(path)["config_key"])


### Best-config selector

In [None]:
def select_best_per_alpha_range(
    retrieval_managers: list[RetrievalManager],
    prompt_managers: list[PromptManager],
):
    """Select the best (highest-F1) config for low, mid, and high alpha values."""
    df = pd.read_csv(EXPERIMENT_LOG_PATH)

    # Define alpha ranges
    best = {
        "low_alpha": {"f1": -1.0, "entry": None},    # alpha < 0.4 (BM25-dominant)
        "mid_alpha": {"f1": -1.0, "entry": None},    # 0.4 <= alpha <= 0.6 (balanced)
        "high_alpha": {"f1": -1.0, "entry": None},   # alpha > 0.6 (QLD-dominant)
    }

    for r_mgr, p_mgr in zip(retrieval_managers, prompt_managers):
        key = generate_config_key(r_mgr, p_mgr)
        row = df[df["config_key"] == key]
        if row.empty:
            continue

        f1 = row.iloc[0]["f1"]
        alpha = r_mgr.alpha

        # Categorize by alpha value
        if alpha < 0.4:
            category = "low_alpha"
        elif alpha <= 0.6:
            category = "mid_alpha"
        else:
            category = "high_alpha"

        if f1 > best[category]["f1"]:
            best[category] = {
                "f1": f1,
                "entry": {
                    "retrieval_mgr": r_mgr,
                    "prompt_mgr": p_mgr,
                    "f1": f1,
                },
            }

    return [
        v["entry"]
        for v in best.values()
        if v["entry"] is not None
    ]


### Phase runner

In [None]:
def run_phase(
    *,
    phase_name: str,
    grid: list[dict],
    validation_data,
    select_best=True,
):
    """
    Run a single experiment phase.
    Each grid item must contain:
      - retrieval_mgr: RetrievalManager
      - prompt_mgr: PromptManager
    """
    print("\n" + "=" * 80)
    print(phase_name)
    print("=" * 80)

    completed = load_completed_configs(EXPERIMENT_LOG_PATH)

    pending = [
        g for g in grid
        if generate_config_key(g["retrieval_mgr"], g["prompt_mgr"]) not in completed
    ]

    print(f"Total configs: {len(grid)}")
    print(f"Pending configs: {len(pending)}")
    print("-" * 80)

    for i, entry in enumerate(pending, start=1):
        retrieval_mgr = entry["retrieval_mgr"]
        prompt_mgr = entry["prompt_mgr"]

        key = generate_config_key(retrieval_mgr, prompt_mgr)
        print(f"[{i}/{len(pending)}] Running: {key}")

        result = run_experiment(
            name=key,
            df_data=validation_data,
            retrieval_manager=retrieval_mgr,
            prompt_manager=prompt_mgr,
            verbose=True,
        )

        save_results_to_csv(result, key, EXPERIMENT_LOG_PATH)
        print(f"✓ F1={result['f1_score']:.2f}")

    if not select_best:
        return grid

    best = select_best_per_alpha_range(
        [g["retrieval_mgr"] for g in grid],
        [g["prompt_mgr"] for g in grid],
    )

    print("\nBest configs selected:")
    for entry in best:
        print(
            f"✓ {generate_config_key(entry['retrieval_mgr'], entry['prompt_mgr'])} | "
            f"F1={entry['f1']:.2f}"
        )

    return best


### Phase 1: Retrieval Method

Goal: Tune QLD μ and BM25 k1/b independently
Reranking disabled

In [30]:
PHASE_1_GRID = []

BASE_RETRIEVAL = {
    "k_docs": 5,
    "k_passages": 3,
    "window": 150,
    "overlap": 50,
}

# QLD
for mu in [500, 1000, 2000]:
    for rerank in [True, False]:
        PHASE_1_GRID.append({
            "retrieval_mgr": build_retrieval_manager(
                BASE_RETRIEVAL,
                {"method": "qld", "mu": mu, "use_rerank": rerank},
            ),
            "prompt_mgr": PromptManager(),
        })

# BM25
for k1, b in itertools.product([0.6, 0.9, 1.2], [0.4, 0.75]):
    for rerank in [True, False]:
        PHASE_1_GRID.append({
            "retrieval_mgr": build_retrieval_manager(
                BASE_RETRIEVAL,
                {"method": "bm25", "k1": k1, "b": b, "use_rerank": rerank},
            ),
            "prompt_mgr": PromptManager(),
        })

BEST_PHASE_1 = run_phase(
    phase_name="PHASE 1 — RETRIEVAL METHOD & RERANKING",
    grid=PHASE_1_GRID,
    validation_data=validation_data,
)



PHASE 1 — RETRIEVAL METHOD & RERANKING
Total configs: 18
Pending configs: 0
--------------------------------------------------------------------------------

Best configs selected:
✓ QLD_kdocs5_kpass3_RERANK_QLD_mu1000_win150_ovl50_promptdefault_sample1_temp0.1_q100 | F1=13.81
✓ BM25_kdocs5_kpass3_RERANK_BM25_k11.2_b0.4_win150_ovl50_promptdefault_sample1_temp0.1_q100 | F1=12.34


### Phase 2: Passage Segmentation
Goal: tune window / overlap with best retrieval params

In [31]:
PHASE_2_GRID = []

for best in BEST_PHASE_1:
    for window, overlap in itertools.product([100, 150, 200, 250, 300], [30, 50, 70]):
        PHASE_2_GRID.append({
            "retrieval_mgr": build_retrieval_manager(
                best["retrieval_mgr"].__dict__,
                {"window": window, "overlap": overlap},
            ),
            "prompt_mgr": best["prompt_mgr"],
        })

BEST_PHASE_2 = run_phase(
    phase_name="PHASE 2 — PASSAGE WINDOW & OVERLAP",
    grid=PHASE_2_GRID,
    validation_data=validation_data,
)



PHASE 2 — PASSAGE WINDOW & OVERLAP
Total configs: 30
Pending configs: 0
--------------------------------------------------------------------------------

Best configs selected:
✓ QLD_kdocs5_kpass3_RERANK_QLD_mu1000_win250_ovl70_promptdefault_sample1_temp0.1_q100 | F1=23.39
✓ BM25_kdocs5_kpass3_RERANK_BM25_k11.2_b0.4_win250_ovl70_promptdefault_sample1_temp0.1_q100 | F1=21.62


### Phase 3 — k_docs / k_passages Tradeoff
Goal: tune recall vs precision tradeoff

In [32]:
PHASE_3_GRID = []

K_VARIANTS = [
    {"k_docs": 3,  "k_passages": 2},
    {"k_docs": 5,  "k_passages": 3},   # baseline
    {"k_docs": 8,  "k_passages": 3},
    {"k_docs": 10, "k_passages": 5},
    {"k_docs": 15, "k_passages": 5},
    {"k_docs": 20, "k_passages": 7},
    {"k_docs": 25, "k_passages": 10},
]

for best in BEST_PHASE_2:
    for k_cfg in K_VARIANTS:
        PHASE_3_GRID.append({
            "retrieval_mgr": build_retrieval_manager(
                best["retrieval_mgr"].__dict__,
                k_cfg,
            ),
            "prompt_mgr": best["prompt_mgr"],
        })

BEST_PHASE_3 = run_phase(
    phase_name="PHASE 3 — RETRIEVAL DEPTH",
    grid=PHASE_3_GRID,
    validation_data=validation_data,
)



PHASE 3 — RETRIEVAL DEPTH
Total configs: 14
Pending configs: 0
--------------------------------------------------------------------------------

Best configs selected:
✓ QLD_kdocs5_kpass3_RERANK_QLD_mu1000_win250_ovl70_promptdefault_sample1_temp0.1_q100 | F1=23.39
✓ BM25_kdocs10_kpass5_RERANK_BM25_k11.2_b0.4_win250_ovl70_promptdefault_sample1_temp0.1_q100 | F1=22.90


In [33]:
PROMPT_VARIANTS = {
    "default": {
        "system": DEFAULT_SYSTEM_PROMPT,
        "user": DEFAULT_USER_PROMPT,
    },
    "strict": {
        "system": (
            "You are a strict information extraction system.\n"
            "Answer the question using ONLY the provided passages.\n"
            "The answer MUST be an exact span copied verbatim from the passages.\n"
            "Do NOT paraphrase, infer, or add information.\n"
            "If the answer does not appear explicitly in the passages, respond with: unknown."
        ),
        "user": (
            "PASSAGES:\n"
            "{context}\n\n"
            "QUESTION:\n"
            "{question}\n\n"
            "INSTRUCTIONS:\n"
            "- Answer with a single concise phrase, copied or extracted directly from the passages.\n"
            "- If no answer is present, answer: unknown.\n\n"
            "ANSWER:"
        ),
    },
}

DECODING_VARIANTS = [
    {"temperature": 0.1, "do_sample": True},    # baseline
    {"temperature": 0.1, "do_sample": False},   # greedy
    {"temperature": 0.3, "do_sample": True},    # diverse
    {"temperature": 0.6, "do_sample": True},    # more diverse
]

PHASE_4_GRID = []

for best in BEST_PHASE_3:
    for prompt_id, prompts in PROMPT_VARIANTS.items():
        for decode in DECODING_VARIANTS:
            PHASE_4_GRID.append({
                "retrieval_mgr": best["retrieval_mgr"],
                "prompt_mgr": PromptManager(
                    system_prompt=prompts["system"],
                    user_prompt=prompts["user"],
                    temperature=decode["temperature"],
                    do_sample=decode["do_sample"],
                    prompt_id=prompt_id,
                ),
            })

BEST_PHASE_4 = run_phase(
    phase_name="PHASE 4 — PROMPT & DECODING",
    grid=PHASE_4_GRID,
    validation_data=validation_data,
)



PHASE 4 — PROMPT & DECODING
Total configs: 16
Pending configs: 0
--------------------------------------------------------------------------------

Best configs selected:
✓ QLD_kdocs5_kpass3_RERANK_QLD_mu1000_win250_ovl70_promptdefault_sample0_temp0.1_q100 | F1=24.35
✓ BM25_kdocs10_kpass5_RERANK_BM25_k11.2_b0.4_win250_ovl70_promptdefault_sample1_temp0.1_q100 | F1=22.90


### Hybrid approach test

In [34]:
from dataclasses import dataclass
from typing import List, Literal
import json


@dataclass
class HybridRetrievalManager(RetrievalManager):
    """
    Hybrid retrieval manager combining QLD and BM25
    using rank-based score interpolation.
    """
    method: Literal["hybrid"] = "hybrid"
    alpha: float = 0.5  # alpha * QLD + (1 - alpha) * BM25

    def __str__(self):
        method_str = f"HYBRID(α={self.alpha}, μ={self.mu}, k1={self.k1}, b={self.b})"
        rerank_str = "BiEncoder" if self.use_rerank else "NoRerank"
        return (
            f"{method_str} → {rerank_str} | "
            f"k_docs={self.k_docs}, k_passages={self.k_passages}"
        )

    def retrieve_context(self, query: str) -> List[str]:
        """
        Retrieve passages using hybrid QLD + BM25,
        optionally followed by bi-encoder reranking.
        """
        # Hybrid document retrieval
        def collect_hits(hits, weight, scores):
            for rank, hit in enumerate(hits):
                score = weight / (rank + 1)
                scores[hit.docid] = max(scores.get(hit.docid, 0.0), score)

        scores = {}

        # QLD
        searcher.set_qld(self.mu)
        qld_hits = searcher.search(query, self.k_docs)
        collect_hits(qld_hits, self.alpha, scores)

        # BM25
        searcher.set_bm25(self.k1, self.b)
        bm25_hits = searcher.search(query, self.k_docs)
        collect_hits(bm25_hits, 1.0 - self.alpha, scores)

        # Top-k documents by combined score
        top_docids = sorted(scores, key=scores.get, reverse=True)[:self.k_docs]
        docs = [searcher.doc(docid) for docid in top_docids]

        passages: List[str] = []
        
        for doc in docs:
            try:
                content = (
                    json.loads(doc.raw())
                    .get("contents", "")
                    .replace("\n", " ")
                )
                passages.extend(self.extract_passages(content))
            except Exception:
                continue

        if not self.use_rerank:
            return passages[:self.k_passages]

        return self.rerank(query, passages)


In [35]:
import numpy as np

HYBRID_ALPHAS = np.round(np.arange(0.0, 1.1, 0.1), 2).tolist()

# Extract best QLD and BM25 configs from Phase 4
qld_best = next(e for e in BEST_PHASE_4 if e["retrieval_mgr"].method == "qld")
bm25_best = next(e for e in BEST_PHASE_4 if e["retrieval_mgr"].method == "bm25")

# Extract retrieval-specific parameters (will be same for all hybrid configs)
mu = qld_best["retrieval_mgr"].mu
k1 = bm25_best["retrieval_mgr"].k1
b = bm25_best["retrieval_mgr"].b

# Extract common parameters from both configs
COMMON_PARAMS = ["k_docs", "k_passages", "window", "overlap", "use_rerank"]
qld_params = {p: getattr(qld_best["retrieval_mgr"], p) for p in COMMON_PARAMS}
bm25_params = {p: getattr(bm25_best["retrieval_mgr"], p) for p in COMMON_PARAMS}

# Report differences
diff_params = [p for p in COMMON_PARAMS if qld_params[p] != bm25_params[p]]
if diff_params:
    print(f"⚠️  Parameter differences between best QLD and BM25:")
    for p in diff_params:
        print(f"   {p}: QLD={qld_params[p]}, BM25={bm25_params[p]}")
    print(f"   → Testing both configurations for each alpha")
else:
    print(f"✓ QLD and BM25 share identical common parameters")

# Build hybrid grid: each alpha tested with both QLD and BM25 parameter sets
HYBRID_GRID = []
for alpha in HYBRID_ALPHAS:
    for params in [qld_params, bm25_params]:
        HYBRID_GRID.append({
            "retrieval_mgr": HybridRetrievalManager(
                mu=mu, k1=k1, b=b, alpha=alpha, **params
            ),
            "prompt_mgr": qld_best["prompt_mgr"],
        })

print(f"\n✓ Created {len(HYBRID_GRID)} hybrid experiments:")
print(f"   {len(HYBRID_ALPHAS)} alpha values × 2 parameter sets = {len(HYBRID_GRID)} configs")
print(f"   Hybrid params: μ={mu}, k1={k1}, b={b}")

BEST_HYBRID = run_phase(
    phase_name="PHASE 5 — HYBRID RETRIEVAL",
    grid=HYBRID_GRID,
    validation_data=validation_data,
    select_best=False,
)

⚠️  Parameter differences between best QLD and BM25:
   k_docs: QLD=5, BM25=10
   k_passages: QLD=3, BM25=5
   → Testing both configurations for each alpha

✓ Created 22 hybrid experiments:
   11 alpha values × 2 parameter sets = 22 configs
   Hybrid params: μ=1000, k1=1.2, b=0.4

PHASE 5 — HYBRID RETRIEVAL
Total configs: 22
Pending configs: 1
--------------------------------------------------------------------------------
[1/1] Running: HYBRID_kdocs10_kpass5_RERANK_HYBRID_alpha1.0_mu1000_k11.2_b0.4_win250_ovl70_promptdefault_sample0_temp0.1_q100


HYBRID_kdocs10_kpass5_RERANK_HYBRID_alpha1.0_mu1000_k11.2_b0.4_win250_ovl70_promptdefault_sample0_temp0.1_q100: 100%|██████████| 25/25 [07:14<00:00, 17.39s/it]


HYBRID_kdocs10_kpass5_RERANK_HYBRID_alpha1.0_mu1000_k11.2_b0.4_win250_ovl70_promptdefault_sample0_temp0.1_q100
   Retrieval: HYBRID(α=1.0, μ=1000, k1=1.2, b=0.4) → BiEncoder | k_docs=10, k_passages=5
   Prompt: temp=0.1, top_p=0.9, max_tokens=256
   F1=18.38 | P=18.28 | R=30.20 | EM=9.00
   Questions: 100

✓ F1=18.38





## Identifying best configuration

In [36]:
ALL_CONFIGS = (
    PHASE_1_GRID +
    PHASE_2_GRID +
    PHASE_3_GRID +
    PHASE_4_GRID +
    HYBRID_GRID
)

print(f"✓ Total configurations collected: {len(ALL_CONFIGS)}")


✓ Total configurations collected: 100


In [37]:
df_results = pd.read_csv(EXPERIMENT_LOG_PATH)

top10_df = (
    df_results
    .sort_values("f1", ascending=False)
    .head(10)
    .reset_index(drop=True)
)

top10_keys = set(top10_df["config_key"])

print("✓ Top-10 configs by validation F1:")
for i, row in top10_df.iterrows():
    print(f"  {i+1}. {row['config_key']} (F1={row['f1']:.2f})")


✓ Top-10 configs by validation F1:
  1. HYBRID_kdocs5_kpass3_RERANK_HYBRID_alpha0.7_mu1000_k11.2_b0.4_win250_ovl70_promptdefault_sample0_temp0.1_q100 (F1=24.71)
  2. HYBRID_kdocs5_kpass3_RERANK_HYBRID_alpha0.8_mu1000_k11.2_b0.4_win250_ovl70_promptdefault_sample0_temp0.1_q100 (F1=24.71)
  3. HYBRID_kdocs5_kpass3_RERANK_HYBRID_alpha0.9_mu1000_k11.2_b0.4_win250_ovl70_promptdefault_sample0_temp0.1_q100 (F1=24.35)
  4. QLD_kdocs5_kpass3_RERANK_QLD_mu1000_win250_ovl70_promptdefault_sample0_temp0.1_q100 (F1=24.35)
  5. HYBRID_kdocs5_kpass3_RERANK_HYBRID_alpha1.0_mu1000_k11.2_b0.4_win250_ovl70_promptdefault_sample0_temp0.1_q100 (F1=24.35)
  6. QLD_kdocs20_kpass7_RERANK_QLD_mu1000_win200_ovl50_promptstrict_sample1_temp0.3_q100 (F1=24.32)
  7. HYBRID_kdocs5_kpass3_RERANK_HYBRID_alpha0.6_mu1000_k11.2_b0.4_win250_ovl70_promptdefault_sample0_temp0.1_q100 (F1=23.85)
  8. QLD_kdocs20_kpass7_RERANK_QLD_mu1000_win200_ovl50_promptdefault_sample1_temp0.3_q100 (F1=23.40)
  9. QLD_kdocs5_kpass3_RERANK_QLD_

In [38]:
TOP10_CONFIGS = []

for entry in ALL_CONFIGS:
    key = generate_config_key(
        entry["retrieval_mgr"],
        entry["prompt_mgr"],
    )
    if key in top10_keys:
        TOP10_CONFIGS.append(entry)

if len(TOP10_CONFIGS) != 10:
    raise ValueError(f"Expected 10 configs, matched {len(TOP10_CONFIGS)}")
print(f"✓ Matched {len(TOP10_CONFIGS)} configs from ALL_CONFIGS")


✓ Matched 10 configs from ALL_CONFIGS


In [39]:
FULL_VAL_CSV = "results/full_val_top10.csv"

completed_full_val = load_completed_configs(FULL_VAL_CSV)

for i, entry in enumerate(TOP10_CONFIGS, 1):
    key = generate_config_key(
        entry["retrieval_mgr"],
        entry["prompt_mgr"],
    )

    if key in completed_full_val:
        print(f"↺ Skipping already completed: {key}")
        continue

    result = run_experiment(
        name=f"FULL-VAL-TOP{i}",
        df_data=df_val,
        retrieval_manager=entry["retrieval_mgr"],
        prompt_manager=entry["prompt_mgr"],
        max_questions=1000,
        batch_size=4,
        verbose=True,
    )

    save_results_to_csv(
        result=result,
        key=key,
        path=FULL_VAL_CSV,
    )

print(f"\n✓ Full validation results saved to {FULL_VAL_CSV}")


FULL-VAL-TOP1: 100%|██████████| 189/189 [39:51<00:00, 12.65s/it] 



FULL-VAL-TOP1
   Retrieval: QLD(mu=1000) → BiEncoder | k_docs=5, k_passages=3
   Prompt: temp=0.1, top_p=0.9, max_tokens=256
   F1=18.30 | P=18.20 | R=26.07 | EM=10.05
   Questions: 756



FULL-VAL-TOP2:   6%|▌         | 11/189 [01:43<27:54,  9.41s/it]


KeyboardInterrupt: 

In [None]:
df_full_val = pd.read_csv(FULL_VAL_CSV)

best_row = df_full_val.sort_values("f1", ascending=False).iloc[0]
best_key = best_row["config_key"]

best_entry = next(
    entry for entry in TOP10_CONFIGS
    if generate_config_key(
        entry["retrieval_mgr"],
        entry["prompt_mgr"],
    ) == best_key
)

print("✓ Best config after 1000-question validation:")
print(f"  Config key: {best_key}")
print(f"  F1={best_row['f1']:.2f}")


In [None]:
FINAL_TEST_CSV = "results/final_test_best.csv"

completed_final = load_completed_configs(FINAL_TEST_CSV)

if best_key in completed_final:
    print(f"↺ Final test already completed for {best_key}")
else:
    final_result = run_experiment(
        name="FINAL-TEST",
        df_data=df_train,
        retrieval_manager=best_entry["retrieval_mgr"],
        prompt_manager=best_entry["prompt_mgr"],
        max_questions=None,
        batch_size=4,
        verbose=True,
    )

    save_results_to_csv(
        result=final_result,
        key=best_key,
        path=FINAL_TEST_CSV,
    )

    print("✓ Final test completed and saved")


In [None]:
import json
import pandas as pd
from tqdm import tqdm

SUBMISSION_CSV = "submission.csv"
TEST_CSV_PATH = "data/test.csv"

df_test = pd.read_csv(TEST_CSV_PATH)
print(f"✓ Loaded test set with {len(df_test)} questions")

retrieval_manager: RetrievalManager = best_entry["retrieval_mgr"]
prompt_manager: PromptManager = best_entry["prompt_mgr"]

print("✓ Using best configuration:")
print(f"  Retrieval: {retrieval_manager}")
print(f"  Prompt: {prompt_manager}")

rows = []

batch_size = 4
num_batches = (len(df_test) + batch_size - 1) // batch_size

for batch_idx in tqdm(range(num_batches), desc="Kaggle Submission"):
    start = batch_idx * batch_size
    end = min(start + batch_size, len(df_test))
    batch_df = df_test.iloc[start:end]

    batch_questions = []
    batch_ids = []
    batch_contexts = []

    for _, row in batch_df.iterrows():
        qid = row["id"]
        question = row["question"]

        contexts = retrieval_manager.retrieve_context(question)

        batch_ids.append(qid)
        batch_questions.append(question)
        batch_contexts.append(contexts)

    batch_answers = prompt_manager.batch_generate_answers(
        batch_questions,
        batch_contexts
    )

    for qid, answer in zip(batch_ids, batch_answers):
        rows.append({
            "id": qid,
            "prediction": answer,
        })

df_prediction = pd.DataFrame(rows)

# Kaggle-required formatting
df_prediction["prediction"] = df_prediction["prediction"].apply(
    lambda x: json.dumps([x], ensure_ascii=False)
)

df_prediction.to_csv(SUBMISSION_CSV, index=False)

print(f"\n✓ Submission written to {SUBMISSION_CSV}")
print(f"✓ Columns: {list(df_prediction.columns)}")
print("✓ Ready for Kaggle upload")
