<a href="https://colab.research.google.com/github/alessioborgi/DL_Project/blob/main/Source/pepotto.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

==================================================

**Project Name:** Neural inverted index for fast and effective information retrieval\
**Course:** Deep Learning\
**University:** Sapienza Università di Roma

**Authors:**
  - [Alessio Borgi] (<tt>1952442</tt>)
  - [Eugenio Bugli] (<tt>1934824</tt>)
  - [Damiano Imola] (<tt>2109063</tt>)

**Date:** [November 2024 - Completion Date]


**Implementations**
*   Differentiable Search Index architecture
*   DSI-Multi dataset generation


**Novelties**
*   Dynamic pruning
*   Semantic and Stopwords Augmentation
*   Part Of Speech Masked Language Model (POS-MLM) Augmentation




==================================================

# 0: INSTALL & IMPORT LIBRARIES

In [None]:
%%capture
!pip install -q --upgrade pip
!pip install -q pyserini==0.12.0 pytorch-lightning transformers datasets torch wandb

In [None]:
import os
import json
import datetime

os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"

import wandb
import numpy as np
import matplotlib as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor, StochasticWeightAveraging, DeviceStatsMonitor, ModelPruning

from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer, AutoTokenizer, AutoModelForSequenceClassification, T5Tokenizer, T5ForConditionalGeneration, EncoderDecoderCache

from sklearn.preprocessing import normalize
from sklearn.cluster import AgglomerativeClustering, KMeans

from pyserini.index import IndexReader
from pyserini.search import SimpleSearcher

import nltk
from nltk.corpus import stopwords
from nltk.tokenize import sent_tokenize, RegexpTokenizer

from peft import LoraConfig, get_peft_model, TaskType

from google.colab import drive

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Currently using {device}")

# wandb keys
USERS = {
    "EUGENIO": "551e67be2c716a42ea3230a7c4fc639fc985f98f", 
    "ALESSIO": "b3bce19a09c51bdf8a19eb3dc58f7c44de929e13", 
    "DAMIANO": "6d550e12a1b8f716ebe580082f495c01ed2adf6c"}

In [None]:
wandb.login()
wandb.login(key=USERS["EUGENIO"])
wandb.init(project="IR_DSI", resume="allow")

drive.mount('/content/drive')

# 1: LOAD TOKENIZED DATASET

In [None]:
# Copy your data in local memory
!cp '/content/drive/MyDrive/deep-learning-files/train_data_tokenized.pt' '/content/train_data_tokenized.pt'
!cp '/content/drive/MyDrive/deep-learning-files/validation_data_tokenized.pt' '/content/validation_data_tokenized.pt'
!cp '/content/drive/MyDrive/deep-learning-files/test_data_tokenized.pt' '/content/test_data_tokenized.pt'

In [None]:
class DatasetLoader(torch.utils.data.Dataset):
    def __init__(self, file_name, up_to_k=5000):
        self.data = torch.load(file_name)[:up_to_k]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
train_dataset = DatasetLoader('/content/train_data_tokenized.pt', up_to_k=5000)
val_dataset = DatasetLoader('/content/validation_data_tokenized.pt', up_to_k=1000)
test_dataset = DatasetLoader('/content/test_data_tokenized.pt', up_to_k=1000)

  self.data = torch.load(file_name)[:up_to_k]


In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=3, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=3)
test_dataloader = DataLoader(test_dataset, batch_size=3)

In [None]:
print(len(train_dataset), len(val_dataset), len(test_dataset))
print(len(train_dataset) + len(val_dataset) + len(test_dataset))

5000 1000 1000
7000


# 2: MODEL

In [None]:
MAX_ENCODER_SEQUENCE_LENGTH = 1797
MAX_DECODER_SEQUENCE_LENGTH = 4
MAX_DECODER_SEQUENCE_LENGTH_1000 = MAX_DECODER_SEQUENCE_LENGTH*1000

