# MRAG Baseline Implementation

This notebook implements a baseline MRAG (Multi-hop Retrieval-Augmented Generation) system for temporal question answering on the TempRAGEval dataset.

## 1. Setup and Installation

Install required packages for the MRAG baseline implementation.

In [None]:
%%bash
# Need to install numpy first to avoid version conflicts
pip install -q "numpy==1.24.3" --no-cache-dir
pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip install -q transformers==4.44.2 accelerate==0.34.2 sentencepiece==0.2.0
pip install -q --no-cache-dir faiss-cpu==1.8.0.post1
pip install -q datasets==2.21.0 tqdm==4.66.5
pip install -q nltk==3.9.1
pip install -q wikipedia==1.4.0 wikipedia-api==0.7.1
# pattern library needs to be installed from GitHub
pip install -q git+https://github.com/clips/pattern.git@master


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 10.9/10.9 MB 146.9 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 43.7/43.7 kB 3.4 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 9.5/9.5 MB 149.3 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 324.4/324.4 kB 30.5 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 78.3 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.6/3.6 MB 118.6 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 61.0/61.0 kB 49.9 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 27.0/27.0 MB 15.0 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.0/18.0 MB 324.6 MB/s eta 0:00:00
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.6/57.6 kB 4.7 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 527.3/527.3 kB 37.1 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.4/78.4 kB 6.9 MB/s eta 0:00:00
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 17

  error: subprocess-exited-with-error
  
  × Getting requirements to build wheel did not run successfully.
  │ exit code: 1
  ╰─> See above for output.
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
error: subprocess-exited-with-error

× Getting requirements to build wheel did not run successfully.
│ exit code: 1
╰─> See above for output.

note: This error originates from a subprocess, and is likely not a problem with pip.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
opencv-contrib-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= "3.9", but you have numpy 1.26.4 which is incompatible.
pytensor 2.35.1 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.
opencv-python 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= "3.9", but you have numpy 1.26.4 which is incompatible.
jax 0.7.2 req

In [None]:
# If numpy compatibility errors show up, run this cell and restart the runtime
# After restarting, skip the install cell above and continue from here
import numpy as np
print(f"NumPy version: {np.__version__}")

import nltk
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger_eng')
nltk.download('wordnet')


NumPy version: 1.26.4


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger_eng to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger_eng is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

## 2. Utility Functions and Text Processing

Helper functions for text preprocessing, keyword expansion, and prompt generation.


In [None]:
# Helper functions for text processing and keyword expansion
import re, json, collections
import numpy as np
from tqdm import tqdm
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.tag import pos_tag
from nltk.stem import WordNetLemmatizer
import unicodedata

# Try importing pattern library for verb forms, fallback if it doesn't work
# pattern sometimes has issues with Python 3.12, so we need a backup
try:
    from pattern.en import lexeme
    # Quick test to make sure it actually works
    _test_word = lexeme('test')
    _PATTERN_AVAILABLE = True
except (ImportError, RuntimeError, StopIteration):
    _PATTERN_AVAILABLE = False
    # Fallback function for verb forms - not as good as pattern but works
    def lexeme(word):
        # Basic fallback that returns common verb forms
        common_forms = {
            'be': ['am', 'is', 'are', 'was', 'were', 'been', 'being'],
            'have': ['has', 'had', 'having'],
            'do': ['does', 'did', 'done', 'doing'],
            'go': ['goes', 'went', 'gone', 'going'],
            'get': ['gets', 'got', 'gotten', 'getting'],
            'make': ['makes', 'made', 'making'],
            'take': ['takes', 'took', 'taken', 'taking'],
            'come': ['comes', 'came', 'coming'],
            'see': ['sees', 'saw', 'seen', 'seeing'],
            'know': ['knows', 'knew', 'known', 'knowing'],
        }
        word_lower = word.lower()
        if word_lower in common_forms:
            return [word] + common_forms[word_lower]
        # Otherwise just return basic inflections
        return [word, word + 's', word + 'ed', word + 'ing', word + 'd']

lemmatizer = WordNetLemmatizer()

number_map = {
    '1': 'one','2':'two','3':'three','4':'four','5':'five','6':'six','7':'seven','8':'eight','9':'nine','10':'ten',
    '11':'eleven','12':'twelve','13':'thirteen','14':'fourteen','15':'fifteen','16':'sixteen','17':'seventeen',
    '18':'eighteen','19':'nineteen','20':'twenty','1st':'first','2nd':'second','3rd':'third','4th':'fourth',
    '5th':'fifth','6th':'sixth','7th':'seventh','8th':'eighth','9th':'ninth',
}
number_map_b = {v: k for k, v in number_map.items()}

month_to_number = {
    "january":1,"february":2,"march":3,"april":4,"may":5,"june":6,"july":7,"august":8,"september":9,"october":10,"november":11,"december":12
}
short_month_to_number = {
    "jan":1,"feb":2,"mar":3,"apr":4,"may":5,"jun":6,"jul":7,"aug":8,"sep":9,"oct":10,"nov":11,"dec":12
}

def find_month(w: str):
    w = w.lower()
    for m in month_to_number:
        if m in w:
            return month_to_number[m]
    for m in short_month_to_number:
        if m in w:
            return short_month_to_number[m]
    return None

def remove_implicit_condition(no_time_question: str):
    mapping_type = {
        ' latest':'last',' last':'last',' first':'first',' earliest':'first',
        ' most recent':'last',' recent':'last',
    }
    implicit_condition = None
    for key in mapping_type:
        if key in no_time_question:
            no_time_question = no_time_question.replace(key, '')
            implicit_condition = mapping_type[key]
            break
    no_time_question = no_time_question.strip()
    if no_time_question[-1] not in '.?':
        no_time_question += '?'
    return no_time_question, implicit_condition

def get_wordnet_pos(treebank_tag):
    if treebank_tag.startswith('J'):
        return 'a'
    if treebank_tag.startswith('V'):
        return 'v'
    if treebank_tag.startswith('N'):
        return 'n'
    if treebank_tag.startswith('R'):
        return 'r'
    return 'n'

def replace_dates(text: str):
    pattern = r'(\b\d{4})[–-](\d{2}\b)'
    def repl(m):
        start = m.group(1)
        end = start[:2] + m.group(2)
        return ' '.join(str(i) for i in range(int(start), int(end) + 1))
    return re.sub(pattern, repl, text)

def expand_year_range(text: str):
    def repl(m):
        s, e = int(m.group(1)), int(m.group(2))
        return ' '.join(str(y) for y in range(s, e + 1))
    return re.sub(r'(\d{4})[–-](\d{4})', repl, text)

