<a href="https://colab.research.google.com/github/alessioborgi/NSIO/blob/main/Source/InfoRetrieval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

==================================================

**Project Name:** Neural inverted index for fast and effective information retrieval\
**Course:** Deep Learning\
**University:** Sapienza Università di Roma

**Authors:**
  - [Alessio Borgi] (<tt>1952442</tt>)
  - [Eugenio Bugli] (<tt>1934824</tt>)
  - [Damiano Imola] (<tt>2109063</tt>)

**Date:** [November 2024 - Completion Date]


**Implementations**
*   Differentiable Search Index architecture
*   DSI-Multi dataset generation


**Novelties**
*   Dynamic pruning
*   Semantic and Stopwords Augmentation
*   Part Of Speech Masked Language Model (POS-MLM) Augmentation




==================================================

# 0: INSTALL & IMPORT LIBRARIES

In [1]:
%%capture
!pip install -q --upgrade pip
!pip install -q pyserini==0.12.0
!pip install -q pytorch-lightning datasets torch wandb
!pip install bitsandbytes -U -q
!pip install accelerate peft -q
!pip install -U gdown transformers -q

In [2]:
# base
import json
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

# for pyserini stuffs
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"

# cool plots
import wandb

# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# lightning
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor, StochasticWeightAveraging, DeviceStatsMonitor, ModelPruning


# HF and similar
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer, AutoTokenizer, AutoModelForSequenceClassification, T5Tokenizer, T5ForConditionalGeneration, EncoderDecoderCache, BitsAndBytesConfig

# sklearn
from sklearn.preprocessing import normalize
from sklearn.cluster import AgglomerativeClustering, KMeans

# pyserini
# import faiss
from pyserini.index import IndexReader
from pyserini.search import SimpleSearcher
# from pyserini.search.lucene import LuceneSearcher

# NLTK
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import sent_tokenize, RegexpTokenizer

# PEFT imports
from peft import LoraConfig, get_peft_model, TaskType, PeftModel, AdaLoraConfig

# bits and bytes
from bitsandbytes.optim import AdamW8bit

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Currently using {device}")

Currently using cuda


In [4]:
USERS = {
    "EUGENIO": "551e67be2c716a42ea3230a7c4fc639fc985f98f",
    "ALESSIO": "b3bce19a09c51bdf8a19eb3dc58f7c44de929e13",
    "DAMIANO": "6d550e12a1b8f716ebe580082f495c01ed2adf6c"}