In [None]:
################################################################################
class DSIT5Model(pl.LightningModule):
    def __init__(self, model_name="t5-small", learning_rate=5e-5, max_decoder_sequence_len=MAX_DECODER_SEQUENCE_LENGTH, max_decoder_squence_len_1000=MAX_DECODER_SEQUENCE_LENGTH_1000):
        super().__init__()
        self.model = T5ForConditionalGeneration.from_pretrained(model_name) # transformer
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.learning_rate = learning_rate

    def training_step(self, batch, batch_idx):
        queries = torch.Tensor(batch['query']).squeeze(1).to(device)
        input_ids = torch.Tensor(batch['input_ids']).squeeze(1)
        decoder_input_ids = torch.Tensor(batch['decoder_input_ids']).squeeze(1)
        decoder_1000_input_ids = torch.Tensor(batch['decoder_ranked_input_ids']).squeeze(1)

        # indexing task
        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        # retrieval task
        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        # metrics = self.compute_metrics(input_ids, decoder_input_ids, decoder_1000_input_ids)

        self.log("train_loss", loss, on_epoch=True)
        return loss


    def validation_step(self, batch, batch_idx):
        queries = torch.Tensor(batch['query']).squeeze(1).to(device)
        input_ids = torch.Tensor(batch['input_ids']).squeeze(1)
        decoder_input_ids = torch.Tensor(batch['decoder_input_ids']).squeeze(1)
        decoder_1000_input_ids = torch.Tensor(batch['decoder_ranked_input_ids']).squeeze(1)

        # indexing task
        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        # retrieval task
        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        metrics = self.compute_metrics(input_ids, decoder_input_ids, decoder_1000_input_ids)

        self.log("validation_loss", loss, on_epoch=True)
        return loss



    def test_step(self, batch, batch_idx):
        queries = torch.Tensor(batch['query']).squeeze(1).to(device)
        input_ids = torch.Tensor(batch['input_ids']).squeeze(1)
        decoder_input_ids = torch.Tensor(batch['decoder_input_ids']).squeeze(1)
        decoder_1000_input_ids = torch.Tensor(batch['decoder_ranked_input_ids']).squeeze(1)

        # indexing task
        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        # retrieval task
        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        metrics = self.compute_metrics(queries, decoder_input_ids, decoder_1000_input_ids)

        self.log("test_loss", loss, on_epoch=True)
        return loss




    def compute_metrics(self, input_ids, decoder_input_ids, decoder_1000_input_ids):
        # infer top 1000 documents
        predicted_docids_tokenized = self.model.generate(input_ids, max_length=max_decoder_squence_len_1000)

        predicted_docids = []
        predicted_1000_docids = []
        for prediction in predicted_docids_tokenized:
            # retrieve top 1
            decoded = self.tokenizer.decode(prediction[:max_decoder_squence_len], skip_special_tokens=True)
            predicted_docids.append(decoded.split()[0])

            # retrieve top 1000
            decoded_1000 = self.tokenizer.decode(prediction[:max_decoder_squence_len_1000], skip_special_tokens=True)
            predicted_1000_docids.append(decoded_1000.split())


        # ground truth
        target_docids = []
        for ground_truth in decoder_input_ids:
            decoded = self.tokenizer.decode(ground_truth, skip_special_tokens=True)
            target_docids.append(decoded)


        # compute metrics
        recall_at_1000 = self.compute_recall_at_1000(predicted_1000_docids, decoder_1000_input_ids)
        map = self.compute_map(predicted_docids, target_docids)

        return {
            "recall_at_1000": recall_at_1000,
            "map": map
        }




    def compute_recall_at_1000(self, predicted_docids, target_docids):
        recalls = []

        for predicted, target in zip(predicted_docids, target_docids):
            predicted_set = set(predicted[:1000])
            target_set = set(target)

            if not target_set:
                recalls.append(0)
                continue

            recall = len(predicted_set.intersection(target_set)) / len(target_set)
            recalls.append(recall)

        return np.mean(recalls)



    def compute_map(self, predicted_docids, target_docids):
        aps = []
        for predicted, target in zip(predicted_docids, target_docids):
            target_set = set(target.split())

            if not target_set:
                aps.append(0)
                continue

            precision_at_k = []
            num_hits = 0

            for i, doc in enumerate(predicted_docids):
                if doc in target_set:
                    num_hits += 1
                    precision_at_k.append(num_hits / (i + 1))

            # Average Precision for this query
            aps.append(np.mean(precision_at_k) if precision_at_k else 0)

        return np.mean(aps)


    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

