In [2]:
!pip install torch faiss-cpu tqdm

!pip install pyarrow==14.0.1 datasets==2.14.6 transformers==4.35.2 faiss-cpu



In [3]:

import torch
import logging
from transformers import (
    RagTokenizer,
    RagRetriever,
    RagSequenceForGeneration,
    RagTokenForGeneration
)
from datasets import load_dataset
from tqdm import tqdm
import numpy as np

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {device}")

logging.getLogger("transformers").setLevel(logging.ERROR)

  _torch_pytree._register_pytree_node(


Running on: cuda


  _torch_pytree._register_pytree_node(


In [4]:
import string
import re

def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def exact_match_score(prediction, ground_truth):
    """Check if the prediction exactly matches the ground truth."""
    return (normalize_answer(prediction) == normalize_answer(ground_truth))

def calculate_em(predictions, references):
    """
    predictions: list of strings
    references: list of lists of strings (since one question can have multiple valid answers)
    """
    total_em = 0
    for pred, refs in zip(predictions, references):
        # If the prediction matches ANY of the valid references, it's a hit
        if any(exact_match_score(pred, gt) for gt in refs):
            total_em += 1
    
    return 100 * (total_em / len(predictions))

In [5]:
class RAGModelManager:
    def __init__(
        self,
        model_name,
        rag_type,  # "sequence" or "token"
        use_dummy=True,
        index_name="exact",
        n_docs=15,
        num_beams=1
        max_new_tokens=16,
    ):
        assert rag_type in ["sequence", "token"]
        self.model_name = model_name
        self.rag_type = rag_type

        self.n_docs = n_docs
        self.num_beams = num_beams
        self.max_new_tokens = max_new_tokens

        print(f"Loading Model: {model_name} ({rag_type})...")

        # Load Tokenizer
        self.tokenizer = RagTokenizer.from_pretrained(model_name)

        # Load Retriever
        self.retriever = RagRetriever.from_pretrained(
            model_name,
            index_name=index_name,
            use_dummy_dataset=use_dummy
        )

        # Load RAG Model
        if rag_type == "token":
            self.model = RagTokenForGeneration.from_pretrained(
                model_name,
                retriever=self.retriever
            ).to(device)
        else:
            self.model = RagSequenceForGeneration.from_pretrained(
                model_name,
                retriever=self.retriever
            ).to(device)

        self.model.config.n_docs = self.n_docs

        self.model.eval()
        print("Model loaded successfully.")
        print(f"Configured n_docs={self.model.config.n_docs}, num_beams={self.num_beams}, max_new_tokens={self.max_new_tokens}")

    def set_n_docs(self, n_docs: int):
        """
        Controls K retrieved docs used by RAG at generation time.
        """
        self.n_docs = n_docs
        self.model.config.n_docs = n_docs
        if hasattr(self.retriever, "n_docs"):
            self.retriever.n_docs = n_docs
        if hasattr(self.retriever, "config"):
            self.retriever.config.n_docs = n_docs

    def generate_answers(self, questions, batch_size=4):
        self.model.eval()
        all_answers = []

        for i in tqdm(range(0, len(questions), batch_size), desc="Generating"):
            batch_questions = questions[i: i + batch_size]

            # Tokenize
            inputs = self.tokenizer(
                batch_questions,
                return_tensors="pt",
                padding=True,
                truncation=True
            ).to(device)

            # Generate
            with torch.no_grad():
                generated_ids = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    n_docs=self.n_docs,
                    num_beams=self.num_beams,
                    do_sample=False,
                    max_new_tokens=self.max_new_tokens,
                    min_length=1,
                    early_stopping=True,
                )

            # Decode
            batch_answers = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            all_answers.extend(batch_answers)

        return all_answers

In [6]:
def run_experiment(task_name, model_manager, num_samples=None):
    print(f"\n{'='*40}")
    print(f"STARTING EXPERIMENT: {task_name}")
    print(f"{'='*40}")

    # DATA LOADING
    questions = []
    answers = []

    if task_name == "Natural Questions":
        # Load NQ Open (Simplified version)
        dataset = load_dataset("nq_open", split="validation")

        print("Processing Natural Questions dataset...")
        for row in tqdm(dataset, desc="Loading Data"):
            q = row['question']
            ans_list = row['answer']
            if ans_list:
                questions.append(q)
                answers.append(ans_list)

            if num_samples and len(questions) >= num_samples:
                break

    elif task_name == "WebQuestions":
        # Load WebQuestions
        dataset = load_dataset("stanfordnlp/web_questions", split="test")

        print("Processing WebQuestions dataset...")
        for row in tqdm(dataset, desc="Loading Data"):
            questions.append(row['question'])
            answers.append(row['answers'])

            if num_samples and len(questions) >= num_samples: 
                break

    print(f"Loaded {len(questions)} TOTAL samples for {task_name}.")

    # EXECUTION
    predictions = model_manager.generate_answers(questions, batch_size=8)

    # EVALUATION
    score = calculate_em(predictions, answers)

    print(f"\nRESULTS for {task_name}:")
    print(f"Exact Match (EM): {score:.2f}%")

    # Save results to a file
    with open(f"{task_name.replace(' ', '_')}_results.txt", "w") as f:
        f.write(f"Task: {task_name}\n")
        f.write(f"Samples: {len(questions)}\n")
        f.write(f"EM Score: {score:.2f}%\n")

    return score


In [None]:
# CONFIGURATION
# TRUE = Runs fast (prototype), Score will be ~0%
# FALSE = Downloads 75GB index
USE_DUMMY = False

# Set to None to run EVERYTHING
# Set to integer (e.g., 50) for testing
NUM_SAMPLES = None

INDEX_NAME = "exact"

# 1. Baseline reproduction: NQ models from HuggingFace
rag_seq_nq = RAGModelManager(
    model_name="facebook/rag-sequence-nq",
    rag_type="sequence",
    use_dummy=USE_DUMMY,
    index_name=INDEX_NAME,
    n_docs=50,
    num_beams=1,
    max_new_tokens=16
)

rag_tok_nq = RAGModelManager(
    model_name="facebook/rag-token-nq",
    rag_type="token",
    use_dummy=USE_DUMMY,
    index_name=INDEX_NAME,
    n_docs=15,
    num_beams=1,
    max_new_tokens=16
)


# 2. Baseline reproduction: Load self made fine-tuned checkpoints from folders
FT_WQ_SEQ_DIR = None
FT_WQ_TOKEN_DIR = None

rag_seq_ft_wq = None
rag_tok_ft_wq = None

if FT_WQ_SEQ_DIR:
    rag_seq_ft_wq = RAGModelManager(
        model_name=FT_WQ_SEQ_DIR,
        rag_type="sequence",
        use_dummy=USE_DUMMY,
        index_name=INDEX_NAME,
        n_docs=50,
        num_beams=1,
        max_new_tokens=16
    )

if FT_WQ_TOKEN_DIR:
    rag_tok_ft_wq = RAGModelManager(
        model_name=FT_WQ_TOKEN_DIR,
        rag_type="token",
        use_dummy=USE_DUMMY,
        index_name=INDEX_NAME,
        n_docs=15,
        num_beams=1,
        max_new_tokens=16
    )

In [None]:
print("Experiment: Natural Questions")
nq_score_seq = run_experiment("Natural Questions", rag_seq_nq, num_samples=NUM_SAMPLES)
nq_score_tok = run_experiment("Natural Questions", rag_tok_nq, num_samples=NUM_SAMPLES)

In [None]:
print("Experiment: WebQuestions")

if rag_seq_ft_wq:
    wb_score_seq = run_experiment("WebQuestions", rag_seq_ft_wq, num_samples=NUM_SAMPLES)
else:
    print("No FT_WQ_SEQ_DIR set -> WE evaluate base NQ sequence model on WQ (we expect low EM).")
    wb_score_seq = run_experiment("WebQuestions", rag_seq_nq, num_samples=NUM_SAMPLES)

if rag_tok_ft_wq:
    wb_score_tok = run_experiment("WebQuestions", rag_tok_ft_wq, num_samples=NUM_SAMPLES)
else:
    print("No FT_WQ_TOKEN_DIR set -> We evaluate base NQ token model on WQ (we expect low EM).")
    wb_score_tok = run_experiment("WebQuestions", rag_tok_nq, num_samples=NUM_SAMPLES)
