==================================================

**Project Name:** Neural inverted index for fast and effective information retrieval\
**Course:** Deep Learning\
**University:** Sapienza Università di Roma

**Authors:**
  - [Alessio Borgi] (<tt>1952442</tt>)
  - [Eugenio Bugli] (<tt>1934824</tt>)
  - [Damiano Imola] (<tt>2109063</tt>)

**Date:** [November 2024 - Completion Date]

==================================================

## 0: INSTALL & IMPORT LIBRARIES

In [1]:
!pip install pyserini==0.12.0
!pip install pytorch-lightning transformers datasets torch wandb

Collecting pyserini==0.12.0
  Downloading pyserini-0.12.0-py3-none-any.whl.metadata (2.4 kB)
Collecting pyjnius>=1.2.1 (from pyserini==0.12.0)
  Downloading pyjnius-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Downloading pyserini-0.12.0-py3-none-any.whl (67.5 MB)
[2K   [91m━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.3/67.5 MB[0m [31m1.4 MB/s[0m eta [36m0:00:43[0m
[?25h[31mERROR: Exception:
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/urllib3/response.py", line 438, in _error_catcher
    yield
  File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/urllib3/response.py", line 561, in read
    data = self._fp_read(amt) if not fp_closed else b""
  File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/urllib3/response.py", line 527, in _fp_read
    return self._fp.read(amt) if amt is not None else self._fp.read()
  File "/usr/local/lib/python3.10/dist-packages/p

In [2]:
import torch
import wandb
import numpy as np
from datetime import datetime
import pytorch_lightning as pl
from collections import Counter
import matplotlib.pyplot as plt
import torch.nn.functional as F
from datasets import load_dataset
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning.loggers import WandbLogger
from sklearn.cluster import AgglomerativeClustering
from transformers import AutoModel, AutoTokenizer, AutoTokenizer, AutoModelForSequenceClassification
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor, StochasticWeightAveraging

wandb.login()
#wandb.login(key="b3bce19a09c51bdf8a19eb3dc58f7c44de929e13")
wandb.init(project="IR_DSI", resume="allow")

## 1: DOWNLOADING DATASET




In [3]:
# PyTorch Dataset class
class MSMARCODataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128, is_test=False):
        """
        Initialize the dataset for MS MARCO.

        Args:
            data: The dataset split (train, validation, or test).
            tokenizer: The tokenizer instance.
            max_length: Maximum token length for inputs.
            is_test: Flag to indicate if the dataset is a test set (no labels).
        """
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.is_test = is_test

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        query = item["query"]
        passage = item["passages"]["passage_text"][0]  # Use the first passage

        # If not test set, fetch the label
        label = None if self.is_test else 1 if item["passages"]["is_selected"][0] else 0

        # Tokenize input
        inputs = self.tokenizer(
            query,
            passage,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt"
        )

        result = {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0)
        }

        # Add label only if it's not the test set
        if not self.is_test:
            result["label"] = torch.tensor(label, dtype=torch.long)

        return result

In [4]:
# PyTorch Lightning Data Module
class MSMarcoDataModule(pl.LightningDataModule):
    def __init__(self, train_data, validation_data, test_data, tokenizer, batch_size=32):
        """
        Data module for handling MS MARCO datasets.

        Args:
            train_data: Training dataset split.
            validation_data: Validation dataset split.
            test_data: Test dataset split.
            tokenizer: The tokenizer instance.
            batch_size: Batch size for data loaders.
        """
        super().__init__()
        self.train_data = train_data
        self.validation_data = validation_data
        self.test_data = test_data
        self.tokenizer = tokenizer
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = MSMARCODataset(self.train_data, self.tokenizer)
        self.val_dataset = MSMARCODataset(self.validation_data, self.tokenizer)
        self.test_dataset = MSMARCODataset(self.test_data, self.tokenizer, is_test=True)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

In [5]:
# Load MS MARCO splits
ms_marco_train = load_dataset("microsoft/ms_marco", "v1.1", split="train")
ms_marco_validation = load_dataset("microsoft/ms_marco", "v1.1", split="validation")
ms_marco_test = load_dataset("microsoft/ms_marco", "v1.1", split="test")



README.md:   0%|          | 0.00/9.48k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/175M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/10047 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/82326 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/9650 [00:00<?, ? examples/s]

## 2: DATASET EXPLORATION

In [6]:
# Function to print dataset characteristics
def print_dataset_info(name, dataset):
    print(f"\nDataset: {name}")
    print("-" * 40)
    print(f"Number of samples: {len(dataset)}")
    print(f"Features: {dataset.features.keys()}")
    print("\nExample:")
    print(dataset[0])

# Print information for each split
print_dataset_info("Train", ms_marco_train)
print_dataset_info("Validation", ms_marco_validation)
print_dataset_info("Test", ms_marco_test)

# Analyze specific features
def analyze_passages(dataset):
    print("\n--- Passage Analysis ---")
    passage_lengths = [len(p["passage_text"][0]) for p in dataset["passages"]]
    print(f"Number of passages per query: {len(dataset[0]['passages']['passage_text'])}")
    print(f"Average passage length: {sum(passage_lengths) / len(passage_lengths):.2f} characters")
    print(f"Max passage length: {max(passage_lengths)} characters")
    print(f"Min passage length: {min(passage_lengths)} characters")

# Analyze passages in the train set
analyze_passages(ms_marco_train)


Dataset: Train
----------------------------------------
Number of samples: 82326
Features: dict_keys(['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'])

Example:
{'answers': ['Results-Based Accountability is a disciplined way of thinking and taking action that communities can use to improve the lives of children, youth, families, adults and the community as a whole.'], 'passages': {'is_selected': [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], 'passage_text': ["Since 2007, the RBA's outstanding reputation has been affected by the 'Securency' or NPA scandal. These RBA subsidiaries were involved in bribing overseas officials so that Australia might win lucrative note-printing contracts. The assets of the bank include the gold and foreign exchange reserves of Australia, which is estimated to have a net worth of A$101 billion. Nearly 94% of the RBA's employees work at its headquarters in Sydney, New South Wales and at the Business Resumption Site.", "The Reserve Bank of Aust

In [None]:

# Passage Length Distribution
def plot_passage_length_distribution(dataset, split_name):
    passage_lengths = [len(p["passage_text"][0]) for p in dataset["passages"]]
    plt.hist(passage_lengths, bins=50, alpha=0.7, color="blue")
    plt.title(f"Passage Length Distribution ({split_name})")
    plt.xlabel("Passage Length (characters)")
    plt.ylabel("Frequency")
    plt.grid(True)
    plt.show()

# Plot distribution for train, validation, and test splits
plot_passage_length_distribution(ms_marco_train, "Train")
plot_passage_length_distribution(ms_marco_validation, "Validation")
plot_passage_length_distribution(ms_marco_test, "Test")

In [None]:
# Query Length Distribution
def plot_query_length_distribution(dataset, split_name):
    query_lengths = [len(q) for q in dataset["query"]]
    plt.hist(query_lengths, bins=30, alpha=0.7, color="green")
    plt.title(f"Query Length Distribution ({split_name})")
    plt.xlabel("Query Length (characters)")
    plt.ylabel("Frequency")
    plt.grid(True)
    plt.show()

# Plot distribution for train, validation, and test splits
plot_query_length_distribution(ms_marco_train, "Train")
plot_query_length_distribution(ms_marco_validation, "Validation")
plot_query_length_distribution(ms_marco_test, "Test")

In [None]:
# Label Distribution
def plot_label_distribution(dataset, split_name):
    labels = [1 if p["is_selected"][0] else 0 for p in dataset["passages"]]
    label_counts = {0: labels.count(0), 1: labels.count(1)}

    plt.bar(label_counts.keys(), label_counts.values(), alpha=0.7, color=["red", "green"])
    plt.xticks([0, 1], ["Not Selected", "Selected"])
    plt.title(f"Label Distribution ({split_name})")
    plt.xlabel("Label")
    plt.ylabel("Count")
    plt.show()

    print(f"Label Counts ({split_name}): {label_counts}")

# Plot label distribution for train and validation splits
plot_label_distribution(ms_marco_train, "Train")
plot_label_distribution(ms_marco_validation, "Validation")

In [None]:
# Unique Queries and Passages
def print_unique_counts(dataset, split_name):
    unique_queries = len(set(dataset["query"]))
    unique_passages = len(set([p["passage_text"][0] for p in dataset["passages"]]))
    print(f"Unique Queries in {split_name}: {unique_queries}")
    print(f"Unique Passages in {split_name}: {unique_passages}")

# Print unique counts for train, validation, and test splits
print_unique_counts(ms_marco_train, "Train")
print_unique_counts(ms_marco_validation, "Validation")
print_unique_counts(ms_marco_test, "Test")

In [None]:
# Top-N Most Frequent Words in Queries
def plot_top_words_in_queries(dataset, split_name, top_n=20):
    all_queries = " ".join(dataset["query"])
    word_counts = Counter(all_queries.split())
    most_common_words = word_counts.most_common(top_n)

    words, counts = zip(*most_common_words)
    plt.barh(words, counts, color="purple")
    plt.gca().invert_yaxis()
    plt.title(f"Top-{top_n} Most Frequent Words in Queries ({split_name})")
    plt.xlabel("Frequency")
    plt.ylabel("Words")
    plt.show()

# Plot top-20 words for train, validation, and test splits
plot_top_words_in_queries(ms_marco_train, "Train")
plot_top_words_in_queries(ms_marco_validation, "Validation")
plot_top_words_in_queries(ms_marco_test, "Test")

In [None]:
# Top-N Most Frequent Words in Passages
def plot_top_words_in_passages(dataset, split_name, top_n=20):
    all_passages = " ".join([p["passage_text"][0] for p in dataset["passages"]])
    word_counts = Counter(all_passages.split())
    most_common_words = word_counts.most_common(top_n)

    words, counts = zip(*most_common_words)
    plt.barh(words, counts, color="orange")
    plt.gca().invert_yaxis()
    plt.title(f"Top-{top_n} Most Frequent Words in Passages ({split_name})")
    plt.xlabel("Frequency")
    plt.ylabel("Words")
    plt.show()

# Plot top-20 words for train, validation, and test splits
plot_top_words_in_passages(ms_marco_train, "Train")
plot_top_words_in_passages(ms_marco_validation, "Validation")
plot_top_words_in_passages(ms_marco_test, "Test")

In [None]:
# Number of Passages per Query
def analyze_passages_per_query(dataset, split_name):
    num_passages = [len(p["passage_text"]) for p in dataset["passages"]]
    plt.hist(num_passages, bins=range(1, max(num_passages) + 2), alpha=0.7, color="cyan")
    plt.title(f"Number of Passages per Query ({split_name})")
    plt.xlabel("Number of Passages")
    plt.ylabel("Frequency")
    plt.grid(True)
    plt.show()

# Analyze for train, validation, and test splits
analyze_passages_per_query(ms_marco_train, "Train")
analyze_passages_per_query(ms_marco_validation, "Validation")
analyze_passages_per_query(ms_marco_test, "Test")

In [None]:
# Average Passage Length per Query
def average_passage_length_per_query(dataset, split_name):
    avg_lengths = [np.mean([len(passage) for passage in p["passage_text"]]) for p in dataset["passages"]]
    plt.hist(avg_lengths, bins=50, alpha=0.7, color="magenta")
    plt.title(f"Average Passage Length per Query ({split_name})")
    plt.xlabel("Average Passage Length")
    plt.ylabel("Frequency")
    plt.grid(True)
    plt.show()

# Analyze for train, validation, and test splits
average_passage_length_per_query(ms_marco_train, "Train")
average_passage_length_per_query(ms_marco_validation, "Validation")
average_passage_length_per_query(ms_marco_test, "Test")

In [None]:
# Relevance Analysis: Number of Relevant Passages per Query
def analyze_relevance_distribution(dataset, split_name):
    relevant_counts = [sum(p["is_selected"]) for p in dataset["passages"]]
    plt.hist(relevant_counts, bins=range(0, max(relevant_counts) + 2), alpha=0.7, color="lime")
    plt.title(f"Number of Relevant Passages per Query ({split_name})")
    plt.xlabel("Number of Relevant Passages")
    plt.ylabel("Frequency")
    plt.grid(True)
    plt.show()

# Analyze for train and validation splits (not test, as it may lack labels)
analyze_relevance_distribution(ms_marco_train, "Train")
analyze_relevance_distribution(ms_marco_validation, "Validation")

In [None]:
# Lexical Diversity: Calculate the ratio of unique words to total words as a measure of diversity.
def calculate_lexical_diversity(dataset, split_name):
    query_text = " ".join(dataset["query"])
    passage_text = " ".join([p["passage_text"][0] for p in dataset["passages"]])

    query_words = query_text.split()
    passage_words = passage_text.split()

    query_diversity = len(set(query_words)) / len(query_words)
    passage_diversity = len(set(passage_words)) / len(passage_words)

    print(f"{split_name} - Lexical Diversity (Queries): {query_diversity:.4f}")
    print(f"{split_name} - Lexical Diversity (Passages): {passage_diversity:.4f}")

# Calculate for train, validation, and test splits
calculate_lexical_diversity(ms_marco_train, "Train")
calculate_lexical_diversity(ms_marco_validation, "Validation")
calculate_lexical_diversity(ms_marco_test, "Test")

In [7]:
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Create data module
data_module = MSMarcoDataModule(
    train_data=ms_marco_train,
    validation_data=ms_marco_validation,
    test_data=ms_marco_test,
    tokenizer=tokenizer,
    batch_size=32
)

# Prepare datasets
data_module.setup()

# Access dataloaders
train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()
test_loader = data_module.test_dataloader()

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [8]:
# Iterate through the training loader
for batch in train_loader:
    input_ids = batch["input_ids"]  # Tokenized input IDs
    attention_mask = batch["attention_mask"]  # Attention mask
    labels = batch["label"]  # Labels for the batch
    print("Batch input_ids shape:", input_ids.shape)
    print("Batch attention_mask shape:", attention_mask.shape)
    print("Batch labels shape:", labels.shape)
    break  # Stop after printing one batch

Batch input_ids shape: torch.Size([32, 128])
Batch attention_mask shape: torch.Size([32, 128])
Batch labels shape: torch.Size([32])


## 4: MODEL

In [9]:
class MSMarcoClassifier(pl.LightningModule):
    def __init__(self, model_name="bert-base-uncased", learning_rate=2e-5):
        super().__init__()
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
        self.learning_rate = learning_rate

    def forward(self, input_ids, attention_mask):
        return self.model(input_ids=input_ids, attention_mask=attention_mask)

    def training_step(self, batch, batch_idx):
        outputs = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["label"]
        )
        loss = outputs.loss
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["label"]
        )
        loss = outputs.loss
        self.log("val_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)


## 5: TRAINING

In [None]:
# Model Checkpointing Callback.
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",  # Metric to monitor
    dirpath="checkpoints/",  # Directory to save checkpoints
    filename="best-checkpoint-{epoch:02d}-{val_loss:.2f}",  # Checkpoint name format
    save_top_k=1,  # Save only the best model
    mode="min"  # Minimize the monitored metric
)

# Early Stopping Callback.
early_stopping_callback = EarlyStopping(
    monitor="val_loss",  # Metric to monitor
    patience=3,  # Number of epochs without improvement to wait
    mode="min"  # Minimize the monitored metric
)

# Learning Rate Monitoring Callback.
lr_monitor = LearningRateMonitor(logging_interval="step")

# StochasticWeightAveraging Callback.
swa_callback = StochasticWeightAveraging()

from pytorch_lightning.callbacks import DeviceStatsMonitor

# Device Statistics Callback
device_stats_callback = DeviceStatsMonitor()

# Trainer implementation.
trainer = Trainer(
    max_epochs=3,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1 if torch.cuda.is_available() else None,
    enable_progress_bar=True,
    callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor, swa_callback, device_stats_callback],
    gradient_clip_val=1.0,  # Clip gradients to this value
    precision=16,  # Enable 16-bit precision (AMP, Automatic Mixed Precision. Speed-Up Training and reduce Memory Usage)
)

In [None]:
# Generate a timestamp for the run name
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

# Initialize WandB logger with the timestamp as the run name
wandb_logger = WandbLogger(
    project="IR_DSI",         # Shared project name
    name=f"run_{current_time}",     # Unique name based on the current time
    log_model=True                  # Log model artifacts
)

In [None]:
# Initialize the model
model = MSMarcoClassifier()

# Initialize the Trainer.
trainer = Trainer(
    max_epochs=3,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1 if torch.cuda.is_available() else None,
    enable_progress_bar=True,
    logger=wandb_logger,
)
# Train the model.
trainer.fit(model, data_module)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                          | Params | Mode
---------------------------------------------------------------
0 | model | BertForSequenceClassification | 109 M  | eval
---------------------------------------------------------------
109 M     Trainable params
0         Non-trainable params
109 M     Total params
437.93

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

## 6: TESTING

In [None]:
model.eval()  # Set model to evaluation mode
predictions = []
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch["input_ids"].to("cuda" if torch.cuda.is_available() else "cpu")
        attention_mask = batch["attention_mask"].to("cuda" if torch.cuda.is_available() else "cpu")

        # Forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        preds = torch.argmax(outputs.logits, dim=1)  # Predicted labels
        predictions.extend(preds.cpu().tolist())

print("Test Predictions:", predictions[:10])

# TRY DSI IMPLEMENTATION

In [None]:
class MSMARCODataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        query = item["query"]
        passage = item["passages"]["passage_text"][0]
        doc_id = int(item["query_id"])  # Convert query ID to an integer
        label = 1 if item["passages"]["is_selected"][0] else 0

        inputs = self.tokenizer(
            query,
            passage,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt"
        )
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "doc_ids": torch.tensor(doc_id, dtype=torch.long),
            "label": torch.tensor(label, dtype=torch.long),
        }

In [None]:
class MultiTaskDSIWithoutDistillation(pl.LightningModule):
    def __init__(self, model_name="t5-base", learning_rate=5e-4):
        super().__init__()
        self.save_hyperparameters()
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        self.learning_rate = learning_rate

        # Loss weights for multi-task learning
        self.indexing_loss_weight = 0.5
        self.retrieval_loss_weight = 0.5

    def forward(self, input_ids, attention_mask):
        return self.model(input_ids=input_ids, attention_mask=attention_mask)

    def compute_indexing_loss(self, outputs, doc_ids):
        """
        Indexing Task: Predict document IDs from passage text.
        """
        loss = F.cross_entropy(outputs.logits, doc_ids)
        return loss

    def compute_retrieval_loss(self, outputs, relevance_labels):
        """
        Retrieval Task: Rank passages based on relevance labels.
        """
        loss = F.cross_entropy(outputs.logits, relevance_labels)
        return loss

    def training_step(self, batch, batch_idx):
        # Forward pass
        outputs = self(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )

        # Compute losses for indexing and retrieval
        indexing_loss = self.compute_indexing_loss(outputs, batch["doc_ids"])
        retrieval_loss = self.compute_retrieval_loss(outputs, batch["label"])

        # Combine losses with weights
        total_loss = (
            self.indexing_loss_weight * indexing_loss
            + self.retrieval_loss_weight * retrieval_loss
        )

        # Log losses
        self.log("train_indexing_loss", indexing_loss, prog_bar=True)
        self.log("train_retrieval_loss", retrieval_loss, prog_bar=True)
        self.log("train_total_loss", total_loss, prog_bar=True)

        return total_loss

    def validation_step(self, batch, batch_idx):
        # Forward pass
        outputs = self(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )

        # Compute losses for indexing and retrieval
        indexing_loss = self.compute_indexing_loss(outputs, batch["doc_ids"])
        retrieval_loss = self.compute_retrieval_loss(outputs, batch["label"])

        # Combine losses
        total_loss = (
            self.indexing_loss_weight * indexing_loss
            + self.retrieval_loss_weight * retrieval_loss
        )

        # Log losses
        self.log("val_indexing_loss", indexing_loss, prog_bar=True)
        self.log("val_retrieval_loss", retrieval_loss, prog_bar=True)
        self.log("val_total_loss", total_loss, prog_bar=True)

        return total_loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForSeq2SeqLM

class MultiTaskDSIWithDistillation(pl.LightningModule):
    def __init__(self, student_model_name="t5-base", teacher_model=None, learning_rate=5e-4):
        super().__init__()
        self.save_hyperparameters()

        # Student model (docT5query or similar)
        self.student = AutoModelForSeq2SeqLM.from_pretrained(student_model_name)

        # Teacher model (dense retriever like ColBERT or BM25)
        self.teacher = teacher_model  # Pre-trained model used for distillation

        self.learning_rate = learning_rate

        # Separate heads for multi-task learning
        self.indexing_head = torch.nn.Linear(self.student.config.hidden_size, 10000)  # 10,000 doc IDs
        self.retrieval_head = torch.nn.Linear(self.student.config.hidden_size, 2)  # Binary classification

        # Loss weights for tasks
        self.indexing_loss_weight = 0.5
        self.retrieval_loss_weight = 0.3
        self.distillation_loss_weight = 0.2

    def forward(self, input_ids, attention_mask):
        encoder_outputs = self.student.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        indexing_logits = self.indexing_head(encoder_outputs[:, 0, :])  # Use [CLS] token
        retrieval_logits = self.retrieval_head(encoder_outputs[:, 0, :])  # Use [CLS] token
        return indexing_logits, retrieval_logits

    def compute_indexing_loss(self, logits, doc_ids):
        return F.cross_entropy(logits, doc_ids)

    def compute_retrieval_loss(self, logits, relevance_labels):
        return F.cross_entropy(logits, relevance_labels)

    def compute_distillation_loss(self, student_logits, teacher_logits):
        """
        Knowledge distillation loss: KL divergence between student and teacher logits.
        """
        student_probs = F.log_softmax(student_logits, dim=-1)
        teacher_probs = F.softmax(teacher_logits, dim=-1)
        return F.kl_div(student_probs, teacher_probs, reduction="batchmean")

    def training_step(self, batch, batch_idx):
        # Forward pass through student model
        indexing_logits, retrieval_logits = self(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])

        # Compute task-specific losses
        indexing_loss = self.compute_indexing_loss(indexing_logits, batch["doc_ids"])
        retrieval_loss = self.compute_retrieval_loss(retrieval_logits, batch["label"])

        # Compute distillation loss (if teacher model is provided)
        if self.teacher:
            with torch.no_grad():
                teacher_logits = self.teacher(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits
            distillation_loss = self.compute_distillation_loss(retrieval_logits, teacher_logits)
        else:
            distillation_loss = 0.0

        # Combine losses
        total_loss = (
            self.indexing_loss_weight * indexing_loss
            + self.retrieval_loss_weight * retrieval_loss
            + self.distillation_loss_weight * distillation_loss
        )

        # Log losses
        self.log("train_indexing_loss", indexing_loss, prog_bar=True)
        self.log("train_retrieval_loss", retrieval_loss, prog_bar=True)
        self.log("train_distillation_loss", distillation_loss, prog_bar=True)
        self.log("train_total_loss", total_loss, prog_bar=True)

        return total_loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

# OLD

## I-PYSERINI INSPECTION

In [None]:
from pyserini.search import get_topics

topics = get_topics('msmarco-passage-dev-subset')
print(f'{len(topics)} queries total')

In [None]:
from pyserini.search import SimpleSearcher

searcher = SimpleSearcher.from_prebuilt_index('msmarco-passage')

# Search the index for a query
hits = searcher.search('What is machine learning?')

# Display the top-ranked results
for i, hit in enumerate(hits):
    print(f"Rank {i+1}: {hit.docid} - {hit.score}")
    print(hit.raw)

## 2: BERT EMBEDDING

In [None]:

# Load a pre-trained model for embeddings
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")

# Generate embeddings for documents
def embed_text(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()

## Model Implementation

### T5 Transformer

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments, TrainerCallback

model_name = "t5-base"

tokenizer = T5Tokenizer.from_pretrained(model_name, cache_dir='cache')
model = T5ForConditionalGeneration.from_pretrained(model_name, cache_dir='cache')

### Bert (12 layers)
For docids embedding generation

In [None]:
!pip install transformers

In [None]:
import torch
from transformers import BertTokenizer, BertModel

# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

In [None]:
# Set model to evaluation mode
model.eval()

text = "Transformers are powerful models for NLP tasks."
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)

# Display tokenized input
print(inputs)

### Inputs2Target

In [None]:
class IndexingTrainDataset(Dataset):
    def _init_(self, path_to_data, max_length, cache_dir, tokenizer):
        super()._init_()

        self.train_data = datasets.load_dataset(
            'json',
            data_files=path_to_data,
            ignore_verifications=False,
            cache_dir=cache_dir
        )['train']

        self.max_length = max_length
        self.tokenizer = tokenizer
        self.total_len = len(self.train_data)


    def _getitem_(self, idx):
        # Retrieve document data
        doc = self.data[idx]
        doc_text = doc['text']
        docid = doc['docid']

        # Tokenize input (document text)
        # BertTokenizer.from_pretrained('bert-base-uncased')
        source = self.tokenizer(
            doc_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Tokenize target (docid)
        target = self.tokenizer(
            docid,
            max_length=10,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Prepare input-output pair
        return {
            'input_ids': source['input_ids'].squeeze(),
            'attention_mask': source['attention_mask'].squeeze(),
            'labels': target['input_ids'].squeeze()
        }

### Training

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=0.0005,
    warmup_steps=10000,
    # weight_decay=0.01,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    evaluation_strategy='steps',
    eval_steps=1000,
    max_steps=1000000,
    dataloader_drop_last=False,  # necessary
    report_to='wandb',
    logging_steps=50,
    save_strategy='no',
    # fp16=True,  # gives 0/nan loss at some point during training, seems this is a transformers bug.
    dataloader_num_workers=10,
    # gradient_accumulation_steps=2
)

trainer = IndexingTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=IndexingCollator(
        tokenizer,
        padding='longest',
    ),
    compute_metrics=compute_metrics,
    callbacks=[QueryEvalCallback(test_dataset, wandb, restrict_decode_vocab, training_args, tokenizer)],
    restrict_decode_vocab=restrict_decode_vocab
)

trainer.train()

### Training (from GPT)

In [None]:
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import wandb

# Initialize Weights & Biases (W&B) for logging
wandb.init(project="DSI-Training")

# 1. Load the Pre-trained T5 Model and Tokenizer
model_name = "t5-base"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# 2. Prepare the Dataset
class IndexingTrainDataset(torch.utils.data.Dataset):
    def _init_(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def _len_(self):
        return len(self.data)

    def _getitem_(self, idx):
        item = self.data[idx]
        doc_text = item['text']
        docid = item['docid']

        # Tokenize the document text (input)
        source = self.tokenizer(
            doc_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Tokenize the document ID (target)
        target = self.tokenizer(
            docid,
            max_length=10,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Prepare input-output pair
        return {
            'input_ids': source['input_ids'].squeeze(),
            'attention_mask': source['attention_mask'].squeeze(),
            'labels': target['input_ids'].squeeze()
        }

# Load your dataset (e.g., Natural Questions)
dataset = load_dataset("path/to/your/dataset")
train_data = IndexingTrainDataset(dataset['train'], tokenizer)
eval_data = IndexingTrainDataset(dataset['validation'], tokenizer)

# 3. Define Training Arguments
training_args = TrainingArguments(
    output_dir="./dsi_checkpoints",
    evaluation_strategy="steps",
    eval_steps=500,
    logging_dir="./logs",
    logging_steps=100,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    learning_rate=2e-5,
    weight_decay=0.01,
    save_steps=1000,
    save_total_limit=2,
    report_to="wandb"  # Enable logging to W&B
)

# 4. Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=eval_data,
    tokenizer=tokenizer
)

# 5. Start Training
trainer.train()

# 6. Save the Fine-tuned Model
model.save_pretrained("./fine_tuned_dsi")
tokenizer.save_pretrained("./fine_tuned_dsi")

# 7. End Logging with W&B
wandb.finish()