def year_identifier(timestamp: str):
    timestamp = replace_dates(timestamp)
    timestamp = expand_year_range(timestamp)
    years = re.findall(r'\b(\d{4})(?:s)?\b', timestamp)
    if not years:
        return None
    return sorted(set(map(int, years)))

def expand_keywords(keyword_list, normalized_question, verbose=False):
    tokens = word_tokenize(normalized_question)
    tagged = pos_tag(tokens)
    q_words = [w for w, _ in tagged]
    q_tags = [t for _, t in tagged]

    keyword_types = {}
    expanded_keyword_list, keyword_type_list = [], []
    for kw in {k: [] for k in keyword_list}:
        kw_list = kw.split()
        new_kw = []
        if kw and kw[0].isupper():
            keyword_types[kw] = 'special'
        elif kw.lower() in number_map:
            keyword_types[kw] = 'numeric'
            new_kw.append(number_map[kw.lower()])
        elif kw.lower() in number_map_b:
            keyword_types[kw] = 'numeric'
            new_kw.append(number_map_b[kw.lower()])
        else:
            n_words = len(kw_list)
            index = None
            try:
                for i in range(len(q_words)):
                    if all(q_words[i + j].lower() == kw_list[j].lower() for j in range(n_words)):
                        index = i
                        break
            except Exception:
                pass
            if index is not None:
                last_word = kw_list[-1]
                last_index = index + n_words - 1
                last_tag = q_tags[last_index]
                if last_tag.startswith('J'):
                    keyword_types[kw] = 'superlative' if last_word.endswith('est') or last_word.lower() == 'most' else 'adjective'
                else:
                    keyword_types[kw] = 'general'
                    new_kw += [kw.replace(last_word, x) for x in lexeme(last_word)]
            else:
                keyword_types[kw] = 'general'
        tmp = list(set([kw] + new_kw))
        if keyword_types[kw] == 'special' and ' and ' in kw:
            tmp += [kw.replace(' and ', '&'), kw.replace(' and ', ' & '), kw.replace(' and ', " N' ")]
        if '-' in kw:
            tmp.append(kw.replace('-', ' '))
        expanded_keyword_list.append(tmp)
        keyword_type_list.append(keyword_types[kw])
    if verbose:
        print('after:', expanded_keyword_list, keyword_type_list)
    return expanded_keyword_list, keyword_type_list

def count_keyword_scores(text, expanded_keyword_list, keyword_type_list):
    text = text.lower()
    score = 0.0
    weights = {'special':1,'superlative':0.7,'general':0.4,'numeric':0.5,'adjective':0.4}
    for keywords, kw_type in zip(expanded_keyword_list, keyword_type_list):
        found = any(kw.lower() in text for kw in keywords) if kw_type == 'general' else any(kw.lower() in text for kw in keywords)
        if found:
            score += weights[kw_type]
    return score

def check_no_knowledge(s: str):
    if not isinstance(s, str): return True
    if len(s) == 0: return True
    s = s.lower()
    if 'not' in s.split(): return True
    phrases = ['unknown', 'none', 'unable', 'no information', 'no answer', "don't have"]
    return any(p in s for p in phrases)

def force_string(item):
    if isinstance(item, str):
        return item
    if isinstance(item, list):
        item = item[0] if len(item) > 0 else ''
    try:
        return str(item)
    except Exception:
        return ''

# Prompt templates for LLM generation

def LLMGenerations(document: str, question: str, short: bool = False):
    prompt = f"""You are a summarizer summarizing a retrieved document about a user question. Keep the key dates in the summarization. Write "None" if the document has no relevant content about the question.

There are some examples for you to refer to:
<Document>
David Beckham | As the summer 2003 transfer window approached, Manchester United appeared keen to sell Beckham to Barcelona and the two clubs even announced that they reached a deal for Beckham's transfer, but instead he joined reigning Spanish champions Real Madrid for €37 million on a four-year contract. Beckham made his Galaxy debut, coming on for Alan Gordon in the 78th minute of a 0–1 friendly loss to Chelsea as part of the World Series of Soccer on 21 July 2007.
</Document>
<Question>
David Beckham played for which team?
</Question>
<Summarization>
David Beckham played for Real Madrid from 2003 to 2007 and for LA Galaxy from July 21, 2007.
</Summarization>

<Document>
Houston Rockets | The Houston Rockets have won the NBA championship twice in their history. Their first win came in 1994, when they defeated the New York Knicks in a seven-game series. The following year, in 1995, they claimed their second title by sweeping the Orlando Magic. Despite several playoff appearances in the 2000s and 2010s, the Rockets have not reached the NBA Finals since their last championship victory in 1995.
</Document>
<Question>
When did the Houston Rockets win the NBA championship?
</Question>
<Summarization>
The Houston Rockets won the NBA championship twice in 1994 and 1995.
</Summarization>

<Document>
India | India has had several distinguished presidents throughout its history. In 21 July 1977, Neelam Sanjiva Reddy was elected as the sixth President of India. Years later, in 1997, K. R. Narayanan became the first Dalit to hold the office, serving until 2002. In 2022, Droupadi Murmu was elected as the 15th President, making her the first tribal woman to serve as the country's president.
</Document>
<Question>
Who serve as President of India?
</Question>
<Summarization>
Neelam Sanjiva Reddy became the sixth President in 21 July 1977. K. R. Narayanan, the first Dalit president, served from 1997 to 2002. In 2022, Droupadi Murmu became the 15th President and the first tribal woman to hold the position.
</Summarization>

<Document>
Doris Schröder-Köpf | Köpf and partner Sven Kuntze moved to New York City in 1990, where they had a daughter named Klara in the following year. Soon after the birth the pair separated and Köpf moved back to Bavaria with the child. In October 1997, Köpf married Gerhard Schröder, then Minister-President of Lower Saxony.
</Document>
<Question>
Who was the spouse of Doris Schröder?
</Question>
<Summarization>
Doris Schröder-Köpf married Gerhard Schröder, then Minister-President of Lower Saxony, in October 1997.
</Summarization>

<Document>
The Lost World: Jurassic Park | The Lost World: Jurassic Park is a 1997 American science fiction action film. In Thailand, The Lost World became the country's highest-grossing film of all time. It ultimately grossed $229.1 million in the U.S. and $389.5 million internationally, for a total of $618.6 million worldwide. The film sold an estimated 49,910,000 tickets in North America.
</Document>
<Question>
What was the worldwide box office of Jurassic movie?
</Question>
<Summarization>
The worldwide box office for The Lost World: Jurassic Park (1997) was $618.6 million.
</Summarization>
"""
    ask = f"""
Now your document and question are
<Document>
{document}
</Document>
<Question>
{question}?
</Question>
<Summarization>
"""
    return prompt + ask


