# MRAG + Contriever Integration on ChroniclingAmericaQA

One-stop Colab-ready notebook that compares baseline Contriever vs time-aware Contriever (FineWeb-trained) with and without MRAG re-ranking on ChroniclingAmericaQA. Training is assumed done elsewhere; this notebook focuses on evaluation and integration.


## 1. Setup & Installs


In [None]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
!pip -q install transformers

In [None]:
!pip -q install transformers[sentencepiece] datasets faiss-cpu pandas tqdm scikit-learn nltk


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m74.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)


Using device: cuda


In [None]:
!unzip "/content/MRAG_TimeQA_Corpus_Checkpoint 2.zip"

Archive:  /content/MRAG_TimeQA_Corpus_Checkpoint 2.zip
   creating: MRAG_TimeQA_Corpus_Checkpoint/
  inflating: __MACOSX/._MRAG_TimeQA_Corpus_Checkpoint  
  inflating: MRAG_TimeQA_Corpus_Checkpoint/timeqa_contriever_timeaware_no_chunk.index  
  inflating: __MACOSX/MRAG_TimeQA_Corpus_Checkpoint/._timeqa_contriever_timeaware_no_chunk.index  
   creating: MRAG_TimeQA_Corpus_Checkpoint/.config/
  inflating: __MACOSX/MRAG_TimeQA_Corpus_Checkpoint/._.config  
   creating: MRAG_TimeQA_Corpus_Checkpoint/contriever_mining_index_fineweb_20k/
  inflating: __MACOSX/MRAG_TimeQA_Corpus_Checkpoint/._contriever_mining_index_fineweb_20k  
  inflating: MRAG_TimeQA_Corpus_Checkpoint/.DS_Store  
  inflating: __MACOSX/MRAG_TimeQA_Corpus_Checkpoint/._.DS_Store  
  inflating: MRAG_TimeQA_Corpus_Checkpoint/caqa_contriever_timeaware.index  
  inflating: __MACOSX/MRAG_TimeQA_Corpus_Checkpoint/._caqa_contriever_timeaware.index  
   creating: MRAG_TimeQA_Corpus_Checkpoint/contriever_finetuned_NEW_20k/
  inflating