wandb.login()
wandb.login(key=USERS["EUGENIO"])
wandb.init(project="IR_DSI", resume="allow")

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: [32m[41mERROR[0m API key must be 40 characters long, yours was 1364


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33meugeniobugli15[0m ([33madavit[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# 1: LOAD TOKENIZED DATASET

In [6]:
TRAIN_DATA_PATH = '/content/drive/MyDrive/deep-learning-files/train_data_tokenized.pt'
VALIDATION_DATA_PATH = '/content/drive/MyDrive/deep-learning-files/validation_data_tokenized.pt'
TEST_DATA_PATH = '/content/drive/MyDrive/deep-learning-files/test_data_tokenized.pt'

In [7]:
class DatasetLoader(torch.utils.data.Dataset):
    def __init__(self, file_name, up_to_k=5000):
        self.data = torch.load(file_name)[:up_to_k]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [8]:
train_dataset = DatasetLoader(TRAIN_DATA_PATH, up_to_k=800)
val_dataset = DatasetLoader(VALIDATION_DATA_PATH, up_to_k=100)
test_dataset = DatasetLoader(TEST_DATA_PATH, up_to_k=100)

  self.data = torch.load(file_name)[:up_to_k]


In [9]:
train_dataloader = DataLoader(train_dataset, batch_size=3, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=3)
test_dataloader = DataLoader(test_dataset, batch_size=3)

In [12]:
print(len(train_dataset), len(val_dataset), len(test_dataset))
print(len(train_dataset) + len(val_dataset) + len(test_dataset))

800 100 100
1000


# 2: MODEL

In [10]:
max_encoder_squence_len = 1797 # 48
max_decoder_squence_len = 4
max_decoder_squence_len_1000 = max_decoder_squence_len * 1000

In [11]:
class DSIT5Model(pl.LightningModule):
    def __init__(self, model_name="t5-small", learning_rate=5e-5, max_decoder_squence_len=max_decoder_squence_len, max_decoder_squence_len_1000=max_decoder_squence_len_1000):
        super().__init__()
        self.model = T5ForConditionalGeneration.from_pretrained(model_name) # transformer
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.learning_rate = learning_rate

        self.val_losses = []
        self.val_input_ids = []
        self.val_decoder_input_ids = []
        self.val_decoder_1000_input_ids = []
        self.val_queries = []

    def training_step(self, batch, batch_idx):
        queries = torch.Tensor(batch['query']).squeeze(1).to(device)
        input_ids = torch.Tensor(batch['input_ids']).squeeze(1)
        decoder_input_ids = torch.Tensor(batch['decoder_input_ids']).squeeze(1)
        decoder_1000_input_ids = torch.Tensor(batch['decoder_ranked_input_ids']).squeeze(1)

        # indexing task
        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        # retrieval task
        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        # metrics = self.compute_metrics(input_ids, decoder_input_ids, decoder_1000_input_ids)

        self.log("train_loss", loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        queries = torch.Tensor(batch['query']).squeeze(1).to(device)
        input_ids = torch.Tensor(batch['input_ids']).squeeze(1)
        decoder_input_ids = torch.Tensor(batch['decoder_input_ids']).squeeze(1)
        decoder_1000_input_ids = torch.Tensor(batch['decoder_ranked_input_ids']).squeeze(1)

        # indexing task
        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        # retrieval task
        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        # metrics = self.compute_metrics(input_ids, decoder_input_ids, decoder_1000_input_ids)

        self.val_losses.append(loss.detach().cpu())
        self.val_input_ids.append(input_ids.cpu())
        self.val_decoder_input_ids.append(decoder_input_ids.cpu())
        self.val_decoder_1000_input_ids.append(decoder_1000_input_ids.cpu())
        self.val_queries.append(queries.cpu())

        self.log("validation_loss", loss, on_epoch=True)

    def test_step(self, batch, batch_idx):
        queries = torch.Tensor(batch['query']).squeeze(1).to(device)
        input_ids = torch.Tensor(batch['input_ids']).squeeze(1)
        decoder_input_ids = torch.Tensor(batch['decoder_input_ids']).squeeze(1)
        decoder_1000_input_ids = torch.Tensor(batch['decoder_ranked_input_ids']).squeeze(1)

        # indexing task
        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        # retrieval task
        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        metrics = self.compute_metrics(queries, decoder_input_ids, decoder_1000_input_ids)

        self.log("test_loss", loss, on_epoch=True)
        return loss

    def compute_metrics(self, input_ids, decoder_input_ids, decoder_1000_input_ids):
        predicted = self.model.generate(input_ids, max_length=max_decoder_squence_len_1000)
        decoded_batch = self.tokenizer.batch_decode(predicted, skip_special_tokens=True)

        # Decoded string into a list of docids
        predicted_1000_docids = [text.split() for text in decoded_batch]
        target_1000_text = self.tokenizer.batch_decode(decoder_1000_input_ids, skip_special_tokens=True)
        target_1000_docids = [text.split() for text in target_1000_text]

        target_text = self.tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True)
        target_docids = [t.split() for t in target_text]

        recall_at_1000 = self.compute_recall_at_1000(predicted_1000_docids, target_1000_docids)
        mean_ap = self.compute_map(predicted_1000_docids, target_docids)

        return {
            "recall_at_1000": recall_at_1000,
            "map": mean_ap
        }

    def compute_recall_at_1000(self, predicted_1000_docids, target_1000_docids):
        recalls = []
        for pred_list, target_list in zip(predicted_1000_docids, target_1000_docids):
            target_set = set(target_list)
            if not target_set:
                # Edge case: if there's no “true” docID, recall is 0 by definition
                recalls.append(0.0)
                continue

            pred_set = set(pred_list)
            intersection_size = len(pred_set.intersection(target_set))
            recall = intersection_size / len(target_set)
            recalls.append(recall)

        return float(np.mean(recalls)) if recalls else 0.0

    def compute_map(self, predicted_docids, target_docids):
        all_aps = []
        for pred_list, target_list in zip(predicted_docids, target_docids):
            target_set = set(target_list)
            if not target_set:
                all_aps.append(0.0) # No g.truth docs
                continue

            num_hits = 0
            precision_accum = []
            for i, doc_id in enumerate(pred_list):
                if doc_id in target_set:
                    num_hits += 1
                    precision_at_i = num_hits / (i + 1)  # i+1 => rank index
                    precision_accum.append(precision_at_i)

            if precision_accum:
                ap = sum(precision_accum) / len(target_set)
            else:
                ap = 0.0
            all_aps.append(ap)
        return float(np.mean(all_aps)) if all_aps else 0.0

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

    def on_validation_start(self) -> None:
        self.val_losses.clear()
        self.val_input_ids.clear()
        self.val_decoder_input_ids.clear()
        self.val_decoder_1000_input_ids.clear()
        self.val_queries.clear()

    def on_validation_epoch_end(self):
        val_losses = torch.stack(self.val_losses, dim=0)
        all_input_ids = torch.cat(self.val_input_ids, dim=0).to(self.device)
        all_decoder_input_ids = torch.cat(self.val_decoder_input_ids, dim=0).to(self.device)
        all_decoder_1000_input_ids = torch.cat(self.val_decoder_input_ids, dim=0).to(self.device)
        all_queries = torch.cat(self.val_queries, dim=0).to(self.device)
        avg_loss = val_losses.mean()

        metrics = self.compute_metrics(
            input_ids=all_input_ids,
            decoder_input_ids=all_decoder_input_ids,
            decoder_1000_input_ids=all_decoder_1000_input_ids
        )

        self.log("validation_loss", avg_loss, on_epoch=True, prog_bar=True)
        self.log("recall_at_1000", metrics["recall_at_1000"], on_epoch=True, prog_bar=True)
        self.log("map", metrics["map"], on_epoch=True, prog_bar=True)

# 3: TRAINING

In [12]:
model = DSIT5Model()

logger = WandbLogger(project="IR_DSI_Project")

# ===== CALLBACKS =====
checkpoint_callback = ModelCheckpoint(
    monitor='validation_loss',
    dirpath='checkpoints/',
    filename='dsi-t5-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    mode='min'
)

early_stopping_callback = EarlyStopping(monitor='validation_loss', patience=5, mode='min')

# ===== DYNAMIC PRUNING ====
# removes individual weights based on magnitude or importance
# removes the ones with smallest L1 norm
# amount removes 20% of the smallest magnitude weights
pruning_callback = ModelPruning("l1_unstructured", amount=0.5)

trainer = pl.Trainer(
    max_epochs=10,
    logger=logger,
    callbacks=[checkpoint_callback, early_stopping_callback, pruning_callback],
    log_every_n_steps=1,
    accelerator="auto",

    # gradient accumulation
    accumulate_grad_batches=4, # gradient is computed after 4 batches

    # mixed precision
    precision='16-mixed', # 16 bit precision of my model
)

# Train the model
trainer.fit(model, train_dataloader, val_dataloaders=val_dataloader)

wandb.finish()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

0,1
epoch,▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▄▄▄▄▅▅▅▅▅▅▅▅▅▇▇▇▇▇▇▇█████
map,▁▁▁▁▁▁
recall_at_1000,▁▁▁▁▁▁
train_loss_epoch,▁█▇▇▇▇
train_loss_step,▄▃▂▂▂▁▁▇███████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
trainer/global_step,▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇█
validation_loss,▁█████

0,1
epoch,5.0
map,0.0
recall_at_1000,0.0
train_loss_epoch,20.75
train_loss_step,20.75
trainer/global_step,401.0
validation_loss,20.75


In [14]:
name_run = "sandy"
torch.save(model, f'/content/drive/MyDrive/deep-learning-files/model_{name_run}.pth')
torch.save(model.state_dict(), f'/content/drive/MyDrive/deep-learning-files/model_{name_run}_state_dict.pth')

# PEFT (Parameter Efficient Finetuning) with LORA (adapter)

In [None]:
class DSIT5ModelLORA(pl.LightningModule):
    def __init__(self, model_name="t5-small", learning_rate=5e-5, max_decoder_squence_len=max_decoder_squence_len, max_decoder_squence_len_1000=max_decoder_squence_len_1000, lora_r=8, lora_alpha=32, lora_dropout=0.1):
        super().__init__()

        self.T5 = T5ForConditionalGeneration.from_pretrained(model_name) # transformer

        # PEFT sets requires_grad=False on all original T5 layers
        # LoRA parameters (the low-rank adapters) are injected into the T5 attention layers and do requires_grad=True
        # during training, only these small LoRA adapters get updated, leaving the main T5 weights untouched
        self.peft_config = LoraConfig(
            task_type=TaskType.SEQ_2_SEQ_LM,
            inference_mode=False,
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout
        )

        self.model = get_peft_model(self.T5, self.peft_config)

        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.learning_rate = learning_rate


    def training_step(self, batch, batch_idx):
        queries = torch.Tensor(batch['query']).squeeze(1).to(device)
        input_ids = torch.Tensor(batch['input_ids']).squeeze(1)
        decoder_input_ids = torch.Tensor(batch['decoder_input_ids']).squeeze(1)
        decoder_1000_input_ids = torch.Tensor(batch['decoder_ranked_input_ids']).squeeze(1)

        # indexing task
        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        # retrieval task
        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        # metrics = self.compute_metrics(input_ids, decoder_input_ids, decoder_1000_input_ids)

        self.log("train_loss", loss, on_epoch=True)
        return loss


    def validation_step(self, batch, batch_idx):
        queries = torch.Tensor(batch['query']).squeeze(1).to(device)
        input_ids = torch.Tensor(batch['input_ids']).squeeze(1)
        decoder_input_ids = torch.Tensor(batch['decoder_input_ids']).squeeze(1)
        decoder_1000_input_ids = torch.Tensor(batch['decoder_ranked_input_ids']).squeeze(1)

        # indexing task
        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        # retrieval task
        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        metrics = self.compute_metrics(input_ids, decoder_input_ids, decoder_1000_input_ids)

        self.log("validation_loss", loss, on_epoch=True)
        return loss



    def test_step(self, batch, batch_idx):
        queries = torch.Tensor(batch['query']).squeeze(1).to(device)
        input_ids = torch.Tensor(batch['input_ids']).squeeze(1)
        decoder_input_ids = torch.Tensor(batch['decoder_input_ids']).squeeze(1)
        decoder_1000_input_ids = torch.Tensor(batch['decoder_ranked_input_ids']).squeeze(1)

        # indexing task
        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        # retrieval task
        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        metrics = self.compute_metrics(queries, decoder_input_ids, decoder_1000_input_ids)

        self.log("test_loss", loss, on_epoch=True)
        return loss

    def compute_metrics(self, input_ids, decoder_input_ids, decoder_1000_input_ids):
        predicted = self.model.get_base_model().generate(input_ids, max_length=max_decoder_squence_len_1000)
        decoded_batch = self.tokenizer.batch_decode(predicted, skip_special_tokens=True)

        # Decoded string into a list of docids
        predicted_1000_docids = [text.split() for text in decoded_batch]
        target_1000_text = self.tokenizer.batch_decode(decoder_1000_input_ids, skip_special_tokens=True)
        target_1000_docids = [text.split() for text in target_1000_text]

        target_text = self.tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True)
        target_docids = [t.split() for t in target_text]

        recall_at_1000 = self.compute_recall_at_1000(predicted_1000_docids, target_1000_docids)
        mean_ap = self.compute_map(predicted_1000_docids, target_docids)

        return {
            "recall_at_1000": recall_at_1000,
            "map": mean_ap
        }

    def compute_recall_at_1000(self, predicted_1000_docids, target_1000_docids):
        recalls = []
        for pred_list, target_list in zip(predicted_1000_docids, target_1000_docids):
            target_set = set(target_list)
            if not target_set:
                # Edge case: if there's no “true” docID, recall is 0 by definition
                recalls.append(0.0)
                continue

            pred_set = set(pred_list)
            intersection_size = len(pred_set.intersection(target_set))
            recall = intersection_size / len(target_set)
            recalls.append(recall)

        return float(np.mean(recalls)) if recalls else 0.0

    def compute_map(self, predicted_docids, target_docids):
        all_aps = []
        for pred_list, target_list in zip(predicted_docids, target_docids):
            target_set = set(target_list)
            if not target_set:
                all_aps.append(0.0) # No g.truth docs
                continue

            num_hits = 0
            precision_accum = []
            for i, doc_id in enumerate(pred_list):
                if doc_id in target_set:
                    num_hits += 1
                    precision_at_i = num_hits / (i + 1)  # i+1 => rank index
                    precision_accum.append(precision_at_i)

            if precision_accum:
                ap = sum(precision_accum) / len(target_set)
            else:
                ap = 0.0
            all_aps.append(ap)
        return float(np.mean(all_aps)) if all_aps else 0.0

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

In [None]:
model = DSIT5ModelLORA(
    lora_r=8,
    lora_alpha=32,
    lora_dropout=0.1
)

logger = WandbLogger(project="IR_DSI_Project")

# ===== CALLBACKS =====
checkpoint_callback = ModelCheckpoint(
    monitor='validation_loss',
    dirpath='checkpoints/',
    filename='dsi-t5-lora-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    mode='min'
)

early_stopping_callback = EarlyStopping(monitor='validation_loss', patience=3, mode='min')

# ===== DYNAMIC PRUNING ====
# removes individual weights based on magnitude or importance
# removes the ones with smallest L1 norm
# amount removes 20% of the smallest magnitude weights
pruning_callback = ModelPruning("l1_unstructured", amount=0.5)

trainer = pl.Trainer(
    max_epochs=5,
    logger=logger,
    callbacks=[checkpoint_callback, early_stopping_callback],#pruning_callback],

    accelerator="auto",

    # gradient accumulation
    accumulate_grad_batches=4, # gradient is computed after 4 batches

    # mixed precision
    precision='16-mixed', # 16 bit precision of my model
)

# Train the model
trainer.fit(model, train_dataloader, val_dataloaders=val_dataloader)

wandb.finish()

INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                       | Params | Mode 
-------------------------------------------------------------
0 | T5    | T5ForConditionalGeneration | 60.8 M | eval 
1 | model | PeftModelForSeq2SeqLM      | 60.8 M | train
-------------------------------------------------------------
294 K     Trainable params
60.5 M    Non-trainable params
60.8 M    Total params
243.206   Total estimated model params size (MB)
362       Modules in train mode
277       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

0,1
epoch,▁▁▁▁▁▁▁▁▃▃▆▆▁▁▁▃▃▆▆██
train_loss_epoch,▁▁▁▁█▆▁▁▁▁
trainer/global_step,▁▁▁▁▁▁▁▁▃▃▆▆▁▁▁▃▃▆▆██
validation_loss,▁▁▁▁█▇▁▁▁▁▁

0,1
epoch,3.0
train_loss_epoch,12.58074
trainer/global_step,3.0
validation_loss,13.29458


# QLORA

In [None]:
class DSIT5ModelQLORA(pl.LightningModule):
    def __init__(self, model_name="t5-small", learning_rate=5e-5, max_decoder_squence_len=max_decoder_squence_len, max_decoder_squence_len_1000=max_decoder_squence_len_1000, lora_r=8, lora_alpha=32, lora_dropout=0.1):
        super().__init__()

        # 4-bit quantization
        bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.float16)

        # load T5 in 4-bit precision
        self.base_t5 = T5ForConditionalGeneration.from_pretrained(model_name, quantization_config=bnb_config, device_map="auto")

        self.peft_config = LoraConfig(
            task_type=TaskType.SEQ_2_SEQ_LM,
            inference_mode=False,
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout
        )

        self.model = get_peft_model(self.base_t5, self.peft_config)

        self.tokenizer = T5Tokenizer.from_pretrained(model_name)

        self.learning_rate = learning_rate
        self.max_decoder_squence_len = max_decoder_squence_len
        self.max_decoder_squence_len_1000 = max_decoder_squence_len_1000

    def training_step(self, batch, batch_idx):
        # Expecting your batch to have these keys
        queries = batch['query'].squeeze(1).to(device)
        input_ids = batch['input_ids'].squeeze(1).to(device)
        decoder_input_ids = batch['decoder_input_ids'].squeeze(1).to(device)
        decoder_1000_input_ids = batch['decoder_ranked_input_ids'].squeeze(1).to(device)

        # indexing task
        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        # retrieval task
        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        self.log("train_loss", loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        queries = batch['query'].squeeze(1).to(device)
        input_ids = batch['input_ids'].squeeze(1).to(device)
        decoder_input_ids = batch['decoder_input_ids'].squeeze(1).to(device)
        decoder_1000_input_ids = batch['decoder_ranked_input_ids'].squeeze(1).to(device)

        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        metrics = self.compute_metrics(input_ids, decoder_input_ids, decoder_1000_input_ids)

        self.log("validation_loss", loss, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        queries = batch['query'].squeeze(1).to(device)
        input_ids = batch['input_ids'].squeeze(1).to(device)
        decoder_input_ids = batch['decoder_input_ids'].squeeze(1).to(device)
        decoder_1000_input_ids = batch['decoder_ranked_input_ids'].squeeze(1).to(device)

        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        metrics = self.compute_metrics(queries, decoder_input_ids, decoder_1000_input_ids)

        self.log("test_loss", loss, on_epoch=True)
        return loss

    def compute_metrics(self, input_ids, decoder_input_ids, decoder_1000_input_ids):
        predicted = self.model.get_base_model().generate(input_ids, max_length=max_decoder_squence_len_1000)
        decoded_batch = self.tokenizer.batch_decode(predicted, skip_special_tokens=True)

        # Decoded string into a list of docids
        predicted_1000_docids = [text.split() for text in decoded_batch]
        target_1000_text = self.tokenizer.batch_decode(decoder_1000_input_ids, skip_special_tokens=True)
        target_1000_docids = [text.split() for text in target_1000_text]

        target_text = self.tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True)
        target_docids = [t.split() for t in target_text]

        recall_at_1000 = self.compute_recall_at_1000(predicted_1000_docids, target_1000_docids)
        mean_ap = self.compute_map(predicted_1000_docids, target_docids)

        return {
            "recall_at_1000": recall_at_1000,
            "map": mean_ap
        }

    def compute_recall_at_1000(self, predicted_1000_docids, target_1000_docids):
        recalls = []
        for pred_list, target_list in zip(predicted_1000_docids, target_1000_docids):
            target_set = set(target_list)
            if not target_set:
                # Edge case: if there's no “true” docID, recall is 0 by definition
                recalls.append(0.0)
                continue

            pred_set = set(pred_list)
            intersection_size = len(pred_set.intersection(target_set))
            recall = intersection_size / len(target_set)
            recalls.append(recall)

        return float(np.mean(recalls)) if recalls else 0.0

    def compute_map(self, predicted_docids, target_docids):
        all_aps = []
        for pred_list, target_list in zip(predicted_docids, target_docids):
            target_set = set(target_list)
            if not target_set:
                all_aps.append(0.0) # No g.truth docs
                continue

            num_hits = 0
            precision_accum = []
            for i, doc_id in enumerate(pred_list):
                if doc_id in target_set:
                    num_hits += 1
                    precision_at_i = num_hits / (i + 1)  # i+1 => rank index
                    precision_accum.append(precision_at_i)

            if precision_accum:
                ap = sum(precision_accum) / len(target_set)
            else:
                ap = 0.0
            all_aps.append(ap)
        return float(np.mean(all_aps)) if all_aps else 0.0

    def configure_optimizers(self):
        optimizer = AdamW8bit(self.parameters(), lr=self.learning_rate)
        return optimizer

In [None]:
model = DSIT5ModelQLORA(
    lora_r=8,
    lora_alpha=32,
    lora_dropout=0.1
)

logger = WandbLogger(project="IR_DSI_Project")

# ===== CALLBACKS =====
checkpoint_callback = ModelCheckpoint(
    monitor='validation_loss',
    dirpath='checkpoints/',
    filename='dsi-t5-lora-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    mode='min'
)

early_stopping_callback = EarlyStopping(monitor='validation_loss', patience=3, mode='min')

# ===== DYNAMIC PRUNING ====
# removes individual weights based on magnitude or importance
# removes the ones with smallest L1 norm
# amount removes 20% of the smallest magnitude weights
pruning_callback = ModelPruning("l1_unstructured", amount=0.5)

trainer = pl.Trainer(
    max_epochs=5,
    logger=logger,
    callbacks=[checkpoint_callback, early_stopping_callback], #pruning_callback],

    accelerator="auto",

    # gradient accumulation
    accumulate_grad_batches=4, # gradient is computed after 4 batches

    # mixed precision
    precision='16-mixed', # 16 bit precision of my model
)

# Train the model
trainer.fit(model, train_dataloader, val_dataloaders=val_dataloader)

wandb.finish()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_z

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

0,1
epoch,▁▁▃▃▆▆██
train_loss_epoch,█▁▆▁
trainer/global_step,▁▁▃▃▆▆██
validation_loss,▁▁▁▁

0,1
epoch,3.0
train_loss_epoch,13.39073
trainer/global_step,3.0
validation_loss,13.63606


# AdaLORA

In [None]:
class DSIT5ModelAdaLORA(pl.LightningModule):
    def __init__(
        self,
        model_name="t5-small",
        learning_rate=5e-5,
        max_decoder_squence_len=max_decoder_squence_len,
        max_decoder_squence_len_1000=max_decoder_squence_len_1000,
        init_r=12, # initial rank for lora adapters
        target_r=4, # final desired rank for lora adapters
        lora_alpha=32, lora_dropout=0.1,
        init_steps=20, target_steps=200, # steps after which rank update begins and ends
        delta_steps=10 # interval for each update
    ):
        super().__init__()

        self.base_t5 = T5ForConditionalGeneration.from_pretrained(model_name)
        self.adalora_config = AdaLoraConfig(
            task_type=TaskType.SEQ_2_SEQ_LM,
            init_r = init_r,
            target_r = target_r,
            tinit = init_steps,
            tfinal = target_steps,
            deltaT = delta_steps,
            lora_alpha = lora_alpha,
            lora_dropout = lora_dropout
        )
        self.model = get_peft_model(self.base_t5, self.adalora_config)
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.learning_rate = learning_rate
        self.max_decoder_squence_len = max_decoder_squence_len
        self.max_decoder_squence_len_1000 = max_decoder_squence_len_1000

    def training_step(self, batch, batch_idx):
        # 'query', 'input_ids', 'decoder_input_ids', 'decoder_ranked_input_ids'
        queries = batch['query'].squeeze(1).to(device)
        input_ids = batch['input_ids'].squeeze(1).to(device)
        decoder_input_ids = batch['decoder_input_ids'].squeeze(1).to(device)
        decoder_1000_input_ids = batch['decoder_ranked_input_ids'].squeeze(1).to(device)

        # indexing task
        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        # retrieval task
        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss
        self.log("train_loss", loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        queries = batch['query'].squeeze(1).to(device)
        input_ids = batch['input_ids'].squeeze(1).to(device)
        decoder_input_ids = batch['decoder_input_ids'].squeeze(1).to(device)
        decoder_1000_input_ids = batch['decoder_ranked_input_ids'].squeeze(1).to(device)

        # indexing task
        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        # retrieval task
        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss
        metrics = self.compute_metrics(input_ids, decoder_input_ids, decoder_1000_input_ids)

        self.log("validation_loss", loss, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        queries = batch['query'].squeeze(1).to(device)
        input_ids = batch['input_ids'].squeeze(1).to(device)
        decoder_input_ids = batch['decoder_input_ids'].squeeze(1).to(device)
        decoder_1000_input_ids = batch['decoder_ranked_input_ids'].squeeze(1).to(device)

        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss
        metrics = self.compute_metrics(queries, decoder_input_ids, decoder_1000_input_ids)

        self.log("test_loss", loss, on_epoch=True)
        return loss

    def compute_metrics(self, input_ids, decoder_input_ids, decoder_1000_input_ids):
        predicted = self.model.get_base_model().generate(input_ids, max_length=max_decoder_squence_len_1000)
        decoded_batch = self.tokenizer.batch_decode(predicted, skip_special_tokens=True)

        # Decoded string into a list of docids
        predicted_1000_docids = [text.split() for text in decoded_batch]
        target_1000_text = self.tokenizer.batch_decode(decoder_1000_input_ids, skip_special_tokens=True)
        target_1000_docids = [text.split() for text in target_1000_text]

        target_text = self.tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True)
        target_docids = [t.split() for t in target_text]

        recall_at_1000 = self.compute_recall_at_1000(predicted_1000_docids, target_1000_docids)
        mean_ap = self.compute_map(predicted_1000_docids, target_docids)

        return {
            "recall_at_1000": recall_at_1000,
            "map": mean_ap
        }

    def compute_recall_at_1000(self, predicted_1000_docids, target_1000_docids):
        recalls = []
        for pred_list, target_list in zip(predicted_1000_docids, target_1000_docids):
            target_set = set(target_list)
            if not target_set:
                # Edge case: if there's no “true” docID, recall is 0 by definition
                recalls.append(0.0)
                continue

            pred_set = set(pred_list)
            intersection_size = len(pred_set.intersection(target_set))
            recall = intersection_size / len(target_set)
            recalls.append(recall)

        return float(np.mean(recalls)) if recalls else 0.0

    def compute_map(self, predicted_docids, target_docids):
        all_aps = []
        for pred_list, target_list in zip(predicted_docids, target_docids):
            target_set = set(target_list)
            if not target_set:
                all_aps.append(0.0) # No g.truth docs
                continue

            num_hits = 0
            precision_accum = []
            for i, doc_id in enumerate(pred_list):
                if doc_id in target_set:
                    num_hits += 1
                    precision_at_i = num_hits / (i + 1)  # i+1 => rank index
                    precision_accum.append(precision_at_i)

            if precision_accum:
                ap = sum(precision_accum) / len(target_set)
            else:
                ap = 0.0
            all_aps.append(ap)
        return float(np.mean(all_aps)) if all_aps else 0.0

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

In [None]:
model = DSIT5ModelAdaLORA()

logger = WandbLogger(project="IR_DSI_Project")

# ===== CALLBACKS =====
checkpoint_callback = ModelCheckpoint(
    monitor='validation_loss',
    dirpath='checkpoints/',
    filename='dsi-t5-lora-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    mode='min'
)

early_stopping_callback = EarlyStopping(monitor='validation_loss', patience=3, mode='min')

pruning_callback = ModelPruning("l1_unstructured", amount=0.5)

trainer = pl.Trainer(
    max_epochs=5,
    logger=logger,
    callbacks=[checkpoint_callback, early_stopping_callback], #pruning_callback],

    accelerator="auto",

    # gradient accumulation
    accumulate_grad_batches=4, # gradient is computed after 4 batches

    # mixed precision
    precision='16-mixed', # 16 bit precision of my model
)

# Train the model
trainer.fit(model, train_dataloader, val_dataloaders=val_dataloader)

wandb.finish()

INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


/usr/local/lib/python3.11/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /content/checkpoints exists and is not empty.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type                       | Params | Mode 
---------------------------------------------------------------
0 | base_t5 | T5ForConditionalGeneration | 62.1 M | eval 
1 | model   | PeftModelForSeq2SeqLM      | 62.1 M | train
---------------------------------------------------------------
1.6 M     Trainable params
60.5 M    Non-trainable params
62.1 M    Total params
248.520   Total estimated model params size (MB)
962       Modules in train mode
277       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

0,1
epoch,▁▁▃▃▆▆██
train_loss_epoch,▃▄▁█
trainer/global_step,▁▁▃▃▆▆██
validation_loss,▁▁▁▁

0,1
epoch,3.0
train_loss_epoch,15.07547
trainer/global_step,3.0
validation_loss,15.78931


# ConvoLORA

In [None]:
class ConvoAdapter(nn.Module):
    def __init__(
        self,
        hidden_dim,
        conv_kernel_size=3,
        rank=4,
        alpha=32,
        dropout=0.1
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.rank = rank
        self.alpha = alpha
        self.dropout = nn.Dropout(dropout)

        self.conv = nn.Conv1d(
            in_channels=hidden_dim,
            out_channels=hidden_dim,
            kernel_size=conv_kernel_size,
            padding=(conv_kernel_size - 1) // 2,
            groups=hidden_dim  # depthwise
        )

        # Low-rank linear mapping
        self.down_proj = nn.Linear(hidden_dim, rank, bias=False)
        self.up_proj   = nn.Linear(rank, hidden_dim, bias=False)

    def forward(self, hidden_states):
        # hidden_states shape: [batch_size, seq_len, hidden_dim]
        bsz, seq_len, dim = hidden_states.shape
        assert dim == self.hidden_dim, "Mismatched hidden dim in LoConAdapter"
        hs = hidden_states.transpose(1, 2)
        conv_out = self.conv(hs)
        conv_out = conv_out.transpose(1, 2)
        conv_out = self.dropout(conv_out)
        down = self.down_proj(conv_out)
        up   = self.up_proj(down)
        scaled = up * self.alpha
        return hidden_states + scaled

In [None]:
class DSIT5ModelConvoLORA(pl.LightningModule):
    def __init__(
        self,
        model_name="t5-small",
        learning_rate=5e-5,
        max_decoder_squence_len=max_decoder_squence_len,
        max_decoder_squence_len_1000=max_decoder_squence_len_1000,
        conv_kernel_size=3,
        conv_lora_rank=4,
        conv_lora_alpha=32,
        conv_lora_dropout=0.1
    ):
        super().__init__()
        self.save_hyperparameters()
        self.base_t5 = T5ForConditionalGeneration.from_pretrained(model_name)
        for param in self.base_t5.parameters():
            param.requires_grad = False
        ff_sublayer = self.base_t5.encoder.block[-1].layer[1]
        ff_sublayer.lo_con = ConvoAdapter(
            hidden_dim=self.base_t5.config.d_model,
            conv_kernel_size=conv_kernel_size,
            rank=conv_lora_rank,
            alpha=conv_lora_alpha,
            dropout=conv_lora_dropout
        )
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.learning_rate = learning_rate
        self.max_decoder_squence_len = max_decoder_squence_len
        self.max_decoder_squence_len_1000 = max_decoder_squence_len_1000
        self.model = self.base_t5

    def forward(self, input_ids, attention_mask=None, labels=None):
        encoder_outputs = self.model.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        last_hidden = encoder_outputs.last_hidden_state
        ff_sublayer = self.model.encoder.block[-1].layer[1]
        if hasattr(ff_sublayer, "lo_con"):
            last_hidden = ff_sublayer.lo_con(last_hidden)
        new_enc_out = [last_hidden]
        outputs = self.model(
            encoder_outputs=new_enc_out,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True
        )
        return outputs

    def training_step(self, batch, batch_idx):
        queries = batch['query'].squeeze(1).to(device)
        input_ids = batch['input_ids'].squeeze(1).to(device)
        decoder_input_ids = batch['decoder_input_ids'].squeeze(1).to(device)
        decoder_1000_input_ids = batch['decoder_ranked_input_ids'].squeeze(1).to(device)

        index_output = self(
            input_ids=input_ids,
            labels=decoder_input_ids
        )
        index_loss = index_output.loss
        retrieval_output = self(
            input_ids=input_ids,
            labels=queries
        )
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss
        self.log("train_loss", loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        queries = batch['query'].squeeze(1).to(device)
        input_ids = batch['input_ids'].squeeze(1).to(device)
        decoder_input_ids = batch['decoder_input_ids'].squeeze(1).to(device)
        decoder_1000_input_ids = batch['decoder_ranked_input_ids'].squeeze(1).to(device)

        index_output = self(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        retrieval_output = self(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        metrics = self.compute_metrics(input_ids, decoder_input_ids, decoder_1000_input_ids)

        self.log("validation_loss", loss, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        queries = batch['query'].squeeze(1).to(device)
        input_ids = batch['input_ids'].squeeze(1).to(device)
        decoder_input_ids = batch['decoder_input_ids'].squeeze(1).to(device)
        decoder_1000_input_ids = batch['decoder_ranked_input_ids'].squeeze(1).to(device)
        index_output = self(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss
        retrieval_output = self(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        metrics = self.compute_metrics(queries, decoder_input_ids, decoder_1000_input_ids)

        self.log("test_loss", loss, on_epoch=True)
        return loss

    def compute_metrics(self, input_ids, decoder_input_ids, decoder_1000_input_ids):
        predicted = self.model.generate(input_ids, max_length=max_decoder_squence_len_1000)
        decoded_batch = self.tokenizer.batch_decode(predicted, skip_special_tokens=True)

        # Decoded string into a list of docids
        predicted_1000_docids = [text.split() for text in decoded_batch]
        target_1000_text = self.tokenizer.batch_decode(decoder_1000_input_ids, skip_special_tokens=True)
        target_1000_docids = [text.split() for text in target_1000_text]

        target_text = self.tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True)
        target_docids = [t.split() for t in target_text]

        recall_at_1000 = self.compute_recall_at_1000(predicted_1000_docids, target_1000_docids)
        mean_ap = self.compute_map(predicted_1000_docids, target_docids)

        return {
            "recall_at_1000": recall_at_1000,
            "map": mean_ap
        }

    def compute_recall_at_1000(self, predicted_1000_docids, target_1000_docids):
        recalls = []
        for pred_list, target_list in zip(predicted_1000_docids, target_1000_docids):
            target_set = set(target_list)
            if not target_set:
                # Edge case: if there's no “true” docID, recall is 0 by definition
                recalls.append(0.0)
                continue

            pred_set = set(pred_list)
            intersection_size = len(pred_set.intersection(target_set))
            recall = intersection_size / len(target_set)
            recalls.append(recall)

        return float(np.mean(recalls)) if recalls else 0.0

    def compute_map(self, predicted_docids, target_docids):
        all_aps = []
        for pred_list, target_list in zip(predicted_docids, target_docids):
            target_set = set(target_list)
            if not target_set:
                all_aps.append(0.0) # No g.truth docs
                continue

            num_hits = 0
            precision_accum = []
            for i, doc_id in enumerate(pred_list):
                if doc_id in target_set:
                    num_hits += 1
                    precision_at_i = num_hits / (i + 1)  # i+1 => rank index
                    precision_accum.append(precision_at_i)

            if precision_accum:
                ap = sum(precision_accum) / len(target_set)
            else:
                ap = 0.0
            all_aps.append(ap)
        return float(np.mean(all_aps)) if all_aps else 0.0

    def configure_optimizers(self):
        trainable_params = [p for p in self.parameters() if p.requires_grad]
        optimizer = torch.optim.AdamW(trainable_params, lr=self.learning_rate)
        return optimizer

In [None]:
model = DSIT5ModelConvoLORA()

logger = WandbLogger(project="IR_DSI_Project")

# ===== CALLBACKS =====
checkpoint_callback = ModelCheckpoint(
    monitor='validation_loss',
    dirpath='checkpoints/',
    filename='dsi-t5-lora-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    mode='min'
)

early_stopping_callback = EarlyStopping(monitor='validation_loss', patience=3, mode='min')

pruning_callback = ModelPruning("l1_unstructured", amount=0.5)

trainer = pl.Trainer(
    max_epochs=5,
    logger=logger,
    callbacks=[checkpoint_callback, early_stopping_callback], #pruning_callback],

    accelerator="auto",

    # gradient accumulation
    accumulate_grad_batches=4, # gradient is computed after 4 batches

    # mixed precision
    precision='16-mixed', # 16 bit precision of my model
)

# Train the model
trainer.fit(model, train_dataloader, val_dataloaders=val_dataloader)

wandb.finish()

INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


/usr/local/lib/python3.11/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /content/checkpoints exists and is not empty.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type                       | Params | Mode
--------------------------------------------------------------
0 | base_t5 | T5ForConditionalGeneration | 60.5 M | eval
--------------------------------------------------------------
6.1 K     Trainable params
60.5 M    Non-trainable params
60.5 M    Total params
242.051   Total estimated model params size (MB)
5         Modules in train mode
277       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

0,1
epoch,▁▁▃▃▆▆██
train_loss_epoch,▁█▃▆
trainer/global_step,▁▁▃▃▆▆██
validation_loss,▁▁▁▁

0,1
epoch,3.0
train_loss_epoch,38.40944
trainer/global_step,3.0
validation_loss,36.55384


# EMPTY GPU RAM

In [None]:
# dump everything in your GPU (call 3 times to work)
import gc
gc.collect()
# del model, trainer
torch.cuda.empty_cache()