## 1. Setup & Dependencies

In [80]:
import pandas as pd
import json
import re
import string
from collections import Counter
from typing import List, Dict, Tuple, Optional, Literal
from dataclasses import dataclass
from tqdm.auto import tqdm
import itertools
tqdm.pandas()
import warnings
warnings.filterwarnings('ignore')

print("✓ Dependencies imported")

✓ Dependencies imported


### Install Required Packages

In [81]:
# !pip install torch torchvision torchaudio
# !pip install pyserini==0.36.0
# !pip install accelerate
# !pip install transformers
# !pip install tqdm
# !pip install python-dotenv

In [None]:
import os

os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-21-openjdk-amd64"
os.environ["PATH"] = os.environ["JAVA_HOME"] + "/bin:" + os.environ["PATH"]

!java -version

In [None]:
# !pip install torch torchvision torchaudio
# !pip install pyserini==0.36.0
# !pip install accelerate
# !pip install transformers
# !pip install tqdm
# !pip install python-dotenv

### Hugging Face Authentication

In [82]:
from huggingface_hub import login
from dotenv import load_dotenv
import os

load_dotenv()

login(os.getenv('HUGGING_FACE_TOKEN'))
print("✓ Logged into Hugging Face")

✓ Logged into Hugging Face


## 2. Data Loading & Preparation

In [83]:
# Load datasets
df_train = pd.read_csv("./data/train.csv", converters={"answers": json.loads})
df_test = pd.read_csv("./data/test.csv")

print(f"Train set: {len(df_train)} questions")
print(f"Test set: {len(df_test)} questions")
print(f"\nSample question: {df_train.iloc[0]['question']}")
print(f"Sample answers: {df_train.iloc[0]['answers']}")

Train set: 3778 questions
Test set: 2032 questions

Sample question: what is the name of justin bieber brother?
Sample answers: ['Jazmyn Bieber', 'Jaxon Bieber']


In [84]:
# Create train/validation split for experiments
RANDOM_SEED = 42
VAL_SIZE = 0.2

df_train_split = df_train.sample(frac=1-VAL_SIZE, random_state=RANDOM_SEED)
df_val = df_train.drop(df_train_split.index).reset_index(drop=True)
df_train_split = df_train_split.reset_index(drop=True)

print(f"✓ Split data:")
print(f"   Training: {len(df_train_split)} questions")
print(f"   Validation: {len(df_val)} questions")

✓ Split data:
   Training: 3022 questions
   Validation: 756 questions


## 3. Retrieval Functions

In [85]:
from pyserini.search import SimpleSearcher
from pyserini.index.lucene import IndexReader

# Load Pyserini index
print("Loading Pyserini index...")
searcher = SimpleSearcher.from_prebuilt_index('wikipedia-kilt-doc')
index_reader = IndexReader.from_prebuilt_index('wikipedia-kilt-doc')

print(f"✓ Index loaded: {index_reader.stats()['documents']} documents")

Loading Pyserini index...
SimpleSearcher class has been deprecated, please use LuceneSearcher from pyserini.search.lucene instead
✓ Index loaded: 5903530 documents


In [None]:
from sentence_transformers import SentenceTransformer
import torch

# Load bi-encoder
print("Loading bi-encoder...")
device = "cuda" if torch.cuda.is_available() else "cpu"
bi_encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=device)

print("✓ Bi-encoder loaded")

In [86]:
from dataclasses import dataclass
from typing import List, Literal
from collections import Counter
import json
import re
import math


@dataclass
class RetrievalConfig:
    """Configuration for document/passage retrieval."""
    k: int = 3
    method: Literal['bm25', 'qld'] = 'qld'
    mu: int = 1000
    k1: float = 0.9
    b: float = 0.4
    use_passages: bool = False
    window: int = 150
    overlap: int = 50
    min_passage_words: int = 30
    max_passages: int = 3
    max_chars: int = 300
    
    def __str__(self):
        method_params = ""
        if self.method == 'qld':
            method_params = f"mu={self.mu}"
        elif self.method == 'bm25':
            method_params = f"k1={self.k1}, b={self.b}"        
        
        if self.use_passages:
            mode_info = f"passages: window={self.window}, overlap={self.overlap}, max={self.max_passages}"
        else:
            mode_info = f"docs: {self.max_chars}chars"
        
        return f"{self.method.upper()}(k={self.k}, {method_params}) | {mode_info}"