# 3: TRAINING

In [None]:
model = DSIT5Model()

logger = WandbLogger(project="IR_DSI_Project")

# ===== CALLBACKS =====
checkpoint_callback = ModelCheckpoint(monitor='validation_loss',
                                      dirpath='checkpoints/',
                                      filename='dsi-t5-{epoch:02d}-{val_loss:.2f}',
                                      save_top_k=1,
                                      mode='min')

early_stopping_callback = EarlyStopping(monitor='validation_loss',
                                        patience=3,
                                        mode='min')

# ===== DYNAMIC PRUNING ====
# removes individual weights based on magnitude or importance
# removes the ones with smallest L1 norm
# amount removes 20% of the smallest magnitude weights
pruning_callback = ModelPruning("l1_unstructured", amount=0.5)

trainer = pl.Trainer(
    max_epochs=5,
    logger=logger,
    callbacks=[checkpoint_callback, early_stopping_callback, pruning_callback],
    accelerator="auto",
    accumulate_grad_batches=4, # gradient is computed after 4 batches
    precision='16-mixed') # 16 bit precision of my model

trainer.fit(
    model,
    train_dataloader,
    val_dataloader)

wandb.finish()

INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                       | Params | Mode
------------------------------------------------------------
0 | model | T5ForConditionalGeneration | 60.5 M | eval
------------------------------------------------------------
60.5 M    Trainable params
0         Non-trainable params
60.5 M    Total params
242.026   Total estimated model params size (MB)
0         Modules in train mode
277       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

RuntimeError: Early stopping conditioned on metric `val_loss` which is not available. Pass in or modify your `EarlyStopping` callback to use any of the following: `train_loss`, `train_loss_step`, `validation_loss`, `train_loss_epoch`

In [None]:
# save locally
torch.save(model, '/content/checkpoints/<CHECKPOINT_NAME>')
# save to GDrive
!cp '/content/model.pth' '/content/drive/MyDrive/deep-learning-files/checkpoint_<EPOCH>.pth'


# save locally
torch.save(model, '/content/model.pth')
# save to GDrive
!cp '/content/model.pth' '/content/drive/MyDrive/deep-learning-files/model_<EPOCH>.pth'


# save locally
torch.save(model.state_dict(), '/content/model_state_dict.pth')
# save to GDrive
!cp '/content/model_state_dict.pth' '/content/drive/MyDrive/deep-learning-files/model_state_dict_<EPOCH>.pth'

# PEFT (Parameter Efficient Finetuning) with LORA (adapter)

