In [1]:
!pip install torch faiss-cpu tqdm

!pip install pyarrow==14.0.1 datasets==2.14.6 transformers==4.35.2 accelerate==0.24.1

!pip -q install ipywidgets



In [2]:
import torch
import logging
from transformers import (
    RagTokenizer,
    RagRetriever,
    RagSequenceForGeneration,
    RagTokenForGeneration
)
from datasets import load_dataset
from tqdm import tqdm
import numpy as np

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {device}")

logging.getLogger("transformers").setLevel(logging.ERROR)

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


Running on: cuda


  _torch_pytree._register_pytree_node(


In [3]:
import string
import re

def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def exact_match_score(prediction, ground_truth):
    """Check if the prediction exactly matches the ground truth."""
    return (normalize_answer(prediction) == normalize_answer(ground_truth))

def calculate_em(predictions, references):
    """
    predictions: list of strings
    references: list of lists of strings (since one question can have multiple valid answers)
    """
    total_em = 0
    for pred, refs in zip(predictions, references):
        # If the prediction matches ANY of the valid references, it's a hit
        if any(exact_match_score(pred, gt) for gt in refs):
            total_em += 1

    return 100 * (total_em / len(predictions))

In [4]:
class RAGModelManager:
    def __init__(
        self,
        model_name,
        rag_type,  # "sequence" or "token"
        n_docs,
        use_dummy=False,
        index_name="exact",
        num_beams=1,
        max_new_tokens=16,
    ):
        assert rag_type in ["sequence", "token"]
        self.model_name = model_name
        self.rag_type = rag_type

        self.n_docs = n_docs
        self.num_beams = num_beams
        self.max_new_tokens = max_new_tokens

        print(f"Loading Model: {model_name} ({rag_type})...")

        # Load Tokenizer
        self.tokenizer = RagTokenizer.from_pretrained(model_name)

        # Load Retriever
        self.retriever = RagRetriever.from_pretrained(
            model_name,
            index_name=index_name,
            use_dummy_dataset=use_dummy
        )

        # Load RAG Model
        if rag_type == "token":
            self.model = RagTokenForGeneration.from_pretrained(
                model_name,
                retriever=self.retriever
            ).to(device)
        else:
            self.model = RagSequenceForGeneration.from_pretrained(
                model_name,
                retriever=self.retriever
            ).to(device)

        self.model.config.n_docs = self.n_docs

        self.model.eval()
        print("Model loaded successfully.")
        print(f"Configured n_docs={self.model.config.n_docs}, num_beams={self.num_beams}, max_new_tokens={self.max_new_tokens}")

    def set_n_docs(self, n_docs: int):
        """
        Controls K retrieved docs used by RAG at generation time.
        """
        self.n_docs = n_docs
        self.model.config.n_docs = n_docs
        if hasattr(self.retriever, "n_docs"):
            self.retriever.n_docs = n_docs
        if hasattr(self.retriever, "config"):
            self.retriever.config.n_docs = n_docs

    def generate_answers(self, questions, batch_size=4):
        self.model.eval()
        all_answers = []

        for i in tqdm(range(0, len(questions), batch_size), desc="Generating"):
            batch_questions = questions[i: i + batch_size]

            # Tokenize
            inputs = self.tokenizer(
                batch_questions,
                return_tensors="pt",
                padding=True,
                truncation=True
            ).to(device)

            # Generate
            with torch.no_grad():
                generated_ids = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    n_docs=self.n_docs,
                    num_beams=self.num_beams,
                    do_sample=False,
                    max_new_tokens=self.max_new_tokens,
                    min_length=1,
                    early_stopping=True,
                )

            # Decode
            batch_answers = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            all_answers.extend(batch_answers)

        return all_answers

In [5]:
def run_experiment(task_name, model_manager, num_samples=None):
    print(f"\n{'='*40}")
    print(f"STARTING EXPERIMENT: {task_name}")
    print(f"{'='*40}")

    # DATA LOADING
    questions = []
    answers = []

    if task_name == "Natural Questions":
        # Load NQ Open (Simplified version)
        dataset = load_dataset("nq_open", split="validation")

        print("Processing Natural Questions dataset...")
        for row in tqdm(dataset, desc="Loading Data"):
            q = row['question']
            ans_list = row['answer']
            if ans_list:
                questions.append(q)
                answers.append(ans_list)

            if num_samples and len(questions) >= num_samples:
                break

    elif task_name == "WebQuestions":
        # Load WebQuestions
        dataset = load_dataset("stanfordnlp/web_questions", split="test")

        print("Processing WebQuestions dataset...")
        for row in tqdm(dataset, desc="Loading Data"):
            questions.append(row['question'])
            answers.append(row['answers'])

            if num_samples and len(questions) >= num_samples: 
                break

    print(f"Loaded {len(questions)} TOTAL samples for {task_name}.")

    # EXECUTION
    predictions = model_manager.generate_answers(questions, batch_size=8)

    # EVALUATION
    score = calculate_em(predictions, answers)

    print(f"\nRESULTS for {task_name}:")
    print(f"Exact Match (EM): {score:.2f}%")

    # Save results to a file
    with open(f"{task_name.replace(' ', '_')}_results.txt", "w") as f:
        f.write(f"Task: {task_name}\n")
        f.write(f"Samples: {len(questions)}\n")
        f.write(f"EM Score: {score:.2f}%\n")

    return score


