<a href="https://colab.research.google.com/github/alessioborgi/DL_Project/blob/main/Source/InfoRetrieval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

==================================================

**Project Name:** Neural inverted index for fast and effective information retrieval\
**Course:** Deep Learning\
**University:** Sapienza Università di Roma

**Authors:**
  - [Alessio Borgi] (<tt>1952442</tt>)
  - [Eugenio Bugli] (<tt>1934824</tt>)
  - [Damiano Imola] (<tt>2109063</tt>)

**Date:** [November 2024 - Completion Date]

==================================================

## 0: INSTALL & IMPORT LIBRARIES

In [None]:
!pip install -q --upgrade pip
!pip install -q pyserini==0.12.0
!pip install -q pytorch-lightning transformers datasets torch wandb

In [2]:
import os
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from collections import Counter
from datasets import load_dataset
from typing import List, Tuple
import wandb


import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor, StochasticWeightAveraging, DeviceStatsMonitor

from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer, AutoTokenizer, AutoModelForSequenceClassification

from sklearn.preprocessing import normalize
from sklearn.cluster import AgglomerativeClustering, KMeans

In [None]:
wandb.login()
# wandb.login(key="b3bce19a09c51bdf8a19eb3dc58f7c44de929e13") #(ALESSIO)
# wandb.login(key="6d550e12a1b8f716ebe580082f495c01ed2adf6c") #(DAMIANO)
wandb.init(project="IR_DSI", resume="allow")

## 1: DOWNLOADING DATASET




In [47]:
# PyTorch Dataset class
class MSMARCODataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128, is_test=False):
        """
        Initialize the dataset for MS MARCO.

        Args:
            data: The dataset split (train, validation, or test).
            tokenizer: The tokenizer instance.
            max_length: Maximum token length for inputs.
            is_test: Flag to indicate if the dataset is a test set (no labels).
        """
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.is_test = is_test

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        query = item["query"]
        # COMMENT (D): bisogna ovviamente cambiarlo
        # soprattutto perchè loro utilizzano anche uno skip dei primi K chunk di un documento per controllare
        # quanto il modello overfitta sui primi chunks (e quindi quando riesce a catturare la semantica del DOCUMENTO INTERO)
        passage = item["passages"]["passage_text"][0]

        # If not test set, fetch the label
        # COMMENT (D): non necessario, grazie al controllo finale
        # label = None if self.is_test else 1 if item["passages"]["is_selected"][0] else 0

        # Tokenize input
        # COMMENT (D): dobbiamo controllare se l'impostazione del tokenizer è corretta
        inputs = self.tokenizer(
            query,
            passage,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt"
        )

        result = {
            "input_ids": inputs["input_ids"].squeeze(0),
            # COMMENT (D): ci serve anche l'attention mask?
            "attention_mask": inputs["attention_mask"].squeeze(0)
        }

        # Add label only if it's not the test set
        # if not self.is_test:
        #     result["label"] = torch.tensor(label, dtype=torch.long)

        return result

In [48]:
class MSMarcoDataModule(pl.LightningDataModule):
    def __init__(self, train_data, validation_data, test_data, tokenizer, batch_size=32):
        """
        Data module for handling MS MARCO datasets.

        Args:
            train_data: Training dataset split.
            validation_data: Validation dataset split.
            test_data: Test dataset split.
            tokenizer: The tokenizer instance.
            batch_size: Batch size for data loaders.
        """
        super().__init__()
        self.train_data = train_data
        self.validation_data = validation_data
        self.test_data = test_data
        self.tokenizer = tokenizer
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = MSMARCODataset(self.train_data, self.tokenizer)
        self.val_dataset = MSMARCODataset(self.validation_data, self.tokenizer)
        self.test_dataset = MSMARCODataset(self.test_data, self.tokenizer, is_test=True)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

In [49]:
# Load MS MARCO splits
ms_marco_train = load_dataset("microsoft/ms_marco", "v1.1", split="train")
ms_marco_validation = load_dataset("microsoft/ms_marco", "v1.1", split="validation")
ms_marco_test = load_dataset("microsoft/ms_marco", "v1.1", split="test")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/9.48k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/175M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/10047 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/82326 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/9650 [00:00<?, ? examples/s]

## 2: DATASET EXPLORATION

In [None]:
# Function to print dataset characteristics
def print_dataset_info(name, dataset):
    print(f"\nDataset: {name}")
    print("-" * 40)
    print(f"Number of samples: {len(dataset)}")
    print(f"Features: {dataset.features.keys()}")
    print("\nExample:")
    for k in dataset.features.keys():
        print('\t', f'{k}: ', dataset[0][k])
        if(k == 'passages'):
            print('\t\t', "Number of passages:", len(dataset[0][k]['passage_text']))
            for i in range(len(dataset[0][k]['passage_text'])):
                print('\t\t', f'Passage {i}: ', dataset[0][k]['passage_text'][i])
    print('\n\n')

# Print information for each split
print_dataset_info("Train", ms_marco_train)
print_dataset_info("Validation", ms_marco_validation)
print_dataset_info("Test", ms_marco_test)


Dataset: Train
----------------------------------------
Number of samples: 82326
Features: dict_keys(['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'])

