In [None]:
!pip install torch faiss-cpu faiss-gpu-cu12 tqdm numpy==1.26.4

!pip install pyarrow==14.0.1 datasets==2.14.6 transformers==4.35.2 accelerate==0.24.1

!pip -q install ipywidgets

In [None]:
pip uninstall faiss-cpu -y

In [None]:
import torch
import logging
import string
import re
import time
import matplotlib.pyplot as plt
from transformers import (
    RagTokenizer,
    RagRetriever,
    RagSequenceForGeneration,
    RagTokenForGeneration
)
from datasets import load_dataset
from tqdm import tqdm

# ==========================================
# 1. SETUP & HELPER FUNCTIONS
# ==========================================

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {device}")
logging.getLogger("transformers").setLevel(logging.ERROR)

In [None]:
def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))

def calculate_em(predictions, references):
    total_em = 0
    for pred, refs in zip(predictions, references):
        if any(exact_match_score(pred, gt) for gt in refs):
            total_em += 1
    return 100 * (total_em / len(predictions))

In [None]:
# ==========================================
# 2. RAG MODEL MANAGER
# ==========================================

class RAGModelManager:
    def __init__(
        self,
        model_name,
        rag_type,  # "sequence" or "token"
        n_docs_init=5,
        use_dummy=False,
        index_name="exact",
        num_beams=1,
        max_new_tokens=16,
    ):
        self.model_name = model_name
        self.rag_type = rag_type
        self.n_docs = n_docs_init
        self.num_beams = num_beams
        self.max_new_tokens = max_new_tokens

        print(f"Loading Model: {model_name} ({rag_type})...")
        self.tokenizer = RagTokenizer.from_pretrained(model_name)
        self.retriever = RagRetriever.from_pretrained(
            model_name,
            index_name=index_name,
            use_dummy_dataset=use_dummy
        )

        if rag_type == "token":
            self.model = RagTokenForGeneration.from_pretrained(
                model_name, retriever=self.retriever
            ).to(device)
        else:
            self.model = RagSequenceForGeneration.from_pretrained(
                model_name, retriever=self.retriever
            ).to(device)

        # Initial config
        self.set_n_docs(n_docs_init)
        self.model.eval()

    def set_n_docs(self, n_docs: int):
        """Updates the number of retrieved documents (k) dynamically."""
        self.n_docs = n_docs
        self.model.config.n_docs = n_docs
        if hasattr(self.retriever, "n_docs"):
            self.retriever.n_docs = n_docs
        if hasattr(self.retriever, "config"):
            self.retriever.config.n_docs = n_docs

    def generate_answers(self, questions, batch_size=4):
        self.model.eval()
        all_answers = []
        
        # Batch processing
        for i in tqdm(range(0, len(questions), batch_size), desc=f"Generating (k={self.n_docs})"):
            batch_questions = questions[i: i + batch_size]
            
            inputs = self.tokenizer(
                batch_questions,
                return_tensors="pt",
                padding=True,
                truncation=True
            ).to(device)

            with torch.no_grad():
                generated_ids = self.model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    n_docs=self.n_docs,
                    num_beams=self.num_beams,
                    max_new_tokens=self.max_new_tokens
                )

            batch_answers = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            all_answers.extend(batch_answers)

        return all_answers

In [None]:
# ==========================================
# 3. EXPERIMENT CONFIGURATION
# ==========================================

# Settings
USE_DUMMY = False  # Set False for real evaluation (downloads large index)
NUM_SAMPLES = None   # Use None for full dataset, or int (e.g., 100) for testing
K_VALUES = [3, 5, 7, 10, 20, 30] # The list of 'k' values to evaluate
BATCH_SIZE = 8
MAX_LOOP_TIME = None        # e.g., 30 minutes per model
MAX_ITER_TIME = 8400        # e.g., 300 for 5 minutes per k (or None to disable)


# Load Data Once
print("Loading Natural Questions Validation Set...")
dataset = load_dataset("nq_open", split="validation")
questions = []
answers = []

for row in dataset:
    questions.append(row['question'])
    answers.append(row['answer'])
    if NUM_SAMPLES and len(questions) >= NUM_SAMPLES:
        break

print(f"Loaded {len(questions)} samples.")

In [None]:
# ==========================================
# 4. RUN EVALUATION LOOP
# ==========================================