In [None]:
!mv /content/MRAG_TimeQA_Corpus_Checkpoint/* /content/

In [None]:
!rmdir /content/MRAG_TimeQA_Corpus_Checkpoint
print('Contents moved successfully and original folder removed.')

rmdir: failed to remove '/content/MRAG_TimeQA_Corpus_Checkpoint': Directory not empty
Contents moved successfully and original folder removed.


## 2. Imports & Shared Config


In [None]:
import os
import re
import gc
import math
import random
import shutil
import numpy as np
import pandas as pd
from typing import List, Tuple

from transformers import (
    AutoModel,
    AutoTokenizer,
    AutoModelForCausalLM,
    T5ForConditionalGeneration,
    T5Tokenizer,
)
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import faiss
from tqdm import tqdm

from torch.cuda.amp import GradScaler

os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch.set_grad_enabled(False)

CONTRIEVER_BASE = "facebook/contriever-msmarco"
# Update to the actual path/name produced by the fine-tuned model (folder or HF repo)
CONTRIEVER_TIMEAWARE = "contriever_finetuned_NEW_20k"
FT_OUT_DIR = CONTRIEVER_TIMEAWARE  # where fine-tuning will save if recreated here

CAQA_SPLIT = "validation"  # change to "test" or "train" if desired
YEAR_REGEX = re.compile(r"\b(18[0-9]{2}|19[0-9]{2}|20[0-2][0-9])\b")
#YEAR_REGEX = re.compile(r"\b(19[0-9]{2}|20[0-2][0-9])\b")

ENCODING_BATCH_SIZE = 64
MAX_LENGTH = 256
RETRIEVE_TOPK = 100
EVAL_KS = (1, 5, 10, 20, 50)
YEAR_SUBSET_LIMIT = None  # evaluate all year-explicit questions

RESULTS = []  # collects metrics for the final summary table

# LLM summarization model (mirrors baseline notebook choice)
LLM_SUMMARY_MODEL = "microsoft/Phi-3.5-mini-instruct"
LLM_SUMMARY_MAX_NEW_TOKENS = 120
LLM_SUMMARY_TEMP = 0.2
LLM_SUMMARY_TOP_P = 0.95
LLM_SUMMARY_DEVICE = DEVICE
LLM_SUMMARY_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32

LLM_SUMMARY_MODEL_OBJ = None
LLM_SUMMARY_TOKENIZER = None


## 3. Full FineWeb Temporal Training Pipeline (from NLP_Fineweb)
This section mirrors the full pipeline: build a temporal FineWeb subset, generate temporal questions with T5, split train/test, mine temporal hard negatives with Contriever, blend with MS MARCO triplets, and fine-tune the time-aware Contriever saved to `FT_OUT_DIR`.


In [None]:
import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
AMP_DTYPE = torch.float16
BASELINE_MODEL = CONTRIEVER_BASE
T5_QG_MODEL = "valhalla/t5-base-qg-hl"
FINEWEB_SAMPLE_SIZE = 500000  # target number of passages to collect
MAX_PASSAGE_CHARS = 1000      # truncate for T5 stability
NUM_QG_PASSAGES = 15000       # passages sampled for QG
QG_BATCH_SIZE = 64
MINING_POOL_K = 100
MAX_POSITIVES = 3
MAX_NEGATIVES = 6
SEMANTIC_THRESHOLD = 0.7
TRAIN_BATCH_SIZE = 64         # for encoding and some loaders
# Hybrid training hyperparams (mirroring original notebook)
TRAIN_EPOCHS_HYBRID = 14
MICRO_BATCH_SIZE = 32
GRAD_ACC_STEPS = 8
TRAIN_LR_HYBRID = 1e-5
TRIPLET_MARGIN = 1.0

print("\n--- Step 2: Initializing Constants ---")
print(f"Using Device: {DEVICE}")
print(f"Targeting {FINEWEB_SAMPLE_SIZE} FineWeb passages.")
print(f"Training for {TRAIN_EPOCHS_HYBRID} epochs (hybrid triplet phase).")

# ---- Helper functions (from NLP_Fineweb) ----
def _norm(s: str) -> str:
    s = re.sub(r"[^a-z0-9 ]+", " ", (s or "").lower())
    return re.sub(r"\s+", " ", s).strip()


def get_years_from_text(text: str) -> set:
    return set(YEAR_REGEX.findall(text))


def mean_pooling(last_hidden_state, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask


@torch.no_grad()
def encode_contriever(model, tokenizer, texts, max_len=256, batch=64):
    model.eval()
    outs = []
    for i in tqdm(range(0, len(texts), batch), desc="Encoding"):
        batch_texts = texts[i:i + batch]
        tok = tokenizer(batch_texts, padding=True, truncation=True, max_length=max_len, return_tensors="pt")
        tok = {k: v.to(DEVICE) for k, v in tok.items()}
        with torch.autocast(DEVICE, dtype=AMP_DTYPE, enabled=torch.cuda.is_available()):
            outputs = model(**tok)
            embeddings = mean_pooling(outputs.last_hidden_state, tok['attention_mask'])
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        outs.append(embeddings.cpu().numpy().astype("float32"))
    return np.vstack(outs) if outs else np.zeros((0, model.config.hidden_size), "float32")


def build_faiss_index(model, tokenizer, passages_list, passage_ids_list, out_dir, index_path, max_len=256):
    print(f"Building FAISS index in {out_dir}...")
    dim = model.config.hidden_size
    index_flat = faiss.IndexFlatIP(dim)
    ids = np.array(passage_ids_list, dtype=np.int64)
    embs = encode_contriever(model, tokenizer, passages_list, batch=TRAIN_BATCH_SIZE * 2, max_len=max_len)
    index_idmap = faiss.IndexIDMap2(index_flat)
    index_idmap.add_with_ids(embs, ids)
    faiss.write_index(index_idmap, index_path)
    print(f"Built FLAT index: {index_idmap.ntotal:,} vectors")
    return index_idmap



--- Step 2: Initializing Constants ---
Using Device: cuda
Targeting 500000 FineWeb passages.
Training for 14 epochs (hybrid triplet phase).


In [None]:
# =========================== #
#  5. PREPARE FINEWEB DATASET
# =========================== #
print("\n--- Step 4: Preparing FineWeb Data ---")
print("Streaming HuggingFaceFW/fineweb-edu (sample-10BT)...")

dataset_stream = load_dataset(
    "HuggingFaceFW/fineweb-edu",
    name="sample-10BT",
    split="train",
    streaming=True,
)

train_passages_all = []
seen_texts = set()
current_id = 0

print(f"Filtering stream for passages containing years (1900-2029)...")
pbar = tqdm(total=FINEWEB_SAMPLE_SIZE, desc="Collecting Passages")

for row in dataset_stream:
    if len(train_passages_all) >= FINEWEB_SAMPLE_SIZE:
        break
    raw_text = row.get('text', "")
    if not raw_text:
        continue
    text_slice = raw_text[:MAX_PASSAGE_CHARS]
    if not get_years_from_text(text_slice):
        continue
    norm_text = _norm(text_slice[:100])
    if norm_text in seen_texts:
        continue
    seen_texts.add(norm_text)
    train_passages_all.append((current_id, text_slice, f"fineweb_{current_id}"))
    current_id += 1
    pbar.update(1)

pbar.close()
print(f"Clean training set size (FineWeb): {len(train_passages_all)}")




--- Step 4: Preparing FineWeb Data ---
Streaming HuggingFaceFW/fineweb-edu (sample-10BT)...


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/2410 [00:00<?, ?it/s]

Filtering stream for passages containing years (1900-2029)...


Collecting Passages: 100%|██████████| 500000/500000 [08:40<00:00, 959.97it/s]

Clean training set size (FineWeb): 500000





In [None]:
# =========================== #
#  6. SYNTHETIC TEMPORAL QG WITH T5
# =========================== #
print("\n--- Step 5: Generating Synthetic TEMPORAL Data ---")
print(f"Loading T5 model: {T5_QG_MODEL}...")
qg_tokenizer = T5Tokenizer.from_pretrained(T5_QG_MODEL)
qg_model = T5ForConditionalGeneration.from_pretrained(T5_QG_MODEL).to(DEVICE)
qg_model.eval()

if len(train_passages_all) > NUM_QG_PASSAGES:
    print(f"Sampling {NUM_QG_PASSAGES} passages for QG...")
    passages_to_gen = random.sample(train_passages_all, NUM_QG_PASSAGES)
else:
    passages_to_gen = train_passages_all

synthetic_pairs = []  # (question, passage_text, passage_id)
passage_batch, passage_info, year_batch = [], [], []

@torch.no_grad()
def generate_temporal_questions_batch(qg_model, qg_tok, passages, years, max_new_tokens=64):
    prompts = [f"generate question about {y}: {p}" for p, y in zip(passages, years)]

    '''prompts = [
        (
            f"Write one factual question about this passage that must be answered using what happened in the year {y}. "
            f"Start the question with 'In {y},' or 'As of {y},'. Include the year {y} exactly once. "
            f"Do not mention 'passage', 'question', or 'given year'. Use only the passage text. Passage: {p}"
        )
        for p, y in zip(passages, years)
    ]'''
    #prompts = [f"generate one question that explicitly asks about the given year {y}.The question must mention the year in the wording. Use only information from {p} " for p, y in zip(passages, years)]
    inputs = qg_tok(prompts, padding="longest", truncation=True, max_length=512, return_tensors="pt")
    inputs = {k: v.to(qg_model.device) for k, v in inputs.items()}
    with torch.autocast("cuda", dtype=AMP_DTYPE, enabled=torch.cuda.is_available()):
        outputs = qg_model.generate(**inputs, max_length=max_new_tokens, num_beams=4, early_stopping=True)
    return qg_tok.batch_decode(outputs, skip_special_tokens=True)

print(f"Generating {len(passages_to_gen)} synthetic TEMPORAL questions...")
for (pid, text, title) in tqdm(passages_to_gen):
    years = get_years_from_text(text)
    if not years:
        continue
    first_year = sorted(list(years))[0]
    passage_batch.append(text)
    year_batch.append(first_year)
    passage_info.append((pid, text))
    if len(passage_batch) >= QG_BATCH_SIZE:
        generated_questions = generate_temporal_questions_batch(qg_model, qg_tokenizer, passage_batch, year_batch)
        for i, q in enumerate(generated_questions):
            if q:
                p_id, p_text = passage_info[i]
                synthetic_pairs.append((q, p_text, p_id))
        passage_batch, passage_info, year_batch = [], [], []

if passage_batch:
    generated_questions = generate_temporal_questions_batch(qg_model, qg_tokenizer, passage_batch, year_batch)
    for i, q in enumerate(generated_questions):
        if q:
            p_id, p_text = passage_info[i]
            synthetic_pairs.append((q, p_text, p_id))

print(f"Created {len(synthetic_pairs)} synthetic TEMPORAL (question, positive_passage) pairs.")
del qg_model, qg_tokenizer
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# =========================== #
#  7. CREATE 80/20 SPLIT
# =========================== #
print("\n--- Step 6: Creating 80/20 Train/Test Split ---")
train_set, test_set = train_test_split(synthetic_pairs, test_size=0.2, random_state=42)
print(f"Temporal Training set size: {len(train_set)}")
print(f"Temporal Test set size: {len(test_set)}")

corpus_passages_map = {pid: text for (q, text, pid) in synthetic_pairs}
corpus_passages_list = list(corpus_passages_map.values())
corpus_passage_ids_list = list(corpus_passages_map.keys())
print(f"Total passages in our T5 dataset: {len(corpus_passages_map)}")


--- Step 5: Generating Synthetic TEMPORAL Data ---
Loading T5 model: valhalla/t5-base-qg-hl...


tokenizer_config.json:   0%|          | 0.00/129 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/15.0 [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


pytorch_model.bin:   0%|          | 0.00/892M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

Sampling 15000 passages for QG...
Generating 15000 synthetic TEMPORAL questions...



  0%|          | 0/15000 [00:00<?, ?it/s][A
  0%|          | 64/15000 [00:04<15:39, 15.89it/s][A
  1%|          | 128/15000 [00:06<11:42, 21.18it/s][A
  1%|▏         | 192/15000 [00:08<09:40, 25.53it/s][A
  2%|▏         | 256/15000 [00:11<10:06, 24.31it/s][A
  2%|▏         | 320/15000 [00:12<08:46, 27.86it/s][A
  3%|▎         | 384/15000 [00:15<08:50, 27.53it/s][A
  3%|▎         | 448/15000 [00:17<08:48, 27.56it/s][A
  3%|▎         | 512/15000 [00:19<08:12, 29.45it/s][A
  4%|▍         | 576/15000 [00:21<07:49, 30.69it/s][A
  4%|▍         | 640/15000 [00:23<07:57, 30.08it/s][A
  5%|▍         | 704/15000 [00:25<08:02, 29.65it/s][A
  5%|▌         | 768/15000 [00:27<08:08, 29.13it/s][A
  6%|▌         | 832/15000 [00:29<07:49, 30.20it/s][A
  6%|▌         | 896/15000 [00:32<07:51, 29.93it/s][A
  6%|▋         | 960/15000 [00:34<07:42, 30.35it/s][A
  7%|▋         | 1024/15000 [00:36<07:38, 30.49it/s][A
  7%|▋         | 1088/15000 [00:38<07:18, 31.73it/s][A
  8%|▊         | 1

Created 15000 synthetic TEMPORAL (question, positive_passage) pairs.

--- Step 6: Creating 80/20 Train/Test Split ---
Temporal Training set size: 12000
Temporal Test set size: 3000
Total passages in our T5 dataset: 15000


In [None]:
generated_questions

['What year was the Detroit News published?',
 'What did Chief Red Cloud say about Sioux in Washington in 1870?',
 'What year did the Future of Jobs Survey take place?',
 'When did Bryant Tyrrell discover a bald eagle nest?',
 'What is one of the topics addressed in Reading the American Past volume 1 To: 1877 chapter four?',
 'When was the National Seashore formed?',
 'What article did Frederick Herzberg write about job enrichment?',
 'What year did the Institute of Medicine, Food and Nutrition Board publish its Dietary Reference Intakes for energy, carbohydrate, fiber, fat, protein and amino acids?',
 "What was the last song from U2's 1984 album, The Unforgettable Fire?",
 'What was the first time major architecture from the Iron Age has been found on the west coast?',
 'What are the conditions exacerbating the drought crisis?',
 'Did you know that 12,800 kids were seen in emergency rooms due to TV and furniture tip-overs in 2011?',
 'What is the UPSC Prelims - 2020?',
 'What was the 

In [None]:
synthetic_pairs[0:10]

[('What is the name of the book that Charles Bamforth wrote?',
  'Food, Fermentation and Micro-organisms\nNovember 2005, Wiley-Blackwell\nIn his engaging style Professor Charles Bamforth covers all known food applications of fermentation. Beginning with the science underpinning food fermentations, Professor Bamforth looks at the relevant aspects of microbiology and microbial physiology, moving on to cover individual food products, how they are made, what is the role of fermentation and what possibilities exist for future development.\n- Internationally respected author\n- Coverage of all major uses of fermentation in the food industry\n- Practical coverage of food processing in relation to fermentation\nA comprehensive guide for all food scientists, technologists and microbiologists in the food industry and academia, this book will be an important addition to all libraries in food companies, research establishments and universities where food studies, food science, food technology and 

In [None]:
# =========================== #
#  8. AUGMENTED TEMPORAL HARD NEGATIVE MINING
# =========================== #

print("\n--- Step 7: Mining *Augmented* Temporal Hard Negatives (train split) ---")
print("Loading BASELINE Contriever model for mining...")
contriever_tokenizer = AutoTokenizer.from_pretrained(BASELINE_MODEL)
contriever_model = AutoModel.from_pretrained(BASELINE_MODEL).to(DEVICE)
contriever_model.eval()

print(f"Building FAISS index for {len(corpus_passages_map)} passages...")
MINING_DIR = "contriever_mining_index_fineweb_20k"
MINING_INDEX_PATH = os.path.join(MINING_DIR, "mining.index")
shutil.rmtree(MINING_DIR, ignore_errors=True)
os.makedirs(MINING_DIR, exist_ok=True)
index_mining = build_faiss_index(
    contriever_model,
    contriever_tokenizer,
    corpus_passages_list,
    corpus_passage_ids_list,
    MINING_DIR,
    MINING_INDEX_PATH,
)

print("Mining for augmented (1-to-N) temporal hard negatives...")
triplet_examples = []
questions_to_mine = [ex[0] for ex in train_set]
q_embs = encode_contriever(contriever_model, contriever_tokenizer, questions_to_mine)
search_results_D, search_results_I = index_mining.search(q_embs, MINING_POOL_K)

for i in tqdm(range(len(train_set)), desc="Finding negatives"):
    q, p_pos_text, p_pos_id = train_set[i]
    pos_years = get_years_from_text(p_pos_text)
    if not pos_years:
        continue
    scores, passage_ids = search_results_D[i], search_results_I[i]
    other_positives, hard_negatives = [p_pos_text], []
    for score, pid in zip(scores, passage_ids):
        if pid == -1 or score < SEMANTIC_THRESHOLD:
            break
        if pid == p_pos_id:
            continue
        p_cand_text = corpus_passages_map.get(pid)
        if not p_cand_text:
            continue
        cand_years = get_years_from_text(p_cand_text)
        if not cand_years:
            continue
        if pos_years == cand_years and len(other_positives) < MAX_POSITIVES:
            other_positives.append(p_cand_text)
        elif pos_years != cand_years:
            hard_negatives.append(p_cand_text)
    if not hard_negatives:
        continue
    for p_pos in other_positives:
        for p_neg in hard_negatives[:MAX_NEGATIVES]:
            triplet_examples.append((q, p_pos, p_neg))

print(f"Created {len(triplet_examples)} augmented triplet training examples.")
del contriever_model, index_mining
if torch.cuda.is_available():
    torch.cuda.empty_cache()


--- Step 7: Mining *Augmented* Temporal Hard Negatives (train split) ---
Loading BASELINE Contriever model for mining...


tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Building FAISS index for 15000 passages...
Building FAISS index in contriever_mining_index_fineweb_20k...


Encoding:   0%|          | 0/118 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

Encoding: 100%|██████████| 118/118 [00:40<00:00,  2.90it/s]


Built FLAT index: 15,000 vectors
Mining for augmented (1-to-N) temporal hard negatives...


Encoding: 100%|██████████| 188/188 [00:03<00:00, 52.98it/s]
Finding negatives: 100%|██████████| 12000/12000 [00:00<00:00, 37889.50it/s]

Created 782 augmented triplet training examples.





In [None]:
# =========================== #
#  9. LOAD MS MARCO TRIPLETS & COMBINE
# =========================== #
print("\n--- Step 8: Loading MS MARCO (General Domain) Triplets ---")
my_triplets = triplet_examples
print(f"Existing Temporal Triplets: {len(my_triplets)}")

print("Streaming 'sentence-transformers/msmarco-msmarco-distilbert-base-tas-b' (Config: triplet-hard)...")
msmarco_stream = load_dataset(
    "sentence-transformers/msmarco-msmarco-distilbert-base-tas-b",
    "triplet-hard",
    split="train",
    streaming=True,
)

msmarco_triplets = []
target_count = 1000  # match original balance intent
pbar = tqdm(total=target_count, desc="Collecting MS MARCO")
for row in msmarco_stream:
    if len(msmarco_triplets) >= target_count:
        break
    q = row.get('query')
    p = row.get('positive')
    n = row.get('negative')
    if isinstance(n, list):
        n = random.choice(n)
    if q and p and n:
        msmarco_triplets.append((q, p, n))
        pbar.update(1)
pbar.close()
print(f"Collected {len(msmarco_triplets)} MS MARCO Triplets.")

combined_triplets = my_triplets + msmarco_triplets
random.shuffle(combined_triplets)
print(f"Final Hybrid Training Set Size: {len(combined_triplets)} Triplets")

del msmarco_triplets, msmarco_stream
gc.collect()

# =========================== #
#  10. HYBRID TRAINING LOOP
# =========================== #
print("\n--- Step 9: Training Model (Hybrid 50/50) ---")

MAX_LEN = 256

class TripletDataset(torch.utils.data.Dataset):
    def __init__(self, examples):
        self.examples = examples
    def __len__(self):
        return len(self.examples)
    def __getitem__(self, idx):
        return self.examples[idx]


def collate_triplets(batch):
    questions = [ex[0] for ex in batch]
    texts_pos = [ex[1] for ex in batch]
    texts_neg = [ex[2] for ex in batch]
    q_inputs = contriever_tokenizer(questions, padding="longest", truncation=True, max_length=MAX_LEN, return_tensors="pt")
    p_pos_inputs = contriever_tokenizer(texts_pos, padding="longest", truncation=True, max_length=MAX_LEN, return_tensors="pt")
    p_neg_inputs = contriever_tokenizer(texts_neg, padding="longest", truncation=True, max_length=MAX_LEN, return_tensors="pt")
    return {"q_inputs": q_inputs, "p_pos_inputs": p_pos_inputs, "p_neg_inputs": p_neg_inputs}

print("Loading Fresh Model for Training...")
# Enable grads for the training block (we set global no-grad earlier for eval helpers)
torch.set_grad_enabled(True)

contriever_tokenizer = AutoTokenizer.from_pretrained(BASELINE_MODEL)
contriever_model_train = AutoModel.from_pretrained(BASELINE_MODEL).to(DEVICE)
contriever_model_train.train()
contriever_model_train.gradient_checkpointing_enable()

train_dataset = TripletDataset(combined_triplets)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=MICRO_BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_triplets,
    num_workers=1,
    pin_memory=True,
)

params = contriever_model_train.parameters()
optimizer = torch.optim.AdamW(params, lr=TRAIN_LR_HYBRID)
num_train_steps = len(train_dataloader) // GRAD_ACC_STEPS * TRAIN_EPOCHS_HYBRID
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.0, total_iters=num_train_steps)
scaler = GradScaler(enabled=(DEVICE == 'cuda'))
triplet_loss_fct = torch.nn.MarginRankingLoss(margin=TRIPLET_MARGIN, reduction='mean')

print(f"Starting Training: {len(combined_triplets)} triplets, {TRAIN_EPOCHS_HYBRID} epochs")
for epoch in range(TRAIN_EPOCHS_HYBRID):
    total_loss = 0
    optimizer.zero_grad(set_to_none=True)
    pbar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{TRAIN_EPOCHS_HYBRID}")
    for step, batch in enumerate(pbar):
        q_inputs = {k: v.to(DEVICE) for k, v in batch["q_inputs"].items()}
        p_pos_inputs = {k: v.to(DEVICE) for k, v in batch["p_pos_inputs"].items()}
        p_neg_inputs = {k: v.to(DEVICE) for k, v in batch["p_neg_inputs"].items()}
        with torch.autocast("cuda", dtype=AMP_DTYPE, enabled=torch.cuda.is_available()):
            q_vectors = contriever_model_train(**q_inputs).last_hidden_state
            p_pos_vectors = contriever_model_train(**p_pos_inputs).last_hidden_state
            p_neg_vectors = contriever_model_train(**p_neg_inputs).last_hidden_state
            def quick_pool(last_hidden, mask):
                mask_exp = mask.unsqueeze(-1).expand(last_hidden.size()).float()
                return torch.sum(last_hidden * mask_exp, 1) / torch.clamp(mask_exp.sum(1), min=1e-9)
            q_emb = quick_pool(q_vectors, q_inputs['attention_mask'])
            p_pos_emb = quick_pool(p_pos_vectors, p_pos_inputs['attention_mask'])
            p_neg_emb = quick_pool(p_neg_vectors, p_neg_inputs['attention_mask'])
            pos_scores = (q_emb * p_pos_emb).sum(1)
            neg_scores = (q_emb * p_neg_emb).sum(1)
            loss = triplet_loss_fct(pos_scores, neg_scores, torch.ones(q_emb.size(0)).to(DEVICE)) / GRAD_ACC_STEPS
        scaler.scale(loss).backward()
        total_loss += loss.item() * GRAD_ACC_STEPS
        if (step + 1) % GRAD_ACC_STEPS == 0:
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)
        pbar.set_postfix({"Loss": loss.item() * GRAD_ACC_STEPS})
    print(f"Epoch {epoch+1} Mean Loss: {total_loss / len(train_dataloader):.4f}")

print("\n--- Saving Model ---")
os.makedirs(FT_OUT_DIR, exist_ok=True)
contriever_model_train.save_pretrained(FT_OUT_DIR)
contriever_tokenizer.save_pretrained(FT_OUT_DIR)
print(f"Saved to {FT_OUT_DIR}")

# Optional: zip for download
if os.path.exists(FT_OUT_DIR):
    shutil.make_archive(FT_OUT_DIR, 'zip', FT_OUT_DIR)
    print(f"Packaged fine-tuned model to {FT_OUT_DIR}.zip")

# Cleanup
if torch.cuda.is_available():
    torch.cuda.empty_cache()
# Disable grads again for later inference sections
torch.set_grad_enabled(False)


--- Step 8: Loading MS MARCO (General Domain) Triplets ---
Existing Temporal Triplets: 782
Streaming 'sentence-transformers/msmarco-msmarco-distilbert-base-tas-b' (Config: triplet-hard)...


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/24 [00:00<?, ?it/s]

Collecting MS MARCO: 100%|██████████| 1000/1000 [00:00<00:00, 1294.33it/s]


Collected 1000 MS MARCO Triplets.
Final Hybrid Training Set Size: 1782 Triplets

--- Step 9: Training Model (Hybrid 50/50) ---
Loading Fresh Model for Training...


  scaler = GradScaler(enabled=(DEVICE == 'cuda'))


Starting Training: 1782 triplets, 14 epochs


Epoch 1/14: 100%|██████████| 56/56 [00:26<00:00,  2.11it/s, Loss=0.5]


Epoch 1 Mean Loss: 0.7612


Epoch 2/14: 100%|██████████| 56/56 [00:26<00:00,  2.15it/s, Loss=0.362]


Epoch 2 Mean Loss: 0.5833


Epoch 3/14: 100%|██████████| 56/56 [00:25<00:00,  2.16it/s, Loss=0.442]


Epoch 3 Mean Loss: 0.4679


Epoch 4/14: 100%|██████████| 56/56 [00:25<00:00,  2.17it/s, Loss=0.267]


Epoch 4 Mean Loss: 0.3819


Epoch 5/14: 100%|██████████| 56/56 [00:25<00:00,  2.19it/s, Loss=0.291]


Epoch 5 Mean Loss: 0.3076


Epoch 6/14: 100%|██████████| 56/56 [00:25<00:00,  2.19it/s, Loss=0.179]


Epoch 6 Mean Loss: 0.2420


Epoch 7/14: 100%|██████████| 56/56 [00:25<00:00,  2.19it/s, Loss=0.0815]


Epoch 7 Mean Loss: 0.1951


Epoch 8/14: 100%|██████████| 56/56 [00:25<00:00,  2.19it/s, Loss=0.248]


Epoch 8 Mean Loss: 0.1524


Epoch 9/14: 100%|██████████| 56/56 [00:25<00:00,  2.21it/s, Loss=0.0758]


Epoch 9 Mean Loss: 0.1304


Epoch 10/14: 100%|██████████| 56/56 [00:25<00:00,  2.21it/s, Loss=0.136]


Epoch 10 Mean Loss: 0.1175


Epoch 11/14: 100%|██████████| 56/56 [00:25<00:00,  2.22it/s, Loss=0.14]


Epoch 11 Mean Loss: 0.1010


Epoch 12/14: 100%|██████████| 56/56 [00:25<00:00,  2.21it/s, Loss=0.0352]


Epoch 12 Mean Loss: 0.0924


Epoch 13/14: 100%|██████████| 56/56 [00:25<00:00,  2.22it/s, Loss=0.00326]


Epoch 13 Mean Loss: 0.0879


Epoch 14/14: 100%|██████████| 56/56 [00:25<00:00,  2.22it/s, Loss=0.00206]


Epoch 14 Mean Loss: 0.0872

--- Saving Model ---
Saved to contriever_finetuned_NEW_20k
Packaged fine-tuned model to contriever_finetuned_NEW_20k.zip


torch.autograd.grad_mode.set_grad_enabled(mode=False)

In [None]:
len(combined_triplets)


1782

In [None]:
xae = 4
print('Question:',my_triplets[xae][0])
print('==================================================================================================')
print('Negative:',my_triplets[xae][1])
print('--------------------------------------------------------------------------------------------------')
print('Positive:',my_triplets[xae][2])

Question: What year was the first African American History Week?
Negative: “Those who have no record of what their forebears have accomplished lose the inspiration which comes from the teaching of biography and history.” -Carter G. Woodson
February is Black History Month, also known as African American History Month, and is a time to celebrate the achievements and acknowledge the inspiration of African Americans in the United States. Around the world; countries such as Canada and those in the United Kingdom also dedicate a month to celebrate African/African American history. Unlike those countries, the U.S. began celebrating Black History Month in 1979. The Association for the Study of African American Life and History (ASALH) was founded by University of Chicago graduate Carter G. Woodson in 1915, and was the first step to creating such an important commemoration. The ASALH initiated the first African American History Week in February, 1926. This week was selected because it included 

## 4. Load Contriever Models (Baseline & Time-Aware)


In [None]:
def load_contriever(name: str):
    tok = AutoTokenizer.from_pretrained(name)
    mdl = AutoModel.from_pretrained(name)
    mdl.to(DEVICE)
    mdl.eval()
    return mdl, tok

base_model, base_tokenizer = load_contriever(CONTRIEVER_BASE)
time_model, time_tokenizer = load_contriever(CONTRIEVER_TIMEAWARE)

# Quick sanity check
with torch.no_grad():
    dummy = ["who won the 1900 election?"]
    toks = base_tokenizer(dummy, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
    outputs = base_model(**toks)
    pooled = (outputs.last_hidden_state * toks.attention_mask.unsqueeze(-1)).sum(1)
    pooled = pooled / toks.attention_mask.sum(1, keepdim=True)
    print("Loaded baseline embedding dim:", pooled.shape[-1])



tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 75ea8b57-ddba-468d-a726-cda931986a1e)')' thrown while requesting HEAD https://huggingface.co/facebook/contriever-msmarco/resolve/main/added_tokens.json
Retrying in 1s [Retry 1/5].


special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Loaded baseline embedding dim: 768


In [None]:
def _pick_field(ex, candidates):
    for f in candidates:
        if f in ex and ex[f] is not None:
            txt = str(ex[f]).strip()
            if txt:
                return txt
    return None

In [None]:
def mean_pooling(last_hidden_state, attention_mask):
    masked = last_hidden_state * attention_mask.unsqueeze(-1)
    summed = masked.sum(dim=1)
    counts = attention_mask.sum(dim=1, keepdim=True).clamp(min=1)
    return summed / counts


def encode_texts(model, tokenizer, texts: List[str], batch_size: int = ENCODING_BATCH_SIZE, max_len: int = MAX_LENGTH):
    all_vecs = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Encoding", leave=False):
        batch = texts[i:i + batch_size]
        toks = tokenizer(batch, padding=True, truncation=True, max_length=max_len, return_tensors="pt").to(DEVICE)
        with torch.no_grad():
            out = model(**toks)
            pooled = mean_pooling(out.last_hidden_state, toks.attention_mask)
            pooled = torch.nn.functional.normalize(pooled, dim=-1)
        all_vecs.append(pooled.cpu())
    if not all_vecs:
        return np.zeros((0, model.config.hidden_size), dtype=np.float32)
    return torch.cat(all_vecs, dim=0).numpy().astype(np.float32)


def build_faiss_index(model, tokenizer, passages: List[str]):
    print("Building FAISS index (IP) with", len(passages), "passages")
    embs = encode_texts(model, tokenizer, passages)
    dim = embs.shape[1]
    index = faiss.IndexIDMap2(faiss.IndexFlatIP(dim))
    ids = np.arange(len(passages))
    index.add_with_ids(embs, ids)
    return index, dim


def retrieve_candidates_caqa(index, model, tokenizer, question_texts: List[str], top_k: int = RETRIEVE_TOPK):
    q_embs = encode_texts(model, tokenizer, question_texts)
    scores, ids = index.search(q_embs, top_k)
    return scores, ids


In [None]:
def _hit_at_k(ranked_ids, gold_id, k):
    return 1.0 if gold_id in ranked_ids[:k] else 0.0

def _mrr_at_k(ranked_ids, gold_id, k):
    for rank, pid in enumerate(ranked_ids[:k], start=1):
        if pid == gold_id:
            return 1.0 / rank
    return 0.0

## MRAG Version 1 TC Score:

In [None]:
import re
from dataclasses import dataclass
from typing import List, Optional

# -------- Temporal utils --------

@dataclass
class TemporalConstraintV1:
    """
    Very lightweight temporal constraint representation.
    type:
      - 'point' : a specific year (e.g. 'in 2019', 'as of 2019')
      - 'range' : between two years (e.g. 'between 2010 and 2015')
      # (You can extend with 'before', 'after', etc. later.)
    """
    type: str
    year: Optional[int] = None
    start_year: Optional[int] = None
    end_year: Optional[int] = None

# Regexes for temporal expressions
TEMP_RANGE_RE = re.compile(
    r"(between|from)\s+(\d{4})\s+(and|to|-)\s+(\d{4})",
    flags=re.IGNORECASE,
)
TEMP_POINT_RE = re.compile(
    r"(as of|in|on|around|by)\s+(\d{4})",
    flags=re.IGNORECASE,
)
YEAR_ANY_RE = re.compile(
    r"\b(1[5-9]\d{2}|20\d{2}|2100)\b"  # 1500–2100, adjust if you want
)

def decompose_question_temporal_v1(question_text: str):
    """
    Heuristic decomposition into:
      - main content (MC) text with temporal phrase removed
      - TemporalConstraint object (or None if no year)
    """

    q = question_text.strip()

    # 1) Range like 'between 2010 and 2015'
    m = TEMP_RANGE_RE.search(q)
    if m:
        y1, y2 = int(m.group(2)), int(m.group(4))
        start, end = min(y1, y2), max(y1, y2)
        tc = TemporalConstraintV1(type="range", start_year=start, end_year=end)

        mc = (q[:m.start()] + " " + q[m.end():]).strip()
        mc = re.sub(r"\s+", " ", mc)
        if not mc:
            mc = q
        return mc, tc

    # 2) Point constraint: 'as of 2019', 'in 2019', etc.
    m = TEMP_POINT_RE.search(q)
    if m:
        y = int(m.group(2))
        tc = TemporalConstraintV1(type="point", year=y)

        mc = (q[:m.start()] + " " + q[m.end():]).strip()
        mc = re.sub(r"\s+", " ", mc)
        if not mc:
            mc = q
        return mc, tc

    # 3) Fallback: first standalone 4-digit year
    m = YEAR_ANY_RE.search(q)
    if m:
        y = int(m.group(1))
        tc = TemporalConstraintV1(type="point", year=y)

        mc = (q[:m.start()] + " " + q[m.end():]).strip()
        mc = re.sub(r"\s+", " ", mc)
        if not mc:
            mc = q
        return mc, tc

    # 4) No temporal constraint
    return q, None

YEAR_PATTERN = re.compile(r"\b(1[5-9]\d{2}|20\d{2}|2100)\b")

DOC_YEARS_CACHE = {}

def extract_years_from_text(text: str) -> List[int]:
    if not text:
        return []
    years = [int(m.group(1)) for m in YEAR_PATTERN.finditer(text)]
    return sorted(set(years))

def get_doc_years(doc_id: int, doc_text: str) -> List[int]:
    if doc_id in DOC_YEARS_CACHE:
        return DOC_YEARS_CACHE[doc_id]
    years = extract_years_from_text(doc_text)
    DOC_YEARS_CACHE[doc_id] = years
    return years

def compute_temporal_score(
    tc: Optional[TemporalConstraintV1],
    doc_years: List[int],
    max_span: int = 20,
    neutral_if_missing: float = 0.5,
) -> float:
    """
    Returns a temporal relevance score in [0,1].
    - tc is TemporalConstraint or None
    - doc_years is a list of years found in the passage (or metadata)
    """

    if tc is None:
        # No explicit temporal constraint => treat as temporally neutral
        return 1.0

    if not doc_years:
        # No years in the doc => partial credit but penalized
        return neutral_if_missing

    # Helper: triangular kernel over distance
    def triangular(distance: float) -> float:
        # 1 at distance=0, linearly decays to 0 at distance >= max_span
        return max(0.0, 1.0 - (distance / float(max_span)))

    if tc.type == "point" and tc.year is not None:
        y0 = tc.year
        # Closest year in the doc
        diff = min(abs(y - y0) for y in doc_years)
        return triangular(diff)

    if tc.type == "range" and tc.start_year is not None and tc.end_year is not None:
        start, end = tc.start_year, tc.end_year
        # If any year lies inside [start, end], perfect temporal match
        if any(start <= y <= end for y in doc_years):
            return 1.0

        # Else use distance to nearest edge of the range
        distances = []
        for y in doc_years:
            if y < start:
                distances.append(start - y)
            elif y > end:
                distances.append(y - end)
        if not distances:
            return neutral_if_missing
        diff = min(distances)
        return triangular(diff)

    # Fallback: unknown type => neutral
    return 1.0

In [None]:
def mrag_rerank_1(
    question_text: str,
    candidate_passages: List[str],
    candidate_ids: List[int],
    model,
    tokenizer,
    base_scores: np.ndarray = None,
    use_llm_summaries: bool = False,  # still ignored in this fast version
    blend_weight: float = 0.0,
    temporal_weight: float = 1.0,     # how strongly to apply temporal scoring
    **kwargs,
):
    """
    MRAG-style reranking:
      1) Decompose question into main content (MC) and temporal constraint (TC).
      2) Semantic score: dense MaxSim over windows using MC.
      3) Temporal score: from TC + doc years.
      4) Hybrid score = semantic_score * (temporal_weight * temporal_score + (1 - temporal_weight)).

    When temporal_weight=1.0, this reduces to semantic_score * temporal_score.
    When temporal_weight=0.0, it collapses back to pure semantic_score.
    """

    if not candidate_passages:
        return [], []

    # --- 1. Decompose question into MC + TC ---
    mc_text, tc = decompose_question_temporal_v1(question_text)

    # --- 2. Pre-tokenize windows if needed (same as your original) ---
    # Assuming _get_doc_id and PRETOKENIZED_WINDOWS / pretokenize_passages exist already
    first_doc_id = _get_doc_id(candidate_ids[0])
    if first_doc_id not in PRETOKENIZED_WINDOWS:
        pretokenize_passages(candidate_passages, candidate_ids)

    # --- 3. Dense MaxSim over windows using MC instead of full question ---
    granular_scores = get_dense_maxsim_scores(
        mc_text, candidate_ids, model, tokenizer
    )
    granular_scores = np.array(granular_scores[: len(candidate_ids)], dtype=np.float32)

    # --- 4. Normalize base scores and granular scores to [0,1] ---
    base_scores_norm = np.zeros(len(candidate_ids), dtype=np.float32)
    if base_scores is not None and len(base_scores) > 0:
        bs = np.array(base_scores[: len(candidate_ids)], dtype=np.float32)
        if bs.max() > bs.min():
            base_scores_norm = (bs - bs.min()) / (bs.max() - bs.min())
        else:
            base_scores_norm = bs

    if granular_scores.max() > granular_scores.min():
        granular_scores_norm = (granular_scores - granular_scores.min()) / (
            granular_scores.max() - granular_scores.min()
        )
    else:
        granular_scores_norm = granular_scores

    # --- 5. Precompute doc years for each candidate ---
    doc_years_list: List[List[int]] = []
    for cid, text in zip(candidate_ids, candidate_passages):
        doc_years_list.append(get_doc_years(cid, text))

    # --- 6. Combine semantic + temporal into hybrid scores ---
    final_scores = {}
    for i, cid in enumerate(candidate_ids):
        # Semantic score: blend base index score and dense MaxSim
        semantic_score = (
            blend_weight * float(base_scores_norm[i])
            + (1.0 - blend_weight) * float(granular_scores_norm[i])
        )

        # Temporal score in [0,1]
        t_score = compute_temporal_score(tc, doc_years_list[i]) if tc else 1.0

        # Hybrid MRAG-style score
        # If temporal_weight=1.0 -> purely multiplicative semantic * t_score
        # If temporal_weight<1.0 -> soften temporal influence
        hybrid_temporal_factor = (
            temporal_weight * t_score + (1.0 - temporal_weight)
        )
        hybrid_score = semantic_score * hybrid_temporal_factor

        '''print('-----------')
        print('semantic score:',semantic_score)
        print('hybrid_temporal_factor:',hybrid_temporal_factor)
        print('hybrid_score:',hybrid_score)
        print('-----------')'''

        final_scores[cid] = hybrid_score

    # --- 7. Sort and return ---
    ranked_ids = sorted(final_scores, key=final_scores.get, reverse=True)
    ranked_scores = [final_scores[cid] for cid in ranked_ids]

    return ranked_ids, ranked_scores

## 7A. In-domain FineWeb Test (20% holdout)
Evaluate both Contriever variants on the held-out synthetic temporal QG split that was reserved from FineWeb (the `test_set`). This keeps the evaluation in-distribution before moving to CAQA.

In [None]:
# --- FineWeb 20% test set evaluation (baseline vs time-aware) ---

import os, shutil, numpy as np, faiss

def evaluate_fineweb_test(index, model, tokenizer, test_pairs, desc, k_list=EVAL_KS):
    questions = [q for q, _, _ in test_pairs]
    gold_ids = [pid for _, _, pid in test_pairs]

    q_embs = encode_contriever(
        model,
        tokenizer,
        questions,
        max_len=MAX_LENGTH,
        batch=ENCODING_BATCH_SIZE,
    )
    scores, ids = index.search(q_embs, max(k_list))

    metrics = {f"hit@{k}": 0.0 for k in k_list}
    metrics.update({f"mrr@{k}": 0.0 for k in k_list})

    for ranked, gold in zip(ids, gold_ids):
        for k in k_list:
            metrics[f"hit@{k}"] += _hit_at_k(ranked, gold, k)
            metrics[f"mrr@{k}"] += _mrr_at_k(ranked, gold, k)

    total = float(len(gold_ids)) if gold_ids else 1.0
    metrics = {k: v / total for k, v in metrics.items()}

    print(f"[EVAL] {desc} | q={len(questions)} | ks={k_list}")
    print(metrics)
    RESULTS.append({
        "Model": desc,
        "Split": "fineweb_test",
        **{k.upper(): v for k, v in metrics.items()},
    })
    return metrics

def build_faiss_index_with_ids(model, tokenizer, passages_list, passage_ids_list, out_dir, index_path, max_len=MAX_LENGTH):
    """ID-preserving FAISS builder for FineWeb test eval."""
    shutil.rmtree(out_dir, ignore_errors=True)
    os.makedirs(out_dir, exist_ok=True)

    embs = encode_contriever(
        model,
        tokenizer,
        passages_list,
        max_len=max_len,
        batch=ENCODING_BATCH_SIZE,
    )
    dim = embs.shape[1]
    index = faiss.IndexIDMap2(faiss.IndexFlatIP(dim))
    ids = np.array(passage_ids_list, dtype=np.int64)
    index.add_with_ids(embs, ids)
    faiss.write_index(index, index_path)
    print(f"Built index at {index_path} with {index.ntotal} vectors")
    return index

print("\n--- Evaluating on held-out FineWeb temporal test set (20%) ---")

FW_TEST_BASE_DIR = "contriever_fineweb_eval_base"
FW_TEST_TIME_DIR = "contriever_fineweb_eval_timeaware"
FW_TEST_BASE_PATH = os.path.join(FW_TEST_BASE_DIR, "fineweb_eval_base.index")
FW_TEST_TIME_PATH = os.path.join(FW_TEST_TIME_DIR, "fineweb_eval_time.index")

fineweb_base_index = build_faiss_index_with_ids(
    base_model,
    base_tokenizer,
    corpus_passages_list,
    corpus_passage_ids_list,
    FW_TEST_BASE_DIR,
    FW_TEST_BASE_PATH,
    max_len=MAX_LENGTH,
)

fineweb_time_index = build_faiss_index_with_ids(
    time_model,
    time_tokenizer,
    corpus_passages_list,
    corpus_passage_ids_list,
    FW_TEST_TIME_DIR,
    FW_TEST_TIME_PATH,
    max_len=MAX_LENGTH,
)

metrics_fineweb_base = evaluate_fineweb_test(
    fineweb_base_index,
    base_model,
    base_tokenizer,
    test_set,
    desc="FineWeb 20% test [BASE]",
    k_list=EVAL_KS,
)

metrics_fineweb_time = evaluate_fineweb_test(
    fineweb_time_index,
    time_model,
    time_tokenizer,
    test_set,
    desc="FineWeb 20% test [TIME-AWARE]",
    k_list=EVAL_KS,
)

In [None]:
def extract_mc(question: str) -> str:
    mc_text, tc = decompose_question_temporal_v1(question)
    return mc_text

## 8. MRAG Re-Ranking Logic (ported from MRAG Baseline Implementation)


In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import nltk
import re
from typing import List, Tuple

# --- Configuration ---
MRAG_BASE_WEIGHT_DEFAULT = 0.5  # Adjust this: 0.8 often works better if the Base retriever is strong
WINDOW_SIZE = 3                 # Number of sentences per window (preserves context)
WINDOW_STRIDE = 1               # Overlap (1 means we shift by 1 sentence)

# Ensure NLTK resources
try:
    nltk.data.find("tokenizers/punkt")
except LookupError:
    nltk.download("punkt", quiet=True)

# Global Cache to prevent re-tokenizing every call
PRETOKENIZED_WINDOWS = {}
MRAG_BASE_WEIGHT_DEFAULT = 0.0

def _get_doc_id(pid):
    return str(pid)

def mean_pooling(last_hidden_state, attention_mask):
    """Standard mean pooling for Contriever/BERT models."""
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
    sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

def pretokenize_passages(passages: List[str], ids: List[int]):
    """
    Splits passages into overlapping WINDOWS (e.g., 3 sentences).
    This creates 'Granular Evidence' that retains context.
    """
    global PRETOKENIZED_WINDOWS
    PRETOKENIZED_WINDOWS = {}

    for pid, text in zip(ids, passages):
        # 1. Split into sentences
        try:
            snts = nltk.sent_tokenize(text or "")
        except Exception:
            snts = re.split(r"(?<=[.!?])\s+", text or "")
        snts = [s.strip() for s in snts if s.strip()]

        # 2. Create Sliding Windows (Contextual Granularity)
        windows = []
        if not snts:
            windows = [""]
        else:
            # If doc is shorter than window, take whole doc
            if len(snts) <= WINDOW_SIZE:
                windows.append(" ".join(snts))
            else:
                # Sliding window
                for i in range(0, len(snts) - WINDOW_SIZE + 1, WINDOW_STRIDE):
                    window_text = " ".join(snts[i : i + WINDOW_SIZE])
                    windows.append(window_text)

        PRETOKENIZED_WINDOWS[_get_doc_id(pid)] = windows

@torch.no_grad()
def get_dense_maxsim_scores(
    question_text: str,
    candidate_ids: List[int],
    model,
    tokenizer,
    batch_size: int = 128
) -> List[float]:
    """
    Computes the MaxSim score using Dense Embeddings over Sliding Windows.
    Returns the maximum similarity found in any window of the document.
    """
    model.eval()

    # 1. Flatten all windows from all candidates
    all_windows = []

    for cid in candidate_ids:
        windows = PRETOKENIZED_WINDOWS.get(_get_doc_id(cid), [])
        if not windows:
            windows = [""]
        all_windows.extend(windows)

    # 2. Encode Question
    q_tok = tokenizer([question_text], padding=True, truncation=True, max_length=128, return_tensors="pt").to(model.device)
    q_out = model(**q_tok)
    q_emb = mean_pooling(q_out.last_hidden_state, q_tok['attention_mask'])
    q_emb = F.normalize(q_emb, p=2, dim=1) # (1, H)

    # 3. Encode Windows (Batched)
    all_window_embs = []
    # Batch processing to avoid OOM
    for i in range(0, len(all_windows), batch_size):
        batch_text = all_windows[i : i + batch_size]
        tok = tokenizer(batch_text, padding=True, truncation=True, max_length=128, return_tensors="pt").to(model.device)
        out = model(**tok)
        emb = mean_pooling(out.last_hidden_state, tok['attention_mask'])
        emb = F.normalize(emb, p=2, dim=1)
        all_window_embs.append(emb)

    if not all_window_embs:
        return [0.0] * len(candidate_ids)

    all_window_embs = torch.cat(all_window_embs, dim=0) # (Total_Windows, H)

    # 4. Compute Similarities (Vectorized)
    # (Total_Windows, H) @ (H, 1) -> (Total_Windows, 1)
    sims = torch.mm(all_window_embs, q_emb.T).squeeze(1)
    sims_cpu = sims.cpu().numpy()

    # 5. Map back to Documents (Max-Pooling)
    doc_max_scores = np.zeros(len(candidate_ids), dtype=np.float32)
    current_offset = 0

    for i, cid in enumerate(candidate_ids):
        windows = PRETOKENIZED_WINDOWS.get(_get_doc_id(cid), [])
        count = len(windows) if windows else 1

        # Get slice of scores for this document's windows
        chunk = sims_cpu[current_offset : current_offset + count]
        if chunk.size > 0:
            doc_max_scores[i] = float(chunk.max())
        else:
            doc_max_scores[i] = 0.0

        current_offset += count

    return doc_max_scores.tolist()

## 11. TempRAGEval + ATLAS 2021 Corpus (Optional)
This section mirrors the baseline notebook flow to evaluate on TempRAGEval using your custom ATLAS 2021 Wikipedia corpus. Set the paths below before running.


### MRAG Version 2 TC Score: True to MRAG and TempRAGEval

In [None]:
from dataclasses import dataclass
from typing import Optional

@dataclass
class TemporalConstraintV2:
    kind: str          # 'first_before', 'last_before', 'first_after', 'last_after',
                       # 'first_between', 'last_between'
    t0: Optional[int]  # main boundary (e.g., 1981)
    t1: Optional[int]  # second boundary when 'between' (e.g., 1999)

import re

YEAR_RE = re.compile(r"\b(1[0-9]{3}|20[0-9]{2})\b")

def parse_temporal_constraint(tc: str) -> Optional[TemporalConstraintV2]:
    tc_lower = tc.lower()

    years = [int(y) for y in YEAR_RE.findall(tc_lower)]
    if not years:
        return None

    # If two years mentioned → likely "between"
    if len(years) >= 2:
        y0, y1 = sorted(years[:2])
        if any(w in tc_lower for w in ["earliest", "first", "earlier", "from"]):
            return TemporalConstraintV2("first_between", y0, y1)
        if any(w in tc_lower for w in ["latest", "last", "as of", "until", "up to"]):
            return TemporalConstraintV2("last_between", y0, y1)
        # fallback
        return TemporalConstraintV2("last_between", y0, y1)

    # Single year
    y = years[0]
    is_firstish = any(w in tc_lower for w in ["earliest", "first", "initial"])
    is_lastish  = any(w in tc_lower for w in ["latest", "last", "most recent", "as of"])
    has_before  = "before" in tc_lower
    has_after   = "after" in tc_lower or "since" in tc_lower

    if is_firstish and has_before:
        return TemporalConstraintV2("first_before", y, None)
    if is_lastish and has_before:
        return TemporalConstraintV2("last_before", y, None)
    if is_firstish and (has_after or "after" in tc_lower):
        return TemporalConstraintV2("first_after", y, None)
    if is_lastish and (has_after or "after" in tc_lower):
        return TemporalConstraintV2("last_after", y, None)

    # “as of 2021” → usually "last_before"
    if "as of" in tc_lower or "by " in tc_lower:
        return TemporalConstraintV2("last_before", y, None)

    # fallback: treat as "last_before"
    return TemporalConstraintV2("last_before", y, None)

def extract_years(text: str):
    return [int(y) for y in YEAR_RE.findall(text)]


def clamp01(x: float) -> float:
    return max(0.0, min(1.0, x))

def score_last_before(t: int, T: int, width_past: int = 80, width_future: int = 20) -> float:
    # baseline floor
    floor = 0.2
    # from (T - width_past) → T, linearly from floor → 1
    if t <= T:
        if t <= T - width_past:
            return floor
        return floor + (1.0 - floor) * (t - (T - width_past)) / width_past
    # after T, decay from 1 → ~0.4 over width_future
    if t >= T + width_future:
        return 0.4
    return 1.0 - (1.0 - 0.4) * (t - T) / width_future

def score_first_before(t: int, T: int, width_past: int = 80) -> float:
    # prefer earlier dates; linear decay towards T
    if t >= T:
        return 0.2
    if t <= T - width_past:
        return 1.0
    # from (T - width_past) → T, 1 → 0.2
    return 1.0 - (1.0 - 0.2) * (t - (T - width_past)) / width_past

def score_last_after(t: int, T: int, width_future: int = 80) -> float:
    # before T: small
    if t <= T:
        return 0.2
    if t >= T + width_future:
        return 0.4
    # from T → T+width_future, 0.6 → 1 then back to 0.4-ish
    # simple hill shape: peak at T + width_future/2
    mid = T + width_future / 2.0
    if t <= mid:
        return 0.6 + (1.0 - 0.6) * (t - T) / (mid - T)
    else:
        return 1.0 - (1.0 - 0.4) * (t - mid) / (T + width_future - mid)

def score_first_after(t: int, T: int, width_future: int = 80) -> float:
    # best just after T, then decay
    if t <= T:
        return 0.2
    if t >= T + width_future:
        return 0.2
    # from T → T+width_future, 1 → 0.2
    return 1.0 - (1.0 - 0.2) * (t - T) / width_future

def score_last_between(t: int, a: int, b: int) -> float:
    if t < a:
        return 0.2
    if t > b + 40:
        return 0.4
    if t > b:
        # slight decay after b
        return 1.0 - (1.0 - 0.4) * (t - b) / 40.0
    # inside [a, b], ramp up from 0.2 at a to 1 at b
    return 0.2 + 0.8 * (t - a) / (b - a + 1e-6)

def score_first_between(t: int, a: int, b: int) -> float:
    if t < a - 40:
        return 0.2
    if t < a:
        # slight ramp up as you approach a
        return 0.2 + 0.8 * (t - (a - 40)) / 40.0
    if t > b:
        return 0.2
    # inside [a, b], ramp down from 1 at a to 0.2 at b
    return 1.0 - (1.0 - 0.2) * (t - a) / (b - a + 1e-6)

def temporal_score_mrag_style(tc: TemporalConstraintV2, years_in_sent):
    if not years_in_sent or tc is None:
        return 0.2  # small floor

    scores = []
    for y in years_in_sent:
        if tc.kind == "last_before":
            scores.append(score_last_before(y, tc.t0))
        elif tc.kind == "first_before":
            scores.append(score_first_before(y, tc.t0))
        elif tc.kind == "last_after":
            scores.append(score_last_after(y, tc.t0))
        elif tc.kind == "first_after":
            scores.append(score_first_after(y, tc.t0))
        elif tc.kind == "last_between":
            scores.append(score_last_between(y, tc.t0, tc.t1))
        elif tc.kind == "first_between":
            scores.append(score_first_between(y, tc.t0, tc.t1))
    return max(scores) if scores else 0.2

def decompose_question_temporal_v2(question: str):
    """
    Decompose a question into:
      - mc_text: main content (question with years stripped out)
      - tc_struct: TemporalConstraint(kind, t0, t1) or None

    This is wired to your *current* TemporalConstraint dataclass and
    parse_temporal_constraint() implementation.
    """
    # 1) Build MC text by removing years from the question
    mc_text = YEAR_RE.sub("", question)
    mc_text = re.sub(r"\s+", " ", mc_text).strip()
    if not mc_text:
        mc_text = question

    # 2) Parse temporal constraint from the full question text
    tc_struct = parse_temporal_constraint(question)  # may return None

    return mc_text, tc_struct

def v1_to_v2_temporal_constraint(tc_v1: Optional[TemporalConstraintV1]) -> Optional[TemporalConstraintV2]:
    """
    Convert from V1's simple 'point'/'range' TC to a V2 MRAG-style constraint.
    Heuristic:
      - point(year)   -> 'last_before' year  (as-of / before)
      - range(a, b)   -> 'last_between' [a,b]
    You can refine this later with 'first_*' variants if you want.
    """
    if tc_v1 is None:
        return None

    if tc_v1.type == "point" and tc_v1.year is not None:
        return TemporalConstraintV2(kind="last_before", t0=tc_v1.year, t1=None)

    if tc_v1.type == "range" and tc_v1.start_year is not None and tc_v1.end_year is not None:
        return TemporalConstraintV2(
            kind="last_between",
            t0=min(tc_v1.start_year, tc_v1.end_year),
            t1=max(tc_v1.start_year, tc_v1.end_year),
        )

    # Fallback: treat unknown as 'last_before' on a single year if available
    if tc_v1.year is not None:
        return TemporalConstraintV2(kind="last_before", t0=tc_v1.year, t1=None)

    return None

In [None]:
def mrag_rerank_2(
    question_text: str,
    candidate_passages: List[str],
    candidate_ids: List[int],
    model,
    tokenizer,
    base_scores: np.ndarray = None,
    use_llm_summaries: bool = False,  # still ignored in this fast version
    blend_weight: float = 0.0,
    temporal_weight: float = 1.0,     # how strongly to apply temporal scoring
    **kwargs,
):
    """
    MRAG-style reranking with:
      - Semantic part matching Version 1's logic:
          * MC from V1 decompose_question_temporal()
          * Dense MaxSim over windows using MC
          * Blend with base FAISS scores using blend_weight
      - Temporal part using MRAG-style temporal_score_mrag_style() on a
        V2 TemporalConstraint converted from V1's TC.
    """

    if not candidate_passages:
        return [], []

    # --- 1. Decompose question into MC + V1-style TC -------------------
    mc_text, tc_v1 = decompose_question_temporal_v1(question_text)

    # Convert to V2 MRAG-style TC for temporal scoring
    tc_v2 = v1_to_v2_temporal_constraint(tc_v1)

    # --- 2. Pre-tokenize windows if needed (same as original) ----------
    first_doc_id = _get_doc_id(candidate_ids[0])
    if first_doc_id not in PRETOKENIZED_WINDOWS:
        pretokenize_passages(candidate_passages, candidate_ids)

    # --- 3. Dense MaxSim over windows using MC (V1 behavior) ----------
    granular_scores = get_dense_maxsim_scores(
        mc_text, candidate_ids, model, tokenizer
    )
    granular_scores = np.array(granular_scores[: len(candidate_ids)], dtype=np.float32)

    # --- 4. Normalize base scores and granular scores to [0,1] ---------
    base_scores_norm = np.zeros(len(candidate_ids), dtype=np.float32)
    if base_scores is not None and len(base_scores) > 0:
        bs = np.array(base_scores[: len(candidate_ids)], dtype=np.float32)
        if bs.max() > bs.min():
            base_scores_norm = (bs - bs.min()) / (bs.max() - bs.min())
        else:
            base_scores_norm = bs

    if granular_scores.max() > granular_scores.min():
        granular_scores_norm = (granular_scores - granular_scores.min()) / (
            granular_scores.max() - granular_scores.min()
        )
    else:
        granular_scores_norm = granular_scores

    # --- 5. Precompute doc years for each candidate passage ------------
    doc_years_list: List[List[int]] = []
    for cid, text in zip(candidate_ids, candidate_passages):
        years = get_doc_years(cid, text)  # V1 behavior
        doc_years_list.append(years)

    # --- 6. Combine semantic + MRAG-style temporal into hybrid scores --
    final_scores = {}
    for i, cid in enumerate(candidate_ids):
        # Semantic score: SAME logic as Version 1
        semantic_score = (
            blend_weight * float(base_scores_norm[i])
            + (1.0 - blend_weight) * float(granular_scores_norm[i])
        )

        # Temporal score in [0,1], MRAG-style:
        if tc_v2 is not None and temporal_weight > 0.0:
            t_score_raw = temporal_score_mrag_style(tc_v2, doc_years_list[i])
        else:
            t_score_raw = 1.0  # neutral if no temporal constraint

        hybrid_temporal_factor = (
            temporal_weight * t_score_raw + (1.0 - temporal_weight)
        )
        hybrid_score = semantic_score * hybrid_temporal_factor

        '''print('=====================')
        print('Semantic score::',semantic_score)
        print('Hybrid_temporal_factor::',hybrid_temporal_factor)
        print('Hybrid_score::',hybrid_score)
        print('=====================')'''

        final_scores[cid] = hybrid_score

    # --- 7. Sort and return -------------------------------------------
    ranked_ids = sorted(final_scores, key=final_scores.get, reverse=True)
    ranked_scores = [final_scores[cid] for cid in ranked_ids]

    return ranked_ids, ranked_scores


### Loading in TempRAGEval and Custom Corpus and Testing:

In [None]:
# =========================
# TempRAGEval + Custom Corpus (MRAG v2)
# =========================

# Paths for TempRAGEval + ATLAS corpus
TEMPRAGEVAL_DATA_PATH = None  # if None, use HF hub "siyue/TempRAGEval"; else point to local json/jsonl
ATLAS_CORPUS_FILES = [
    "/content/infobox.jsonl",
    "/content/text-list-100-sec.jsonl"
]  # list of json/jsonl files for ATLAS 2021 Wikipedia dump (your custom corpus)

DOC_ID_FIELD = "id"     # adjust if your corpus uses a different field
TITLE_FIELD = "title"
TEXT_FIELD = "text"

# Retrieval settings for TempRAGEval
TEMPRAGEVAL_TOPK = 100
TEMPRAGEVAL_KS = (1, 5, 10, 20)

MAX_TEMPRAGE_QUESTIONS = 100  # change to None or bigger number later

from nltk.tokenize import word_tokenize
from datasets import load_dataset
from tqdm.auto import tqdm
import faiss
import json
import numpy as np

ANSWER_TOKENIZER = None

# ---------- Utility: MC extraction (V1 decomposition) ----------

def extract_mc_temprage(q: str) -> str:
    """
    Use your existing decompose_question_temporal_v1() to get MC-only text
    for retrieval.
    """
    mc_text, _ = decompose_question_temporal_v1(q)
    return mc_text

# ---------- Loading corpus & dataset ----------

def _load_json_or_jsonl(path):
    data = []
    if path.endswith('.jsonl') or path.endswith('.jsonl.gz'):
        import gzip
        opener = gzip.open if path.endswith('.gz') else open
        print(f"[TempRAGE] Loading JSONL file: {path}")
        with opener(path, 'rt', encoding='utf-8') as f:
            for line in tqdm(f, desc=f"Reading {path}", unit="lines"):
                line = line.strip()
                if not line:
                    continue
                data.append(json.loads(line))
    else:
        print(f"[TempRAGE] Loading JSON file: {path}")
        with open(path, 'r', encoding='utf-8') as f:
            loaded = json.load(f)
            if isinstance(loaded, dict):
                loaded = loaded.get('data', loaded.get('docs', loaded.get('corpus', [])))
            data = loaded
    print(f"[TempRAGE] Loaded {len(data):,} records from {path}")
    return data


def load_atlas_corpus(file_list, doc_id_field=DOC_ID_FIELD, title_field=TITLE_FIELD, text_field=TEXT_FIELD):
    corpus = []
    for fp in file_list:
        print(f"[TempRAGE] Loading corpus file: {fp}")
        records = _load_json_or_jsonl(fp)
        for rec in tqdm(records, desc=f"Indexing records from {fp}", unit="docs"):
            orig_id = rec.get(doc_id_field)
            title = rec.get(title_field, "") or ""
            text = rec.get(text_field, "") or ""
            doc_num = len(corpus)
            doc_id_str = str(doc_num)
            corpus.append({
                "id": doc_id_str,
                "orig_id": orig_id,
                "title": title,
                "text": text
            })
    print(f"[TempRAGE] Loaded {len(corpus):,} corpus passages total.")
    return corpus


def load_temprageval(path=None, split=None):
    split = split or globals().get("TEMPRAGEVAL_SPLIT", "test")
    if path is None:
        print(f"[TempRAGE] Loading TempRAGEval from HF hub: siyue/TempRAGEval split={split}")
        ds = load_dataset("siyue/TempRAGEval", split=split)
    else:
        print(f"[TempRAGE] Loading TempRAGEval from local: {path}")
        ds = load_dataset('json', data_files=path, split="train")
    questions = []
    answers = []
    print("[TempRAGE] Extracting questions & answers from dataset...")
    for ex in tqdm(ds, desc="Reading TempRAGE examples", unit="ex"):
        q = ex.get('question') or ex.get('query') or ex.get('input') or ""
        ans_list = ex.get('answers') or ex.get('answer') or []
        if isinstance(ans_list, str):
            ans_list = [ans_list]
        questions.append(q)
        answers.append(ans_list)
    print(f"[TempRAGE] Loaded {len(questions)} TempRAGEval questions")
    return questions, answers

# ---------- Answer match / qrels ----------

def _get_answer_tokenizer():
    global ANSWER_TOKENIZER
    if ANSWER_TOKENIZER is None:
        from transformers import AutoTokenizer
        print(f"[TempRAGE] Initializing answer tokenizer from {CONTRIEVER_BASE}")
        ANSWER_TOKENIZER = AutoTokenizer.from_pretrained(CONTRIEVER_BASE)
    return ANSWER_TOKENIZER


def has_answer(answers, text):
    # Baseline-style substring check; could be replaced with tokenizer-based match
    t = (text or "").lower()
    for a in answers:
        if a and a.lower() in t:
            return True
    return False


def build_temprageval_qrels(corpus, questions, answers):
    # Baseline builds qrels by matching any answer string in any passage text
    print("[TempRAGE] Building TempRAGEval qrels by answer string match...")
    qrels = []
    for qi, ans_list in enumerate(
        tqdm(answers, desc="Building qrels over questions", unit="q")
    ):
        rel = set()
        # simple linear scan; this is heavy, so tqdm is helpful
        for doc in corpus:
            if has_answer(ans_list, doc.get('text', '')):
                rel.add(doc['id'])
        qrels.append(rel)
    evaluable = sum(1 for r in qrels if r)
    print(f"[TempRAGE] Qrels built: {evaluable}/{len(qrels)} questions have at least 1 matching passage")
    return qrels

# ---------- FAISS Building ----------

def build_faiss_for_corpus(model, tokenizer, passages):
    print("[TempRAGE] Encoding corpus for FAISS index...")
    texts = [p.get('title', '') + ' ' + p.get('text', '') for p in passages]
    embs = encode_texts(model, tokenizer, texts)
    dim = embs.shape[1]
    print(f"[TempRAGE] Building FAISS index (dim={dim}, n_docs={len(passages):,})...")
    ids = np.arange(len(passages), dtype=np.int64)  # enforce sequential numeric IDs
    index = faiss.IndexIDMap2(faiss.IndexFlatIP(dim))
    index.add_with_ids(embs, ids)
    print("[TempRAGE] FAISS index ready.")
    return index

# ---------- Eval with MRAG v2 (mrag_rerank_2) + MC/full split ----------

def evaluate_temprageval(
    corpus,
    questions,
    answers,
    index,
    model,
    tokenizer,
    use_mrag=False,
    k_list=TEMPRAGEVAL_KS,
    desc="",
    use_llm_summaries=False,
    top_k_candidates=TEMPRAGEVAL_TOPK,
    retrieval_questions=None,
    full_questions=None,
    qrels=None,
):
    """
    Generalized TempRAGEval evaluator:

    - retrieval_questions: what gets encoded and sent to FAISS
    - full_questions: what gets passed to MRAG (for TC parsing). If None, we
      fall back to retrieval_questions.
    - qrels: optional precomputed qrels; if None, we build them from answers.
    """

    # Defaults
    if retrieval_questions is None:
        retrieval_questions = questions
    if full_questions is None:
        full_questions = questions

    # Build qrels once if not provided
    if qrels is None:
        print("[TempRAGE] No qrels provided, building now...")
        qrels = build_temprageval_qrels(corpus, questions, answers)

    print(
        f"[EVAL-TempRAGE] {desc} | q={len(retrieval_questions)} | ks={k_list} | "
        f"top_k_candidates={top_k_candidates} | use_mrag={use_mrag}"
    )

    # Encode retrieval questions and retrieve
    print("[TempRAGE] Encoding retrieval questions...")
    q_embs = encode_texts(model, tokenizer, retrieval_questions)

    print("[TempRAGE] Running FAISS search for all questions...")
    scores, ids = index.search(q_embs, top_k_candidates)

    metrics = {f"hit@{k}": 0.0 for k in k_list}
    metrics.update({f"mrr@{k}": 0.0 for k in k_list})

    print("[TempRAGE] Scoring / reranking over all questions...")
    for qi, rel_set in enumerate(
        tqdm(qrels, desc=f"Scoring ({desc})", unit="q")
    ):
        if not rel_set:
            continue

        cand_ids = [int(cid) for cid in ids[qi] if cid >= 0]
        cand_scores = scores[qi][: len(cand_ids)] if scores is not None else None

        cand_texts = []
        for cid in cand_ids:
            if 0 <= cid < len(corpus):
                cand_texts.append(corpus[cid]['title'] + ' ' + corpus[cid]['text'])
            else:
                cand_texts.append("")

        # Reranking
        if use_mrag:
            q_for_mrag = full_questions[qi]
            reranked_ids, _ = mrag_rerank_2(
                q_for_mrag,
                cand_texts,
                cand_ids,
                model,
                tokenizer,
                base_scores=cand_scores,
                use_llm_summaries=False,  # summaries disabled
                blend_weight=0.0,         # semantic = granular MaxSim only
                temporal_weight=1.0,      # pure semantic * temporal (tune if needed)
            )
        else:
            reranked_ids = cand_ids

        # Compute metrics
        for k in k_list:
            topk = reranked_ids[:k]
            # standard MRR / Hit@k over a set of relevant doc ids
            found = 0.0
            for rank, cid in enumerate(topk):
                if str(cid) in rel_set or cid in rel_set:
                    found = 1.0 / (rank + 1)
                    break
            if found > 0:
                metrics[f"hit@{k}"] += 1.0
            metrics[f"mrr@{k}"] += found

    total = sum(1 for r in qrels if r)
    total = total if total > 0 else 1
    metrics = {k: v / total for k, v in metrics.items()}
    print(f"[EVAL-TempRAGE] Completed: {desc}")
    print("  " + " | ".join(f"{k}: {v:.4f}" for k, v in metrics.items()))
    RESULTS.append({
        "Model": desc,
        "Split": "temprageval",
        **{k.upper(): v for k, v in metrics.items()}
    })
    return metrics

# ---------- Run TempRAGEval Experiments (4 conditions) ----------

if ATLAS_CORPUS_FILES:
    # 1) Load corpus + pretokenize for MRAG windows
    print("[TempRAGE] Step 1: Loading ATLAS corpus...")
    atlas_corpus = load_atlas_corpus(ATLAS_CORPUS_FILES)

    print("[TempRAGE] Step 2: Pretokenizing corpus for MRAG windows...")
    pretokenize_passages(
        [c['title'] + ' ' + c['text'] for c in atlas_corpus],
        list(range(len(atlas_corpus)))
    )

    # 2) Load TempRAGEval (FULL), then slice for a small test
    print("[TempRAGE] Step 3: Loading TempRAGEval dataset...")
    temprage_questions_full, temprage_answers_full = load_temprageval(TEMPRAGEVAL_DATA_PATH, split="test")

    if MAX_TEMPRAGE_QUESTIONS is not None:
        temprage_questions = temprage_questions_full[:MAX_TEMPRAGE_QUESTIONS]
        temprage_answers = temprage_answers_full[:MAX_TEMPRAGE_QUESTIONS]
    else:
        temprage_questions = temprage_questions_full
        temprage_answers = temprage_answers_full

    print(f"[TempRAGE] Using {len(temprage_questions)} TempRAGEval questions for this run "
          f"(MAX_TEMPRAGE_QUESTIONS={MAX_TEMPRAGE_QUESTIONS})")

    # 3) Build qrels once (for the sliced subset)
    print("[TempRAGE] Step 4: Building qrels for selected questions...")
    temprage_qrels = build_temprageval_qrels(atlas_corpus, temprage_questions, temprage_answers)

    # 4) Build FAISS indexes for BASE & TIME-aware models
    print("[TempRAGE] Step 5: Building FAISS index (BASE Contriever)...")
    atlas_index_base = build_faiss_for_corpus(base_model, base_tokenizer, atlas_corpus)

    print("[TempRAGE] Step 6: Building FAISS index (TIME-AWARE Contriever)...")
    atlas_index_time = build_faiss_for_corpus(time_model, time_tokenizer, atlas_corpus)

    # Prepare MC-only vs full questions (on the subset)
    print("[TempRAGE] Step 7: Extracting MC-only versions of questions...")
    temprage_questions_mc = [extract_mc_temprage(q) for q in tqdm(temprage_questions, desc="MC extraction", unit="q")]

    print("[TempRAGE] Step 8: Running 4 evaluation conditions...")

    # ---- A. BASE Contriever only — MC-only RETRIEVAL ----
    print("\n[A] TempRAGEval [BASE Only, MC-only retrieval]")
    evaluate_temprageval(
        atlas_corpus,
        temprage_questions,
        temprage_answers,
        atlas_index_base,
        base_model,
        base_tokenizer,
        use_mrag=False,
        desc="TempRAGEval [BASE Only, MC-only retrieval]",
        retrieval_questions=temprage_questions_mc,
        full_questions=temprage_questions,   # unused since use_mrag=False
        qrels=temprage_qrels,
    )

    # ---- B. TIME-AWARE Contriever only — FULL question RETRIEVAL ----
    print("\n[B] TempRAGEval [TIME-AWARE Only, full-question retrieval]")
    evaluate_temprageval(
        atlas_corpus,
        temprage_questions,
        temprage_answers,
        atlas_index_time,
        time_model,
        time_tokenizer,
        use_mrag=False,
        desc="TempRAGEval [TIME-AWARE Only, full-question retrieval]",
        retrieval_questions=temprage_questions,  # full question
        full_questions=temprage_questions,
        qrels=temprage_qrels,
    )

    # ---- C. MRAG + BASE — MC-only RETRIEVAL, full question for MRAG ----
    print("\n[C] TempRAGEval [MRAG v2 + BASE, MC-only retrieval]")
    evaluate_temprageval(
        atlas_corpus,
        temprage_questions,
        temprage_answers,
        atlas_index_base,
        base_model,
        base_tokenizer,
        use_mrag=True,
        desc="TempRAGEval [MRAG v2 + BASE, MC-only retrieval]",
        use_llm_summaries=False,
        retrieval_questions=temprage_questions_mc,  # MC-only
        full_questions=temprage_questions,          # full question for temporal parse
        qrels=temprage_qrels,
    )

    # ---- D. MRAG + TIME-AWARE — FULL question RETRIEVAL & MRAG ----
    print("\n[D] TempRAGEval [MRAG v2 + TIME-AWARE, full-question retrieval]")
    evaluate_temprageval(
        atlas_corpus,
        temprage_questions,
        temprage_answers,
        atlas_index_time,
        time_model,
        time_tokenizer,
        use_mrag=True,
        desc="TempRAGEval [MRAG v2 + TIME-AWARE, full-question retrieval]",
        use_llm_summaries=False,
        retrieval_questions=temprage_questions,  # full
        full_questions=temprage_questions,      # full for MRAG as well
        qrels=temprage_qrels,
    )

else:
    print("Skip TempRAGEval: please set ATLAS_CORPUS_FILES to your custom corpus json/jsonl files.")

[TempRAGE] Step 1: Loading ATLAS corpus...
[TempRAGE] Loading corpus file: /content/infobox.jsonl
[TempRAGE] Loading JSONL file: /content/infobox.jsonl


Reading /content/infobox.jsonl: 0lines [00:00, ?lines/s]

[TempRAGE] Loaded 273 records from /content/infobox.jsonl


Indexing records from /content/infobox.jsonl:   0%|          | 0/273 [00:00<?, ?docs/s]

[TempRAGE] Loading corpus file: /content/text-list-100-sec.jsonl
[TempRAGE] Loading JSONL file: /content/text-list-100-sec.jsonl


Reading /content/text-list-100-sec.jsonl: 0lines [00:00, ?lines/s]

[TempRAGE] Loaded 219,667 records from /content/text-list-100-sec.jsonl


Indexing records from /content/text-list-100-sec.jsonl:   0%|          | 0/219667 [00:00<?, ?docs/s]

[TempRAGE] Loaded 219,940 corpus passages total.
[TempRAGE] Step 2: Pretokenizing corpus for MRAG windows...
[TempRAGE] Step 3: Loading TempRAGEval dataset...
[TempRAGE] Loading TempRAGEval from HF hub: siyue/TempRAGEval split=test


README.md:   0%|          | 0.00/2.23k [00:00<?, ?B/s]

test.csv:   0%|          | 0.00/470k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/1244 [00:00<?, ? examples/s]

[TempRAGE] Extracting questions & answers from dataset...


Reading TempRAGE examples:   0%|          | 0/1244 [00:00<?, ?ex/s]

[TempRAGE] Loaded 1244 TempRAGEval questions
[TempRAGE] Using 100 TempRAGEval questions for this run (MAX_TEMPRAGE_QUESTIONS=100)
[TempRAGE] Step 4: Building qrels for selected questions...
[TempRAGE] Building TempRAGEval qrels by answer string match...


Building qrels over questions:   0%|          | 0/100 [00:00<?, ?q/s]

[TempRAGE] Qrels built: 56/100 questions have at least 1 matching passage
[TempRAGE] Step 5: Building FAISS index (BASE Contriever)...
[TempRAGE] Encoding corpus for FAISS index...


Encoding:   0%|          | 0/3437 [00:00<?, ?it/s]

[TempRAGE] Building FAISS index (dim=768, n_docs=219,940)...
[TempRAGE] FAISS index ready.
[TempRAGE] Step 6: Building FAISS index (TIME-AWARE Contriever)...
[TempRAGE] Encoding corpus for FAISS index...


Encoding:   0%|          | 0/3437 [00:00<?, ?it/s]

[TempRAGE] Building FAISS index (dim=768, n_docs=219,940)...
[TempRAGE] FAISS index ready.
[TempRAGE] Step 7: Extracting MC-only versions of questions...


MC extraction:   0%|          | 0/100 [00:00<?, ?q/s]

[TempRAGE] Step 8: Running 4 evaluation conditions...

[A] TempRAGEval [BASE Only, MC-only retrieval]
[EVAL-TempRAGE] TempRAGEval [BASE Only, MC-only retrieval] | q=100 | ks=(1, 5, 10, 20) | top_k_candidates=100 | use_mrag=False
[TempRAGE] Encoding retrieval questions...


Encoding:   0%|          | 0/2 [00:00<?, ?it/s]

[TempRAGE] Running FAISS search for all questions...
[TempRAGE] Scoring / reranking over all questions...


Scoring (TempRAGEval [BASE Only, MC-only retrieval]):   0%|          | 0/100 [00:00<?, ?q/s]

[EVAL-TempRAGE] Completed: TempRAGEval [BASE Only, MC-only retrieval]
  hit@1: 0.4107 | hit@5: 0.7500 | hit@10: 0.9821 | hit@20: 1.0000 | mrr@1: 0.4107 | mrr@5: 0.5500 | mrr@10: 0.5853 | mrr@20: 0.5865

[B] TempRAGEval [TIME-AWARE Only, full-question retrieval]
[EVAL-TempRAGE] TempRAGEval [TIME-AWARE Only, full-question retrieval] | q=100 | ks=(1, 5, 10, 20) | top_k_candidates=100 | use_mrag=False
[TempRAGE] Encoding retrieval questions...


Encoding:   0%|          | 0/2 [00:00<?, ?it/s]

[TempRAGE] Running FAISS search for all questions...
[TempRAGE] Scoring / reranking over all questions...


Scoring (TempRAGEval [TIME-AWARE Only, full-question retrieval]):   0%|          | 0/100 [00:00<?, ?q/s]

[EVAL-TempRAGE] Completed: TempRAGEval [TIME-AWARE Only, full-question retrieval]
  hit@1: 0.3214 | hit@5: 0.8214 | hit@10: 0.9464 | hit@20: 0.9821 | mrr@1: 0.3214 | mrr@5: 0.4970 | mrr@10: 0.5131 | mrr@20: 0.5156

[C] TempRAGEval [MRAG v2 + BASE, MC-only retrieval]
[EVAL-TempRAGE] TempRAGEval [MRAG v2 + BASE, MC-only retrieval] | q=100 | ks=(1, 5, 10, 20) | top_k_candidates=100 | use_mrag=True
[TempRAGE] Encoding retrieval questions...


Encoding:   0%|          | 0/2 [00:00<?, ?it/s]

[TempRAGE] Running FAISS search for all questions...
[TempRAGE] Scoring / reranking over all questions...


Scoring (TempRAGEval [MRAG v2 + BASE, MC-only retrieval]):   0%|          | 0/100 [00:00<?, ?q/s]

[EVAL-TempRAGE] Completed: TempRAGEval [MRAG v2 + BASE, MC-only retrieval]
  hit@1: 0.4643 | hit@5: 0.8393 | hit@10: 0.8750 | hit@20: 0.8929 | mrr@1: 0.4643 | mrr@5: 0.5985 | mrr@10: 0.6035 | mrr@20: 0.6051

[D] TempRAGEval [MRAG v2 + TIME-AWARE, full-question retrieval]
[EVAL-TempRAGE] TempRAGEval [MRAG v2 + TIME-AWARE, full-question retrieval] | q=100 | ks=(1, 5, 10, 20) | top_k_candidates=100 | use_mrag=True
[TempRAGE] Encoding retrieval questions...


Encoding:   0%|          | 0/2 [00:00<?, ?it/s]

[TempRAGE] Running FAISS search for all questions...
[TempRAGE] Scoring / reranking over all questions...


Scoring (TempRAGEval [MRAG v2 + TIME-AWARE, full-question retrieval]):   0%|          | 0/100 [00:00<?, ?q/s]

[EVAL-TempRAGE] Completed: TempRAGEval [MRAG v2 + TIME-AWARE, full-question retrieval]
  hit@1: 0.5179 | hit@5: 0.8214 | hit@10: 0.8750 | hit@20: 0.9286 | mrr@1: 0.5179 | mrr@5: 0.6348 | mrr@10: 0.6426 | mrr@20: 0.6459


**Big Improvement in Hit@1 metric with TempRAGEval**

### More Testing:

In [None]:
import faiss
import os

SAVE_DIR = "/content/faiss_indexes"  # or "/content/drive/MyDrive/faiss_indexes"
os.makedirs(SAVE_DIR, exist_ok=True)

base_idx_path = os.path.join(SAVE_DIR, "atlas_index_base.faiss")
time_idx_path = os.path.join(SAVE_DIR, "atlas_index_time.faiss")

print(f"Saving BASE index to: {base_idx_path}")
faiss.write_index(atlas_index_base, base_idx_path)

print(f"Saving TIME-AWARE index to: {time_idx_path}")
faiss.write_index(atlas_index_time, time_idx_path)

print("✅ Done saving FAISS indexes.")


Saving BASE index to: /content/faiss_indexes/atlas_index_base.faiss
Saving TIME-AWARE index to: /content/faiss_indexes/atlas_index_time.faiss
✅ Done saving FAISS indexes.


In [None]:
import os
from google.colab import files

# Ensure we are in /content (optional)
os.chdir('/content')

# Zip ONLY the faiss_indexes folder
!zip -r /content/faiss_indexes.zip faiss_indexes

# Download the zip file
files.download('/content/faiss_indexes.zip')

  adding: faiss_indexes/ (stored 0%)
  adding: faiss_indexes/atlas_index_time.faiss (deflated 8%)
  adding: faiss_indexes/atlas_index_base.faiss (deflated 8%)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# ============================================
# Re-run 4 TempRAGEval experiments with 500 Qs
# (Reuses existing FAISS indexes & corpus)
# ============================================

from tqdm.auto import tqdm

TEMPRAGEVAL_TOPK = 10           # how many passages to retrieve
TEMPRAGEVAL_KS   = (1, 5, 10)   # which k's to compute metrics for

N_QUESTIONS = 500  # change as you like

# --- Safety checks ---
assert 'atlas_corpus' in globals(), "atlas_corpus not found. Run the setup cell first."
assert 'atlas_index_base' in globals(), "atlas_index_base not found. Build FAISS indexes first."
assert 'atlas_index_time' in globals(), "atlas_index_time not found. Build FAISS indexes first."
assert 'evaluate_temprageval' in globals(), "evaluate_temprageval() not defined. Run the previous cell."
assert 'extract_mc_temprage' in globals(), "extract_mc_temprage() not defined. Run the previous cell."

# Use full questions if they exist; otherwise fall back to whatever you used before
if 'temprage_questions_full' in globals():
    base_questions = temprage_questions_full
    base_answers  = temprage_answers_full
else:
    base_questions = temprage_questions
    base_answers  = temprage_answers
    print("[Warn] temprage_questions_full not found; using current 'temprage_questions' as base.")

N = min(N_QUESTIONS, len(base_questions))
print(f"[TempRAGE] Preparing subset of {N} questions (requested {N_QUESTIONS})")

temprage_questions_1k = base_questions[:N]
temprage_answers_1k   = base_answers[:N]

# --- Build qrels for this subset only ---
print(f"[TempRAGE] Building qrels for {N} subset...")
temprage_qrels_1k = build_temprageval_qrels(
    atlas_corpus,
    temprage_questions_1k,
    temprage_answers_1k
)

# --- MC-only questions for retrieval (using V1 decomposition) ---
print("[TempRAGE] Extracting MC-only questions for retrieval...")
temprage_questions_mc_1k = [
    extract_mc_temprage(q)
    for q in tqdm(temprage_questions_1k, desc=f"MC extraction ({N})", unit="q")
]

print(f"[TempRAGE] Running 4 evaluation conditions on {N} subset "
      f"with topk={TEMPRAGEVAL_TOPK}, metrics ks={TEMPRAGEVAL_KS}...")

# ---- A. BASE Contriever only — MC-only RETRIEVAL ----
print(f"\n[A] TempRAGEval-{N} [BASE Only, MC-only retrieval]")
metrics_temprage_base_1k = evaluate_temprageval(
    atlas_corpus,
    temprage_questions_1k,
    temprage_answers_1k,
    atlas_index_base,
    base_model,
    base_tokenizer,
    use_mrag=False,
    desc=f"TempRAGEval-{N} [BASE Only, MC-only retrieval, topk={TEMPRAGEVAL_TOPK}]",
    retrieval_questions=temprage_questions_mc_1k,
    full_questions=temprage_questions_1k,   # unused when use_mrag=False
    qrels=temprage_qrels_1k,
    k_list=TEMPRAGEVAL_KS,                 # <-- explicitly control ks
    top_k_candidates=TEMPRAGEVAL_TOPK,     # <-- explicitly control topk
)

# ---- B. TIME-AWARE Contriever only — FULL question RETRIEVAL ----
print(f"\n[B] TempRAGEval-{N} [TIME-AWARE Only, full-question retrieval]")
metrics_temprage_time_1k = evaluate_temprageval(
    atlas_corpus,
    temprage_questions_1k,
    temprage_answers_1k,
    atlas_index_time,
    time_model,
    time_tokenizer,
    use_mrag=False,
    desc=f"TempRAGEval-{N} [TIME-AWARE Only, full-question retrieval, topk={TEMPRAGEVAL_TOPK}]",
    retrieval_questions=temprage_questions_1k,  # full question
    full_questions=temprage_questions_1k,
    qrels=temprage_qrels_1k,
    k_list=TEMPRAGEVAL_KS,
    top_k_candidates=TEMPRAGEVAL_TOPK,
)

# ---- C. MRAG v2 + BASE — MC-only RETRIEVAL, full question for MRAG ----
print(f"\n[C] TempRAGEval-{N} [MRAG v2 + BASE, MC-only retrieval]")
metrics_temprage_mrag_base_1k = evaluate_temprageval(
    atlas_corpus,
    temprage_questions_1k,
    temprage_answers_1k,
    atlas_index_base,
    base_model,
    base_tokenizer,
    use_mrag=True,
    desc=f"TempRAGEval-{N} [MRAG v2 + BASE, MC-only retrieval, topk={TEMPRAGEVAL_TOPK}]",
    use_llm_summaries=False,
    retrieval_questions=temprage_questions_mc_1k,  # MC-only retrieval
    full_questions=temprage_questions_1k,          # full question for temporal parsing
    qrels=temprage_qrels_1k,
    k_list=TEMPRAGEVAL_KS,
    top_k_candidates=TEMPRAGEVAL_TOPK,
)

# ---- D. MRAG v2 + TIME-AWARE — FULL question RETRIEVAL & MRAG ----
print(f"\n[D] TempRAGEval-{N} [MRAG v2 + TIME-AWARE, full-question retrieval]")
metrics_temprage_mrag_time_1k = evaluate_temprageval(
    atlas_corpus,
    temprage_questions_1k,
    temprage_answers_1k,
    atlas_index_time,
    time_model,
    time_tokenizer,
    use_mrag=True,
    desc=f"TempRAGEval-{N} [MRAG v2 + TIME-AWARE, full-question retrieval, topk={TEMPRAGEVAL_TOPK}]",
    use_llm_summaries=False,
    retrieval_questions=temprage_questions_1k,  # full-question retrieval
    full_questions=temprage_questions_1k,       # full question for MRAG
    qrels=temprage_qrels_1k,
    k_list=TEMPRAGEVAL_KS,
    top_k_candidates=TEMPRAGEVAL_TOPK,
)

print(f"\n[TempRAGE] Done with {N}-question experiments (topk={TEMPRAGEVAL_TOPK}, ks={TEMPRAGEVAL_KS}).")


[TempRAGE] Preparing subset of 500 questions (requested 500)
[TempRAGE] Building qrels for 500 subset...
[TempRAGE] Building TempRAGEval qrels by answer string match...


Building qrels over questions:   0%|          | 0/500 [00:00<?, ?q/s]

[TempRAGE] Qrels built: 233/500 questions have at least 1 matching passage
[TempRAGE] Extracting MC-only questions for retrieval...


MC extraction (500):   0%|          | 0/500 [00:00<?, ?q/s]

[TempRAGE] Running 4 evaluation conditions on 500 subset with topk=10, metrics ks=(1, 5, 10)...

[A] TempRAGEval-500 [BASE Only, MC-only retrieval]
[EVAL-TempRAGE] TempRAGEval-500 [BASE Only, MC-only retrieval, topk=10] | q=500 | ks=(1, 5, 10) | top_k_candidates=10 | use_mrag=False
[TempRAGE] Encoding retrieval questions...


Encoding:   0%|          | 0/8 [00:00<?, ?it/s]

[TempRAGE] Running FAISS search for all questions...
[TempRAGE] Scoring / reranking over all questions...


Scoring (TempRAGEval-500 [BASE Only, MC-only retrieval, topk=10]):   0%|          | 0/500 [00:00<?, ?q/s]

[EVAL-TempRAGE] Completed: TempRAGEval-500 [BASE Only, MC-only retrieval, topk=10]
  hit@1: 0.4034 | hit@5: 0.7339 | hit@10: 0.9099 | mrr@1: 0.4034 | mrr@5: 0.5273 | mrr@10: 0.5507

[B] TempRAGEval-500 [TIME-AWARE Only, full-question retrieval]
[EVAL-TempRAGE] TempRAGEval-500 [TIME-AWARE Only, full-question retrieval, topk=10] | q=500 | ks=(1, 5, 10) | top_k_candidates=10 | use_mrag=False
[TempRAGE] Encoding retrieval questions...


Encoding:   0%|          | 0/8 [00:00<?, ?it/s]

[TempRAGE] Running FAISS search for all questions...
[TempRAGE] Scoring / reranking over all questions...


Scoring (TempRAGEval-500 [TIME-AWARE Only, full-question retrieval, topk=10]):   0%|          | 0/500 [00:00<?…

[EVAL-TempRAGE] Completed: TempRAGEval-500 [TIME-AWARE Only, full-question retrieval, topk=10]
  hit@1: 0.3262 | hit@5: 0.6652 | hit@10: 0.8069 | mrr@1: 0.3262 | mrr@5: 0.4524 | mrr@10: 0.4715

[C] TempRAGEval-500 [MRAG v2 + BASE, MC-only retrieval]
[EVAL-TempRAGE] TempRAGEval-500 [MRAG v2 + BASE, MC-only retrieval, topk=10] | q=500 | ks=(1, 5, 10) | top_k_candidates=10 | use_mrag=True
[TempRAGE] Encoding retrieval questions...


Encoding:   0%|          | 0/8 [00:00<?, ?it/s]

[TempRAGE] Running FAISS search for all questions...
[TempRAGE] Scoring / reranking over all questions...


Scoring (TempRAGEval-500 [MRAG v2 + BASE, MC-only retrieval, topk=10]):   0%|          | 0/500 [00:00<?, ?q/s]

[EVAL-TempRAGE] Completed: TempRAGEval-500 [MRAG v2 + BASE, MC-only retrieval, topk=10]
  hit@1: 0.4378 | hit@5: 0.7554 | hit@10: 0.9099 | mrr@1: 0.4378 | mrr@5: 0.5558 | mrr@10: 0.5761

[D] TempRAGEval-500 [MRAG v2 + TIME-AWARE, full-question retrieval]
[EVAL-TempRAGE] TempRAGEval-500 [MRAG v2 + TIME-AWARE, full-question retrieval, topk=10] | q=500 | ks=(1, 5, 10) | top_k_candidates=10 | use_mrag=True
[TempRAGE] Encoding retrieval questions...


Encoding:   0%|          | 0/8 [00:00<?, ?it/s]

[TempRAGE] Running FAISS search for all questions...
[TempRAGE] Scoring / reranking over all questions...


Scoring (TempRAGEval-500 [MRAG v2 + TIME-AWARE, full-question retrieval, topk=10]):   0%|          | 0/500 [00…

[EVAL-TempRAGE] Completed: TempRAGEval-500 [MRAG v2 + TIME-AWARE, full-question retrieval, topk=10]
  hit@1: 0.4506 | hit@5: 0.7082 | hit@10: 0.8069 | mrr@1: 0.4506 | mrr@5: 0.5456 | mrr@10: 0.5584

[TempRAGE] Done with 500-question experiments (topk=10, ks=(1, 5, 10)).


In [None]:
# ============================================
# Re-run 4 TempRAGEval experiments with 500 Qs
# (Reuses existing FAISS indexes & corpus)
# ============================================

from tqdm.auto import tqdm

TEMPRAGEVAL_TOPK = 20           # how many passages to retrieve
TEMPRAGEVAL_KS   = (1, 5, 10, 20)   # which k's to compute metrics for

N_QUESTIONS = 1244  # change as you like

# --- Safety checks ---
assert 'atlas_corpus' in globals(), "atlas_corpus not found. Run the setup cell first."
assert 'atlas_index_base' in globals(), "atlas_index_base not found. Build FAISS indexes first."
assert 'atlas_index_time' in globals(), "atlas_index_time not found. Build FAISS indexes first."
assert 'evaluate_temprageval' in globals(), "evaluate_temprageval() not defined. Run the previous cell."
assert 'extract_mc_temprage' in globals(), "extract_mc_temprage() not defined. Run the previous cell."

# Use full questions if they exist; otherwise fall back to whatever you used before
if 'temprage_questions_full' in globals():
    base_questions = temprage_questions_full
    base_answers  = temprage_answers_full
else:
    base_questions = temprage_questions
    base_answers  = temprage_answers
    print("[Warn] temprage_questions_full not found; using current 'temprage_questions' as base.")

N = min(N_QUESTIONS, len(base_questions))
print(f"[TempRAGE] Preparing subset of {N} questions (requested {N_QUESTIONS})")

temprage_questions_1k = base_questions[:N]
temprage_answers_1k   = base_answers[:N]

# --- Build qrels for this subset only ---
print(f"[TempRAGE] Building qrels for {N} subset...")
temprage_qrels_1k = build_temprageval_qrels(
    atlas_corpus,
    temprage_questions_1k,
    temprage_answers_1k
)

# --- MC-only questions for retrieval (using V1 decomposition) ---
print("[TempRAGE] Extracting MC-only questions for retrieval...")
temprage_questions_mc_1k = [
    extract_mc_temprage(q)
    for q in tqdm(temprage_questions_1k, desc=f"MC extraction ({N})", unit="q")
]

print(f"[TempRAGE] Running 4 evaluation conditions on {N} subset "
      f"with topk={TEMPRAGEVAL_TOPK}, metrics ks={TEMPRAGEVAL_KS}...")

# ---- A. BASE Contriever only — MC-only RETRIEVAL ----
print(f"\n[A] TempRAGEval-{N} [BASE Only, MC-only retrieval]")
metrics_temprage_base_1k = evaluate_temprageval(
    atlas_corpus,
    temprage_questions_1k,
    temprage_answers_1k,
    atlas_index_base,
    base_model,
    base_tokenizer,
    use_mrag=False,
    desc=f"TempRAGEval-{N} [BASE Only, MC-only retrieval, topk={TEMPRAGEVAL_TOPK}]",
    retrieval_questions=temprage_questions_mc_1k,
    full_questions=temprage_questions_1k,   # unused when use_mrag=False
    qrels=temprage_qrels_1k,
    k_list=TEMPRAGEVAL_KS,                 # <-- explicitly control ks
    top_k_candidates=TEMPRAGEVAL_TOPK,     # <-- explicitly control topk
)

# ---- B. TIME-AWARE Contriever only — FULL question RETRIEVAL ----
print(f"\n[B] TempRAGEval-{N} [TIME-AWARE Only, full-question retrieval]")
metrics_temprage_time_1k = evaluate_temprageval(
    atlas_corpus,
    temprage_questions_1k,
    temprage_answers_1k,
    atlas_index_time,
    time_model,
    time_tokenizer,
    use_mrag=False,
    desc=f"TempRAGEval-{N} [TIME-AWARE Only, full-question retrieval, topk={TEMPRAGEVAL_TOPK}]",
    retrieval_questions=temprage_questions_1k,  # full question
    full_questions=temprage_questions_1k,
    qrels=temprage_qrels_1k,
    k_list=TEMPRAGEVAL_KS,
    top_k_candidates=TEMPRAGEVAL_TOPK,
)

# ---- C. MRAG v2 + BASE — MC-only RETRIEVAL, full question for MRAG ----
print(f"\n[C] TempRAGEval-{N} [MRAG v2 + BASE, MC-only retrieval]")
metrics_temprage_mrag_base_1k = evaluate_temprageval(
    atlas_corpus,
    temprage_questions_1k,
    temprage_answers_1k,
    atlas_index_base,
    base_model,
    base_tokenizer,
    use_mrag=True,
    desc=f"TempRAGEval-{N} [MRAG v2 + BASE, MC-only retrieval, topk={TEMPRAGEVAL_TOPK}]",
    use_llm_summaries=False,
    retrieval_questions=temprage_questions_mc_1k,  # MC-only retrieval
    full_questions=temprage_questions_1k,          # full question for temporal parsing
    qrels=temprage_qrels_1k,
    k_list=TEMPRAGEVAL_KS,
    top_k_candidates=TEMPRAGEVAL_TOPK,
)

# ---- D. MRAG v2 + TIME-AWARE — FULL question RETRIEVAL & MRAG ----
print(f"\n[D] TempRAGEval-{N} [MRAG v2 + TIME-AWARE, full-question retrieval]")
metrics_temprage_mrag_time_1k = evaluate_temprageval(
    atlas_corpus,
    temprage_questions_1k,
    temprage_answers_1k,
    atlas_index_time,
    time_model,
    time_tokenizer,
    use_mrag=True,
    desc=f"TempRAGEval-{N} [MRAG v2 + TIME-AWARE, full-question retrieval, topk={TEMPRAGEVAL_TOPK}]",
    use_llm_summaries=False,
    retrieval_questions=temprage_questions_1k,  # full-question retrieval
    full_questions=temprage_questions_1k,       # full question for MRAG
    qrels=temprage_qrels_1k,
    k_list=TEMPRAGEVAL_KS,
    top_k_candidates=TEMPRAGEVAL_TOPK,
)

print(f"\n[TempRAGE] Done with {N}-question experiments (topk={TEMPRAGEVAL_TOPK}, ks={TEMPRAGEVAL_KS}).")


[TempRAGE] Preparing subset of 1244 questions (requested 1244)
[TempRAGE] Building qrels for 1244 subset...
[TempRAGE] Building TempRAGEval qrels by answer string match...


Building qrels over questions:   0%|          | 0/1244 [00:00<?, ?q/s]

[TempRAGE] Qrels built: 736/1244 questions have at least 1 matching passage
[TempRAGE] Extracting MC-only questions for retrieval...


MC extraction (1244):   0%|          | 0/1244 [00:00<?, ?q/s]

[TempRAGE] Running 4 evaluation conditions on 1244 subset with topk=20, metrics ks=(1, 5, 10, 20)...

[A] TempRAGEval-1244 [BASE Only, MC-only retrieval]
[EVAL-TempRAGE] TempRAGEval-1244 [BASE Only, MC-only retrieval, topk=20] | q=1244 | ks=(1, 5, 10, 20) | top_k_candidates=20 | use_mrag=False
[TempRAGE] Encoding retrieval questions...


Encoding:   0%|          | 0/20 [00:00<?, ?it/s]

[TempRAGE] Running FAISS search for all questions...
[TempRAGE] Scoring / reranking over all questions...


Scoring (TempRAGEval-1244 [BASE Only, MC-only retrieval, topk=20]):   0%|          | 0/1244 [00:00<?, ?q/s]

[EVAL-TempRAGE] Completed: TempRAGEval-1244 [BASE Only, MC-only retrieval, topk=20]
  hit@1: 0.3329 | hit@5: 0.7283 | hit@10: 0.8872 | hit@20: 0.9429 | mrr@1: 0.3329 | mrr@5: 0.4780 | mrr@10: 0.4997 | mrr@20: 0.5038

[B] TempRAGEval-1244 [TIME-AWARE Only, full-question retrieval]
[EVAL-TempRAGE] TempRAGEval-1244 [TIME-AWARE Only, full-question retrieval, topk=20] | q=1244 | ks=(1, 5, 10, 20) | top_k_candidates=20 | use_mrag=False
[TempRAGE] Encoding retrieval questions...


Encoding:   0%|          | 0/20 [00:00<?, ?it/s]

[TempRAGE] Running FAISS search for all questions...
[TempRAGE] Scoring / reranking over all questions...


Scoring (TempRAGEval-1244 [TIME-AWARE Only, full-question retrieval, topk=20]):   0%|          | 0/1244 [00:00…

[EVAL-TempRAGE] Completed: TempRAGEval-1244 [TIME-AWARE Only, full-question retrieval, topk=20]
  hit@1: 0.2880 | hit@5: 0.6834 | hit@10: 0.8315 | hit@20: 0.9049 | mrr@1: 0.2880 | mrr@5: 0.4351 | mrr@10: 0.4554 | mrr@20: 0.4607

[C] TempRAGEval-1244 [MRAG v2 + BASE, MC-only retrieval]
[EVAL-TempRAGE] TempRAGEval-1244 [MRAG v2 + BASE, MC-only retrieval, topk=20] | q=1244 | ks=(1, 5, 10, 20) | top_k_candidates=20 | use_mrag=True
[TempRAGE] Encoding retrieval questions...


Encoding:   0%|          | 0/20 [00:00<?, ?it/s]

[TempRAGE] Running FAISS search for all questions...
[TempRAGE] Scoring / reranking over all questions...


Scoring (TempRAGEval-1244 [MRAG v2 + BASE, MC-only retrieval, topk=20]):   0%|          | 0/1244 [00:00<?, ?q/…

[EVAL-TempRAGE] Completed: TempRAGEval-1244 [MRAG v2 + BASE, MC-only retrieval, topk=20]
  hit@1: 0.4429 | hit@5: 0.7976 | hit@10: 0.8832 | hit@20: 0.9429 | mrr@1: 0.4429 | mrr@5: 0.5740 | mrr@10: 0.5853 | mrr@20: 0.5898

[D] TempRAGEval-1244 [MRAG v2 + TIME-AWARE, full-question retrieval]
[EVAL-TempRAGE] TempRAGEval-1244 [MRAG v2 + TIME-AWARE, full-question retrieval, topk=20] | q=1244 | ks=(1, 5, 10, 20) | top_k_candidates=20 | use_mrag=True
[TempRAGE] Encoding retrieval questions...


Encoding:   0%|          | 0/20 [00:00<?, ?it/s]

[TempRAGE] Running FAISS search for all questions...
[TempRAGE] Scoring / reranking over all questions...


Scoring (TempRAGEval-1244 [MRAG v2 + TIME-AWARE, full-question retrieval, topk=20]):   0%|          | 0/1244 […

[EVAL-TempRAGE] Completed: TempRAGEval-1244 [MRAG v2 + TIME-AWARE, full-question retrieval, topk=20]
  hit@1: 0.4348 | hit@5: 0.7310 | hit@10: 0.8125 | hit@20: 0.9049 | mrr@1: 0.4348 | mrr@5: 0.5455 | mrr@10: 0.5568 | mrr@20: 0.5635

[TempRAGE] Done with 1244-question experiments (topk=20, ks=(1, 5, 10, 20)).


In [149]:
# ============================================
# Re-run 4 TempRAGEval experiments with 500 Qs
# (Reuses existing FAISS indexes & corpus)
# ============================================

from tqdm.auto import tqdm

TEMPRAGEVAL_TOPK = 10           # how many passages to retrieve
TEMPRAGEVAL_KS   = (1, 5, 10)   # which k's to compute metrics for

N_QUESTIONS = 700  # change as you like

# --- Safety checks ---
assert 'atlas_corpus' in globals(), "atlas_corpus not found. Run the setup cell first."
assert 'atlas_index_base' in globals(), "atlas_index_base not found. Build FAISS indexes first."
assert 'atlas_index_time' in globals(), "atlas_index_time not found. Build FAISS indexes first."
assert 'evaluate_temprageval' in globals(), "evaluate_temprageval() not defined. Run the previous cell."
assert 'extract_mc_temprage' in globals(), "extract_mc_temprage() not defined. Run the previous cell."

# Use full questions if they exist; otherwise fall back to whatever you used before
if 'temprage_questions_full' in globals():
    base_questions = temprage_questions_full
    base_answers  = temprage_answers_full
else:
    base_questions = temprage_questions
    base_answers  = temprage_answers
    print("[Warn] temprage_questions_full not found; using current 'temprage_questions' as base.")

N = min(N_QUESTIONS, len(base_questions))
print(f"[TempRAGE] Preparing subset of {N} questions (requested {N_QUESTIONS})")

temprage_questions_1k = base_questions[:N]
temprage_answers_1k   = base_answers[:N]

# --- Build qrels for this subset only ---
print(f"[TempRAGE] Building qrels for {N} subset...")
temprage_qrels_1k = build_temprageval_qrels(
    atlas_corpus,
    temprage_questions_1k,
    temprage_answers_1k
)

# --- MC-only questions for retrieval (using V1 decomposition) ---
print("[TempRAGE] Extracting MC-only questions for retrieval...")
temprage_questions_mc_1k = [
    extract_mc_temprage(q)
    for q in tqdm(temprage_questions_1k, desc=f"MC extraction ({N})", unit="q")
]

print(f"[TempRAGE] Running 4 evaluation conditions on {N} subset "
      f"with topk={TEMPRAGEVAL_TOPK}, metrics ks={TEMPRAGEVAL_KS}...")

# ---- A. BASE Contriever only — MC-only RETRIEVAL ----
print(f"\n[A] TempRAGEval-{N} [BASE Only, MC-only retrieval]")
metrics_temprage_base_1k = evaluate_temprageval(
    atlas_corpus,
    temprage_questions_1k,
    temprage_answers_1k,
    atlas_index_base,
    base_model,
    base_tokenizer,
    use_mrag=False,
    desc=f"TempRAGEval-{N} [BASE Only, MC-only retrieval, topk={TEMPRAGEVAL_TOPK}]",
    retrieval_questions=temprage_questions_mc_1k,
    full_questions=temprage_questions_1k,   # unused when use_mrag=False
    qrels=temprage_qrels_1k,
    k_list=TEMPRAGEVAL_KS,                 # <-- explicitly control ks
    top_k_candidates=TEMPRAGEVAL_TOPK,     # <-- explicitly control topk
)

# ---- B. TIME-AWARE Contriever only — FULL question RETRIEVAL ----
print(f"\n[B] TempRAGEval-{N} [TIME-AWARE Only, full-question retrieval]")
metrics_temprage_time_1k = evaluate_temprageval(
    atlas_corpus,
    temprage_questions_1k,
    temprage_answers_1k,
    atlas_index_time,
    time_model,
    time_tokenizer,
    use_mrag=False,
    desc=f"TempRAGEval-{N} [TIME-AWARE Only, full-question retrieval, topk={TEMPRAGEVAL_TOPK}]",
    retrieval_questions=temprage_questions_1k,  # full question
    full_questions=temprage_questions_1k,
    qrels=temprage_qrels_1k,
    k_list=TEMPRAGEVAL_KS,
    top_k_candidates=TEMPRAGEVAL_TOPK,
)

# ---- C. MRAG v2 + BASE — MC-only RETRIEVAL, full question for MRAG ----
print(f"\n[C] TempRAGEval-{N} [MRAG v2 + BASE, MC-only retrieval]")
metrics_temprage_mrag_base_1k = evaluate_temprageval(
    atlas_corpus,
    temprage_questions_1k,
    temprage_answers_1k,
    atlas_index_base,
    base_model,
    base_tokenizer,
    use_mrag=True,
    desc=f"TempRAGEval-{N} [MRAG v2 + BASE, MC-only retrieval, topk={TEMPRAGEVAL_TOPK}]",
    use_llm_summaries=False,
    retrieval_questions=temprage_questions_mc_1k,  # MC-only retrieval
    full_questions=temprage_questions_1k,          # full question for temporal parsing
    qrels=temprage_qrels_1k,
    k_list=TEMPRAGEVAL_KS,
    top_k_candidates=TEMPRAGEVAL_TOPK,
)

# ---- D. MRAG v2 + TIME-AWARE — FULL question RETRIEVAL & MRAG ----
print(f"\n[D] TempRAGEval-{N} [MRAG v2 + TIME-AWARE, full-question retrieval]")
metrics_temprage_mrag_time_1k = evaluate_temprageval(
    atlas_corpus,
    temprage_questions_1k,
    temprage_answers_1k,
    atlas_index_time,
    time_model,
    time_tokenizer,
    use_mrag=True,
    desc=f"TempRAGEval-{N} [MRAG v2 + TIME-AWARE, full-question retrieval, topk={TEMPRAGEVAL_TOPK}]",
    use_llm_summaries=False,
    retrieval_questions=temprage_questions_1k,  # full-question retrieval
    full_questions=temprage_questions_1k,       # full question for MRAG
    qrels=temprage_qrels_1k,
    k_list=TEMPRAGEVAL_KS,
    top_k_candidates=TEMPRAGEVAL_TOPK,
)

print(f"\n[TempRAGE] Done with {N}-question experiments (topk={TEMPRAGEVAL_TOPK}, ks={TEMPRAGEVAL_KS}).")


[TempRAGE] Preparing subset of 700 questions (requested 700)
[TempRAGE] Building qrels for 700 subset...
[TempRAGE] Building TempRAGEval qrels by answer string match...


Building qrels over questions:   0%|          | 0/700 [00:00<?, ?q/s]

[TempRAGE] Qrels built: 337/700 questions have at least 1 matching passage
[TempRAGE] Extracting MC-only questions for retrieval...


MC extraction (700):   0%|          | 0/700 [00:00<?, ?q/s]

[TempRAGE] Running 4 evaluation conditions on 700 subset with topk=10, metrics ks=(1, 5, 10)...

[A] TempRAGEval-700 [BASE Only, MC-only retrieval]
[EVAL-TempRAGE] TempRAGEval-700 [BASE Only, MC-only retrieval, topk=10] | q=700 | ks=(1, 5, 10) | top_k_candidates=10 | use_mrag=False
[TempRAGE] Encoding retrieval questions...


Encoding:   0%|          | 0/11 [00:00<?, ?it/s]

[TempRAGE] Running FAISS search for all questions...
[TempRAGE] Scoring / reranking over all questions...


Scoring (TempRAGEval-700 [BASE Only, MC-only retrieval, topk=10]):   0%|          | 0/700 [00:00<?, ?q/s]

[EVAL-TempRAGE] Completed: TempRAGEval-700 [BASE Only, MC-only retrieval, topk=10]
  hit@1: 0.3887 | hit@5: 0.7507 | hit@10: 0.9110 | mrr@1: 0.3887 | mrr@5: 0.5213 | mrr@10: 0.5433

[B] TempRAGEval-700 [TIME-AWARE Only, full-question retrieval]
[EVAL-TempRAGE] TempRAGEval-700 [TIME-AWARE Only, full-question retrieval, topk=10] | q=700 | ks=(1, 5, 10) | top_k_candidates=10 | use_mrag=False
[TempRAGE] Encoding retrieval questions...


Encoding:   0%|          | 0/11 [00:00<?, ?it/s]

[TempRAGE] Running FAISS search for all questions...
[TempRAGE] Scoring / reranking over all questions...


Scoring (TempRAGEval-700 [TIME-AWARE Only, full-question retrieval, topk=10]):   0%|          | 0/700 [00:00<?…

[EVAL-TempRAGE] Completed: TempRAGEval-700 [TIME-AWARE Only, full-question retrieval, topk=10]
  hit@1: 0.3056 | hit@5: 0.6736 | hit@10: 0.8249 | mrr@1: 0.3056 | mrr@5: 0.4452 | mrr@10: 0.4657

[C] TempRAGEval-700 [MRAG v2 + BASE, MC-only retrieval]
[EVAL-TempRAGE] TempRAGEval-700 [MRAG v2 + BASE, MC-only retrieval, topk=10] | q=700 | ks=(1, 5, 10) | top_k_candidates=10 | use_mrag=True
[TempRAGE] Encoding retrieval questions...


Encoding:   0%|          | 0/11 [00:00<?, ?it/s]

[TempRAGE] Running FAISS search for all questions...
[TempRAGE] Scoring / reranking over all questions...


Scoring (TempRAGEval-700 [MRAG v2 + BASE, MC-only retrieval, topk=10]):   0%|          | 0/700 [00:00<?, ?q/s]

[EVAL-TempRAGE] Completed: TempRAGEval-700 [MRAG v2 + BASE, MC-only retrieval, topk=10]
  hit@1: 0.4659 | hit@5: 0.7804 | hit@10: 0.9110 | mrr@1: 0.4659 | mrr@5: 0.5835 | mrr@10: 0.6009

[D] TempRAGEval-700 [MRAG v2 + TIME-AWARE, full-question retrieval]
[EVAL-TempRAGE] TempRAGEval-700 [MRAG v2 + TIME-AWARE, full-question retrieval, topk=10] | q=700 | ks=(1, 5, 10) | top_k_candidates=10 | use_mrag=True
[TempRAGE] Encoding retrieval questions...


Encoding:   0%|          | 0/11 [00:00<?, ?it/s]

[TempRAGE] Running FAISS search for all questions...
[TempRAGE] Scoring / reranking over all questions...


Scoring (TempRAGEval-700 [MRAG v2 + TIME-AWARE, full-question retrieval, topk=10]):   0%|          | 0/700 [00…

[EVAL-TempRAGE] Completed: TempRAGEval-700 [MRAG v2 + TIME-AWARE, full-question retrieval, topk=10]
  hit@1: 0.4748 | hit@5: 0.7359 | hit@10: 0.8249 | mrr@1: 0.4748 | mrr@5: 0.5685 | mrr@10: 0.5802

[TempRAGE] Done with 700-question experiments (topk=10, ks=(1, 5, 10)).