Example:
	 answers:  ['Results-Based Accountability is a disciplined way of thinking and taking action that communities can use to improve the lives of children, youth, families, adults and the community as a whole.']
	 passages:  {'is_selected': [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], 'passage_text': ["Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.", "The Reserve Bank of Aust

In [None]:
# Analyze specific features
def analyze_passages(dataset):
    print("\n--- Passage Analysis ---")
    passage_lengths = [len(p["passage_text"][0]) for p in dataset["passages"]]
    print(f"Number of passages per query: {len(dataset[0]['passages']['passage_text'])}")
    print(f"Average passage length: {sum(passage_lengths) / len(passage_lengths):.2f} characters")
    print(f"Max passage length: {max(passage_lengths)} characters")
    print(f"Min passage length: {min(passage_lengths)} characters")

# Analyze passages in the train set
analyze_passages(ms_marco_train)


--- Passage Analysis ---
Number of passages per query: 10
Average passage length: 414.24 characters
Max passage length: 1167 characters
Min passage length: 43 characters


In [None]:
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Create data module
data_module = MSMarcoDataModule(
    train_data=ms_marco_train,
    validation_data=ms_marco_validation,
    test_data=ms_marco_test,
    tokenizer=tokenizer,
    batch_size=32
)

# Prepare datasets
data_module.setup()

# Access dataloaders
train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()
test_loader = data_module.test_dataloader()

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [None]:
# Iterate through the training loader
for batch in train_loader:
    input_ids = batch["input_ids"]  # Tokenized input IDs
    attention_mask = batch["attention_mask"]  # Attention mask
    labels = batch["labels"]  # Labels for the batch
    print("Batch input_ids shape:", input_ids.shape)
    print("Batch attention_mask shape:", attention_mask.shape)
    print("Batch labels shape:", labels.shape)
    break  # Stop after printing one batch

## 3: (VANILLA) PREPROCESSING - ON VANILLA DATASET

The core idea is that the model must create associations between `queries` and `docids`.\
So let's start with a simple example without models and fluffy strange stuffs.\
Here, I'm going to use pairs of `(query, vanilla_tokenized_docid)`, but for our DSI we need `(pseudo-query, densly_semantically_tokenized_docid)`.

<h2>What do we have to do?</h2>

As the (second) paper mention, we need to construct a dataset $\mathcal{T}'$.\
To do so, we start from the generation of $$\mathcal{U} = \mathcal{O} \cup \mathcal{P}\qquad
\begin{cases}
\mathcal{O} = \bigcup_i \mathcal{O}_i = \bigcup_i \{d^1_i, d^2_i, \dots, d^m_i\}\\[10pt]
\mathcal{P} = \bigcup_i \mathcal{P}_i = \bigcup_i \{pq^1_i, pq^2_i, \dots, pq^m_i\}
\end{cases}$$
with
*   $d_i^j$ the $j$-th segment of the $i$-th document that belongs to the set $\mathcal{D}$ and
*   $pq_i^j$ the pseudo-query generated by [docT5query](https://github.com/castorini/docTTTTTquery) against the document segment $d_i^j$

then, having such computed $\mathcal{U}$, we need to filter it using a so called *'dense model'* named $M$. Due to its nature, dense retrieval models effectively preserve textual information in its representation (i.e. in its latent space), so they are perfect to address such task.

> We need to find a dense retrieval model open-source



<h2>How to filter?</h2>

We need to input individual fragment $t$ originating from the document of $id_t$ that belongs to $\mathcal{U}$ in our dense retrieval model $M$; so our $t$ now behaves like a query
$$t\in d_t = doc(id_t)\in \mathcal{U}$$
this process outputs a ranked list of $k$ elements
$$R_k(t, M) = (id^1, id^2,\dots, id^k)$$
if $id_t$ belongs to $R_k(t, M)$ this means that *'the fragment $t$ possessed keyinformation relevant to the original document'* and so can be included in the training corpus $\mathcal{T}'$.\
In other words, in this way we asses that $t$ har enough information to represent the document identified by $id_t$ ans so it is a suitable query.

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

# load a pretrained transformer
# DSI prefer generative-based transformers, so T5 will be good enough (also mentioned by the paper)
model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

In [None]:
# use examples of document ids (docids) and documents
doc_ids = ["doc_001", "doc_002", "doc_003"]

# note, here documents are made by a single segment
# in genral we can have multiple segments belonging to the same document
documents = [
    "Deep learning is a subset of machine learning that uses neural networks.",
    "Reinforcement learning is a type of machine learning for decision-making.",
    "Transfer learning reuses pre-trained models to solve new tasks."
]

queries = [
    "What is deep learning?",
    "Explain reinforcement learning.",
    "How does transfer learning work?"
]

In [None]:
# tokenize docids (vanilla tokenization, just for understanding)
def tokenize_doc_id(doc_id):
    return " ".join(list(doc_id.replace("_", " _ ")))

tokenized_doc_ids = [tokenize_doc_id(doc_id) for doc_id in doc_ids]
print("Tokenized Document IDs:", tokenized_doc_ids)

In [None]:
# query-document pairs (note that we need pseudo-query)
training_data = [
    (query, doc_id) for query, doc_id in zip(queries, tokenized_doc_ids)
]

In [None]:
training_data

After building the dataset $\mathcal{T'}$, we need to encode both the `docids` (in a semantically meaningful way) and the `pseudo-queries` in order to feed them in the DSI model; recall this is a vanilla code and we are going to encode (`query`, `vanilla_docid`) instead of (`pseudo-query`, `semantic_docid`).

To do so, we can use a simple tokenizer (ideally also the google T5)

In [None]:
def prepare_inputs(query, tokenized_doc_id, tokenizer, max_length=512):
    # input: query
    input_encodings = tokenizer(
        query,
        max_length=max_length,
        truncation=True,
        padding="max_length",
        # pt stands for 'pytorch tensor'
        return_tensors="pt"
    )

    # output: tokenized_docid
    target_encodings = tokenizer(
        tokenized_doc_id,
        max_length=max_length,
        truncation=True,
        padding="max_length",
        # pt stands for 'pytorch tensor'
        return_tensors="pt"
    )

    return input_encodings, target_encodings

query, tokenized_doc_id = training_data[0]
input_encodings, target_encodings = prepare_inputs(query, tokenized_doc_id, tokenizer)

# print("Input Encodings:", input_encodings)
# print("Target Encodings:", target_encodings)

In [None]:
# common practice: set decoder's start token ID to T5's default value
# (PAD token for now)
model.config.decoder_start_token_id = tokenizer.pad_token_id

### via GPT

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Load the docT5Query model and tokenizer
tokenizer = T5Tokenizer.from_pretrained("castorini/doc2query-t5-base-msmarco")
model = T5ForConditionalGeneration.from_pretrained("castorini/doc2query-t5-base-msmarco")

# Input passage or document
document = "The capital of France is Paris, known for its art, culture, and history."

# Tokenize input
input_ids = tokenizer(document, return_tensors="pt").input_ids

# Generate synthetic query
outputs = model.generate(input_ids, max_length=64, num_beams=4, length_penalty=1.0)
query = tokenizer.decode(outputs[0], skip_special_tokens=True)

print("Generated Query:", query)

## 3: (REAL) PREPROCESSING - VANILLA DATASET AND TESTS ON REAL

We'll proceed by steps, applying the real preprocessing we mentioned earlied on the vanilla dataset created; hence, in order:


1.   pseudo-query generation using [docT5query](https://github.com/castorini/docTTTTTquery)
2.   real tokenization of documents
3.   semantically structured docids



In [3]:
# use examples of document ids (docids) and documents
doc_ids = ["doc_001", "doc_002", "doc_003"]

# note, here documents are made by a single segment
# in genral we can have multiple segments belonging to the same document
documents = [
    "Deep learning is a subset of machine learning that uses neural networks.",
    "Reinforcement learning is a type of machine learning for decision-making.",
    "Transfer learning reuses pre-trained models to solve new tasks."
]

queries = [
    "What is deep learning?",
    "Explain reinforcement learning.",
    "How does transfer learning work?"
]

### (VANILLA) SEMANTICALLY STRUCTURED DOCIDS

In [44]:
def generate_semantic_ids(document_embeddings, n_clusters = 10, max_docs_per_cluster = 100, depth = 0) -> List[str]:
    n_docs, _ = document_embeddings.shape
    cluster_prefix = []  # Store identifiers at the current depth

    # base case, whether we have a cluster with at most
    # 'c' documents, we return their index in current cluster
    if n_docs <= max_docs_per_cluster:
        # return [f"{depth}-{i}" for i in range(n_docs)]
        return [f"{i}" for i in range(n_docs)]

    # k-means clustering
    kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(document_embeddings)

    cluster_labels = kmeans.labels_
    centroids = kmeans.cluster_centers_

    # generate docids for each cluster
    structured_ids = []
    for cluster_id in range(n_clusters):
        cluster_indices = np.where(cluster_labels == cluster_id)[0] # mask
        cluster_embeddings = document_embeddings[cluster_indices]

        # doids for the sub-cluster
        sub_ids = generate_semantic_ids(cluster_embeddings, n_clusters, max_docs_per_cluster, depth + 1)

        # cluster_prefix = [f"{depth}-{cluster_id}"] * len(sub_ids)
        cluster_prefix = [f"{cluster_id}"] * len(sub_ids)

        # combine parent prefix with sub-cluster identifiers
        # structured_ids.extend([f"{prefix}-{sub_id}" for prefix, sub_id in zip(cluster_prefix, sub_ids)])
        structured_ids.extend([f"{prefix}-{sub_id}" for prefix, sub_id in zip(cluster_prefix, sub_ids)])

    return structured_ids

In [None]:
# Example Usage
# Simulated embeddings (e.g., 100 documents with 768-dimensional embeddings)
np.random.seed(42)
document_embeddings = np.random.rand(100, 768)

# Generate structured identifiers
structured_ids = generate_semantic_ids(document_embeddings, n_clusters=10, max_docs_per_cluster=10)

for i, doc_id in enumerate(structured_ids[:20]):
    print(f'docid for document {i} = {doc_id}')

### (REAL) SEMANTICALLY STRUCTURED DOCIDS

In [None]:
def generate_semantically_structured_docids(dataset, document_embedder, n_clusters=10, max_docs_per_cluster=10, stop_at_first_100=False):
    document_embeddings = []

    for i in range(len(dataset)):
        # retrieve all passages in a document
        passages = dataset[i]['passages']['passage_text']

        # compute embeddings for each passage
        passages_embeddings = document_embedder.encode(passages)

        # average passages embeddings to maintain semantic meaning
        document_embedding = np.mean(passages_embeddings, axis=0)

        # append one document embedding
        document_embeddings.append(document_embedding)

        if stop_at_first_100 and i == 100: break

    return generate_semantic_ids(np.array(document_embeddings), n_clusters=n_clusters, max_docs_per_cluster=max_docs_per_cluster)


model = SentenceTransformer('all-MiniLM-L6-v2')

ssd = generate_semantically_structured_docids(ms_marco_train, model, 10, 10, True)


for i, doc_id in enumerate(ssd):
    print(f'docid for document {i} = {doc_id}')

### (VANILLA) DOCUMENT SPLITTING AND PSEUDO QUERIES


In [None]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

tokenizer = T5Tokenizer.from_pretrained('castorini/doc2query-t5-base-msmarco')
model = T5ForConditionalGeneration.from_pretrained('castorini/doc2query-t5-base-msmarco').to(device)

In [None]:
# the one on deep learning
doc_text = documents[0]

# tokenize the document
input_ids = tokenizer.encode(doc_text, return_tensors='pt').to(device)

# doct5query infer top 10 pseudo-queries
outputs = model.generate(
    input_ids=input_ids,
    max_length=64,
    do_sample=True,
    top_k=10,
    num_return_sequences=3
)

for i in range(3):
    print(f'sample {i + 1}: {tokenizer.decode(outputs[i], skip_special_tokens=True)}') # if need special tokens try with False

sample 1: what is deep learning
sample 2: what is deep learning machine learning
sample 3: what is deep learning


### (REAL) DOCUMENT SPLITTING AND PSEUDO QUERIES

In [None]:
def generate_pseudo_query(dataset, max_pseudo_query_len=64, top_k=1, stop_at_first=False):
    for i in range(len(dataset)):
        l = len(dataset[i]['passages']['passage_text'])

        for j in range(l):
            t = dataset[i]['passages']['passage_text'][j]
            id_t = dataset[i]['query_id']

            # tokenize the document
            input_ids = tokenizer.encode(t, return_tensors='pt').to(device)

            # doct5query infer top 1 query
            outputs = model.generate(
                input_ids=input_ids,
                max_length=max_pseudo_query_len,
                do_sample=True,
                top_k=top_k,
                num_return_sequences=3
            )

            decoded = tokenizer.decode(outputs[i], skip_special_tokens=True)

            print(f'SAMPLE {j + 1} DOCID {id_t}: \n passage: {t}\n pseudo-query: {decoded}')
            print('-'*50)

        if stop_at_first: break

generate_pseudo_query(ms_marco_train, 64, 1, True)

SAMPLE 1 DOCID 19699: 
 passage: Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.
 pseudo-query: what is the rba's assets
--------------------------------------------------
SAMPLE 2 DOCID 19699: 
 passage: The Reserve Bank of Australia (RBA) came into being on 14 January 1960 as Australia 's central bank and banknote issuing authority, when the Reserve Bank Act 1959 removed the central banking functions from the Commonwealth Bank. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth

## 3: (REAL) PREPROCESSING - REAL DATASET

In [None]:
%%capture output
import os

!apt-get -q install openjdk-21-jre-headless -qq > /dev/null
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-21-openjdk-amd64"

!update-alternatives --set java /usr/lib/jvm/java-21-openjdk-amd64/bin/java
!java -version

!pip install pyserini

In [None]:
# !pip install -q faiss-gpu
# !pip install -q faiss-gpu-cu12
!pip install -q faiss-cpu

# stable pyserini version
!pip install -q pyserini==0.21.0

# install openjdk 11
!apt-get -qq update
!apt-get -qq install -y openjdk-11-jdk
!java -version

# explicit verbose fault handler
import os
os.environ['PYTHONFAULTHANDLER'] = '1'

We are going to use the MSMARCO-100k, so we'll apply each of these stuffs on it.

1.   For the pseudo-queries generation, we can use the [docT5query](https://github.com/castorini/docTTTTTquery) from castorini
2.   As a *'dense retrieval model'* for ranking documents we can use the one proposed by castorini: [FAISS](https://github.com/castorini/pyserini/blob/master/docs/usage-search.md#faiss)
3.   For the segments extraction we can rely on a manually crafted approach

In order to better check the improvements lead from a dense retrieval model w.r.t. a sparse one, we'll use an approach based also on Lucene (a sparse retrieval model) and an hybrid one, combining both FAISS and Lucene.




### 3.1 DOCUMENT & QUERY EMBEDDING GENERATOR



The hybrid retrieval pipeline we propose combines **dense retrieval** (using FAISS) and **sparse retrieval** (using Lucene with BM25). The objective of this step is to rank documents based on both semantic similarity and lexical matching to achieve robust results.

---
As a first, step, we need to use a `SentenceTransformer` model to encode documents into dense vector embeddings, which are then indexed by FAISS.

- For a document $d_i$, the embedding is represented as:
  $$
  \mathbf{e}_{d_i} = f_{\text{dense}}(d_i)
  $$
  where $f_{\text{dense}}$ is the embedding function provided by the `SentenceTransformer` model.

- The embeddings of all documents are stored in a matrix:
  $$
  \mathbf{E} = [\mathbf{e}_{d_1}, \mathbf{e}_{d_2}, \dots, \mathbf{e}_{d_N}] \in \mathbb{R}^{N \times D}
  $$
  Here, $N$ is the number of documents, and $D$ is the embedding dimensionality.

Similarly, the query $q$ is encoded into a dense vector:
$$
\mathbf{e}_q = f_{\text{dense}}(q)
$$

In [None]:
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from pyserini.search import SimpleSearcher

# Step 1: Generate Embeddings with SentenceTransformer
documents = [
    "Deep learning is a subset of machine learning that uses neural networks.",
    "Reinforcement learning is a type of machine learning for decision-making.",
    "Transfer learning reuses pre-trained models to solve new tasks."
]

queries = [
    "What is deep learning?",
    "Explain reinforcement learning.",
    "How does transfer learning work?"
]

# Load the SentenceTransformer model
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

# Generate document embeddings
document_embeddings = embedding_model.encode(documents, convert_to_numpy=True)

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

### 3.2 FAISS DENSE RETRIEVAL MODEL






- **Index Creation**: The normalized document embeddings are added to a FAISS `IndexFlatIP` index.
- **Querying**: Given the query embedding $\mathbf{e}_q$, FAISS retrieves the top-$k$ documents based on inner product:
$$
\text{score}_{\text{FAISS}}(q, d_i) = \mathbf{e}_q \cdot \mathbf{e}_{d_i}
$$

The FAISS search returns:
- `indices`: The IDs of the top-$k$ documents.
- `distances`: The corresponding similarity scores.


#### 3.2.1 NORMALIZATION FOR COSINE SIMILARITY

FAISS uses **inner product** for similarity scoring by default. To ensure the scores are equivalent to cosine similarity, we apply a simple trick consisting in making both the document embeddings and the query embedding to be normalized:
$$
\mathbf{e}' = \frac{\mathbf{e}}{\|\mathbf{e}\|}
$$
Where $\|\mathbf{e}\|$ is the L2 norm of the embedding.

Normalized embeddings allow the inner product to behave like cosine similarity:
$$
\text{cosine_similarity}(\mathbf{e}_q, \mathbf{e}_{d_i}) = \mathbf{e}_q \cdot \mathbf{e}_{d_i}
$$

In [None]:
# Initialize FAISS Index
dim = document_embeddings.shape[1]  # Embedding dimensionality
faiss_index = faiss.IndexFlatIP(dim)  # Inner product for cosine similarity

# Normalize and add document embeddings to FAISS index
faiss.normalize_L2(document_embeddings)
faiss_index.add(document_embeddings)

# Step 2: Generate Query Embedding
query = "What is deep learning?"
query_embedding = embedding_model.encode([query], convert_to_numpy=True)
faiss.normalize_L2(query_embedding)

# Perform FAISS search
k = 3  # Number of results to retrieve
faiss_distances, faiss_indices = faiss_index.search(query_embedding, k)

# Print FAISS results
print("FAISS Results:")
for i, idx in enumerate(faiss_indices.squeeze()):
    print(f"Rank {i+1}: Document: {documents[idx]}, Distance: {faiss_distances.squeeze()[i]}")

FAISS Results:
Rank 1: Document: Deep learning is a subset of machine learning that uses neural networks., Distance: 0.8303160667419434
Rank 2: Document: Reinforcement learning is a type of machine learning for decision-making., Distance: 0.3908756971359253
Rank 3: Document: Transfer learning reuses pre-trained models to solve new tasks., Distance: 0.33839863538742065


In [None]:
# from pyserini.encode import TctColBertQueryEncoder
# from pyserini.search.faiss import FaissSearcher

# encoder = TctColBertQueryEncoder('castorini/tct_colbert-v2-hnp-msmarco')
# faiss_searcher = FaissSearcher.from_prebuilt_index('msmarco-v1-passage', None)
# hits = faiss_searcher.search('what is a lobster roll')

# for i in range(0, 10):
#     print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.5f}')

### 3.3 LUCENE SPARSE RETRIEVAL MODEL


Lucene uses the **BM25** algorithm for sparse retrieval. For a query $q$ and document $d_i$, the BM25 relevance score is given by:
$$
\text{BM25}(q, d_i) = \sum_{t \in q} \text{IDF}(t) \cdot \frac{f(t, d_i) \cdot (k_1 + 1)}{f(t, d_i) + k_1 \cdot (1 - b + b \cdot \frac{|d_i|}{\text{avgdl}})}
$$
Where:
- $t$: A term in the query.
- $f(t, d_i)$: The frequency of term $t$ in document $d_i$.
- $|d_i|$: The length of the document.
- $\text{avgdl}$: The average document length in the corpus.
- $k_1$: Controls term frequency saturation (default $1.2$).
- $b$: Controls length normalization (default $0.75$).
- $\text{IDF}(t)$: The inverse document frequency of term $t$:
  $$
  \text{IDF}(t) = \log\left(\frac{N - n_t + 0.5}{n_t + 0.5} + 1\right)
  $$
  Where $N$ is the total number of documents and $n_t$ is the number of documents containing term $t$.




Pyserini retrieves the top-$k$ documents using BM25 scoring. The results include:
- `docid`: Document IDs.
- `BM25_score`: BM25 relevance scores.

---


In [None]:
#@title OLD (prebuilt index)
# Step 3: Perform Sparse Retrieval with Pyserini
lucene_searcher = SimpleSearcher.from_prebuilt_index("msmarco-passage")
lucene_searcher.set_bm25(k1=0.9, b=0.4)

# Sparse retrieval results
lucene_hits = lucene_searcher.search(query, k=5)
lucene_scores = {hit.docid: hit.score for hit in lucene_hits}

# Print Sparse Results
print("\nSparse Results (Lucene):")
for i, hit in enumerate(lucene_hits):
    doc = str(hit.raw.split(":")[2].rstrip("}").rstrip("\n").replace('"', ''))
    print(f"Rank {i+1}: Distance: {hit.score}, Document: {doc}")

Attempting to initialize pre-built index msmarco-passage.
/root/.cache/pyserini/indexes/index-msmarco-passage-20201117-f87c94.1efad4f1ae6a77e235042eff4be1612d already exists, skipping download.
Initializing msmarco-passage...


In [None]:
import os
import json
from pyserini.search import SimpleSearcher
from pyserini.search.lucene import LuceneSearcher
from pyserini.index import IndexReader


doc_ids = ["doc1", "doc2", "doc3"]

# Create a temporary JSONL file to hold the documents
jsonl_path = "documents.jsonl"
with open(jsonl_path, "w") as f:
    for doc_id, content in zip(doc_ids, documents):
        f.write(json.dumps({"id": doc_id, "contents": content}) + "\n")

# Step 2: Index the Documents with Pyserini
index_path = "lucene-index"
os.system(f"rm -rf {index_path}")  # Clean previous index if exists

# Build the Lucene index
os.system(f"python -m pyserini.index -collection JsonCollection "
          f"-generator DefaultLuceneDocumentGenerator "
          f"-threads 1 -input . -index {index_path} -storeRaw")

# Step 3: Perform Sparse Retrieval with LuceneSearcher
searcher = LuceneSearcher(index_path)

# Set BM25 parameters
searcher.set_bm25(k1=0.9, b=0.4)



# Search and display results for each query
lucene_hits = searcher.search("What is deep learning?", k=3)  # Top 3 results

for rank, hit in enumerate(lucene_hits):
    print(f"Rank {rank + 1}:")
    print(f"  DocID: {hit.docid}")
    print(f"  Score: {hit.score}")
    print(f"  Content: {hit.raw}")

Rank 1:
  DocID: doc1
  Score: 0.608299970626831
  Content: {
  "id" : "doc1",
  "contents" : "Deep learning is a subset of machine learning that uses neural networks."
}
Rank 2:
  DocID: doc2
  Score: 0.09350000321865082
  Content: {
  "id" : "doc2",
  "contents" : "Reinforcement learning is a type of machine learning for decision-making."
}
Rank 3:
  DocID: doc3
  Score: 0.06870000064373016
  Content: {
  "id" : "doc3",
  "contents" : "Transfer learning reuses pre-trained models to solve new tasks."
}


### 3.4 HYBRID RETRIEVAL MODEL


To leverage the strengths of both retrieval approaches, we combine the scores from FAISS and Lucene.

The combined score for each document is calculated through a **Linear Score Fusion**, i.e., as a weighted sum:
$$
\text{Combined_Score}(q, d_i) = \alpha \cdot \text{score}_{\text{FAISS}}(q, d_i) + (1 - \alpha) \cdot \text{BM25}(q, d_i)
$$
Where:
- $\alpha \in [0, 1]$: Weight parameter controlling the contribution of FAISS vs. BM25 scores.

In [None]:
# Step 4: Combine FAISS and Lucene Scores
alpha = 0.5
combined_scores = {}

# Add FAISS results to combined score dictionary
for i, idx in enumerate(faiss_indices.squeeze()):
    combined_scores[idx] = alpha * faiss_distances.squeeze()[i] + (1 - alpha) * lucene_hits[i].score


print(combined_scores.items())


# Sort by combined scores (dimension 1 contains scores)
ranked_results = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)

# Display Hybrid Ranked Results
print("\nHybrid Ranked Results:")
for rank, (doc_id, score) in enumerate(ranked_results, start=1):
    print(f"Rank {rank}: DocID: {doc_id}, Combined Score: {score:.4f}, Document: {documents[doc_id]}")

dict_items([(0, 0.7193080186843872), (1, 0.24218785017728806), (2, 0.2035493180155754)])

Hybrid Ranked Results:
Rank 1: DocID: 0, Combined Score: 0.7193, Document: Deep learning is a subset of machine learning that uses neural networks.
Rank 2: DocID: 1, Combined Score: 0.2422, Document: Reinforcement learning is a type of machine learning for decision-making.
Rank 3: DocID: 2, Combined Score: 0.2035, Document: Transfer learning reuses pre-trained models to solve new tasks.


### TODO Hierarchical Navigable Small Worls (HNSW)


Due to the big amount of document we need to process, we opt for faster approximat nearest neighbor search exploited by FAISS.\
Since the original FAISS implementation uses the brute force approach, parsing a huge quantity of document can result in a prohibitive processing time; the addendum lead by approximat nearest neighbor search allows us to speed-up such operation without loosing to much in performances.

## STEPS SUMMARY



The documents are sorted in descending order of their combined scores to produce the final hybrid ranking.


1. **FAISS Dense Retrieval**:
   - Create embeddings with `SentenceTransformer`.
   - Normalize embeddings.
   - Search FAISS for top-$k$ results.

2. **Lucene Sparse Retrieval**:
   - Use Pyserini to query a Lucene index.
   - Retrieve BM25 scores for top-$k$ results.

3. **Score Fusion**:
   - Combine FAISS and BM25 scores with a weighted sum.
   - Sort documents by combined scores.



## Advantages of the Hybrid Approach

1. **Semantic Understanding**:
   - FAISS captures semantic relationships between queries and documents.
   - Example: "What is deep learning?" matches "Neural networks power deep learning."

2. **Lexical Precision**:
   - BM25 ensures term-matching precision.
   - Example: "What is deep learning?" prefers documents explicitly mentioning "deep learning."

3. **Robustness**:
   - The hybrid approach balances semantic and lexical relevance, reducing the risk of missing relevant documents.

## Mathematical Representation of the Pipeline


Given:
- $q$: Query.
- $\mathbf{E}$: Document embeddings.
- $\mathbf{e}_q$: Query embedding.

1. **FAISS Score**:
   $$
   \text{score}_{\text{FAISS}}(q, d_i) = \mathbf{e}_q \cdot \mathbf{e}_{d_i}
   $$

2. **BM25 Score**:
   $$
   \text{BM25}(q, d_i) = \sum_{t \in q} \text{IDF}(t) \cdot \frac{f(t, d_i) \cdot (k_1 + 1)}{f(t, d_i) + k_1 \cdot (1 - b + b \cdot \frac{|d_i|}{\text{avgdl}})}
   $$

3. **Hybrid Score**:
   $$
   \text{Combined_Score}(q, d_i) = \alpha \cdot \text{score}_{\text{FAISS}}(q, d_i) + (1 - \alpha) \cdot \text{BM25}(q, d_i)
   $$

## OLD 3.3 HYBRID APPROACH (FAISS + LUCENE)

In [None]:
from typing import List, Dict

from pyserini.search.faiss import FaissSearcher, DenseSearchResult
from pyserini.search.lucene import LuceneSearcher


class HybridSearcher:
    """Hybrid Searcher for dense + sparse

        Parameters
        ----------
        dense_searcher : FaissSearcher
        sparse_searcher : LuceneSearcher
    """

    def _init_(self, dense_searcher, sparse_searcher):
        self.dense_searcher = dense_searcher
        self.sparse_searcher = sparse_searcher

    def search(self, query: str, k0: int = 10, k: int = 10, alpha: float = 0.1, normalization: bool = False, weight_on_dense: bool = False) -> List[DenseSearchResult]:
        dense_hits = self.dense_searcher.search(query, k0)
        sparse_hits = self.sparse_searcher.search(query, k0)
        return self._hybrid_results(dense_hits, sparse_hits, alpha, k, normalization, weight_on_dense)

    def batch_search(self, queries: List[str], q_ids: List[str], k0: int = 10, k: int = 10, threads: int = 1,
            alpha: float = 0.1, normalization: bool = False, weight_on_dense: bool = False) \
            -> Dict[str, List[DenseSearchResult]]:
        dense_result = self.dense_searcher.batch_search(queries, q_ids, k0, threads)
        sparse_result = self.sparse_searcher.batch_search(queries, q_ids, k0, threads)
        hybrid_result = {
            key: self._hybrid_results(dense_result[key], sparse_result[key], alpha, k, normalization, weight_on_dense)
            for key in dense_result
        }
        return hybrid_result

    @staticmethod
    def _hybrid_results(dense_results, sparse_results, alpha, k, normalization=False, weight_on_dense=False):
        dense_hits = {hit.docid: hit.score for hit in dense_results}
        sparse_hits = {hit.docid: hit.score for hit in sparse_results}
        hybrid_result = []
        min_dense_score = min(dense_hits.values()) if len(dense_hits) > 0 else 0
        max_dense_score = max(dense_hits.values()) if len(dense_hits) > 0 else 1
        min_sparse_score = min(sparse_hits.values()) if len(sparse_hits) > 0 else 0
        max_sparse_score = max(sparse_hits.values()) if len(sparse_hits) > 0 else 1

        for doc in set(dense_hits.keys()) | set(sparse_hits.keys()):

            if doc not in dense_hits:
                sparse_score = sparse_hits[doc]
                dense_score = min_dense_score
            elif doc not in sparse_hits:
                sparse_score = min_sparse_score
                dense_score = dense_hits[doc]
            else:
                sparse_score = sparse_hits[doc]
                dense_score = dense_hits[doc]

            if normalization:
                sparse_score = (sparse_score - (min_sparse_score + max_sparse_score) / 2) \
                               / (max_sparse_score - min_sparse_score)
                dense_score = (dense_score - (min_dense_score + max_dense_score) / 2) \
                              / (max_dense_score - min_dense_score)

            score = alpha * sparse_score + dense_score if not weight_on_dense else sparse_score + alpha * dense_score

            hybrid_result.append(DenseSearchResult(doc, score))
        return sorted(hybrid_result, key=lambda x: x.score, reverse=True)[:k]

ModuleNotFoundError: No module named 'pyserini.search.faiss'

## 4: MODEL

In [None]:
class MSMarcoClassifier(pl.LightningModule):
    def __init__(self, model_name="bert-base-uncased", learning_rate=2e-5):
        super().__init__()
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
        self.learning_rate = learning_rate

    def forward(self, input_ids, attention_mask):
        return self.model(input_ids=input_ids, attention_mask=attention_mask)

    def training_step(self, batch, batch_idx):
        outputs = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["label"]
        )
        loss = outputs.loss
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["label"]
        )
        loss = outputs.loss
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)


## 5: TRAINING

In [None]:
# Model Checkpointing Callback.
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",  # Metric to monitor
    dirpath="checkpoints/",  # Directory to save checkpoints
    filename="best-checkpoint-{epoch:02d}-{val_loss:.2f}",  # Checkpoint name format
    save_top_k=1,  # Save only the best model
    mode="min"  # Minimize the monitored metric
)

# Early Stopping Callback.
early_stopping_callback = EarlyStopping(
    monitor="val_loss",  # Metric to monitor
    patience=3,  # Number of epochs without improvement to wait
    mode="min"  # Minimize the monitored metric
)

# Learning Rate Monitoring Callback.
lr_monitor = LearningRateMonitor(logging_interval="step")

# StochasticWeightAveraging Callback.
swa_callback = StochasticWeightAveraging()


# Device Statistics Callback
device_stats_callback = DeviceStatsMonitor()

# Trainer implementation.
trainer = Trainer(
    max_epochs=3,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1 if torch.cuda.is_available() else None,
    enable_progress_bar=True,
    callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor, swa_callback, device_stats_callback],
    gradient_clip_val=1.0,  # Clip gradients to this value
    precision=16,  # Enable 16-bit precision (AMP, Automatic Mixed Precision. Speed-Up Training and reduce Memory Usage)
)

In [None]:
# Generate a timestamp for the run name
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

# Initialize WandB logger with the timestamp as the run name
wandb_logger = WandbLogger(
    project="IR_DSI",         # Shared project name
    name=f"run_{current_time}",     # Unique name based on the current time
    log_model=True                  # Log model artifacts
)

In [None]:
# Initialize the model
model = MSMarcoClassifier()

# Initialize the Trainer.
trainer = Trainer(
    max_epochs=3,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1 if torch.cuda.is_available() else None,
    enable_progress_bar=True,
    logger=wandb_logger,
)
# Train the model.
trainer.fit(model, data_module)

## 6: TESTING

In [None]:
model.eval()  # Set model to evaluation mode
predictions = []
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch["input_ids"].to("cuda" if torch.cuda.is_available() else "cpu")
        attention_mask = batch["attention_mask"].to("cuda" if torch.cuda.is_available() else "cpu")

        # Forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        preds = torch.argmax(outputs.logits, dim=1)  # Predicted labels
        predictions.extend(preds.cpu().tolist())

print("Test Predictions:", predictions[:10])

# TRY DSI IMPLEMENTATION

In [None]:
class MSMARCODataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        query = item["query"]
        passage = item["passages"]["passage_text"][0]
        doc_id = int(item["query_id"])  # Convert query ID to an integer
        label = 1 if item["passages"]["is_selected"][0] else 0

        inputs = self.tokenizer(
            query,
            passage,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt"
        )

        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "doc_ids": torch.tensor(doc_id, dtype=torch.long),
            "label": torch.tensor(label, dtype=torch.long),
        }