In [6]:
# CONFIGURATION
# TRUE = Runs fast (prototype), Score will be ~0%
# FALSE = Downloads 75GB index
USE_DUMMY = False

# Set to None to run EVERYTHING
# Set to integer (e.g., 50) for testing
NUM_SAMPLES = None

INDEX_NAME = "exact"

In [7]:
# ============================================================
# Task: Natural Questions (NQ Open) — evaluate Exact Match (EM)
# Model: RAG-Sequence NQ checkpoint (facebook/rag-sequence-nq)
# Paper mapping: Table 1 "RAG-Seq." on NQ; key knob is K retrieved docs (n_docs)
# Goal: Reproduce a comparable EM score trend using HuggingFace’s released checkpoint
# ============================================================
rag_seq_nq = RAGModelManager(
    model_name="facebook/rag-sequence-nq",
    rag_type="sequence",
    n_docs=15,  # Paper-style setting: RAG-Sequence often benefits from larger K at test time
    use_dummy=USE_DUMMY,
    index_name=INDEX_NAME,  # "exact" uses the built-in Wikipedia DPR index format
    num_beams=1,  # Greedy decoding (paper notes greedy was sufficient for QA)
    max_new_tokens=16  # QA answers are short; keep generation bounded for speed/consistency
)

print("Experiment: Natural Questions (RAG-Sequence baseline reproduction)")
nq_score_seq = run_experiment("Natural Questions", rag_seq_nq, num_samples=NUM_SAMPLES)