In [None]:
class DSIT5ModelLORA(pl.LightningModule):
    def __init__(self, model_name="t5-small", learning_rate=5e-5, max_decoder_squence_len=max_decoder_squence_len, max_decoder_squence_len_1000=max_decoder_squence_len_1000, lora_r=8, lora_alpha=32, lora_dropout=0.1):
        super().__init__()

        self.T5 = T5ForConditionalGeneration.from_pretrained(model_name) # transformer

        # PEFT sets requires_grad=False on all original T5 layers
        # LoRA parameters (the low-rank adapters) are injected into the T5 attention layers and do requires_grad=True
        # during training, only these small LoRA adapters get updated, leaving the main T5 weights untouched
        self.peft_config = LoraConfig(
            task_type=TaskType.SEQ_2_SEQ_LM,
            inference_mode=False,
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout
        )

        self.model = get_peft_model(self.T5, self.peft_config)

        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.learning_rate = learning_rate


    def training_step(self, batch, batch_idx):
        queries = torch.Tensor(batch['query']).squeeze(1).to(device)
        input_ids = torch.Tensor(batch['input_ids']).squeeze(1)
        decoder_input_ids = torch.Tensor(batch['decoder_input_ids']).squeeze(1)
        decoder_1000_input_ids = torch.Tensor(batch['decoder_ranked_input_ids']).squeeze(1)

        # indexing task
        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        # retrieval task
        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        # metrics = self.compute_metrics(input_ids, decoder_input_ids, decoder_1000_input_ids)

        self.log("train_loss", loss, on_epoch=True)
        return loss


    def validation_step(self, batch, batch_idx):
        queries = torch.Tensor(batch['query']).squeeze(1).to(device)
        input_ids = torch.Tensor(batch['input_ids']).squeeze(1)
        decoder_input_ids = torch.Tensor(batch['decoder_input_ids']).squeeze(1)
        decoder_1000_input_ids = torch.Tensor(batch['decoder_ranked_input_ids']).squeeze(1)

        # indexing task
        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        # retrieval task
        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        metrics = self.compute_metrics(input_ids, decoder_input_ids, decoder_1000_input_ids)

        self.log("validation_loss", loss, on_epoch=True)
        return loss



    def test_step(self, batch, batch_idx):
        queries = torch.Tensor(batch['query']).squeeze(1).to(device)
        input_ids = torch.Tensor(batch['input_ids']).squeeze(1)
        decoder_input_ids = torch.Tensor(batch['decoder_input_ids']).squeeze(1)
        decoder_1000_input_ids = torch.Tensor(batch['decoder_ranked_input_ids']).squeeze(1)

        # indexing task
        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        # retrieval task
        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        metrics = self.compute_metrics(queries, decoder_input_ids, decoder_1000_input_ids)

        self.log("test_loss", loss, on_epoch=True)
        return loss



    def compute_metrics(self, input_ids, decoder_input_ids, decoder_1000_input_ids):
        # get_base_model().generate(...) uses the LoRA adapters in the forward pass
        # We don’t lose the LoRA modifications by calling the base model, because LoRA modifies the internals of T5’s attention layers

        # infer top 1000 documents
        predicted_docids_tokenized = self.model.get_base_model().generate(input_ids, max_length=max_decoder_squence_len_1000)

        predicted_docids = []
        predicted_1000_docids = []
        for prediction in predicted_docids_tokenized:
            # retrieve top 1
            decoded = self.tokenizer.decode(prediction[:max_decoder_squence_len], skip_special_tokens=True)
            predicted_docids.append(decoded.split()[0])

            # retrieve top 1000
            decoded_1000 = self.tokenizer.decode(prediction[:max_decoder_squence_len_1000], skip_special_tokens=True)
            predicted_1000_docids.append(decoded_1000.split())


        # ground truth
        target_docids = []
        for ground_truth in decoder_input_ids:
            decoded = self.tokenizer.decode(ground_truth, skip_special_tokens=True)
            target_docids.append(decoded)


        # compute metrics
        recall_at_1000 = self.compute_recall_at_1000(predicted_1000_docids, decoder_1000_input_ids)
        map = self.compute_map(predicted_docids, target_docids)

        return {
            "recall_at_1000": recall_at_1000,
            "map": map
        }




    def compute_recall_at_1000(self, predicted_docids, target_docids):
        recalls = []

        for predicted, target in zip(predicted_docids, target_docids):
            predicted_set = set(predicted[:1000])
            target_set = set(target)

            if not target_set:
                recalls.append(0)
                continue

            recall = len(predicted_set.intersection(target_set)) / len(target_set)
            recalls.append(recall)

        return np.mean(recalls)



    def compute_map(self, predicted_docids, target_docids):
        aps = []
        for predicted, target in zip(predicted_docids, target_docids):
            target_set = set(target.split())

            if not target_set:
                aps.append(0)
                continue

            precision_at_k = []
            num_hits = 0

            for i, doc in enumerate(predicted_docids):
                if doc in target_set:
                    num_hits += 1
                    precision_at_k.append(num_hits / (i + 1))

            # Average Precision for this query
            aps.append(np.mean(precision_at_k) if precision_at_k else 0)

        return np.mean(aps)


    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

In [None]:
model = DSIT5ModelLORA(
    lora_r=8,
    lora_alpha=32,
    lora_dropout=0.1
)

logger = WandbLogger(project="IR_DSI_Project")

# ===== CALLBACKS =====
checkpoint_callback = ModelCheckpoint(
    monitor='validation_loss',
    dirpath='checkpoints/',
    filename='dsi-t5-lora-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    mode='min'
)

early_stopping_callback = EarlyStopping(monitor='validation_loss', patience=3, mode='min')

# ===== DYNAMIC PRUNING ====
# removes individual weights based on magnitude or importance
# removes the ones with smallest L1 norm
# amount removes 20% of the smallest magnitude weights
pruning_callback = ModelPruning("l1_unstructured", amount=0.5)