In [None]:
class MultiTaskDSIWithoutDistillation(pl.LightningModule):
    def __init__(self, model_name="t5-base", learning_rate=5e-4):
        super().__init__()
        self.save_hyperparameters()
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        self.learning_rate = learning_rate

        # Loss weights for multi-task learning
        self.indexing_loss_weight = 0.5
        self.retrieval_loss_weight = 0.5

    def forward(self, input_ids, attention_mask):
        return self.model(input_ids=input_ids, attention_mask=attention_mask)

    def compute_indexing_loss(self, outputs, doc_ids):
        """
        Indexing Task: Predict document IDs from passage text.
        """
        loss = F.cross_entropy(outputs.logits, doc_ids)
        return loss

    def compute_retrieval_loss(self, outputs, relevance_labels):
        """
        Retrieval Task: Rank passages based on relevance labels.
        """
        loss = F.cross_entropy(outputs.logits, relevance_labels)
        return loss

    def training_step(self, batch, batch_idx):
        # Forward pass
        outputs = self(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )

        # Compute losses for indexing and retrieval
        indexing_loss = self.compute_indexing_loss(outputs, batch["doc_ids"])
        retrieval_loss = self.compute_retrieval_loss(outputs, batch["label"])

        # Combine losses with weights
        total_loss = (
            self.indexing_loss_weight * indexing_loss
            + self.retrieval_loss_weight * retrieval_loss
        )

        # Log losses
        self.log("train_indexing_loss", indexing_loss, prog_bar=True)
        self.log("train_retrieval_loss", retrieval_loss, prog_bar=True)
        self.log("train_total_loss", total_loss, prog_bar=True)

        return total_loss

    def validation_step(self, batch, batch_idx):
        # Forward pass
        outputs = self(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )

        # Compute losses for indexing and retrieval
        indexing_loss = self.compute_indexing_loss(outputs, batch["doc_ids"])
        retrieval_loss = self.compute_retrieval_loss(outputs, batch["label"])

        # Combine losses
        total_loss = (
            self.indexing_loss_weight * indexing_loss
            + self.retrieval_loss_weight * retrieval_loss
        )

        # Log losses
        self.log("val_indexing_loss", indexing_loss, prog_bar=True)
        self.log("val_retrieval_loss", retrieval_loss, prog_bar=True)
        self.log("val_total_loss", total_loss, prog_bar=True)

        return total_loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForSeq2SeqLM