Loading Model: facebook/rag-sequence-nq (sequence)...


  table = cls._concat_blocks(blocks, axis=0)
  _torch_pytree._register_pytree_node(


Model loaded successfully.
Configured n_docs=15, num_beams=1, max_new_tokens=16
Experiment: Natural Questions (RAG-Sequence baseline reproduction)

STARTING EXPERIMENT: Natural Questions
Processing Natural Questions dataset...


Loading Data: 100%|██████████| 3610/3610 [00:00<00:00, 37666.20it/s]


Loaded 3610 TOTAL samples for Natural Questions.


Generating: 100%|██████████| 452/452 [4:10:30<00:00, 33.25s/it]  


RESULTS for Natural Questions:
Exact Match (EM): 38.31%





In [8]:
# ============================================================
# Task: Natural Questions (NQ Open) — evaluate Exact Match (EM)
# Model: RAG-Token NQ checkpoint (facebook/rag-token-nq)
# Paper mapping: Table 1 “RAG-Token” on NQ; key knob is K retrieved docs (n_docs)
# Goal: Reproduce the baseline RAG-Token behavior and EM score with HF checkpoint
# ============================================================
rag_tok_nq = RAGModelManager(
    model_name="facebook/rag-token-nq",
    rag_type="token",
    n_docs=15,  # Paper-style setting: RAG-Token commonly reported with ~15 docs for NQ test
    use_dummy=USE_DUMMY,
    index_name=INDEX_NAME,
    num_beams=1,  # Greedy decoding for QA
    max_new_tokens=16  # Keep outputs short; aligns with EM evaluation
)

print("Experiment: Natural Questions (RAG-Token baseline reproduction)")
nq_score_tok = run_experiment("Natural Questions", rag_tok_nq, num_samples=NUM_SAMPLES)

Loading Model: facebook/rag-token-nq (token)...




Model loaded successfully.
Configured n_docs=15, num_beams=1, max_new_tokens=16
Experiment: Natural Questions (RAG-Token baseline reproduction)

STARTING EXPERIMENT: Natural Questions
Processing Natural Questions dataset...


Loading Data: 100%|██████████| 3610/3610 [00:00<00:00, 47529.09it/s]


Loaded 3610 TOTAL samples for Natural Questions.


Generating: 100%|██████████| 452/452 [29:51<00:00,  3.96s/it]


RESULTS for Natural Questions:
Exact Match (EM): 39.86%





In [9]:
# ============================================================
# Task: WebQuestions (WQ) — evaluate Exact Match (EM)
# Model: Fine-tuned RAG-Sequence checkpoint (local folder path via FT_WQ_SEQ_DIR)
# Paper mapping: Table 1 “RAG-Seq.” on WQ; paper initializes WQ from an NQ-trained RAG model, then fine-tunes
# Goal: Reproduce WQ EM after task-specific fine-tuning (or, if missing, run a transfer baseline using NQ checkpoint)
# ============================================================
FT_WQ_SEQ_DIR = "WQ_models/facebook_rag-sequence-nq__sequence__wq_ft__nDocs10__20260118_124326"
rag_seq_ft_wq = None

if FT_WQ_SEQ_DIR:
    rag_seq_ft_wq = RAGModelManager(
        model_name=FT_WQ_SEQ_DIR,
        rag_type="sequence",
        use_dummy=USE_DUMMY,
        index_name=INDEX_NAME,
        n_docs=15,
        num_beams=1,
        max_new_tokens=16
    )
print("Experiment: WebQuestions (RAG-Sequence baseline reproduction)")
if rag_seq_ft_wq:
    wb_score_seq = run_experiment("WebQuestions", rag_seq_ft_wq, num_samples=NUM_SAMPLES)
else:
    print("No FT_WQ_SEQ_DIR set -> Evaluating NQ sequence model on WQ (expecting low EM)")
    wb_score_seq = run_experiment("WebQuestions", rag_seq_nq, num_samples=NUM_SAMPLES)

Loading Model: WQ_models/facebook_rag-sequence-nq__sequence__wq_ft__nDocs10__20260118_124326 (sequence)...
Model loaded successfully.
Configured n_docs=15, num_beams=1, max_new_tokens=16
Experiment: WebQuestions (RAG-Sequence baseline reproduction)

STARTING EXPERIMENT: WebQuestions
Processing WebQuestions dataset...


Loading Data: 100%|██████████| 2032/2032 [00:00<00:00, 36784.52it/s]


Loaded 2032 TOTAL samples for WebQuestions.


Generating: 100%|██████████| 254/254 [2:22:05<00:00, 33.57s/it]  


RESULTS for WebQuestions:
Exact Match (EM): 37.25%





In [10]:
# ============================================================
# Task: WebQuestions (WQ) — evaluate Exact Match (EM)
# Model: Fine-tuned RAG-Token checkpoint (local folder path via FT_WQ_TOKEN_DIR)
# Paper mapping: Table 1 “RAG-Token” on WQ; paper initializes WQ from NQ RAG, then fine-tunes and evaluates EM
# Goal: Reproduce WQ EM for a token-level RAG model after fine-tuning (or, if missing, run a transfer baseline using NQ token checkpoint)
# ============================================================
FT_WQ_TOKEN_DIR = "WQ_models/facebook_rag-token-nq__token__wq_ft__nDocs10__20260117_163413"
rag_tok_ft_wq = None

if FT_WQ_TOKEN_DIR:
    rag_tok_ft_wq = RAGModelManager(
        model_name=FT_WQ_TOKEN_DIR,
        rag_type="token",
        use_dummy=USE_DUMMY,
        index_name=INDEX_NAME,
        n_docs=15,
        num_beams=1,
        max_new_tokens=16
    )
print("Experiment: WebQuestions (RAG-Token baseline reproduction)")
if rag_tok_ft_wq:
    wb_score_tok = run_experiment("WebQuestions", rag_tok_ft_wq, num_samples=NUM_SAMPLES)
else:
    print("No FT_WQ_TOKEN_DIR set -> Evaluating NQ token model on WQ (expecting low EM)")
    wb_score_tok = run_experiment("WebQuestions", rag_tok_nq, num_samples=NUM_SAMPLES)


Loading Model: WQ_models/facebook_rag-token-nq__token__wq_ft__nDocs10__20260117_163413 (token)...


  table = cls._concat_blocks(blocks, axis=0)
  _torch_pytree._register_pytree_node(


Model loaded successfully.
Configured n_docs=15, num_beams=1, max_new_tokens=16
Experiment: WebQuestions (RAG-Token baseline reproduction)

STARTING EXPERIMENT: WebQuestions
Processing WebQuestions dataset...


Loading Data: 100%|██████████| 2032/2032 [00:00<00:00, 39253.63it/s]


Loaded 2032 TOTAL samples for WebQuestions.


Generating: 100%|██████████| 254/254 [35:11<00:00,  8.31s/it]


RESULTS for WebQuestions:
Exact Match (EM): 32.92%