trainer = pl.Trainer(
    max_epochs=5,
    logger=logger,
    callbacks=[checkpoint_callback, early_stopping_callback, pruning_callback],

    accelerator="auto",

    # gradient accumulation
    accumulate_grad_batches=4, # gradient is computed after 4 batches

    # mixed precision
    precision='16-mixed', # 16 bit precision of my model
)

# Train the model
trainer.fit(model, train_dataloader, val_dataloaders=val_dataloader)

wandb.finish()

INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                       | Params | Mode 
-------------------------------------------------------------
0 | T5    | T5ForConditionalGeneration | 60.8 M | eval 
1 | model | PeftModelForSeq2SeqLM      | 60.8 M | train
--------------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:
Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

# QLORA

In [None]:
import transformers as tra
print(tra.__version__)

4.47.1


In [None]:
import peft as pf
print(pf.__version__)

0.14.0


In [None]:
# !pip uninstall transformers peft bitsandbytes -y

In [None]:
!pip install bitsandbytes -q -U

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m44.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m18.9 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
sentence-transformers 3.3.1 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.30.0 which is incompatible.[0m[31m
[0m

In [None]:
import bitsandbytes as bnb

In [None]:
from transformers import T5ForConditionalGeneration, T5Tokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
from bitsandbytes.optim import AdamW8bit

class DSIT5ModelQLORA(pl.LightningModule):
    def __init__(self, model_name="t5-small", learning_rate=5e-5, max_decoder_squence_len=max_decoder_squence_len, max_decoder_squence_len_1000=max_decoder_squence_len_1000, lora_r=8, lora_alpha=32, lora_dropout=0.1):
        super().__init__()

        # 4-bit quantization
        bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.float16)

        # load T5 in 4-bit precision
        self.base_t5 = T5ForConditionalGeneration.from_pretrained(model_name, quantization_config=bnb_config, device_map="auto")

        self.peft_config = LoraConfig(
            task_type=TaskType.SEQ_2_SEQ_LM,
            inference_mode=False,
            r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout
        )

        self.model = get_peft_model(self.base_t5, self.peft_config)

        self.tokenizer = T5Tokenizer.from_pretrained(model_name)

        self.learning_rate = learning_rate
        self.max_decoder_squence_len = max_decoder_squence_len
        self.max_decoder_squence_len_1000 = max_decoder_squence_len_1000

    def training_step(self, batch, batch_idx):
        # Expecting your batch to have these keys
        queries = batch['query'].squeeze(1).to(device)
        input_ids = batch['input_ids'].squeeze(1).to(device)
        decoder_input_ids = batch['decoder_input_ids'].squeeze(1).to(device)
        decoder_1000_input_ids = batch['decoder_ranked_input_ids'].squeeze(1).to(device)

        # indexing task
        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        # retrieval task
        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        self.log("train_loss", loss, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        queries = batch['query'].squeeze(1).to(device)
        input_ids = batch['input_ids'].squeeze(1).to(device)
        decoder_input_ids = batch['decoder_input_ids'].squeeze(1).to(device)
        decoder_1000_input_ids = batch['decoder_ranked_input_ids'].squeeze(1).to(device)

        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        metrics = self.compute_metrics(input_ids, decoder_input_ids, decoder_1000_input_ids)

        self.log("validation_loss", loss, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        queries = batch['query'].squeeze(1).to(device)
        input_ids = batch['input_ids'].squeeze(1).to(device)
        decoder_input_ids = batch['decoder_input_ids'].squeeze(1).to(device)
        decoder_1000_input_ids = batch['decoder_ranked_input_ids'].squeeze(1).to(device)

        index_output = self.model(input_ids=input_ids, labels=decoder_input_ids)
        index_loss = index_output.loss

        retrieval_output = self.model(input_ids=input_ids, labels=queries)
        retrieval_loss = retrieval_output.loss

        loss = index_loss + retrieval_loss

        metrics = self.compute_metrics(queries, decoder_input_ids, decoder_1000_input_ids)

        self.log("test_loss", loss, on_epoch=True)
        return loss

    def compute_metrics(self, input_ids, decoder_input_ids, decoder_1000_input_ids):
        # For generation, we call get_base_model() to access T5ForConditionalGeneration.generate(...)
        predicted_docids_tokenized = self.model.get_base_model().generate(input_ids=input_ids, max_length=self.max_decoder_squence_len_1000)

        predicted_docids = []
        predicted_1000_docids = []

        for prediction in predicted_docids_tokenized:
            # top 1
            decoded_top1 = self.tokenizer.decode(prediction[:self.max_decoder_squence_len], skip_special_tokens=True)
            splitted = decoded_top1.split()
            predicted_docids.append(splitted[0] if splitted else "")

            # top 1000
            decoded_1000 = self.tokenizer.decode(prediction[:self.max_decoder_squence_len_1000], skip_special_tokens=True)
            predicted_1000_docids.append(decoded_1000.split())

        # ground truth
        target_docids = []
        for ground_truth in decoder_input_ids:
            decoded = self.tokenizer.decode(ground_truth, skip_special_tokens=True)
            target_docids.append(decoded)

        # compute metrics
        recall_at_1000 = self.compute_recall_at_1000(predicted_1000_docids, decoder_1000_input_ids)
        map_val = self.compute_map(predicted_docids, target_docids)

        return {
            "recall_at_1000": recall_at_1000,
            "map": map_val
        }

    def compute_recall_at_1000(self, predicted_docids, target_docids):
        recalls = []
        for predicted, target in zip(predicted_docids, target_docids):
            predicted_set = set(predicted[:1000])
            target_set = set(target)
            if not target_set:
                recalls.append(0)
                continue
            recall = len(predicted_set.intersection(target_set)) / len(target_set)
            recalls.append(recall)
        return np.mean(recalls)

    def compute_map(self, predicted_docids, target_docids):
        aps = []
        for predicted, target in zip(predicted_docids, target_docids):
            target_set = set(target.split())
            if not target_set:
                aps.append(0)
                continue

            precision_at_k = []
            num_hits = 0
            for i, doc in enumerate(predicted):
                if doc in target_set:
                    num_hits += 1
                    precision_at_k.append(num_hits / (i + 1))

            avg_precision = np.mean(precision_at_k) if precision_at_k else 0
            aps.append(avg_precision)

        return np.mean(aps)

    def configure_optimizers(self):
        optimizer = AdamW8bit(self.parameters(), lr=self.learning_rate)
        return optimizer

In [None]:
model = DSIT5ModelQLORA(
    lora_r=8,
    lora_alpha=32,
    lora_dropout=0.1
)

logger = WandbLogger(project="IR_DSI_Project")

# ===== CALLBACKS =====
checkpoint_callback = ModelCheckpoint(
    monitor='validation_loss',
    dirpath='checkpoints/',
    filename='dsi-t5-lora-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    mode='min'
)

early_stopping_callback = EarlyStopping(monitor='validation_loss', patience=3, mode='min')

# ===== DYNAMIC PRUNING ====
# removes individual weights based on magnitude or importance
# removes the ones with smallest L1 norm
# amount removes 20% of the smallest magnitude weights
pruning_callback = ModelPruning("l1_unstructured", amount=0.5)

trainer = pl.Trainer(
    max_epochs=5,
    logger=logger,
    callbacks=[checkpoint_callback, early_stopping_callback, pruning_callback],

    accelerator="auto",

    # gradient accumulation
    accumulate_grad_batches=4, # gradient is computed after 4 batches

    # mixed precision
    precision='16-mixed', # 16 bit precision of my model
)

# Train the model
trainer.fit(model, train_dataloader, val_dataloaders=val_dataloader)

wandb.finish()

ImportError: Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`

# EMPTY GPU RAM

In [None]:
# dump everything in your GPU (call 3 times to work)
import gc
gc.collect()
# del model, trainer
torch.cuda.empty_cache()

# 8: FURTHER IMPROVEMENTS AND TODOS

In [None]:
# TODO
# 1. refactoring totale (EUGENIO)
# 2. mean number of words in passages with and without stopwords/punktuation (plot)
# 3. (DAMIANO) Adversarial Natural Language Inference
# 4. Generate Dataset of top1000 results by using a ranker (Faiss)