def extract_passages(text, window=150, overlap=50, min_words=30):
    """Extract overlapping passages from text."""
    if not text:
        return []
    
    words = text.split()
    if len(words) < min_words:
        return []
    
    passages = []
    step = max(1, window - overlap)
    
    for start in range(0, len(words), step):
        chunk = words[start:start + window]
        
        if len(chunk) < min_words:
            if passages:
                passages[-1] += " " + " ".join(chunk)
            else:
                passages.append(" ".join(chunk))
            break
        
        passages.append(" ".join(chunk))
    
    return passages


def tokenize(text: str):
    return re.findall(r"\b\w+\b", text.lower())


def compute_idf(passages: List[str]):
    df = Counter()
    N = len(passages)

    for p in passages:
        df.update(set(tokenize(p)))

    return {
        t: math.log((N + 1) / (df[t] + 1))
        for t in df
    }


def rerank_passages(query: str, passages: List[str]) -> List[str]:
    """Rerank passages using TF×IDF-weighted query overlap."""
    q_tokens = Counter(tokenize(query))
    idf = compute_idf(passages)
    scored = []

    for p in passages:
        p_tokens = Counter(tokenize(p))
        score = sum(
            q_tokens[t] * p_tokens.get(t, 0) * idf.get(t, 0.0)
            for t in q_tokens
        )
        scored.append((score, p))

    scored.sort(key=lambda x: x[0], reverse=True)
    return [p for _, p in scored]


def retrieve_context(query: str, config: RetrievalConfig) -> List[str]:
    """Retrieve documents or passages for a given query based on the retrieval configuration."""
    if config.method == 'bm25':
        searcher.set_bm25(config.k1, config.b)
    else:
        searcher.set_qld(config.mu)
    
    hits = searcher.search(query, config.k)
    contexts = []

    for hit in hits:
        try:
            doc = searcher.doc(hit.docid)
            data = json.loads(doc.raw())
            content = data['contents'].replace('\n', ' ')
            
            if config.use_passages:
                passages = extract_passages(
                    content,
                    config.window,
                    config.overlap,
                    config.min_passage_words
                )
                contexts.extend(passages)
            else:
                contexts.append(content[:config.max_chars])
        except:
            continue

    if config.use_passages:
        contexts = rerank_passages(query, contexts)
        return contexts[:config.max_passages]

    return contexts


# Test retrieval
query = "Who wrote Harry Potter?"

print("=" * 60)
print("Testing QLD")
test_config_qld = RetrievalConfig(k=3, method='qld', mu=1000)
print(f"{test_config_qld}")
test_docs_qld = retrieve_context(query, test_config_qld)
print(f"✓ Retrieved {len(test_docs_qld)} documents")
for i, doc in enumerate(test_docs_qld, 1):
    print(f"  {i}. {doc[:80]}...")
print()

print("=" * 60)
print("Testing BM25")
test_config_bm25 = RetrievalConfig(k=3, method='bm25', k1=0.9, b=0.4)
print(f"{test_config_bm25}")
test_docs_bm25 = retrieve_context(query, test_config_bm25)
print(f"✓ Retrieved {len(test_docs_bm25)} documents")
for i, doc in enumerate(test_docs_bm25, 1):
    print(f"  {i}. {doc[:80]}...")
print()

print("=" * 60)
test_passages = test_docs_qld


Testing QLD
QLD(k=3, mu=1000) | docs: 300chars
✓ Retrieved 3 documents
  1. Harry Potter Harry Potter is a series of fantasy novels written by British autho...
  2. Bonnie Wright Bonnie Francesca Wright (born 17 February 1991) is an English actr...
  3. Politics of Harry Potter There are many published theories about the politics of...

Testing BM25
BM25(k=3, k1=0.9, b=0.4) | docs: 300chars
✓ Retrieved 3 documents
  1. Bonnie Wright Bonnie Francesca Wright (born 17 February 1991) is an English actr...
  2. The Magical Worlds of Harry Potter The Magical Worlds of Harry Potter: A Treasur...
  3. Harry Potter Harry Potter is a series of fantasy novels written by British autho...



## 4. LLM Generation

In [87]:
import transformers
import torch
import logging

# Suppress transformers warnings
transformers.logging.set_verbosity_error()
logging.getLogger("transformers").setLevel(logging.ERROR)

print("Loading LLM model...")
model_id = "meta-llama/Llama-3.2-1B-Instruct"

pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
)

terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

# Set pad_token for batch processing
pipeline.tokenizer.pad_token = pipeline.tokenizer.eos_token