class MultiTaskDSIWithDistillation(pl.LightningModule):
    def __init__(self, student_model_name="t5-base", teacher_model=None, learning_rate=5e-4):
        super().__init__()
        self.save_hyperparameters()

        # Student model (docT5query or similar)
        self.student = AutoModelForSeq2SeqLM.from_pretrained(student_model_name)

        # Teacher model (dense retriever like ColBERT or BM25)
        self.teacher = teacher_model  # Pre-trained model used for distillation

        self.learning_rate = learning_rate

        # Separate heads for multi-task learning
        self.indexing_head = torch.nn.Linear(self.student.config.hidden_size, 10000)  # 10,000 doc IDs
        self.retrieval_head = torch.nn.Linear(self.student.config.hidden_size, 2)  # Binary classification

        # Loss weights for tasks
        self.indexing_loss_weight = 0.5
        self.retrieval_loss_weight = 0.3
        self.distillation_loss_weight = 0.2

    def forward(self, input_ids, attention_mask):
        encoder_outputs = self.student.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        indexing_logits = self.indexing_head(encoder_outputs[:, 0, :])  # Use [CLS] token
        retrieval_logits = self.retrieval_head(encoder_outputs[:, 0, :])  # Use [CLS] token
        return indexing_logits, retrieval_logits

    def compute_indexing_loss(self, logits, doc_ids):
        return F.cross_entropy(logits, doc_ids)

    def compute_retrieval_loss(self, logits, relevance_labels):
        return F.cross_entropy(logits, relevance_labels)

    def compute_distillation_loss(self, student_logits, teacher_logits):
        """
        Knowledge distillation loss: KL divergence between student and teacher logits.
        """
        student_probs = F.log_softmax(student_logits, dim=-1)
        teacher_probs = F.softmax(teacher_logits, dim=-1)
        return F.kl_div(student_probs, teacher_probs, reduction="batchmean")

    def training_step(self, batch, batch_idx):
        # Forward pass through student model
        indexing_logits, retrieval_logits = self(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])

        # Compute task-specific losses
        indexing_loss = self.compute_indexing_loss(indexing_logits, batch["doc_ids"])
        retrieval_loss = self.compute_retrieval_loss(retrieval_logits, batch["label"])

        # Compute distillation loss (if teacher model is provided)
        if self.teacher:
            with torch.no_grad():
                teacher_logits = self.teacher(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits
            distillation_loss = self.compute_distillation_loss(retrieval_logits, teacher_logits)
        else:
            distillation_loss = 0.0

        # Combine losses
        total_loss = (
            self.indexing_loss_weight * indexing_loss
            + self.retrieval_loss_weight * retrieval_loss
            + self.distillation_loss_weight * distillation_loss
        )

        # Log losses
        self.log("train_indexing_loss", indexing_loss, prog_bar=True)
        self.log("train_retrieval_loss", retrieval_loss, prog_bar=True)
        self.log("train_distillation_loss", distillation_loss, prog_bar=True)
        self.log("train_total_loss", total_loss, prog_bar=True)

        return total_loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

# OLD

## I-PYSERINI INSPECTION

In [None]:
from pyserini.search import get_topics

topics = get_topics('msmarco-passage-dev-subset')
print(f'{len(topics)} queries total')

In [None]:
from pyserini.search import SimpleSearcher

searcher = SimpleSearcher.from_prebuilt_index('msmarco-passage')

# Search the index for a query
hits = searcher.search('What is machine learning?')

# Display the top-ranked results
for i, hit in enumerate(hits):
    print(f"Rank {i+1}: {hit.docid} - {hit.score}")
    print(hit.raw)

## 2: BERT EMBEDDING

In [None]:

# Load a pre-trained model for embeddings
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")

# Generate embeddings for documents
def embed_text(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()

## Model Implementation

### T5 Transformer

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments, TrainerCallback

model_name = "t5-base"

tokenizer = T5Tokenizer.from_pretrained(model_name, cache_dir='cache')
model = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir='cache')

### Bert (12 layers)
For docids embedding generation

In [None]:
!pip install transformers

In [None]:
import torch
from transformers import BertTokenizer, BertModel

# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

In [None]:
# Set model to evaluation mode
model.eval()

text = "Transformers are powerful models for NLP tasks."
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)