def c_prompt(query: str, texts: str):
    return f"""As an assistant, your task is to answer the question based on the given knowledge. Answer the given question, you can refer to the document provided. Your answer should be after <Answer>.
The given knowledge will be after the <Context> tag. You can refer to the knowledge to answer the question.
If the context knowledge does not contain the answer, answer the question directly.

Now your question and context knowledge are
<Context>
{texts}
</Context>
<Question>
{query}
</Question>
<Answer>
"""


## 3. Model Loading

Load the language model for text generation.


In [None]:
# Load the language model for text generation
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

LLM_ID = "microsoft/Phi-3.5-mini-instruct"

tokenizer = AutoTokenizer.from_pretrained(LLM_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    LLM_ID,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto"
)

def generate_text(prompt: str, max_new_tokens: int = 200, temperature: float = 0.2, top_p: float = 0.95):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            pad_token_id=tokenizer.eos_token_id,
        )
    text = tokenizer.decode(out_ids[0], skip_special_tokens=True)
    # Stop at any of these tags if they appear in the output
    for stopper in ['</Keywords>', '</Summarization>', '</Answer>', '</Info>', '</Sentences>', '</Sentence>', '</Response>']:
        if stopper in text:
            text = text.split(stopper)[0]
    # Remove the prompt from the beginning if it's still there
    if text.startswith(prompt):
        text = text[len(prompt):]
    return text.strip()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/306 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.67G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/195 [00:00<?, ?B/s]

### 3.1 HuggingFace Authentication

Authenticate with HuggingFace to access models.


In [None]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## 4. Dataset Loading

Load the TempRAGEval dataset for evaluation.

In [None]:
# Load TempRAGEval
from datasets import load_dataset

dataset = load_dataset("siyue/TempRAGEval")
test_data = dataset["test"]
print(len(test_data))
print(test_data[0])


Downloading data:   0%|          | 0.00/470k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/1244 [00:00<?, ? examples/s]

1244
{'id': 's_1', 'original_dataset': 'situatedqa', 'original_id': '-4.73E+18', 'original_question': 'when was the last time the dodgers played yankees in the world series as of 1991', 'original_answer': '1981', 'key_time': '1981', 'question': 'When was the time the Dodgers played the Yankees in the World Series?\n', 'exact_time': 0, 'time_relation': None, 'answer': '1981', 'gold_evidence_1': 'In 1998, the Yankees celebrated the 20th anniversary of the 1977, 1978 and 1981 World Series that they played against the Los Angeles Dodgers, and invited some members of those Dodger teams.', 'gold_evidence_2': 'although the Yankees subsequently met and beat the now-San Francisco Giants in 1962, and played the now-Los Angeles Dodgers four times, losing to them in a four-game sweep in 1963, beating them back-to-back in 1977 and 1978 and losing to them in 1981'}


## 5. Data Preprocessing

Parse time relations and normalize questions from the dataset.


In [None]:
# Preprocess examples (time relation parsing + normalization)

def preprocess_examples(examples):
    processed = []
    for ex in tqdm(examples, desc="Preprocessing time info"):
        question = ex["question"].strip()
        time_relation = ex.get("time_relation", "") or ""
        time_relation = time_relation.strip().lower() if time_relation else ""

        time_relation_type = ""
        years, months = [], []
        no_time_question = question

        if time_relation != "":
            parts = question.split(time_relation)
            no_time_question = time_relation.join(parts[:-1])
            date = parts[-1]

            yrs = year_identifier(date) or []
            if len(yrs) > 2:
                yrs = [min(yrs), max(yrs)]
            years = yrs

            if len(years) > 1:
                time_relation_type = "between"
            elif time_relation in ["before", "as of", "by", "until"]:
                time_relation_type = "before"
            elif time_relation in ["from", "since", "after"]:
                time_relation_type = "after"
            else:
                time_relation_type = "other"

            def append_month(month_str):
                m = find_month(month_str)
                months.append(m if m else 0)

            if time_relation_type == "between":
                delimiters = ["and", "to", "until"]
                if any(d in date for d in delimiters):
                    delimiter = [d for d in delimiters if d in date][0]
                    for w in date.split(delimiter):
                        append_month(w.strip())
                else:
                    months = [0, 0]
            else:
                append_month(date.strip())

        normalized_question, implicit_condition = remove_implicit_condition(no_time_question)
        if normalized_question and normalized_question[-1] in ".?!":
            normalized_question = normalized_question[:-1]

        processed.append({
            "id": ex["id"],
            "question": question,
            "normalized_question": normalized_question,
            "implicit_condition": implicit_condition,
            "time_relation": time_relation,
            "time_relation_type": time_relation_type,
            "years": years,
            "months": months,
            "answers": (
                [a.strip() for a in ex.get("answer", "").split(" | ") if a.strip()]
                if ex.get("answer")
                else (ex.get("answers") or [])
            ),
            "gold_evidences": [x for x in [ex.get("gold_evidence_1",""), ex.get("gold_evidence_2","")] if x],
            "source": ex.get("original_dataset", ex.get("source","")),
            "exact": ex.get("exact_time", ex.get("exact", "")),
        })
    return processed

examples = preprocess_examples(test_data)
examples[:1] and print(examples[0])



