# eRAG: Evaluating Retrieval Quality in Retrieval-Augmented Generation

This notebook is an attempt to reproduce the core pipeline used to obtain the results in the paper titled "[Evaluating Retrieval Quality in Retrieval-Augmented Generation](https://arxiv.org/abs/2404.13781)."

## Downloading the NQ dataset

In [None]:
# Create a directory for data if it doesn't exist
!mkdir -p data

# Download the NQ dataset (e.g., the dev file)
!wget -O data/nq-dev-kilt.jsonl http://dl.fbaipublicfiles.com/KILT/nq-dev-kilt.jsonl
!wget -O data/nq-train-kilt.jsonl http://dl.fbaipublicfiles.com/KILT/nq-train-kilt.jsonl
!wget -O data/nq-test_without_answers-kilt.jsonl http://dl.fbaipublicfiles.com/KILT/nq-test_without_answers-kilt.jsonl
#!wget -O data/fever-dev-kilt.jsonl http://dl.fbaipublicfiles.com/KILT/fever-dev-kilt.jsonl


# Display a few lines of the file to verify the download
!head -n 5 data/nq-dev-kilt.jsonl
!head -n 5 data/nq-train-kilt.jsonl
!head -n 5 data/nq-test_without_answers-kilt.jsonl
#!head -n 5 data/fever-dev-kilt.jsonl

--2025-04-01 14:41:25--  http://dl.fbaipublicfiles.com/KILT/nq-dev-kilt.jsonl
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 108.157.254.15, 108.157.254.102, 108.157.254.124, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|108.157.254.15|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7936566 (7.6M) [text/plain]
Saving to: ‘data/nq-dev-kilt.jsonl’


2025-04-01 14:41:25 (135 MB/s) - ‘data/nq-dev-kilt.jsonl’ saved [7936566/7936566]

--2025-04-01 14:41:25--  http://dl.fbaipublicfiles.com/KILT/nq-train-kilt.jsonl
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 108.157.254.15, 108.157.254.102, 108.157.254.124, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|108.157.254.15|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 51895886 (49M) [text/plain]
Saving to: ‘data/nq-train-kilt.jsonl’


2025-04-01 14:41:26 (213 MB/s) - ‘data/nq-train-kilt.jsonl’ saved [51895886/51895886]

--2025-0

## Wikipedia preprocessing using the KILT source

In [None]:
import json
import requests
import pdb


# Splits text into passages max_words long
def split_into_passages(text, max_words=100):
    """Splits text into passages of up to max_words words (without overlap)."""
    words = text.split()
    passages = []
    for i in range(0, len(words), max_words):
        chunk = " ".join(words[i:i+max_words])
        passages.append(chunk)
    return passages


# Retrieves title and text of a document, applies split_into_passages, and concatenates
def process_kilt_page(page_json, max_words=100):
    """
    Given a KILT knowledge source page (record),
    segments the "text" field (list of paragraphs) into passages of max_words.
    Returns a list of documents combining the title and passage.
    """
    title = page_json.get("wikipedia_title", "").strip()
    paragraphs = page_json.get("text", [])
    docs = []
    for para in paragraphs:
        para = para.strip()
        if not para:
            continue
        passages = split_into_passages(para, max_words=max_words)
        for p in passages:
            doc = f"{title} [SEP] {p}"
            docs.append(doc)
    return docs


# File paths
input_file = 'kilt_knowledgesource.json'
output_file = 'wikipedia_passages_sample.jsonl'

# Process 1/1000 of all the records for demonstration (the full file is too large: 5903530 records)
num_records_to_process = 5903

url = "http://dl.fbaipublicfiles.com/KILT/kilt_knowledgesource.json"

# Reads the file line by line to avoid the full download
with requests.get(url, stream=True) as r, open(output_file, 'w', encoding='utf-8') as fout:
    for i, line in enumerate(r.iter_lines()):
        if i >= num_records_to_process:
            print("Superato limite")
            break
        try:
            page = json.loads(line)
        except json.JSONDecodeError as e:
            print(f"Error decoding record {i}: {e}")
            continue
        docs = process_kilt_page(page, max_words=100)
        for doc in docs:
            # Save each document as JSON Lines
            json.dump({"document": doc}, fout)
            fout.write("\n")

print(f"Processed {num_records_to_process} records. Output saved in '{output_file}'.")

# Display the first 5 lines of the output file
!head -n 5 {output_file}

Superato limite
Processed 5903 records. Output saved in 'wikipedia_passages_sample.jsonl'.
{"document": "A [SEP] A"}
{"document": "A [SEP] A (named , plural \"As\", \"A's\", \"a\"s, \"a's\" or \"aes\") is the first letter and the first vowel of the modern English alphabet and the ISO basic Latin alphabet. It is similar to the Ancient Greek letter alpha, from which it derives. The uppercase version consists of the two slanting sides of a triangle, crossed in the middle by a horizontal bar. The lowercase version can be written in two forms: the double-storey a and single-storey \u0251. The latter is commonly used in handwriting and fonts based on it, especially fonts intended to be read by children, and is also"}
{"document": "A [SEP] found in italic type."}
{"document": "A [SEP] In the English grammar, \"a\", and its variant \"an\", is an indefinite article."}
{"document": "A [SEP] Section::::History."}


## Index building

Dopo aver pre-processato i file (da puro testo, i documenti vengono divisi in titolo + testo e vengono accorciati(?)).
Ora viene svolto l'indicizzazione tramite BM25 (modello lessicale) e Contriever (modello denso) dei documenti pre-processati.

### BM25

In [None]:
!pip install rank_bm25 nltk

import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
import string
import numpy as np
import json
from rank_bm25 import BM25Okapi

Collecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Downloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Installing collected packages: rank_bm25
Successfully installed rank_bm25-0.2.2


In [None]:
try:
    stopwords.words('english')
except LookupError:
    print("Downloading NLTK data ('stopwords', 'punkt')...")
    nltk.download('punkt_tab', quiet=True)
    nltk.download('punkt', quiet=True)
    nltk.download('stopwords', quiet=True)

stop_words = set(stopwords.words('english'))
stemmer = PorterStemmer()
punctuation_table = str.maketrans('', '', string.punctuation)

# Simple preprocessing removing stop words and using a stemmer
def preprocess_text(text):
    lowered = text.lower()
    no_punct = lowered.translate(punctuation_table)
    tokens = word_tokenize(no_punct)
    filtered_tokens = [word for word in tokens if word not in stop_words and word.isalnum()]
    stemmed_tokens = [stemmer.stem(word) for word in filtered_tokens]
    return stemmed_tokens

# Loading the collection documents
documents = []
input_file = "wikipedia_passages_sample.jsonl"
try:
    with open(input_file, "r", encoding="utf-8") as fin:
        for i, line in enumerate(fin):
            try:
                data = json.loads(line)
                documents.append(data["document"])
            except json.JSONDecodeError:
                print(f"Warning: Skipping invalid JSON on line {i+1}")
except FileNotFoundError:
    print(f"Error: Input file '{input_file}' not found.")
    exit()

if not documents:
    print("No documents loaded. Exiting.")
    exit()


print("Preprocessing documents...")
tokenized_docs = [preprocess_text(doc) for doc in documents]
print(f"Sample preprocessed doc: {tokenized_docs[0:10][:20] if tokenized_docs else 'N/A'}")

print("Building BM25 index...")
bm25 = BM25Okapi(tokenized_docs)

def bm25_retrieve(query, bm25_index, original_docs, k=5):
    # Apply the same preprocessing to the query
    tokenized_query = preprocess_text(query)
    if not tokenized_query:
        print("Warning: Query resulted in empty tokens after preprocessing.")
        return [], [], []

    scores = bm25_index.get_scores(tokenized_query)

    top_n_indices = np.argsort(scores)[::-1][:k]

    top_docs = [original_docs[i] for i in top_n_indices]
    top_scores = [scores[i] for i in top_n_indices]

    return top_n_indices, top_docs, top_scores