# Display tokenized input
print(inputs)

### Inputs2Target

In [None]:
class IndexingTrainDataset(Dataset):
    def _init_(self, path_to_data, max_length, cache_dir, tokenizer):
        super()._init_()

        self.train_data = datasets.load_dataset(
            'json',
            data_files=path_to_data,
            ignore_verifications=False,
            cache_dir=cache_dir
        )['train']

        self.max_length = max_length
        self.tokenizer = tokenizer
        self.total_len = len(self.train_data)


    def _getitem_(self, idx):
        # Retrieve document data
        doc = self.data[idx]
        doc_text = doc['text']
        docid = doc['docid']

        # Tokenize input (document text)
        # BertTokenizer.from_pretrained('bert-base-uncased')
        source = self.tokenizer(
            doc_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Tokenize target (docid)
        target = self.tokenizer(
            docid,
            max_length=10,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Prepare input-output pair
        return {
            'input_ids': source['input_ids'].squeeze(),
            'attention_mask': source['attention_mask'].squeeze(),
            'labels': target['input_ids'].squeeze()
        }

### Training

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=0.0005,
    warmup_steps=10000,
    # weight_decay=0.01,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    evaluation_strategy='steps',
    eval_steps=1000,
    max_steps=1000000,
    dataloader_drop_last=False,  # necessary
    report_to='wandb',
    logging_steps=50,
    save_strategy='no',
    # fp16=True,  # gives 0/nan loss at some point during training, seems this is a transformers bug.
    dataloader_num_workers=10,
    # gradient_accumulation_steps=2
)

trainer = IndexingTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=IndexingCollator(
        tokenizer,
        padding='longest',
    ),
    compute_metrics=compute_metrics,
    callbacks=[QueryEvalCallback(test_dataset, wandb, restrict_decode_vocab, training_args, tokenizer)],
    restrict_decode_vocab=restrict_decode_vocab
)

