In [2]:
!pip install torch faiss-cpu tqdm

!pip install pyarrow==14.0.1 datasets==2.14.6 transformers==4.35.2 faiss-cpu



In [3]:

import torch
import logging
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
from datasets import load_dataset
from tqdm import tqdm
import numpy as np

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {device}")

logging.getLogger("transformers").setLevel(logging.ERROR)

  _torch_pytree._register_pytree_node(


Running on: cuda


  _torch_pytree._register_pytree_node(


In [4]:
import string
import re

def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def exact_match_score(prediction, ground_truth):
    """Check if the prediction exactly matches the ground truth."""
    return (normalize_answer(prediction) == normalize_answer(ground_truth))

def calculate_em(predictions, references):
    """
    predictions: list of strings
    references: list of lists of strings (since one question can have multiple valid answers)
    """
    total_em = 0
    for pred, refs in zip(predictions, references):
        # If the prediction matches ANY of the valid references, it's a hit
        if any(exact_match_score(pred, gt) for gt in refs):
            total_em += 1
    
    return 100 * (total_em / len(predictions))

In [5]:
class RAGModelManager:
    def __init__(self, model_name="facebook/rag-sequence-nq", use_dummy=True):
        self.model_name = model_name
        print(f"Loading Model: {model_name}...")

        # Load Tokenizer
        self.tokenizer = RagTokenizer.from_pretrained(model_name)

        # Load Retriever
        self.retriever = RagRetriever.from_pretrained(
            model_name, 
            index_name="exact", 
            use_dummy_dataset=use_dummy)

        # Load Sequence Generator
        self.model = RagSequenceForGeneration.from_pretrained(
            model_name,
            retriever=self.retriever
        ).to(device)

        print("Model loaded successfully.")

    def generate_answers(self, questions, batch_size=4):
        self.model.eval()
        all_answers = []

        for i in tqdm(range(0, len(questions), batch_size), desc="Generating"):
            batch_questions = questions[i: i + batch_size]

            # tokenize
            inputs = self.tokenizer(
                batch_questions,
                return_tensors="pt",
                padding=True,
                truncation=True
            ).to(device)

            # Generate
            with torch.no_grad():
                generated_ids = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    num_beams=2,  # Paper uses Beam Search
                    min_length=1,
                    max_length=50
            )

            # Decode
            batch_answers = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            all_answers.extend(batch_answers)

        return all_answers

        

In [6]:
def run_experiment(task_name, model_manager, num_samples=None):
    """
    num_samples: Set to None (or -1) to run on the FULL dataset.
    """
    print(f"\n{'='*40}")
    print(f"STARTING EXPERIMENT: {task_name}")
    print(f"{'='*40}")

    # --- DATA LOADING ---
    questions = []
    answers = []
    
    if task_name == "Natural Questions":
        # Load NQ Open (Simplified version)
        dataset = load_dataset("nq_open", split="validation")
        
        print("Processing Natural Questions dataset...")
        for row in tqdm(dataset, desc="Loading Data"):
            q = row['question']
            ans_list = row['answer']
            if ans_list:
                questions.append(q)
                answers.append(ans_list)
                
            if num_samples and len(questions) >= num_samples: 
                break
                
    elif task_name == "WebQuestions":
        # Load WebQuestions
        dataset = load_dataset("stanfordnlp/web_questions", split="test")
        
        print("Processing WebQuestions dataset...")
        for row in tqdm(dataset, desc="Loading Data"):
            questions.append(row['question'])
            answers.append(row['answers'])
            
            if num_samples and len(questions) >= num_samples: 
                break

    print(f"Loaded {len(questions)} TOTAL samples for {task_name}.")

    # --- EXECUTION ---
    # We pass the full list; the manager handles batching + tqdm
    predictions = model_manager.generate_answers(questions, batch_size=8)

    # --- EVALUATION ---
    score = calculate_em(predictions, answers)
    
    print(f"\nRESULTS for {task_name}:")
    print(f"Exact Match (EM): {score:.2f}%")
    
    # Save results to a file (optional, good for reproducibility)
    with open(f"{task_name.replace(' ', '_')}_results.txt", "w") as f:
        f.write(f"Task: {task_name}\n")
        f.write(f"Samples: {len(questions)}\n")
        f.write(f"EM Score: {score:.2f}%\n")
        
    return score

In [8]:
# --- CONFIGURATION ---
# TRUE = Runs fast (prototype), Score ~0%
# FALSE = Downloads 75GB index, requires >64GB RAM, Score ~44.5%
USE_DUMMY = False

# Set to None to run EVERYTHING (Warning: Takes ~30-40 mins)
# Set to integer (e.g., 50) for testing
NUM_SAMPLES = None

# 1. Initialize Model
# Note: If you crash here with USE_DUMMY=False, you ran out of RAM.
rag = RAGModelManager(model_name="facebook/rag-sequence-nq", use_dummy=USE_DUMMY)

# 2. Run Natural Questions (Full Validation Set)
nq_score = run_experiment("Natural Questions", rag, num_samples=NUM_SAMPLES)

# 3. Run WebQuestions (Full Test Set)
wb_score = run_experiment("WebQuestions", rag, num_samples=NUM_SAMPLES)

Loading Model: facebook/rag-sequence-nq...




Downloading data files:   0%|          | 0/157 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/545M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/546M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/546M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/546M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/546M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/545M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/537M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/530M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/538M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/546M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/545M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/545M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/545M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/545M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/545M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/545M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/545M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/545M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/545M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/545M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/545M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/545M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/545M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/545M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/545M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/545M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/544M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/543M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/542M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/157 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/21015300 [00:00<?, ? examples/s]

  table = cls._concat_blocks(blocks, axis=0)


Downloading data files:   0%|          | 0/157 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/157 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/21015300 [00:00<?, ? examples/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/38.0G [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Model loaded successfully.

STARTING EXPERIMENT: Natural Questions
Processing Natural Questions dataset...


Loading Data: 100%|██████████| 3610/3610 [00:00<00:00, 36902.52it/s]


Loaded 3610 TOTAL samples for Natural Questions.


Generating: 100%|██████████| 452/452 [1:49:39<00:00, 14.56s/it]



RESULTS for Natural Questions:
Exact Match (EM): 37.45%

STARTING EXPERIMENT: WebQuestions


Downloading readme: 0.00B [00:00, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/260k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/142k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/3778 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2032 [00:00<?, ? examples/s]

Processing WebQuestions dataset...


Loading Data: 100%|██████████| 2032/2032 [00:00<00:00, 35491.37it/s]


Loaded 2032 TOTAL samples for WebQuestions.


Generating: 100%|██████████| 254/254 [59:15<00:00, 14.00s/it]


RESULTS for WebQuestions:
Exact Match (EM): 15.99%