print(f"✓ Model loaded on: {'GPU' if torch.cuda.is_available() else 'CPU'}")

Loading LLM model...
✓ Model loaded on: GPU


In [88]:
DEFAULT_SYSTEM_PROMPT = (
    "You must respond based strictly on the information in provided passages."
    "Do not incorporate any external knowledge or infer any details beyond what is given."
    "If the answer is not in the context, return 'I dont know'."
    "Do not include explanations, only the final answer!"
)

DEFAULT_USER_PROMPT = (
    "Based on the following documents, provide a concise answer to the question.\n\n"
    "{context}\n\n"
    "Question: {question}\n\n"
    "Answer:"
)

@dataclass
class PromptConfig:
    """Configuration for prompt generation and LLM parameters."""
    system_prompt: str = DEFAULT_SYSTEM_PROMPT
    user_prompt: str = DEFAULT_USER_PROMPT
    temperature: float = 0.1
    top_p: float = 0.9
    max_new_tokens: int = 256
    do_sample: bool = True
    
    def __str__(self):
        return f"temp={self.temperature}, top_p={self.top_p}, max_tokens={self.max_new_tokens}"

def clean_answer(answer: str) -> str:
    """Clean and standardize the generated answer."""
    answer = re.sub(r'^(Answer|The answer is|Based on the .*?,):?\s*', '', answer, flags=re.I)
    answer = answer.rstrip('.')
    if any(phrase in answer.lower() for phrase in ["dont know", "don't know", "do not know", "unknown"]):
        return "unknown"
    return answer.strip()

def create_messages(question: str, contexts: List[str], config: PromptConfig) -> List[Dict]:
    """Create messages for the LLM based on the question, contexts, and prompt configuration."""
    if not contexts:
        context_str = "No relevant documents found."
    else:
        context_str = '\n\n'.join([f"Document {i+1}: {ctx}" for i, ctx in enumerate(contexts)])
    
    return [
        {"role": "system", "content": config.system_prompt},
        {"role": "user", "content": config.user_prompt.format(context=context_str, question=question)}
    ]

def generate_answer(question: str, contexts: List[str], config: PromptConfig) -> str:
    """Generate an answer using the LLM based on the question, contexts, and prompt configuration."""
    messages = create_messages(question, contexts, config)
    
    outputs = pipeline(
        messages,
        max_new_tokens=config.max_new_tokens,
        eos_token_id=terminators,
        do_sample=config.do_sample,
        temperature=config.temperature,
        top_p=config.top_p,
    )
    
    answer = outputs[0]["generated_text"][-1].get('content', '')
    return clean_answer(answer)

def batch_generate_answers(questions: List[str], contexts_list: List[List[str]], config: PromptConfig) -> List[str]:
    """Generate answers for multiple questions in batch."""
    # Create messages for all questions
    batch_messages = [create_messages(q, ctx, config) for q, ctx in zip(questions, contexts_list)]
    
    # Process batch through pipeline
    outputs = pipeline(
        batch_messages,
        max_new_tokens=config.max_new_tokens,
        eos_token_id=terminators,
        do_sample=config.do_sample,
        temperature=config.temperature,
        top_p=config.top_p
    )
    
    # Extract and clean answers
    answers = []
    for output in outputs:
        answer = output[0]["generated_text"][-1].get('content', '')
        answers.append(clean_answer(answer))
    
    return answers

test_prompt_config = PromptConfig(temperature=0.1)
print(f"Testing: {test_prompt_config}")
test_answer = generate_answer(query, test_passages, test_prompt_config)
print(f"✓ Generated answer: '{test_answer}'")

Testing: temp=0.1, top_p=0.9, max_tokens=256
✓ Generated answer: 'J. K. Rowling'


## 5. Evaluation Metrics