results = {
    "rag-token": {"scores": [], "k": K_VALUES},
    "rag-sequence": {"scores": [], "k": K_VALUES}
}

print("\n=== Evaluating RAG-Token ===")
rag_token_mgr = RAGModelManager(
    model_name="facebook/rag-token-nq",
    rag_type="token",
    use_dummy=USE_DUMMY
)

loop_start_time = time.time()

for k in K_VALUES:
    iter_start_time = time.time()

    # ---- Global loop timeout ----
    elapsed_loop_time = time.time() - loop_start_time
    if MAX_LOOP_TIME is not None and elapsed_loop_time > MAX_LOOP_TIME:
        print(
            f"[INTERRUPTED] Loop exceeded max time "
            f"({elapsed_loop_time:.1f}s > {MAX_LOOP_TIME}s)"
        )
        break

    print(f"\nProcessing k={k} for RAG-Token...")

    rag_token_mgr.set_n_docs(k)
    predictions = rag_token_mgr.generate_answers(
        questions, batch_size=BATCH_SIZE
    )
    score = calculate_em(predictions, answers)
    results["rag-token"]["scores"].append(score)

    iter_time = time.time() - iter_start_time
    print(
        f"RAG-Token (k={k}) EM: {score:.2f}% "
        f"| Iter time: {iter_time:.1f}s"
    )

    # ---- Per-iteration timeout (optional) ----
    if MAX_ITER_TIME is not None and iter_time > MAX_ITER_TIME:
        print(
            f"[INTERRUPTED] RAG-Token iteration exceeded "
            f"max iteration time ({iter_time:.1f}s > {MAX_ITER_TIME}s)"
        )
        break


# Free memory before next model
del rag_token_mgr
torch.cuda.empty_cache()

print("\n=== Evaluating RAG-Sequence ===")
rag_seq_mgr = RAGModelManager(
    model_name="facebook/rag-sequence-nq",
    rag_type="sequence",
    use_dummy=USE_DUMMY
)

loop_start_time = time.time()

for k in K_VALUES:
    iter_start_time = time.time()

    # ---- Global loop timeout ----
    elapsed_loop_time = time.time() - loop_start_time
    if MAX_LOOP_TIME is not None and elapsed_loop_time > MAX_LOOP_TIME:
        print(
            f"[INTERRUPTED] Loop exceeded max time "
            f"({elapsed_loop_time:.1f}s > {MAX_LOOP_TIME}s)"
        )
        break

    print(f"\nProcessing k={k} for RAG-Sequence...")

    rag_seq_mgr.set_n_docs(k)
    predictions = rag_seq_mgr.generate_answers(
        questions, batch_size=BATCH_SIZE
    )
    score = calculate_em(predictions, answers)
    results["rag-sequence"]["scores"].append(score)

    iter_time = time.time() - iter_start_time
    print(
        f"RAG-Sequence (k={k}) EM: {score:.2f}% "
        f"| Iter time: {iter_time:.1f}s"
    )

    # ---- Per-iteration timeout (optional) ----
    if MAX_ITER_TIME is not None and iter_time > MAX_ITER_TIME:
        print(
            f"[INTERRUPTED] RAG-Sequence iteration exceeded "
            f"max iteration time ({iter_time:.1f}s > {MAX_ITER_TIME}s)"
        )
        break


In [None]:
# ==========================================
# 5. PLOTTING RESULTS
# ==========================================

plt.figure(figsize=(10, 6))

# Plot RAG-Token
plt.plot(
    results["rag-token"]["k"], 
    results["rag-token"]["scores"], 
    marker='o', 
    label='RAG-Token', 
    linestyle='-', 
    color='blue'
)

# Plot RAG-Sequence
plt.plot(
    results["rag-sequence"]["k"], 
    results["rag-sequence"]["scores"], 
    marker='s', 
    label='RAG-Sequence', 
    linestyle='--', 
    color='green'
)

plt.title(f'RAG Performance on NQ (Exact Match) vs Retrieved Docs (k)\n(Samples: {len(questions)})')
plt.xlabel('Number of Retrieved Documents (k)')
plt.ylabel('Exact Match (EM) Score %')
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend()
plt.xticks(K_VALUES)  # Ensure all k values are shown on x-axis

# Save plot or show
plt.savefig("k_retrieval_performance.png")
print("Plot saved as 'k_retrieval_performance.png'")
plt.show()