Preprocessing time info:   0%|          | 0/1244 [00:00<?, ?it/s][A
Preprocessing time info: 100%|██████████| 1244/1244 [00:00<00:00, 7875.32it/s]

{'id': 's_1', 'question': 'When was the time the Dodgers played the Yankees in the World Series?', 'normalized_question': 'When was the time the Dodgers played the Yankees in the World Series', 'implicit_condition': None, 'time_relation': '', 'time_relation_type': '', 'years': [], 'months': [], 'answers': ['1981'], 'gold_evidences': ['In 1998, the Yankees celebrated the 20th anniversary of the 1977, 1978 and 1981 World Series that they played against the Los Angeles Dodgers, and invited some members of those Dodger teams.', 'although the Yankees subsequently met and beat the now-San Francisco Giants in 1962, and played the now-Los Angeles Dodgers four times, losing to them in a four-game sweep in 1963, beating them back-to-back in 1977 and 1978 and losing to them in 1981'], 'source': 'situatedqa', 'exact': 0}





## 6. Retrieval Setup

Initialize Contriever model and Wikipedia API for document retrieval.


In [None]:
# Contriever retrieval on Wikipedia + FAISS index
import wikipedia as wk
import wikipediaapi
from transformers import AutoModel, AutoTokenizer
import torch, faiss

CONTRIEVER_ID = "facebook/contriever-msmarco"
ctr_tokenizer = AutoTokenizer.from_pretrained(CONTRIEVER_ID)
ctr_model = AutoModel.from_pretrained(CONTRIEVER_ID).to("cuda" if torch.cuda.is_available() else "cpu")

wiki_api = wikipediaapi.Wikipedia(
    user_agent='MRAG-Baseline-Colab/1.0 (https://github.com/yourusername)',
    language='en'
)

# Answer verification utilities (from contriever evaluation)
class SimpleTokenizer:
    def __init__(self):
        self.token_pattern = re.compile(r'\S+')

    def tokenize(self, text, uncased=False):
        matches = re.finditer(self.token_pattern, text)
        if uncased:
            tokens = [m.group().lower() for m in matches]
        else:
            tokens = [m.group() for m in matches]
        return tokens

def _normalize(text):
    """Normalize unicode text."""
    return unicodedata.normalize('NFD', text)

def has_answer(answers, text, tokenizer):
    """Check if a document contains an answer string."""
    text = _normalize(text)
    text = tokenizer.tokenize(text, uncased=True)

    for answer in answers:
        answer = _normalize(answer)
        answer = tokenizer.tokenize(answer, uncased=True)
        for i in range(0, len(text) - len(answer) + 1):
            if answer == text[i: i + len(answer)]:
                return True
    return False

answer_tokenizer = SimpleTokenizer()


def embed_texts(texts, batch_size=64):  # Increased batch size for Colab Pro
    device = ctr_model.device
    all_vecs = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        tok = ctr_tokenizer(batch, padding=True, truncation=True, return_tensors="pt", max_length=512).to(device)
        with torch.no_grad():
            out = ctr_model(**tok, output_hidden_states=True).last_hidden_state
            attn = tok["attention_mask"]
            masked = out.masked_fill(~attn[..., None].bool(), 0.0)
            emb = masked.sum(dim=1) / attn.sum(dim=1)[..., None]
            emb = torch.nn.functional.normalize(emb, dim=-1)
        all_vecs.append(emb.cpu())
    return torch.cat(all_vecs, dim=0).numpy().astype("float32")


def contriever_retrieve(normalized_question, ctxs, topk=50):
    if len(ctxs) == 0:
        return []
    passages = [c["title"] + " | " + c["text"] for c in ctxs]
    p_emb = embed_texts(passages)
    dim = p_emb.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(p_emb)
    q_emb = embed_texts([normalized_question])
    scores, idxs = index.search(q_emb, topk)
    out = []
    for i, s in zip(idxs[0], scores[0]):
        c = dict(ctxs[i])
        c["score"] = float(s)
        out.append(c)
    return out

# Optimized corpus building for Colab Pro (larger corpus)
def build_wikipedia_corpus(seed_queries=None, n_titles=500, per_title_sections=15, chunk_sents=8):
    """
    Build a larger corpus from Wikipedia for filtering and evaluation.

    Args:
        seed_queries: List of queries to seed Wikipedia search (if None, uses generic topics)
        n_titles: Number of Wikipedia titles to retrieve (optimized for Colab Pro)
        per_title_sections: Number of chunks per title
        chunk_sents: Sentences per chunk

    Returns:
        List of corpus passages with 'id', 'title', 'text'
    """
    if seed_queries is None:
        # Generic seed queries covering diverse topics
        seed_queries = [
            "history", "politics", "science", "technology", "sports", "entertainment",
            "geography", "culture", "economics", "medicine", "biography", "literature",
            "art", "music", "film", "television", "world war", "country", "president",
            "prime minister", "company", "university", "city", "event", "award"
        ]

    wk.set_lang("en")
    all_titles = set()
    corpus_passages = []
    passage_id = 0

    print(f"Building corpus with target of {n_titles} titles...")

    # Collect titles from seed queries
    for query in tqdm(seed_queries, desc="Collecting titles"):
        try:
            titles = wk.search(query, results=50)  # Get more candidates per query
            all_titles.update(titles)
            if len(all_titles) >= n_titles:
                break
        except Exception as e:
            print(f"Warning: Error searching for '{query}': {e}")
            continue

    # Limit to n_titles
    all_titles = list(all_titles)[:n_titles]
    print(f"Processing {len(all_titles)} Wikipedia pages...")

    # Process each page
    for title in tqdm(all_titles, desc="Processing pages"):
        try:
            page = wiki_api.page(title)
            if not page.exists():
                continue

            text = page.text
            if len(text) < 100:  # Skip very short pages
                continue

            sents = sent_tokenize(text)
            # Process more sections per title for larger corpus
            for i in range(0, min(len(sents), per_title_sections * chunk_sents), chunk_sents):
                chunk_text = " ".join(sents[i:i+chunk_sents])
                if len(chunk_text) < 40:
                    continue
                corpus_passages.append({
                    "id": passage_id,
                    "title": title,
                    "text": chunk_text
                })
                passage_id += 1
        except Exception as e:
            print(f"Warning: Error processing '{title}': {e}")
            continue

    print(f"Built corpus with {len(corpus_passages)} passages")
    return corpus_passages

def filter_questions_with_answers_in_corpus(examples, corpus_passages, verbose=True):
    """
    Filter examples to only those where answers exist in corpus.

    Args:
        examples: List of test examples with 'answers' field
        corpus_passages: List of dicts with 'title' and 'text' fields

    Returns:
        Filtered list of examples
    """
    filtered_examples = []

    if verbose:
        print(f"Filtering {len(examples)} questions against corpus of {len(corpus_passages)} passages...")

    for ex in tqdm(examples, desc="Filtering questions"):
        answers = ex.get('answers', [])
        if not answers:
            continue

        # Check if any answer exists in any corpus passage
        found = False
        for passage in corpus_passages:
            text = passage.get('title', '') + ' ' + passage.get('text', '')
            if has_answer(answers, text, answer_tokenizer):
                found = True
                break

        if found:
            filtered_examples.append(ex)

    if verbose:
        print(f"\nFiltered results:")
        print(f"  Original questions: {len(examples)}")
        print(f"  Questions with answers in corpus: {len(filtered_examples)}")
        print(f"  Coverage: {len(filtered_examples)/len(examples)*100:.1f}%")

    return filtered_examples


tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

## 7. Corpus Loading

Load the pre-built corpus from JSONL files.


In [None]:
# Upload corpus files (text-list-100-sec.jsonl and infobox.jsonl)
from google.colab import files
uploaded = files.upload()

import os, io
os.makedirs("atlas_slice_with_neg", exist_ok=True)
for name, data in uploaded.items():
    with open(f"atlas_slice_with_neg/{name}", "wb") as f:
        f.write(data)

JSONL_FILES = [
    "atlas_slice_with_neg/text-list-100-sec.jsonl",
    "atlas_slice_with_neg/infobox.jsonl",
]

In [None]:
JSONL_FILES = [
    "atlas_slice_with_neg/text-list-100-sec.jsonl",
    "atlas_slice_with_neg/infobox.jsonl",
]

In [None]:
# Load corpus from JSONL files
import json, hashlib

def read_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        for lineno, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            yield lineno, json.loads(line)

def normalize_record(src, file_idx, lineno):
    # Extract fields with fallbacks for different JSONL formats
    title   = src.get("title") or src.get("page_title") or src.get("wiki_title") or ""
    text    = src.get("text") or src.get("contents") or src.get("passage") or src.get("body") or ""
    pageid  = src.get("pageid") or src.get("wikipedia_id") or src.get("page_id")
    section = src.get("section_idx") or src.get("section") or 0
    chunk   = src.get("chunk_idx") or src.get("chunk_id") or 0
    pid     = src.get("id") or src.get("_id")
    if not pid:
        # Generate a stable ID if one doesn't exist
        h = hashlib.md5(f"{file_idx}:{lineno}:{title}:{text[:256]}".encode("utf-8")).hexdigest()
        pid = f"{str(pageid) if pageid is not None else 'NA'}:{section}:{chunk}:{h[:10]}"
    return {
        "id": str(pid),
        "title": str(title),
        "text": str(text),
        "pageid": pageid,
        "section_idx": int(section) if str(section).isdigit() else 0,
        "chunk_idx": int(chunk) if str(chunk).isdigit() else 0,
    }

corpus_passages = []
for fi, p in enumerate(JSONL_FILES):
    for lineno, obj in read_jsonl(p):
        rec = normalize_record(obj, fi, lineno)
        if rec["text"].strip():
            corpus_passages.append(rec)

print(f"Loaded {len(corpus_passages)} passages from {len(JSONL_FILES)} files.")


Loaded 219940 passages from 2 files.


### 7.1 Filter Questions

Filter the dataset to only include questions where answers exist in the corpus.


In [None]:
print("\nFiltering questions based on corpus...")
filtered_examples = filter_questions_with_answers_in_corpus(examples, corpus_passages, verbose=True)
print(f"\nUsing {len(filtered_examples)} filtered questions for evaluation (out of {len(examples)} total)")



Filtering questions based on corpus...
Filtering 1244 questions against corpus of 219940 passages...



Filtering questions:   0%|          | 0/1244 [00:00<?, ?it/s][A
Filtering questions:   1%|          | 7/1244 [00:00<00:20, 61.52it/s][A
Filtering questions:   1%|          | 14/1244 [00:07<12:44,  1.61it/s][A
Filtering questions:   2%|▏         | 22/1244 [00:08<07:10,  2.84it/s][A
Filtering questions:   2%|▏         | 25/1244 [00:09<08:04,  2.52it/s][A
Filtering questions:   2%|▏         | 27/1244 [00:10<08:34,  2.36it/s][A
Filtering questions:   2%|▏         | 29/1244 [00:10<07:07,  2.84it/s][A
Filtering questions:   2%|▏         | 31/1244 [00:11<05:48,  3.48it/s][A
Filtering questions:   2%|▏         | 31/1244 [00:21<05:48,  3.48it/s][A
Filtering questions:   4%|▎         | 46/1244 [00:24<14:36,  1.37it/s][A
Filtering questions:   4%|▍         | 47/1244 [00:38<30:36,  1.53s/it][A
Filtering questions:   4%|▍         | 48/1244 [00:52<50:17,  2.52s/it][A
Filtering questions:   4%|▍         | 49/1244 [01:05<1:13:18,  3.68s/it][A
Filtering questions:   4%|▍         | 50/1244


Filtered results:
  Original questions: 1244
  Questions with answers in corpus: 1199
  Coverage: 96.4%

Using 1199 filtered questions for evaluation (out of 1244 total)





## 8. FAISS Index Construction

Build a FAISS index over the corpus for fast retrieval.


In [None]:
# Build FAISS index for fast similarity search
import numpy as np, faiss

corpus_texts = [p["text"] for p in corpus_passages]
corpus_vecs  = embed_texts(corpus_texts, batch_size=64)
# Normalize for cosine similarity
faiss.normalize_L2(corpus_vecs)
index = faiss.IndexFlatIP(corpus_vecs.shape[1])
index.add(corpus_vecs)

def retrieve_from_corpus(question, topk=100):
    # Retrieve top-k passages for a question
    qv = embed_texts([question])[0:1]
    faiss.normalize_L2(qv)
    D, I = index.search(qv, topk)
    return [corpus_passages[i] for i in I[0].tolist()]


## 9. QFS Summarization and Sentence Ranking

Implement Query-Focused Summarization (QFS) and sentence-level keyword scoring.


In [None]:
# Generate summaries and rank sentences by keyword scores

def summarize_and_rank_sentences(ex, latest_ctxs, QFS_topk=10, snt_with_title=True):
    keyword_list = [w for w in word_tokenize(ex["normalized_question"]) if w.isalpha()]
    expanded_keyword_list, keyword_type_list = expand_keywords(keyword_list, ex["normalized_question"], verbose=False)

    summaries = []
    for ctx in latest_ctxs[:QFS_topk]:
        prompt = LLMGenerations(ctx["title"]+" | "+ctx["text"], ex["normalized_question"])
        summary = generate_text(prompt, max_new_tokens=200)
        if "None" in summary:
            summary = None
        summaries.append(summary)

    sentence_tuples = []
    for idx, ctx in enumerate(latest_ctxs):
        snts = sent_tokenize(ctx["text"])
        if snt_with_title:
            snts = [ctx["title"] + " " + s for s in snts]
        summary = summaries[idx] if idx < len(summaries) else None
        if summary:
            snts.append(summary)
        for snt in snts:
            snt = snt.strip()
            text = ctx["title"] + " " + snt
            snt_kw_score = count_keyword_scores(text, expanded_keyword_list, keyword_type_list)
            sentence_tuples.append((ctx["id"], snt, snt_kw_score))
    sentence_tuples = sorted(sentence_tuples, key=lambda x: x[2], reverse=True)
    return sentence_tuples, summaries


## 10. End-to-End Question Answering

Main pipeline function that combines retrieval, summarization, and answer generation.


In [None]:
# Main MRAG pipeline for answering questions
import nltk
nltk.download('punkt_tab')

def answer_question_with_mrag(ex, topn_titles=50, per_title_sections=15, topk_retrieve=100, QFS_topk=15, ctx_topk_for_answer=10):
    """
    Run the full MRAG pipeline to answer a question.
    """
    # Retrieve relevant passages from corpus
    latest_ctxs = retrieve_from_corpus(ex["normalized_question"], topk=topk_retrieve)
    # Generate summaries and rank sentences
    sentence_tuples, summaries = summarize_and_rank_sentences(ex, latest_ctxs, QFS_topk=QFS_topk)
    # Combine top contexts for answer generation
    text = "\n\n".join([c["title"] + " | " + c["text"].strip() for c in latest_ctxs[:ctx_topk_for_answer]])
    prompt = c_prompt(ex["question"], text)
    rag_pred = generate_text(prompt, max_new_tokens=400)
    rag_pred = force_string(rag_pred)
    # Filter out overly long answers
    if len(rag_pred.split()) > 60:
        rag_pred = ""
    return {
        "rag_pred": rag_pred,
        "latest_ctxs": latest_ctxs,
        "summaries": summaries,
        "sentence_tuples": sentence_tuples
    }

# Make sure corpus is loaded
assert 'corpus_passages' in globals() and len(corpus_passages) > 0, \
    "Load your JSONL corpus (JSONL_FILES) before this cell."

# Add IDs to passages that don't have them
for p in corpus_passages:
    if not p.get('id'):
        pid = str(p.get('pageid', 'NA'))
        sec = str(p.get('section_idx', 0))
        chk = str(p.get('chunk_idx', 0))
        p['id'] = f"{pid}:{sec}:{chk}"

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


### 10.1 Test Run

Run a quick test on a few examples to verify the pipeline works.


In [None]:
# Quick test run on a few examples
results = []
test_size = min(5, len(filtered_examples))
print(f"\nRunning test evaluation on {test_size} filtered examples...")

for ex in filtered_examples[:test_size]:
    out = answer_question_with_mrag(
        ex,
        topn_titles=50,
        per_title_sections=15,
        topk_retrieve=100,
        QFS_topk=15,
        ctx_topk_for_answer=10
    )
    results.append((ex["question"], ex["answers"], out["rag_pred"]))
    print("\nQ:", ex["question"])
    print("Gold:", ex["answers"])
    print("Pred:", out["rag_pred"])



Running test evaluation on 5 filtered examples...


You are not running the flash-attention implementation, expect numerical differences.



Q: When was the time the Dodgers played the Yankees in the World Series?
Gold: ['1981']
Pred: From 1941 to 1956

Q: When was the last time the Dodgers played the Yankees in the World Series as of 1981?
Gold: ['1981']
Pred: 1981

Q: When was the last time the Dodgers played the Yankees in the World Series as of 1991?
Gold: ['1981']
Pred: 1955

Q: When was the last time the Dodgers played the Yankees in the World Series before October 2, 2008?
Gold: ['1981']
Pred: 2004

Q: When was the last time the Dodgers played the Yankees in the World Series between 1979 and 1999?
Gold: ['1981']
Pred: 1981


## 11. Retrieval Evaluation

Evaluate retrieval performance using MRAG re-ranking with metrics like Recall@K, MRR, MAP, and nDCG.


In [None]:
# Fast retrieval evaluation with MRAG re-ranking
# Optimizations:
#   - Pre-tokenize all passages once (reused across queries)
#   - Build relevance sets in parallel
#   - Option to skip LLM summaries for faster evaluation

import os, math, time
from collections import defaultdict
from multiprocessing import Pool, cpu_count
from nltk.tokenize import sent_tokenize

# Configuration parameters
KS                 = (1,5,10,20,50,100)
RETRIEVE_K         = 100          # Must be >= max(KS)
QFS_TOPK           = 15           # Number of contexts to consider for sentence scoring
USE_SUMMARY        = True        # Set to False to skip LLM summaries (faster but less accurate)
N_WORKERS          = max(1, min(cpu_count(), 8))  # Number of parallel workers
QRELS_PARALLEL     = True         # Build relevance sets in parallel
VERBOSE_EVERY      = 20           # Print progress every N queries
QRELS_LOG_EVERY    = 20           # Print progress during relevance set building

# Get a stable document ID for a passage
def _get_doc_id(p):
    if isinstance(p, dict) and p.get("id") is not None:
        return str(p["id"])
    title = (p.get("title","") if isinstance(p, dict) else "")
    text  = (p.get("text","")  if isinstance(p, dict) else "")
    return f"{title[:50]}::{abs(hash(text)) & 0xffffffff}"

# Pre-tokenize all passages once for efficiency
# Result: PRETOKENIZED_SENTS[doc_id] = ["title s1", "title s2", ...]
def _tok_one(args):
    pid, title, text = args
    snts = sent_tokenize(text or "")
    return pid, [ (title or "") + " " + s.strip() for s in snts if s.strip() ]

print(f"[PRETOKENIZE] Tokenizing {len(corpus_passages):,} passages "
      f"with {N_WORKERS} workers...")
t_tok = time.time()
args = [(_get_doc_id(p), p.get("title",""), p.get("text","")) for p in corpus_passages]
PRETOKENIZED_SENTS = {}
if N_WORKERS > 1:
    with Pool(processes=N_WORKERS) as pool:
        for i, (pid, sarr) in enumerate(pool.imap_unordered(_tok_one, args), 1):
            PRETOKENIZED_SENTS[pid] = sarr
            if i % 5000 == 0:
                print(f"[PRETOKENIZE] {i}/{len(args)} done...")
else:
    for i, a in enumerate(args, 1):
        pid, sarr = _tok_one(a)
        PRETOKENIZED_SENTS[pid] = sarr
        if i % 5000 == 0:
            print(f"[PRETOKENIZE] {i}/{len(args)} done...")
print(f"[PRETOKENIZE] Done in {time.time()-t_tok:.1f}s")

# Fast version of QFS with optional summaries
def summarize_and_rank_sentences_fast(
    ex, latest_ctxs, QFS_topk=QFS_TOPK, snt_with_title=True, use_summary=USE_SUMMARY
):
    # Extract keywords from query
    from nltk.tokenize import word_tokenize
    keyword_list = [w for w in word_tokenize(ex["normalized_question"]) if w.isalpha()]
    # Expand keywords using existing utilities
    expanded_keyword_list, keyword_type_list = expand_keywords(keyword_list, ex["normalized_question"], verbose=False)

    # Generate summaries for top contexts (optional)
    summaries = [None] * len(latest_ctxs)
    if use_summary and QFS_topk > 0:
        for idx, ctx in enumerate(latest_ctxs[:QFS_topk]):
            prompt = LLMGenerations(ctx["title"]+" | "+ctx["text"], ex["normalized_question"])
            summary = generate_text(prompt, max_new_tokens=200)
            if "None" in str(summary):
                summary = None
            summaries[idx] = summary

    # Score sentences using pre-tokenized sentences
    sentence_tuples = []
    for idx, ctx in enumerate(latest_ctxs):
        cid = _get_doc_id(ctx)
        snts = PRETOKENIZED_SENTS.get(cid, [])
        # Add summary as an extra sentence if available
        summary = summaries[idx] if idx < len(summaries) else None
        if summary:
            snts = snts + [summary]
        for snt in snts:
            text = (ctx.get("title","") + " " + snt) if snt_with_title else snt
            snt_kw_score = count_keyword_scores(text, expanded_keyword_list, keyword_type_list)
            sentence_tuples.append((cid, snt, snt_kw_score))

    sentence_tuples.sort(key=lambda x: x[2], reverse=True)
    return sentence_tuples, summaries

# Build relevance sets: find passages that contain any gold answer
def _qrels_for_example(args):
    qi, ex, corpus = args
    qid = ex.get("qid", f"q{qi}")
    rel_ids = set()
    hits = 0
    for p in corpus:
        if has_answer(ex["answers"], p["text"], answer_tokenizer):
            rel_ids.add(_get_doc_id(p))
            hits += 1
    return qid, rel_ids, hits

def build_relevance_sets_parallel(examples, corpus_passages, log_every=QRELS_LOG_EVERY):
    print(f"[QRELS] Parallel build with {N_WORKERS} workers over {len(examples)} queries and "
          f"{len(corpus_passages):,} passages...")
    t0 = time.time()
    corpus = [p for p in corpus_passages if isinstance(p, dict) and "text" in p]
    qrels = {}
    total_hits = 0
    if N_WORKERS > 1:
        with Pool(processes=N_WORKERS) as pool:
            for i, (qid, rel_ids, hits) in enumerate(
                pool.imap_unordered(_qrels_for_example,
                                    [(qi, ex, corpus) for qi, ex in enumerate(examples)]),
                1
            ):
                qrels[qid] = rel_ids
                total_hits += hits
                if log_every and (i % log_every == 0):
                    print(f"[QRELS] {i}/{len(examples)} queries built "
                          f"(rel_docs this batch={hits}, elapsed={time.time()-t0:.1f}s)")
    else:
        for i, ex in enumerate(examples, 1):
            qid, rel_ids, hits = _qrels_for_example((i-1, ex, corpus))
            qrels[qid] = rel_ids
            total_hits += hits
            if log_every and (i % log_every == 0):
                print(f"[QRELS] {i}/{len(examples)} queries built "
                      f"(rel_docs this batch={hits}, elapsed={time.time()-t0:.1f}s)")

    evaluable = sum(1 for s in qrels.values() if len(s) > 0)
    print(f"[QRELS] Done in {time.time()-t0:.1f}s — evaluable={evaluable}/{len(examples)}, "
          f"total rel matches={total_hits}")
    return qrels

def build_relevance_sets_serial(examples, corpus_passages, log_every=QRELS_LOG_EVERY):
    print(f"[QRELS] Serial build over {len(examples)} queries and {len(corpus_passages):,} passages...")
    t0 = time.time()
    corpus = [p for p in corpus_passages if isinstance(p, dict) and "text" in p]
    qrels = {}
    total_hits = 0
    for i, ex in enumerate(examples, 1):
        qid, rel_ids, hits = _qrels_for_example((i-1, ex, corpus))
        qrels[qid] = rel_ids
        total_hits += hits
        if log_every and (i % log_every == 0):
            print(f"[QRELS] {i}/{len(examples)} queries built "
                  f"(rel_docs this batch={hits}, elapsed={time.time()-t0:.1f}s)")
    evaluable = sum(1 for s in qrels.values() if len(s) > 0)
    print(f"[QRELS] Done in {time.time()-t0:.1f}s — evaluable={evaluable}/{len(examples)}, "
          f"total rel matches={total_hits}")
    return qrels

def build_relevance_sets(examples, corpus_passages):
    return (build_relevance_sets_parallel if QRELS_PARALLEL else build_relevance_sets_serial)(
        examples, corpus_passages
    )

# Get MRAG re-ranked passage IDs for a query
def retrieve_ranked_ids_mrag(ex, retrieve_k=RETRIEVE_K, QFS_topk=QFS_TOPK, use_summary=USE_SUMMARY):
    # Retrieve initial candidates
    ctxs = retrieve_from_corpus(ex["normalized_question"], topk=retrieve_k)
    by_id = { _get_doc_id(c): c for c in ctxs }
    # Score sentences using QFS
    sent_tuples, _ = summarize_and_rank_sentences_fast(ex, ctxs, QFS_topk=QFS_topk, use_summary=use_summary)
    # Aggregate to document scores (max sentence score per doc)
    scores = {}
    for cid, _, s in sent_tuples:
        scores[cid] = max(scores.get(cid, float("-inf")), s)
    # Return ranked IDs, with unscored candidates appended
    ranked_ids = sorted(scores, key=scores.get, reverse=True)
    tail = [cid for cid in by_id.keys() if cid not in scores]
    return ranked_ids + tail

# Metric calculation helpers
def _rr_at_k(ranked_ids, rel_set, k):
    for i, did in enumerate(ranked_ids[:k], start=1):
        if did in rel_set: return 1.0 / i
    return 0.0

def _ap_at_k(ranked_ids, rel_set, k):
    if not rel_set: return None
    hits = 0; sum_prec = 0.0
    for i, did in enumerate(ranked_ids[:k], start=1):
        if did in rel_set:
            hits += 1
            sum_prec += hits / i
    denom = min(len(rel_set), k)
    return sum_prec / denom if denom > 0 else 0.0

def _dcg_at_k(ranked_ids, rel_set, k):
    return sum((1.0 if did in rel_set else 0.0) / math.log2(i+1)
               for i, did in enumerate(ranked_ids[:k], start=1))

def _ndcg_at_k(ranked_ids, rel_set, k):
    if not rel_set: return None
    dcg  = _dcg_at_k(ranked_ids, rel_set, k)
    m    = min(len(rel_set), k)
    idcg = sum(1.0 / math.log2(i+1) for i in range(1, m+1))
    return (dcg / idcg) if idcg > 0 else 0.0

# Evaluate MRAG re-ranking performance
def evaluate_retrieval_mrag_fast(
    examples,
    corpus_passages,
    ks=KS,
    retrieve_k=RETRIEVE_K,
    QFS_topk=QFS_TOPK,
    use_summary=USE_SUMMARY,
    verbose_every=VERBOSE_EVERY
):
    print(f"[EVAL] MRAG re-ranking over fixed pool | ks={ks}, retrieve_k={retrieve_k}, "
          f"QFS_topk={QFS_topk}, use_summary={use_summary}")
    t_all = time.time()

    # Build relevance sets from the same corpus
    qrels = build_relevance_sets(examples, corpus_passages)

    ks = sorted(set(int(k) for k in ks))
    agg = { "mrr@k": defaultdict(float), "map@k": defaultdict(float),
            "ndcg@k": defaultdict(float), "recall@k": defaultdict(float),
            "hit@k": defaultdict(float), }
    counts = defaultdict(int)
    details = []

    print("[EVAL] Scoring ranked lists...")
    t0 = time.time()
    for qi, ex in enumerate(examples, 1):
        qid     = ex.get("qid", f"q{qi-1}")
        rel_set = qrels.get(qid, set())

        ranked_ids = retrieve_ranked_ids_mrag(ex, retrieve_k=retrieve_k,
                                              QFS_topk=QFS_topk, use_summary=use_summary)

        # Find rank of first relevant document
        first_rel_rank = None
        for r, did in enumerate(ranked_ids, start=1):
            if did in rel_set: first_rel_rank = r; break

        details.append({
            "qid": qid,
            "question": ex.get("question", ex.get("normalized_question", "")),
            "num_relevant": len(rel_set),
            "first_rel_rank": first_rel_rank,
            "retrieved": len(ranked_ids)
        })

        if rel_set:
            for k in ks:
                topk = ranked_ids[:k]
                found = sum(1 for did in topk if did in rel_set)
                agg["recall@k"][k] += found / len(rel_set)
                agg["hit@k"][k]    += 1.0 if found > 0 else 0.0
                agg["mrr@k"][k]    += _rr_at_k(ranked_ids, rel_set, k)

                ap_k = _ap_at_k(ranked_ids, rel_set, k)
                if ap_k is not None: agg["map@k"][k]  += ap_k
                ndcg_k = _ndcg_at_k(ranked_ids, rel_set, k)
                if ndcg_k is not None: agg["ndcg@k"][k] += ndcg_k
                counts[k] += 1

        if verbose_every and (qi % verbose_every == 0):
            elapsed = time.time() - t0
            print(f"[EVAL] {qi}/{len(examples)} queries processed "
                  f"(first_rel_rank={first_rel_rank}, elapsed={elapsed:.1f}s)")

    # Compute macro-averaged metrics
    metrics = {}
    for name, table in agg.items():
        for k, total in table.items():
            denom = counts[k] if counts[k] > 0 else 1
            metrics[f"{name.upper().replace('@K','')}@{k}"] = total / denom

    metrics["NUM_QUERIES"]   = len(examples)
    metrics["NUM_EVALUABLE"] = counts[min(ks)] if ks else 0

    print(f"[EVAL] Done in {time.time()-t_all:.1f}s "
          f"(evaluable={metrics['NUM_EVALUABLE']}/{metrics['NUM_QUERIES']})")
    return metrics, details

# Run evaluation
metrics_mrag, per_query_mrag = evaluate_retrieval_mrag_fast(
    filtered_examples,
    corpus_passages,
    ks=KS,
    retrieve_k=RETRIEVE_K,
    QFS_topk=QFS_TOPK,
    use_summary=USE_SUMMARY,   # Set to False for faster evaluation
    verbose_every=VERBOSE_EVERY
)

print("=== MRAG Re-ranking Retrieval Metrics (fixed pool, fast) ===")
for k in KS:
    print(f"Recall@{k}: {metrics_mrag[f'RECALL@{k}']:.4f}  |  Hit@{k}: {metrics_mrag[f'HIT@{k}']:.4f}  "
          f"|  MRR@{k}: {metrics_mrag[f'MRR@{k}']:.4f}  |  MAP@{k}: {metrics_mrag[f'MAP@{k}']:.4f}  "
          f"|  nDCG@{k}: {metrics_mrag[f'NDCG@{k}']:.4f}")
print(f"\nQueries: {metrics_mrag['NUM_QUERIES']} | Evaluable (>=1 relevant): {metrics_mrag['NUM_EVALUABLE']}")


[PRETOKENIZE] Tokenizing 219,940 passages with 8 workers...
[PRETOKENIZE] 5000/219940 done...
[PRETOKENIZE] 10000/219940 done...
[PRETOKENIZE] 15000/219940 done...
[PRETOKENIZE] 20000/219940 done...
[PRETOKENIZE] 25000/219940 done...
[PRETOKENIZE] 30000/219940 done...
[PRETOKENIZE] 35000/219940 done...
[PRETOKENIZE] 40000/219940 done...
[PRETOKENIZE] 45000/219940 done...
[PRETOKENIZE] 50000/219940 done...
[PRETOKENIZE] 55000/219940 done...
[PRETOKENIZE] 60000/219940 done...
[PRETOKENIZE] 65000/219940 done...
[PRETOKENIZE] 70000/219940 done...
[PRETOKENIZE] 75000/219940 done...
[PRETOKENIZE] 80000/219940 done...
[PRETOKENIZE] 85000/219940 done...
[PRETOKENIZE] 90000/219940 done...
[PRETOKENIZE] 95000/219940 done...
[PRETOKENIZE] 100000/219940 done...
[PRETOKENIZE] 105000/219940 done...
[PRETOKENIZE] 110000/219940 done...
[PRETOKENIZE] 115000/219940 done...
[PRETOKENIZE] 120000/219940 done...
[PRETOKENIZE] 125000/219940 done...
[PRETOKENIZE] 130000/219940 done...
[PRETOKENIZE] 135000/219