In [89]:
def normalize_answer(s: str) -> str:
    """Normalize answer for comparison"""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)
    
    def white_space_fix(text):
        return ' '.join(text.split())
    
    def remove_punc(text):
        return ''.join(ch for ch in text if ch not in set(string.punctuation))
    
    def lower(text):
        return text.lower()
    
    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction: str, ground_truth: str) -> float:
    """Compute token-level F1 score"""
    pred_tokens = normalize_answer(prediction).split()
    gt_tokens = normalize_answer(ground_truth).split()
    
    common = Counter(pred_tokens) & Counter(gt_tokens)
    num_same = sum(common.values())
    
    if len(pred_tokens) == 0 or len(gt_tokens) == 0:
        return int(pred_tokens == gt_tokens)
    if num_same == 0:
        return 0
    
    precision = num_same / len(pred_tokens)
    recall = num_same / len(gt_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def evaluate_predictions(df_gold: pd.DataFrame, predictions: Dict[int, str]) -> Dict:
    """
    Evaluate predictions against ground truth.
    
    Returns:
        Dict with average metrics and individual scores:
        {
            'f1': average_f1,
            'precision': average_precision,
            'recall': average_recall,
            'exact_match': exact_match_percentage,
            'f1_scores': list of individual f1 scores,
            'precision_scores': list of individual precision scores,
            'recall_scores': list of individual recall scores,
            'exact_matches': list of individual exact match flags
        }
    """
    f1_scores = []
    precision_scores = []
    recall_scores = []
    exact_matches = []
    
    for _, row in df_gold.iterrows():
        qid = row['id']
        if qid not in predictions:
            f1_scores.append(0.0)
            precision_scores.append(0.0)
            recall_scores.append(0.0)
            exact_matches.append(0)
            continue
        
        prediction = predictions[qid]
        ground_truths = row['answers']
        
        # Compute metrics for each ground truth and take the best
        best_f1 = 0.0
        best_precision = 0.0
        best_recall = 0.0
        is_exact = 0
        
        for gt in ground_truths:
            # F1 score
            pred_tokens = normalize_answer(prediction).split()
            gt_tokens = normalize_answer(gt).split()
            
            common = Counter(pred_tokens) & Counter(gt_tokens)
            num_same = sum(common.values())
            
            if len(pred_tokens) == 0 or len(gt_tokens) == 0:
                f1 = int(pred_tokens == gt_tokens)
                prec = int(pred_tokens == gt_tokens)
                rec = int(pred_tokens == gt_tokens)
            elif num_same == 0:
                f1 = 0.0
                prec = 0.0
                rec = 0.0
            else:
                prec = num_same / len(pred_tokens)
                rec = num_same / len(gt_tokens)
                f1 = (2 * prec * rec) / (prec + rec)
            
            # Track best scores
            if f1 > best_f1:
                best_f1 = f1
                best_precision = prec
                best_recall = rec
            
            # Exact match
            if normalize_answer(prediction) == normalize_answer(gt):
                is_exact = 1
        
        f1_scores.append(best_f1)
        precision_scores.append(best_precision)
        recall_scores.append(best_recall)
        exact_matches.append(is_exact)
    
    return {
        'f1': 100.0 * sum(f1_scores) / len(f1_scores) if f1_scores else 0.0,
        'precision': 100.0 * sum(precision_scores) / len(precision_scores) if precision_scores else 0.0,
        'recall': 100.0 * sum(recall_scores) / len(recall_scores) if recall_scores else 0.0,
        'exact_match': 100.0 * sum(exact_matches) / len(exact_matches) if exact_matches else 0.0,
        'f1_scores': f1_scores,
        'precision_scores': precision_scores,
        'recall_scores': recall_scores,
        'exact_matches': exact_matches
    }


# Test evaluation
test_predictions = {1: "J.K. Rowling", 2: "Paris", 3: "Shakespeare"}
test_gold = pd.DataFrame({
    'id': [1, 2, 3],
    'answers': [["J.K. Rowling", "Rowling"], ["Earth"], ["William Shakespeare", "Shakespeare"]]
})

test_metrics = evaluate_predictions(test_gold, test_predictions)
print(f"✓ Evaluation test: F1={test_metrics['f1']:.2f}, P={test_metrics['precision']:.2f}, R={test_metrics['recall']:.2f}, EM={test_metrics['exact_match']:.2f}")

✓ Evaluation test: F1=66.67, P=66.67, R=66.67, EM=66.67


## 6. Experiment Framework

In [90]:
def run_experiment(
    name: str,
    df_data: pd.DataFrame,
    retrieval_config: RetrievalConfig,
    prompt_config: PromptConfig,
    max_questions: Optional[int] = None,
    batch_size: int = 4,
    verbose: bool = True
) -> Dict:
    if max_questions:
        df_data = df_data.head(max_questions)
    
    predictions = {}
    
    # Process in batches
    num_batches = (len(df_data) + batch_size - 1) // batch_size
    iterator = tqdm(range(num_batches), desc=name) if verbose else range(num_batches)
    
    for batch_idx in iterator:
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, len(df_data))
        batch_df = df_data.iloc[start_idx:end_idx]
        
        # Retrieve contexts for all questions in batch
        batch_questions = []
        batch_qids = []
        batch_contexts = []
        
        for _, row in batch_df.iterrows():
            question = row['question']
            qid = row['id']
            contexts = retrieve_context(question, retrieval_config)
            
            batch_questions.append(question)
            batch_qids.append(qid)
            batch_contexts.append(contexts)
        
        # Generate answers in batch
        batch_answers = batch_generate_answers(batch_questions, batch_contexts, prompt_config)
        
        # Store predictions
        for qid, answer in zip(batch_qids, batch_answers):
            predictions[qid] = answer
    
    metrics = evaluate_predictions(df_data, predictions)
    
    result = {
        'name': name,
        'retrieval': retrieval_config,
        'prompt': prompt_config,
        'f1_score': metrics['f1'],
        'precision': metrics['precision'],
        'recall': metrics['recall'],
        'exact_match': metrics['exact_match'],
        'num_questions': len(df_data),
        'predictions': predictions,
        'f1_scores': metrics['f1_scores'],
        'precision_scores': metrics['precision_scores'],
        'recall_scores': metrics['recall_scores'],
        'exact_matches': metrics['exact_matches']
    }
    
    if verbose:
        print(f"\n{name}")
        print(f"   Retrieval: {retrieval_config}")
        print(f"   Prompt: {prompt_config}")
        print(f"   F1={metrics['f1']:.2f} | P={metrics['precision']:.2f} | R={metrics['recall']:.2f} | EM={metrics['exact_match']:.2f}")
        print(f"   Questions: {len(df_data)}\n")
    
    return result