Downloading NLTK data ('stopwords', 'punkt')...
Preprocessing documents...
Sample preprocessed doc: [['sep'], ['sep', 'name', 'plural', 'ae', 'first', 'letter', 'first', 'vowel', 'modern', 'english', 'alphabet', 'iso', 'basic', 'latin', 'alphabet', 'similar', 'ancient', 'greek', 'letter', 'alpha', 'deriv', 'uppercas', 'version', 'consist', 'two', 'slant', 'side', 'triangl', 'cross', 'middl', 'horizont', 'bar', 'lowercas', 'version', 'written', 'two', 'form', 'doublestorey', 'singlestorey', 'ɑ', 'latter', 'commonli', 'use', 'handwrit', 'font', 'base', 'especi', 'font', 'intend', 'read', 'children', 'also'], ['sep', 'found', 'ital', 'type'], ['sep', 'english', 'grammar', 'variant', 'indefinit', 'articl'], ['sep', 'sectionhistori'], ['sep', 'earliest', 'certain', 'ancestor', 'aleph', 'also', 'written', 'aleph', 'first', 'letter', 'phoenician', 'alphabet', 'consist', 'entir', 'conson', 'reason', 'also', 'call', 'abjad', 'distinguish', 'true', 'alphabet', 'turn', 'ancestor', 'aleph', 'may',

In [None]:
# BM25 example usage
query = "gravitational time dilation"
print(f"\nSearching for: '{query}'")
indices, docs_retrieved, scores_retrieved = bm25_retrieve(query, bm25, documents, k=3)

print("Top matching documents:")
for idx, doc, score in zip(indices, docs_retrieved, scores_retrieved):
    print(f"  - Index: {idx}, Score: {score:.4f}\n    Text: {doc[:200]}...")


Searching for: 'gravitational time dilation'
Top matching documents:
  - Index: 9274, Score: 15.1777
    Text: Muon [SEP] In the Rossi–Hall experiment (1941), muons were used to observe the time dilation (or, alternatively, length contraction) predicted by special relativity, for the first time....
  - Index: 139877, Score: 14.0688
    Text: Surprise (emotion) [SEP] Pupil dilation and constriction can determine the valence of surprise from the action to the reaction of the individual. Positive valence to surprise is shown through a dilati...
  - Index: 266, Score: 13.2422
    Text: International Atomic Time [SEP] In the 1970s, it became clear that the clocks participating in TAI were ticking at different rates due to gravitational time dilation, and the combined TAI scale theref...


### Contriever

In [None]:
!pip install faiss-cpu

Collecting faiss-cpu
  Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.4 kB)
Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl (30.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.7/30.7 MB[0m [31m46.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.10.0


In [None]:
import torch
import faiss
import numpy as np
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

# Model Loading
print("Loading model and tokenizer...")
model_name = "facebook/contriever"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# Device setup (use GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model moved to device: {device}")

def encode_documents(docs, batch_size=16, max_length=256):
    """
    Encodes a list of documents into normalized embeddings using Contriever.
    Maintains the original function signature.
    """
    all_embeddings = []
    model.eval()
    with torch.no_grad():
        for i in range(0, len(docs), batch_size):
            batch_docs = docs[i:i+batch_size]
            # Tokenize and move tensors to the correct device
            inputs = tokenizer(batch_docs, padding=True, truncation=True, return_tensors="pt", max_length=max_length)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            # Get model outputs
            outputs = model(**inputs)

            # Perform mean pooling (weighted by attention mask)
            attention_mask = inputs["attention_mask"]
            token_embeddings = outputs.last_hidden_state
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
            sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
            mean_embeddings = sum_embeddings / sum_mask

            # Normalize embeddings to unit length (L2 norm) - Crucial for Inner Product
            normalized_embeddings = F.normalize(mean_embeddings, p=2, dim=1)

            # Move embeddings to CPU and convert to NumPy
            all_embeddings.append(normalized_embeddings.cpu().numpy())

    # Concatenate all batch embeddings into a single NumPy array
    embeddings = np.concatenate(all_embeddings, axis=0)
    return embeddings


print("Encoding documents...")
doc_embeddings = encode_documents(documents, batch_size=16)
print("Document embedding shape:", doc_embeddings.shape)

# --- Build Faiss Index using Inner Product ---
embedding_dim = doc_embeddings.shape[1]
# Use IndexFlatIP for Inner Product similarity
index = faiss.IndexFlatIP(embedding_dim)

# Faiss *requires* normalized vectors for IndexFlatIP.
# Although encode_documents already returns normalized vectors,
# it's good practice and sometimes necessary for specific Faiss indexes
# to explicitly normalize *again* right before adding.
print("Normalizing embeddings for Faiss (redundant but safe)...")
faiss.normalize_L2(doc_embeddings)

print("Adding documents to Faiss index...")
index.add(doc_embeddings)
print(f"Faiss index built with {index.ntotal} documents using Inner Product (IP).")
# --- End Faiss Indexing ---


# Dense retrieval function using Contriever and Faiss (Inner Product)
def dense_retrieve(query, index, original_docs, k=5, max_length=256):
    """
    Retrieves the top-k documents for a query using the Contriever model
    and a Faiss index (expecting normalized vectors and IP metric).
    Maintains the original function signature.
    """
    model.eval()
    with torch.no_grad():
        # Tokenize query and move to device
        inputs = tokenizer(query, return_tensors="pt", truncation=True, max_length=max_length)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Get model outputs
        outputs = model(**inputs)

        # Perform the same mean pooling as for documents
        attention_mask = inputs["attention_mask"]
        token_embeddings = outputs.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
        sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
        mean_embedding = sum_embeddings / sum_mask

        # Normalize the query embedding - Crucial for Inner Product
        normalized_embedding = F.normalize(mean_embedding, p=2, dim=1)

        # Move to CPU and convert to NumPy array
        query_embedding_np = normalized_embedding.cpu().numpy()

    # Faiss requires normalized query vector for IndexFlatIP search
    faiss.normalize_L2(query_embedding_np)

    # Search the Faiss index
    distances, indices = index.search(query_embedding_np, k)

    top_docs = [original_docs[i] for i in indices[0]]

    # Return indices and distances (scores)
    return indices[0], top_docs, distances[0]

Loading model and tokenizer...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Model moved to device: cuda
Encoding documents...


model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

Document embedding shape: (216122, 768)
Normalizing embeddings for Faiss (redundant but safe)...
Adding documents to Faiss index...
Faiss index built with 216122 documents using Inner Product (IP).


In [None]:
# Dense retrieval example usage
query = "gravitational time dilation"
print(f"\nSearching for query: '{query}'")
top_k_indices, docs_retrieved, top_k_scores = dense_retrieve(query, index, original_docs=documents, k=3)

print("Top matching documents:")
for i, doc, score in zip(top_k_indices, docs_retrieved, top_k_scores):
    print(f"  - Index: {i}, Score (Dot Product): {score:.4f}\n    Text: {doc[:200]}...")


Searching for query: 'gravitational time dilation'
Top matching documents:
  - Index: 9274, Score (Dot Product): 0.5993
    Text: Muon [SEP] In the Rossi–Hall experiment (1941), muons were used to observe the time dilation (or, alternatively, length contraction) predicted by special relativity, for the first time....
  - Index: 214834, Score (Dot Product): 0.5908
    Text: Geometrodynamics [SEP] the geometry of Yang--Mills and gravitational gauge theories,...
  - Index: 1137, Score (Dot Product): 0.5904
    Text: Astronomer [SEP] time....


## Base RAG pipeline with Gemini as Generator

In [None]:
# Retrieval function that allows method selection
def retrieve_documents(query, method='BM25', k=5):
    if method == 'BM25':
        _, docs, _ = bm25_retrieve(query, bm25, documents, k)
        return docs
    elif method == 'dense':
        _, docs, _ = dense_retrieve(query, index, original_docs=documents, k=k)
        return docs
    else:
        raise ValueError("Unsupported retrieval method. Use 'BM25' or 'dense'.")


# Text generation with Gemini API
import google.generativeai as genai

# Configure Gemini API (replace with your API key)
genai.configure(api_key="AIzaSyB4PodZpw89yrgLt6AFM01XcNX68iAQIsg") # my free Google key :) (will expire soon)

# Simple generator function
def text_generator(query, retrieved_docs):
    """Generates a response using Gemini with retrieved documents as context."""
    context = " ".join(retrieved_docs)
    prompt = f"Answer this question: {query} with the least amount of words possible (also only 1). Context: {context} answer: "
    response = genai.GenerativeModel('gemini-2.0-flash').generate_content(prompt)
    return response.text


# Main RAG pipeline function
def rag_pipeline(query, retrieval_method='BM25', k=5):
    # Retrieve documents using the selected method
    retrieved_docs = retrieve_documents(query, method=retrieval_method, k=k)
    print(f"Retrieved documents ({retrieval_method}):")
    for i, doc in enumerate(retrieved_docs):
        print(f"{i+1}. {doc[:150]}...")
    # Generate response using Gemini
    answer = text_generator(query, retrieved_docs)
    return answer


# Example usage:
query = "What is gravitational time dilation?"
# Test with BM25
answer_bm25 = rag_pipeline(query, retrieval_method='BM25', k=5)
print("\nAnswer (BM25):", answer_bm25)

# Test with dense retrieval
answer_dense = rag_pipeline(query, retrieval_method='dense', k=5)
print("\nAnswer (Dense):", answer_dense)

Retrieved documents (BM25):
1. Muon [SEP] In the Rossi–Hall experiment (1941), muons were used to observe the time dilation (or, alternatively, length contraction) predicted by spec...
2. Surprise (emotion) [SEP] Pupil dilation and constriction can determine the valence of surprise from the action to the reaction of the individual. Posi...
3. International Atomic Time [SEP] In the 1970s, it became clear that the clocks participating in TAI were ticking at different rates due to gravitationa...
4. M. C. Escher [SEP] BULLET::::- "Gravitation", (1952)...
5. Surprise (emotion) [SEP] BULLET::::- Pupil dilation mydriasis or pupil constriction miosis...

Answer (BM25): Time's slowing near gravity.

Retrieved documents (dense):
1. Geometrodynamics [SEP] the geometry of Yang--Mills and gravitational gauge theories,...
2. Moment of inertia [SEP] Notice that the distance to the center of oscillation of the seconds pendulum must be adjusted to accommodate different values...
3. Moment of inertia [

## Full experiment

### Data Loading

Loading the whole NQ dataset.

In [None]:
# Load all expected outputs from the KILT NQ dev file
def load_all_nq_expected_outputs(filename):
    """Loads all queries and their gold answers from the KILT NQ dev file.
    Returns a dict {query: [gold_answer1, gold_answer2, ...]}."""
    expected = {}
    with open(filename, "r", encoding="utf-8") as f:
        for line in f:
            record = json.loads(line)
            query = record["input"].strip()
            golds = [out["answer"].strip() for out in record.get("output", []) if "answer" in out]
            if golds:
                expected[query] = golds
    return expected

# Load the full NQ dev dataset
expected_outputs = load_all_nq_expected_outputs("data/nq-dev-kilt.jsonl")
queries = list(expected_outputs.keys())
print(f"Loaded {len(queries)} queries from NQ.")

# For each query, retrieval_results[query] = [doc1, doc2, ..., doc50]
retrieval_results = {query: retrieve_documents(query, method='BM25', k=50) for query in queries}

Loaded 2837 queries from NQ.


### T5-Small Fine-tuning

Building fine-tuning dataset.

In [None]:
# Add retrieved documents to the dataset
def augment_with_retrieved_documents(nq_dataset, retrieval_results):
    augmented_data = []
    for query, gold_answers in nq_dataset.items():

        # Retrieve corresponding documents for the query
        retrieved_docs = retrieval_results.get(query, [])

        # Target answers (since it's a multi-answer task, we just take the first answer)
        target_text = gold_answers[0]

        augmented_data.append({"query": query, "retrieved_docs": retrieved_docs, "gold_answer": target_text})
    return augmented_data

# Augment the NQ dataset with retrieved documents
augmented_dataset = augment_with_retrieved_documents(expected_outputs, retrieval_results)

# Save the augmented dataset to a new file for training
with open('augmented_nq_dataset.json', 'w') as f:
    json.dump(augmented_dataset, f, indent=4)

print(f"Augmented dataset saved with {len(augmented_dataset)} samples.")

Augmented dataset saved with 2837 samples.


Training loop.

In [None]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from transformers import T5Tokenizer, T5ForConditionalGeneration
from torch.optim import AdamW
from functools import partial
import math
import gc

In [None]:
# Custom Dataset class
class QA_Dataset_FiD(Dataset):
    def __init__(self, augmented_data, tokenizer, max_input_length=512, max_target_length=128):
        self.data = augmented_data
        self.tokenizer = tokenizer
        self.max_input_length = max_input_length
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        query = item["query"]
        gold_answer = item["gold_answer"]
        retrieved_docs = item["retrieved_docs"]

        input_encodings = []
        # retrieved_docs = retrieved_docs[:50]

        for doc in retrieved_docs:
            input_text = f"question: {query} context: {doc}"
            input_encoding = self.tokenizer(input_text, truncation=True, max_length=self.max_input_length)
            input_encodings.append(input_encoding)

        target_text = gold_answer
        target_encoding = self.tokenizer(target_text, truncation=True, max_length=self.max_target_length)

        return {
            'input_ids_list': [enc['input_ids'] for enc in input_encodings],
            'attention_mask_list': [enc['attention_mask'] for enc in input_encodings],
            'labels': target_encoding['input_ids']
        }

# Load the augmented dataset
with open('augmented_nq_dataset.json', 'r') as f:
    augmented_dataset = json.load(f)

# Split the dataset
train_data, test_data = train_test_split(augmented_dataset, test_size=0.2, random_state=42)

# Initialize the tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-small")

# Create Datasets
train_dataset = QA_Dataset_FiD(train_data, tokenizer)
test_dataset = QA_Dataset_FiD(test_data, tokenizer)

# Custom Collate Function for FiD
def collate_fn_fid(batch, tokenizer, max_docs_per_item=10, max_input_length=512, max_target_length=128):
    actual_max_docs_in_batch = max(len(item['input_ids_list']) for item in batch) if batch else 0
    max_docs_this_batch = min(actual_max_docs_in_batch, max_docs_per_item)

    max_len_input = max_input_length
    max_len_target = max_target_length

    all_input_ids = []
    all_attention_masks = []
    all_labels = []

    pad_token_id = tokenizer.pad_token_id
    label_pad_token_id = -100

    for item in batch:
        item_input_ids = []
        item_attention_masks = []

        # Process up to max_docs_this_batch, handling items with fewer docs
        num_docs_to_process = min(len(item['input_ids_list']), max_docs_this_batch)

        # Pad up to max_docs_per_item for consistent tensor shapes across batches
        for i in range(max_docs_per_item):
            if i < num_docs_to_process:
                input_ids = item['input_ids_list'][i][:max_len_input]
                attention_mask = item['attention_mask_list'][i][:max_len_input]

                padding_length = max_len_input - len(input_ids)
                input_ids = input_ids + ([pad_token_id] * padding_length)
                attention_mask = attention_mask + ([0] * padding_length)
            else:
                # Pad with empty docs if item has < max_docs_per_item
                input_ids = [pad_token_id] * max_len_input
                attention_mask = [0] * max_len_input

            item_input_ids.append(torch.tensor(input_ids, dtype=torch.long))
            item_attention_masks.append(torch.tensor(attention_mask, dtype=torch.long))

        if len(item_input_ids) != max_docs_per_item:
             print(f"Warning: Mismatch in expected docs {max_docs_per_item} vs actual {len(item_input_ids)}")

        all_input_ids.append(torch.stack(item_input_ids))
        all_attention_masks.append(torch.stack(item_attention_masks))

        labels = item['labels'][:max_len_target]
        label_padding_length = max_len_target - len(labels)
        padded_labels = labels + ([label_pad_token_id] * label_padding_length)
        all_labels.append(torch.tensor(padded_labels, dtype=torch.long))

    if not all_input_ids:
        return {
            'input_ids': torch.empty(0, max_docs_per_item, max_len_input, dtype=torch.long),
            'attention_mask': torch.empty(0, max_docs_per_item, max_len_input, dtype=torch.long),
            'labels': torch.empty(0, max_len_target, dtype=torch.long)
        }

    batch_input_ids = torch.stack(all_input_ids)
    batch_attention_masks = torch.stack(all_attention_masks)
    batch_labels = torch.stack(all_labels)

    return {
        'input_ids': batch_input_ids,
        'attention_mask': batch_attention_masks,
        'labels': batch_labels
    }

# --- DataLoader Setup ---
per_device_batch_size = 1
num_epochs = 1 # As per paper
effective_batch_size = 64 # As per paper
max_docs_per_item = 40

# Calculate gradient accumulation steps
if effective_batch_size % per_device_batch_size != 0:
    raise ValueError("Effective batch size must be divisible by per-device batch size")
gradient_accumulation_steps = effective_batch_size // per_device_batch_size

train_dataloader = DataLoader(train_dataset, batch_size=per_device_batch_size, shuffle=True,
                              collate_fn=partial(collate_fn_fid, tokenizer=tokenizer, max_docs_per_item=max_docs_per_item))
test_dataloader = DataLoader(test_dataset, batch_size=per_device_batch_size, shuffle=False,
                             collate_fn=partial(collate_fn_fid, tokenizer=tokenizer, max_docs_per_item=max_docs_per_item))
# --- End DataLoader Setup ---


# --- Model, Optimizer, and Device Setup ---
model = T5ForConditionalGeneration.from_pretrained("t5-small")
initial_lr = 5e-5 # As per paper
optimizer = AdamW(model.parameters(), lr=initial_lr, weight_decay=1e-2) # As per paper
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# --- End Model Setup ---


# --- Warmup and Training Steps Calculation ---
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
if num_update_steps_per_epoch == 0:
     raise ValueError("Train dataloader is effectively empty with gradient accumulation. Check dataset size and batch sizes.")

total_training_steps = num_epochs * num_update_steps_per_epoch
warmup_proportion = 0.05 # 5% as per paper
num_warmup_steps = int(total_training_steps * warmup_proportion)
global_optimizer_step = 0
# --- End Warmup Calculation ---


print(f"--- Configuration ---")
print(f"Using device: {device}")
print(f"Per-device batch size: {per_device_batch_size}")
print(f"Effective batch size: {effective_batch_size}")
print(f"Gradient Accumulation steps: {gradient_accumulation_steps}")
print(f"Max docs per item: {max_docs_per_item}")
print(f"Num Epochs: {num_epochs}")
print(f"Initial Learning Rate: {initial_lr}")
print(f"Weight Decay: {optimizer.param_groups[0]['weight_decay']}")
print(f"Train Dataloader size (batches): {len(train_dataloader)}")
print(f"Optimizer steps per epoch: {num_update_steps_per_epoch}")
print(f"Total optimizer steps: {total_training_steps}")
print(f"Warmup optimizer steps: {num_warmup_steps}")
print(f"--------------------")


# --- Training Loop with Gradient Accumulation, FiD, and Warmup ---
model.zero_grad()

for epoch in range(num_epochs):
    model.train()
    total_loss_accumulated = 0.0
    processed_batches_count = 0

    for i, batch in enumerate(train_dataloader):

        input_ids_batch = batch['input_ids'].to(device)
        attention_mask_batch = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        if input_ids_batch.numel() == 0:
            print(f"Warning: Skipping empty batch at step {i}")
            continue

        bsz, n_docs, seq_len = input_ids_batch.shape
        target_seq_len = labels.shape[-1]

        if labels.dim() == 3 and labels.shape[1] == 1:
            labels = labels.squeeze(1)
        elif labels.dim() != 2:
            print(f"Unexpected labels shape: {labels.shape} at step {i}. Skipping batch.")
            continue

        # --- FiD Forward Pass ---
        # 1. Prepare and Run Encoder
        input_ids_enc = input_ids_batch.view(bsz * n_docs, seq_len)
        attention_mask_enc = attention_mask_batch.view(bsz * n_docs, seq_len)

        encoder_outputs = model.encoder(
            input_ids=input_ids_enc,
            attention_mask=attention_mask_enc,
            return_dict=True
        )

        # 2. Prepare Inputs for Decoder
        cross_attention_mask = attention_mask_batch.view(bsz, n_docs * seq_len)

        # 3. Run Decoder
        outputs = model(
            labels=labels,
            encoder_outputs=encoder_outputs,
            attention_mask=cross_attention_mask
        )
        loss = outputs.loss

        if loss is None:
            print(f"Warning: Loss is None at step {i}. Skipping backward/step.")
            continue

        # --- Scale Loss for Gradient Accumulation ---
        # Normalize loss to average gradients correctly
        scaled_loss = loss / gradient_accumulation_steps
        total_loss_accumulated += loss.item()
        processed_batches_count += 1

        # --- Backward Pass (Accumulate Gradients) ---
        scaled_loss.backward()

        # --- Optimizer Step ---
        # Check if we have processed enough batches for one optimizer update
        if (i + 1) % gradient_accumulation_steps == 0:

            # --- Optional: Gradient Clipping (Common practice) ---
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # --- Apply Linear Warmup ---
            if global_optimizer_step < num_warmup_steps:
                lr_scale = float(global_optimizer_step) / float(max(1, num_warmup_steps))
            else:
                lr_scale = 1.0

            for param_group in optimizer.param_groups:
                param_group['lr'] = initial_lr * lr_scale
            # --- End Linear Warmup ---

            # --- Perform Optimizer Step ---
            optimizer.step()

            # --- Zero Gradients for the next accumulation cycle ---
            optimizer.zero_grad()

            # --- Increment global OPTIMIZER step counter ---
            global_optimizer_step += 1

            if global_optimizer_step % 10 == 0:
                 current_lr = optimizer.param_groups[0]['lr']
                 recent_loss = total_loss_accumulated / processed_batches_count if processed_batches_count > 0 else 0.0
                 print(f"Epoch: {epoch}, Opt Step: {global_optimizer_step}/{total_training_steps}, LR: {current_lr:.2e}, Avg Epoch Loss So Far: {recent_loss:.4f}")

    # --- End of Epoch ---
    avg_epoch_loss = total_loss_accumulated / processed_batches_count if processed_batches_count > 0 else 0
    print(f"Epoch {epoch} Finished - Average Loss: {avg_epoch_loss:.4f}")

    # --- Save Model Checkpoint Per Epoch ---
    output_dir = f"finetuned_t5_model_fid"
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Model checkpoint saved to {output_dir}")

print("Training finished.")

tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

--- Configuration ---
Using device: cuda
Per-device batch size: 1
Effective batch size: 64
Gradient Accumulation steps: 64
Max docs per item: 40
Num Epochs: 1
Initial Learning Rate: 5e-05
Weight Decay: 0.01
Train Dataloader size (batches): 2269
Optimizer steps per epoch: 36
Total optimizer steps: 36
Warmup optimizer steps: 1
--------------------


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch: 0, Opt Step: 10/36, LR: 5.00e-05, Avg Epoch Loss So Far: 4.4274
Epoch: 0, Opt Step: 20/36, LR: 5.00e-05, Avg Epoch Loss So Far: 4.2880
Epoch: 0, Opt Step: 30/36, LR: 5.00e-05, Avg Epoch Loss So Far: 4.2540
Epoch 0 Finished - Average Loss: 4.2101
Model checkpoint saved to finetuned_t5_model_fid
Training finished.


### ERAG eval and End-to-end eval

In [None]:
!pip install erag

import json
import re
import string
import time
import erag
import os
import pickle
import scipy.stats as stats

Collecting erag
  Downloading erag-0.0.1.tar.gz (7.8 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting pytrec_eval==0.5 (from erag)
  Downloading pytrec_eval-0.5.tar.gz (15 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: erag, pytrec_eval
  Building wheel for erag (pyproject.toml) ... [?25l[?25hdone
  Created wheel for erag: filename=erag-0.0.1-py3-none-any.whl size=7176 sha256=5e018bb9bda0e422fc905eb20efda02a21688d9de1c69f90b7fcf1bca85c6c41
  Stored in directory: /root/.cache/pip/wheels/36/db/db/48902dfbd5390e49a5c925b591d60a6b004e4e00cf8f7b48cc
  Building wheel for pytrec_eval (setup.py) ... [?25l[?25hdone
  Created wheel for pytrec_eval: filename=pytrec_eval-0.5-cp311-cp311-linux_x86_64.whl size=308655 sha256=08266c467ab4975807db6bfa0659fa51b3cd04191131be9636ffbd59a575fb23
  Stored in directory: /

T5 text generator

In [None]:
from transformers.modeling_outputs import BaseModelOutput

def T5_text_generator(
    queries_and_documents: dict,
    model: T5ForConditionalGeneration,
    tokenizer: T5Tokenizer,
    device: torch.device,
    max_input_len: int = 512,
    max_output_len: int = 128,
    num_beams: int = 4,
    **generate_kwargs):
    """
    Generates answers for multiple queries using a fine-tuned T5 FiD model.

    Args:
        queries_and_documents: Dictionary where keys are query strings and
                               values are lists of retrieved document strings.
        model: The loaded fine-tuned T5ForConditionalGeneration model (on device).
        tokenizer: The loaded corresponding T5Tokenizer.
        device: The torch.device where the model is located.
        max_input_len: Max sequence length for each (query + doc) input.
        max_output_len: Max sequence length for the generated answer.
        num_beams: Number of beams for beam search generation.
        **generate_kwargs: Additional keyword arguments passed to model.generate().

    Returns:
        Dictionary where keys are the input query strings and values are the
        corresponding generated answer strings (or an error message if no docs provided).
    """
    model.eval()
    results = {}
    pad_token_id = tokenizer.pad_token_id

    print(f"Generating answers for {len(queries_and_documents)} queries...")
    start_time_total = time.time()

    # Iterate through each query and its associated documents
    for i, (query, retrieved_docs) in enumerate(queries_and_documents.items()):
        start_time_query = time.time()
        print(f"  Processing query {i+1}/{len(queries_and_documents)}: \"{query[:50]}...\"")

        if not retrieved_docs:
            print(f"    Warning: No documents found for query {i+1}. Skipping.")
            results[query] = "Error: No documents provided for this query."
            continue

        all_input_ids = []
        all_attention_masks = []

        # --- Core FiD Generation Logic (applied per query) ---

        # 1. Preprocess and Tokenize each document for the CURRENT query
        for doc in retrieved_docs:
            input_text = f"question: {query} context: {doc}"
            encoding = tokenizer(
                input_text,
                truncation=True,
                max_length=max_input_len,
                padding="max_length",
                return_attention_mask=True,
                add_special_tokens=True
            )
            all_input_ids.append(torch.tensor(encoding['input_ids']))
            all_attention_masks.append(torch.tensor(encoding['attention_mask']))

        # 2. Stack inputs and move to device for CURRENT query
        input_ids_stacked = torch.stack(all_input_ids).to(device)
        attention_mask_stacked = torch.stack(all_attention_masks).to(device)
        num_docs, seq_len = input_ids_stacked.shape

        # Use no_grad context for efficiency during inference
        with torch.no_grad():
            # 3. Encoder Pass for CURRENT query
            raw_encoder_outputs = model.encoder(
                input_ids=input_ids_stacked,
                attention_mask=attention_mask_stacked,
                return_dict=True
            )

            # 4. Reshape Encoder Outputs for Generate
            encoder_hidden_states_reshaped = raw_encoder_outputs.last_hidden_state.view(
                1, num_docs * seq_len, raw_encoder_outputs.last_hidden_state.size(-1)
            )
            encoder_outputs_for_generate = BaseModelOutput(
                last_hidden_state=encoder_hidden_states_reshaped
            )

            # 5. Prepare Cross-Attention Mask
            cross_attention_mask_reshaped = attention_mask_stacked.view(1, num_docs * seq_len)

            # 6. Generation (Decoder) using model.generate() for CURRENT query
            generated_ids = model.generate(
                encoder_outputs=encoder_outputs_for_generate, # Reshaped encoder outputs
                attention_mask=cross_attention_mask_reshaped, # Mask for cross-attention
                max_length=max_output_len,
                num_beams=num_beams,
                early_stopping=True,
                **generate_kwargs
            )

        # 7. Decode the generated token IDs back to text
        generated_text = tokenizer.decode(
            generated_ids[0],
            skip_special_tokens=True
        )

        # Store the generated answer in the results dictionary
        results[query] = generated_text.strip()
        end_time_query = time.time()
        # print(f"    Generated answer in {end_time_query - start_time_query:.2f} seconds.")

    end_time_total = time.time()
    print(f"Finished generating all answers in {end_time_total - start_time_total:.2f} seconds.")

    # Return the dictionary containing {query: answer} pairs
    return results

### T5 Model Loading

In [None]:
from functools import partial

model_path = "finetuned_t5_model_fid"
max_input_len = 512
max_output_len = 128
num_beams_eval = 4

# Load the fine-tuned model and tokenizer
print(f"Loading model from: {model_path}")
model = T5ForConditionalGeneration.from_pretrained(model_path)
tokenizer = T5Tokenizer.from_pretrained(model_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

print(f"Model loaded on device: {device}")

Loading model from: finetuned_t5_model_fid
Model loaded on device: cuda


In [None]:
# Create a partial function that has model, tokenizer, device, etc. pre-filled
t5_generator_for_eval = partial(
    T5_text_generator,
    model=model,
    tokenizer=tokenizer,
    device=device,
    max_input_len=max_input_len,
    max_output_len=max_output_len,
    num_beams=num_beams_eval
)

Loading the Test set queries

In [None]:
test_queries = set(item['query'] for item in test_data)
print(f"\nExtracted {len(test_queries)} unique queries for the test set.")

# Create the test split dictionaries using the test queries
test_expected_outputs = {
    query: answers
    for query, answers in expected_outputs.items()
    if query in test_queries
}

test_retrieval_results = {
    query: docs
    for query, docs in retrieval_results.items()
    if query in test_queries
}

print("\nVerification:")
print(f"Size of test_expected_outputs: {len(test_expected_outputs)}")
print(f"Size of test_retrieval_results: {len(test_retrieval_results)}")


Extracted 568 unique queries for the test set.

Verification:
Size of test_expected_outputs: 568 (should match test_data size)
Size of test_retrieval_results: 568 (should match test_data size)


Defining the evaluation function for Erag

In [None]:
# Define the Exact Match evaluation function
def normalize_answer(s):
    """Converts text to lowercase, removes punctuation and extra spaces."""
    return ' '.join(''.join(ch for ch in s.lower() if ch not in string.punctuation).split())

def exact_match_metric(generated_outputs, expected_outputs):
    """Computes if the generated text (normalized) exactly matches one of the gold answers (normalized).
    Returns a dict {query: score}, where score is 1 or 0."""
    return {query: 1 if any(normalize_answer(gen) == normalize_answer(gold) for gold in expected_outputs.get(query, [])) else 0 for query, gen in generated_outputs.items()}

Evaluation Loop

In [None]:
# Create the log directory if it doesn't exist
LOG_DIR = "logs"
if not os.path.exists(LOG_DIR):
    os.makedirs(LOG_DIR)

# Define checkpoint file
CHECKPOINT_FILE = "experiment_checkpoint.pkl"

# Load checkpoint if it exists, otherwise initialize an empty dictionary
if os.path.exists(CHECKPOINT_FILE):
    with open(CHECKPOINT_FILE, "rb") as f:
        checkpoint = pickle.load(f)
    print("Checkpoint loaded.")
else:
    checkpoint = {}
    print("No checkpoint found, starting from scratch.")



# Define values for K (number of retrieved documents) and retrieval methods
k_values = [40]
retriever_methods = ['BM25'] # , 'dense'

# Load existing correlations from checkpoint
correlations = {method: checkpoint.get(method, {}) for method in retriever_methods}

# Iterate over retrieval methods and values of K
for method in retriever_methods:
    for k in k_values:
        if k in checkpoint.get(method, {}):
            print(f"Skipping {method} with K={k} (already processed).")
            continue

        print(f"\n--- Processing {method} with K = {k} ---")
        try:
            # Define retrieval metrics based on k
            retrieval_metrics = {f'P_{k}', f'success_{k}'}

            # Evaluate retrieval and generation
            erag_results = erag.eval(
                retrieval_results=test_retrieval_results,
                expected_outputs=test_expected_outputs,
                text_generator=t5_generator_for_eval,
                downstream_metric=exact_match_metric,
                retrieval_metrics=retrieval_metrics
            )

            # Save per-query results
            per_input_file = os.path.join(LOG_DIR, f"per_input_{method}_K{k}.json")
            with open(per_input_file, "w", encoding="utf-8") as f:
                json.dump(erag_results['per_input'], f, ensure_ascii=False, indent=2)
            print(f"Saved per-input results in {per_input_file}.")

            # Save aggregated results
            aggregated_file = os.path.join(LOG_DIR, f"aggregated_{method}_K{k}.json")
            with open(aggregated_file, "w", encoding="utf-8") as f:
                json.dump(erag_results['aggregated'], f, ensure_ascii=False, indent=2)
            print(f"Saved aggregated results in {aggregated_file}.")

            # Generate end-to-end responses
            end_to_end_generated = t5_generator_for_eval(test_retrieval_results)
            e2e_scores_dict = exact_match_metric(end_to_end_generated, test_expected_outputs)

            # Save end-to-end scores
            e2e_file = os.path.join(LOG_DIR, f"end_to_end_{method}_K{k}.json")
            with open(e2e_file, "w", encoding="utf-8") as f:
                json.dump(e2e_scores_dict, f, ensure_ascii=False, indent=2)
            print(f"Saved end-to-end scores in {e2e_file}.")

            # --- Compute correlation between erag and end-to-end scores ---
            end_to_end_scores = [e2e_scores_dict.get(query, 0) for query in test_queries]

            local_corr = {}
            for metric_key in retrieval_metrics:
                eRAG_scores = [erag_results['per_input'][query].get(metric_key, None) for query in test_queries]
                eRAG_scores = [score for score in eRAG_scores if score is not None]

                if len(eRAG_scores) != len(end_to_end_scores):
                    print(f"Warning: dimension mismatch for {method}, metric {metric_key}, K={k}")

                # Compute correlation if variance exists
                spearman_corr, spearman_p = stats.spearmanr(eRAG_scores, end_to_end_scores)
                kendall_corr, kendall_p = stats.kendalltau(eRAG_scores, end_to_end_scores)
                local_corr[metric_key] = {
                    'spearman': spearman_corr,
                    'kendall': kendall_corr,
                    'num_queries': len(eRAG_scores)
                }
                print(f"\nFor metric {metric_key} ({method}, K={k}):")
                print(f"  Spearman correlation: {spearman_corr:.3f} (p={spearman_p:.3f})")
                print(f"  Kendall correlation:   {kendall_corr:.3f} (p={kendall_p:.3f})")

            # Update checkpoint
            checkpoint[method][k] = local_corr
            with open(CHECKPOINT_FILE, "wb") as f:
                pickle.dump(checkpoint, f)
            print(f"Checkpoint updated for {method} with K={k}.")

        except Exception as e:
            print(f"Error for {method} with K={k}: {e}")
            time.sleep(10)
            continue

print("\n--- Final correlation summary ---")
print(correlations)

[1;30;43mOutput streaming troncato alle ultime 5000 righe.[0m
  Processing query 145/568: "when did the the regulatory reform (fire safety) o..."
  Processing query 146/568: "who want to be a millionaire calls his dad..."
  Processing query 147/568: "two examples where low voltage transformers are us..."
  Processing query 148/568: "what causes cracked skin at the corners of your mo..."
  Processing query 149/568: "who had the longest tenure as moderator on meet th..."
  Processing query 150/568: "industrial city in germany on the rhine herne cana..."
  Processing query 151/568: "who sings the song it ain't me..."
  Processing query 152/568: "where does the optic nerve cross the midline ​..."
  Processing query 153/568: "who did the central powers defeat on the eastern f..."
  Processing query 154/568: "who was the captain of the mayflower when it took ..."
  Processing query 155/568: "when does the cannes film festival take place..."
  Processing query 156/568: "where was held the f