trainer.train()

### Training (from GPT)

In [None]:
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import wandb

# Initialize Weights & Biases (W&B) for logging
wandb.init(project="DSI-Training")

# 1. Load the Pre-trained T5 Model and Tokenizer
model_name = "t5-base"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# 2. Prepare the Dataset
class IndexingTrainDataset(torch.utils.data.Dataset):
    def _init_(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def _len_(self):
        return len(self.data)

    def _getitem_(self, idx):
        item = self.data[idx]
        doc_text = item['text']
        docid = item['docid']

        # Tokenize the document text (input)
        source = self.tokenizer(
            doc_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Tokenize the document ID (target)
        target = self.tokenizer(
            docid,
            max_length=10,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Prepare input-output pair
        return {
            'input_ids': source['input_ids'].squeeze(),
            'attention_mask': source['attention_mask'].squeeze(),
            'labels': target['input_ids'].squeeze()
        }

# Load your dataset (e.g., Natural Questions)
dataset = load_dataset("path/to/your/dataset")
train_data = IndexingTrainDataset(dataset['train'], tokenizer)
eval_data = IndexingTrainDataset(dataset['validation'], tokenizer)

# 3. Define Training Arguments
training_args = TrainingArguments(
    output_dir="./dsi_checkpoints",
    evaluation_strategy="steps",
    eval_steps=500,
    logging_dir="./logs",
    logging_steps=100,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    learning_rate=2e-5,
    weight_decay=0.01,
    save_steps=1000,
    save_total_limit=2,
    report_to="wandb"  # Enable logging to W&B
)

# 4. Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=eval_data,
    tokenizer=tokenizer
)

# 5. Start Training
trainer.train()

# 6. Save the Fine-tuned Model
model.save_pretrained("./fine_tuned_dsi")
tokenizer.save_pretrained("./fine_tuned_dsi")

# 7. End Logging with W&B
wandb.finish()