# Test experiment
test_retrieval = RetrievalConfig(k=3, method='qld')
test_prompt = PromptConfig(temperature=0.1)
print(f"Testing experiment with:")
print(f"  Retrieval: {test_retrieval}")
print(f"  Prompt: {test_prompt}")

test_exp = run_experiment(
    "Quick Test",
    df_val.head(5),
    test_retrieval,
    test_prompt,
    verbose=True
)

print(f"✓ Experiment framework ready")

Testing experiment with:
  Retrieval: QLD(k=3, mu=1000) | docs: 300chars
  Prompt: temp=0.1, top_p=0.9, max_tokens=256


Quick Test: 100%|██████████| 2/2 [00:30<00:00, 15.07s/it]


Quick Test
   Retrieval: QLD(k=3, mu=1000) | docs: 300chars
   Prompt: temp=0.1, top_p=0.9, max_tokens=256
   F1=14.67 | P=10.67 | R=30.00 | EM=0.00
   Questions: 5

✓ Experiment framework ready





## 7. Grid Search Experimental Framework

### Grid Search Configuration

**Hyperparameters to test:**

In [91]:
# Global experiment settings
EXPERIMENT_SEED = 42
EXPERIMENT_QUESTIONS = 100  # Number of questions per configuration
EXPERIMENT_LOG_PATH = "./results/grid_search_results.csv"

# Grid search parameter space
GRID_PARAMS = {
    # Retrieval parameters
    'method': ['qld', 'bm25'],
    'k': [3, 5, 10],
    'mu': [500, 1000, 2000],
    'k1': [0.6, 0.9, 1.2],
    'b': [0.4, 0.75],
    'use_passages': [False, True],
    'max_chars': [200, 300, 400],
    'window': [100, 150],
    'overlap': [30, 50],
    'max_passages': [3, 5],
    
    # Prompt parameters
    'temperature': [0.1, 0.6],
}

print("="*80)
print("GRID SEARCH EXPERIMENTAL FRAMEWORK")
print("="*80)
print(f"Validation questions per config: {EXPERIMENT_QUESTIONS}")
print(f"Random seed: {EXPERIMENT_SEED}")
print(f"Results cache: {EXPERIMENT_LOG_PATH}")
print("\nParameter Grid:")
for param, values in GRID_PARAMS.items():
    print(f"  {param:20s}: {values}")
print("="*80)

GRID SEARCH EXPERIMENTAL FRAMEWORK
Validation questions per config: 100
Random seed: 42
Results cache: ./results/grid_search_results.csv

Parameter Grid:
  method              : ['qld', 'bm25']
  k                   : [3, 5, 10]
  mu                  : [500, 1000, 2000]
  k1                  : [0.6, 0.9, 1.2]
  b                   : [0.4, 0.75]
  use_passages        : [False, True]
  max_chars           : [200, 300, 400]
  window              : [100, 150]
  overlap             : [30, 50]
  max_passages        : [3, 5]
  temperature         : [0.1, 0.6]


In [92]:
def generate_config_key(retrieval_config: RetrievalConfig, prompt_config: PromptConfig) -> str:
    """Generate unique key from retrieval and prompt configuration."""
    key_parts = []
    key_parts.append(f"method={retrieval_config.method}")
    key_parts.append(f"k={retrieval_config.k}")
    
    if retrieval_config.method == 'qld':
        key_parts.append(f"mu={retrieval_config.mu}")
    elif retrieval_config.method == 'bm25':
        key_parts.append(f"k1={retrieval_config.k1}")
        key_parts.append(f"b={retrieval_config.b}")
    
    if retrieval_config.use_passages:
        key_parts.append("mode=passages")
        key_parts.append(f"win={retrieval_config.window}")
        key_parts.append(f"ovlp={retrieval_config.overlap}")
        key_parts.append(f"maxp={retrieval_config.max_passages}")
    else:
        key_parts.append("mode=docs")
        key_parts.append(f"chars={retrieval_config.max_chars}")
    
    key_parts.append(f"temp={prompt_config.temperature}")
    
    return "__".join(key_parts)

def save_result_to_csv(result: Dict, config_key: str, filepath: str):
    """Save single experiment result to CSV (append mode)."""    
    retrieval = result['retrieval']
    prompt = result['prompt']
    
    row = {
        'config_key': config_key,
        'f1_score': result['f1_score'],
        'precision': result['precision'],
        'recall': result['recall'],
        'exact_match': result['exact_match'],
        'num_questions': result['num_questions'],
        
        # Retrieval params
        'method': retrieval.method,
        'k': retrieval.k,
        'mu': retrieval.mu if retrieval.method == 'qld' else None,
        'k1': retrieval.k1 if retrieval.method == 'bm25' else None,
        'b': retrieval.b if retrieval.method == 'bm25' else None,
        'use_passages': retrieval.use_passages,
        'max_chars': retrieval.max_chars if not retrieval.use_passages else None,
        'window': retrieval.window if retrieval.use_passages else None,
        'overlap': retrieval.overlap if retrieval.use_passages else None,
        'max_passages': retrieval.max_passages if retrieval.use_passages else None,
        'min_passage_words': retrieval.min_passage_words if retrieval.use_passages else None,
        
        # Prompt params
        'temperature': prompt.temperature,
        'top_p': prompt.top_p,
        'max_new_tokens': prompt.max_new_tokens,
        'do_sample': prompt.do_sample,
    }
    
    df_row = pd.DataFrame([row])
    
    # Create directory if needed
    os.makedirs(os.path.dirname(filepath), exist_ok=True)
    
    # Append to CSV (create with header if doesn't exist)
    if os.path.exists(filepath):
        df_row.to_csv(filepath, mode='a', header=False, index=False)
    else:
        df_row.to_csv(filepath, mode='w', header=True, index=False)

def load_existing_results(filepath: str) -> set:
    """Load set of already-completed configuration keys."""
    if os.path.exists(filepath):
        df = pd.read_csv(filepath)
        return set(df['config_key'].values)
    return set()

def generate_all_configs(grid_params: Dict) -> List[Tuple[str, RetrievalConfig, PromptConfig]]:
    """Generate all valid configurations from parameter grid."""
    configs = []
    
    # Generate all combinations using itertools.product
    for temp, method, k, use_passages in itertools.product(
        grid_params['temperature'],
        grid_params['method'],
        grid_params['k'],
        grid_params['use_passages']
    ):
        # Create prompt config
        prompt_config = PromptConfig(temperature=temp, top_p=0.9, max_new_tokens=256)
        
        # Method-specific parameters (mu for QLD, k1/b for BM25)
        if method == 'qld':
            method_params = [(mu, 0.9, 0.4) for mu in grid_params['mu']]
        elif method == 'bm25':
            method_params = [
                (1000, k1, b) 
                for k1, b in itertools.product(grid_params['k1'], grid_params['b'])
            ]
        else:
            continue
        
        # Generate configs for each method parameter combination
        for mu, k1, b in method_params:
            if use_passages:
                # Passage mode: combine window, overlap, max_passages
                for window, overlap, max_passages in itertools.product(
                    grid_params['window'],
                    grid_params['overlap'],
                    grid_params['max_passages']
                ):
                    config = RetrievalConfig(
                        k=k, method=method, mu=mu, k1=k1, b=b,
                        use_passages=True, window=window, overlap=overlap, 
                        max_passages=max_passages
                    )
                    key = generate_config_key(config, prompt_config)
                    configs.append((key, config, prompt_config))
            else:
                # Document mode: iterate over max_chars
                for max_chars in grid_params['max_chars']:
                    config = RetrievalConfig(
                        k=k, method=method, mu=mu, k1=k1, b=b,
                        use_passages=False, max_chars=max_chars
                    )
                    key = generate_config_key(config, prompt_config)
                    configs.append((key, config, prompt_config))
    

    return configs

print("✓ Grid search helper functions ready")

✓ Grid search helper functions ready


### Generate Configuration Matrix

In [93]:
# Generate all valid configurations
all_configs = generate_all_configs(GRID_PARAMS)

print(f"✓ Generated {len(all_configs)} unique configurations")
print(f"\nConfiguration breakdown:")
print(f"  Methods: {len(GRID_PARAMS['method'])}")
print(f"  k values: {len(GRID_PARAMS['k'])}")
print(f"  QLD mu values: {len(GRID_PARAMS['mu'])}")
print(f"  BM25 (k1 × b): {len(GRID_PARAMS['k1'])} × {len(GRID_PARAMS['b'])} = {len(GRID_PARAMS['k1']) * len(GRID_PARAMS['b'])}")
print(f"  Document mode (max_chars): {len(GRID_PARAMS['max_chars'])}")
print(f"  Passage mode (window × overlap × max_passages): {len(GRID_PARAMS['window'])} × {len(GRID_PARAMS['overlap'])} × {len(GRID_PARAMS['max_passages'])} = {len(GRID_PARAMS['window']) * len(GRID_PARAMS['overlap']) * len(GRID_PARAMS['max_passages'])}")
print(f"  Temperature values: {len(GRID_PARAMS['temperature'])}")

# Show sample configurations
print(f"\nSample configurations:")
for i in range(min(5, len(all_configs))):
    key, retrieval_config, prompt_config = all_configs[i]
    print(f"  {i+1}. {key}")
    print(f"     Retrieval: {retrieval_config}")
    print(f"     Prompt: {prompt_config}")
    print()

✓ Generated 594 unique configurations

Configuration breakdown:
  Methods: 2
  k values: 3
  QLD mu values: 3
  BM25 (k1 × b): 3 × 2 = 6
  Document mode (max_chars): 3
  Passage mode (window × overlap × max_passages): 2 × 2 × 2 = 8
  Temperature values: 2

Sample configurations:
  1. method=qld__k=3__mu=500__mode=docs__chars=200__temp=0.1
     Retrieval: QLD(k=3, mu=500) | docs: 200chars
     Prompt: temp=0.1, top_p=0.9, max_tokens=256

  2. method=qld__k=3__mu=500__mode=docs__chars=300__temp=0.1
     Retrieval: QLD(k=3, mu=500) | docs: 300chars
     Prompt: temp=0.1, top_p=0.9, max_tokens=256

  3. method=qld__k=3__mu=500__mode=docs__chars=400__temp=0.1
     Retrieval: QLD(k=3, mu=500) | docs: 400chars
     Prompt: temp=0.1, top_p=0.9, max_tokens=256

  4. method=qld__k=3__mu=1000__mode=docs__chars=200__temp=0.1
     Retrieval: QLD(k=3, mu=1000) | docs: 200chars
     Prompt: temp=0.1, top_p=0.9, max_tokens=256

  5. method=qld__k=3__mu=1000__mode=docs__chars=300__temp=0.1
     Retriev

### Run Grid Search

**Important:** This will run all configurations sequentially. Each result is cached to CSV immediately after completion, so you can safely interrupt and resume.

In [96]:
# Load already-completed configurations (if resuming)
completed_keys = load_existing_results(EXPERIMENT_LOG_PATH)
print(f"Already completed: {len(completed_keys)} configurations")

# Filter to only pending configurations
pending_configs = [(key, r_cfg, p_cfg) for key, r_cfg, p_cfg in all_configs if key not in completed_keys]
print(f"Remaining to run: {len(pending_configs)} configurations")
print("="*80)

# Run grid search
validation_data = df_val.sample(n=EXPERIMENT_QUESTIONS, random_state=EXPERIMENT_SEED).reset_index(drop=True)

for idx, (config_key, retrieval_config, prompt_config) in enumerate(pending_configs, start=1):
    print(f"\n[{idx}/{len(pending_configs)}] Running: {config_key}")
    
    try:
        result = run_experiment(
            name=config_key,
            df_data=validation_data,
            retrieval_config=retrieval_config,
            prompt_config=prompt_config,
            verbose=False
        )
        
        # Save to CSV immediately
        save_result_to_csv(result, config_key, EXPERIMENT_LOG_PATH)
        
        print(f"✓ F1={result['f1_score']:.2f} | P={result['precision']:.2f} | R={result['recall']:.2f} | EM={result['exact_match']:.2f} | Saved")
        
    except Exception as e:
        print(f"✗ ERROR: {str(e)}")
        continue

print("\n" + "="*80)
print(f"✓ Grid search complete!")
print(f"Results saved to: {EXPERIMENT_LOG_PATH}")
print("="*80)

Already completed: 122 configurations
Remaining to run: 472 configurations

[1/472] Running: method=bm25__k=3__k1=0.6__b=0.4__mode=passages__win=150__ovlp=30__maxp=5__temp=0.1
✓ F1=15.22 | P=15.32 | R=20.98 | EM=9.00 | Saved

[2/472] Running: method=bm25__k=3__k1=0.6__b=0.4__mode=passages__win=150__ovlp=50__maxp=3__temp=0.1
✓ F1=9.99 | P=10.77 | R=15.50 | EM=5.00 | Saved

[3/472] Running: method=bm25__k=3__k1=0.6__b=0.4__mode=passages__win=150__ovlp=50__maxp=5__temp=0.1


KeyboardInterrupt: 

### Analyze Results

In [97]:
# Load and analyze results
df_results = pd.read_csv(EXPERIMENT_LOG_PATH)

print(f"Total configurations tested: {len(df_results)}")
print(f"\nTop 10 configurations by F1 Score:")
print("="*80)

df_top = df_results.nlargest(10, 'f1_score')
for idx, row in df_top.iterrows():
    print(f"\nRank #{list(df_top.index).index(idx) + 1}")
    print(f"  Config: {row['config_key']}")
    print(f"  F1={row['f1_score']:.2f} | Precision={row['precision']:.2f} | Recall={row['recall']:.2f} | EM={row['exact_match']:.2f}")
    print(f"  Method: {row['method']} | k={row['k']}")
    if row['method'] == 'qld':
        print(f"  QLD: mu={row['mu']}")
    elif row['method'] == 'bm25':
        print(f"  BM25: k1={row['k1']}, b={row['b']}")
    if row['use_passages']:
        print(f"  Passages: window={row['window']}, overlap={row['overlap']}, max={row['max_passages']}")
    else:
        print(f"  Documents: max_chars={row['max_chars']}")

print("\n" + "="*80)
print("\nMethod comparison:")
print(df_results.groupby('method')['f1_score'].agg(['mean', 'std', 'max', 'min']))

print("\nDocument vs Passage mode:")
print(df_results.groupby('use_passages')['f1_score'].agg(['mean', 'std', 'max', 'min']))

Total configurations tested: 124

Top 10 configurations by F1 Score:

Rank #1
  Config: method=bm25__k=3__k1=0.6__b=0.4__mode=passages__win=150__ovlp=30__maxp=5__temp=0.1
  F1=15.22 | Precision=15.32 | Recall=20.98 | EM=9.00
  Method: bm25 | k=3
  BM25: k1=0.6, b=0.4
  Passages: window=150.0, overlap=30.0, max=5.0

Rank #2
  Config: method=qld__k=3__mu=2000__mode=passages__win=100__ovlp=30__maxp=3__temp=0.1
  F1=13.67 | Precision=11.99 | Recall=20.23 | EM=1.00
  Method: qld | k=3
  QLD: mu=2000.0
  Passages: window=100.0, overlap=30.0, max=3.0

Rank #3
  Config: method=bm25__k=3__k1=0.6__b=0.4__mode=passages__win=100__ovlp=30__maxp=3__temp=0.1
  F1=13.63 | Precision=11.21 | Recall=23.27 | EM=1.00
  Method: bm25 | k=3
  BM25: k1=0.6, b=0.4
  Passages: window=100.0, overlap=30.0, max=3.0

Rank #4
  Config: method=qld__k=3__mu=500__mode=passages__win=150__ovlp=30__maxp=5__temp=0.1
  F1=13.17 | Precision=11.05 | Recall=24.40 | EM=1.00
  Method: qld | k=3
  QLD: mu=500.0
  Passages: window=