<a href="https://colab.research.google.com/github/SZAftabi/UseRQE/blob/main/(Step2)TagGeneration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<center> <font size='6'> 💟 <b> UseRQE </b> 💟 </font> <br> </center>
<center>Recognizing Question Entailment with User Background-knowledge Modeling</center> <center> <font size='4' color='red'> <b> Step (2) </b> Tag generation </font> </center>


# 😎 **1. mount the drive**

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
Drive_path = "/content/drive/MyDrive/"

# 😎 **3. Libraries**

In [None]:
!pip install -q -U transformers                                                 # ==4.31.0
!pip install -q torchmetrics
!pip install -q pytorch_lightning
!pip install -q bitsandbytes
!pip install -q -U peft                                                         # ==0.4.0
!pip install -q accelerate                                                      # ==0.21.0
!pip install -q trl
!pip install -q tensorboard
!pip install -q datasets
!pip install -q rouge
!pip install -q bert-score

In [None]:
import os
import gc
import re
import torch
import warnings
import nltk
import json
import time
import requests
nltk.download('punkt')

import numpy as np
import pandas as pd
import bitsandbytes as bnb
import pytorch_lightning as pl
import matplotlib.pyplot as plt

In [None]:
# !pip install --upgrade huggingface-hub
# !pip install --upgrade transformers

In [None]:
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Callback
from tensorboard import notebook

from torchmetrics import MetricCollection
from torchmetrics.text.bert import BERTScore
from torchmetrics.text.rouge import ROUGEScore
from torchmetrics.classification import (
    BinaryAccuracy,
    BinaryPrecision,
    BinaryRecall,
    BinaryF1Score
    )

from peft import (
    TaskType,
    PeftModel,
    PeftConfig,
    LoraConfig,
    get_peft_model,
    AutoPeftModelForCausalLM,
    prepare_model_for_kbit_training,
    )

from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForCausalLM,
    HfArgumentParser,
    TrainingArguments,
    )

from dataclasses import dataclass, field
from nltk.tokenize import word_tokenize
from typing import Optional
from tqdm import tqdm
from bert_score import BERTScorer
from rouge import Rouge
from statistics import mean
from sklearn.model_selection import train_test_split

tqdm.pandas()
warnings.filterwarnings('ignore')
import transformers
print(transformers.__version__)

# 😎 **4. Helper Functions**

In [None]:
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
B_INST, E_INST = "[INST]", "[/INST]"

In [None]:
def get_tg_prompt(_question, _tags = None):
  system_prompt = 'You are a Tag Generator. Respond only with a list of tags; do not include any additional text or explanations.'
  user_prompt = f'''Please generate at least 5 tags for the provided question. Tags can include multi-word phrases if appropriate and should help hierarchically categorize the question's topics.
### Question:
{_question}
### Tags:
'''
  prompt = f"{B_INST} {B_SYS}{system_prompt}{E_SYS}{user_prompt} {E_INST}\n\n"
  if _tags: prompt += f'{_tags}</s>'
  return prompt

In [None]:
def get_response_index(_input_ids, _task):
  _index = None
  _skip_tokens = None
  if _task == 'RQE':
    _index = 2
    _skip_tokens = 10
  if _task == 'SUM':
    _index = 1
    _skip_tokens = 11
  if _task == 'TG':
    _index = 1
    _skip_tokens = 10
  hashtags_indexes = [i for i, n in enumerate(_input_ids) if n == 29937]
  if len(hashtags_indexes) > _index:
    return [i for i, n in enumerate(_input_ids) if n == 29937][_index] + _skip_tokens
  elif _task == 'RQE':
    return 0
  else:
    return -1

In [None]:
def generate_prompt(data, is_eval):
  promp = None
  if is_eval: prompt = get_tg_prompt(data['text'])
  else: prompt = get_tg_prompt(data['text'], data['tags'])
  return prompt

# 😎 **5. LLama2-TG**

## 🌻 **5.1. hyper-parameters**

In [None]:
@dataclass
class ScriptArguments:
    # ##########################################################################
    #                             Configuration
    # ##########################################################################
    model_name: Optional[str] = field(
        default = f"{Drive_path}llama-2-7b-chat-hf",
        metadata = {"help": "The model that you want to train from the Hugging Face hub."}
      )
    adapter_name: Optional[str] = field(
        default = "LLama-TG",
        metadata = {"help": "The adapter name saved in the HuggingFace hub."}
      )
    save_to: Optional[str] = field(
        default = "Drive",                                                      # Save to "Hub", or "Drive", or "Both"
        metadata = {"help": "Determine where to save Adapters"}
      )
    # ##########################################################################
    #                         Logs and Checkpoints
    # ##########################################################################
    logging_steps: Optional[int] = field(
        default = 1,
        metadata = {"help": "log every X update steps"}
      )
    output_dir: Optional[str] = field(
        default = "/content/UseRQE",
        metadata = {"help": "the output directory for both logs and checkpoints"}
      )
    every_n_epochs : Optional[int] = field(
        default = 1,
        metadata = {"help": "Save checkpoints every X epochs"}
      )
    save_on_train_epoch_end: Optional[bool] = field(
        default = None,
        metadata = {"help": "Whether to run checkpointing at the end of training epochs or validation"}
      )
    total_num_samples: Optional[str] = field(
        default = 'All',                                                        # Use {your desired number of samples} or 'All'
        metadata = {"help": "Number of samples to be selected from the whole dataset"}
      )
    # ##########################################################################
    #                             Hyper-parameters
    # ##########################################################################
    max_epochs: Optional[int] = field(
        default = 10,
        metadata = {"help": "maximum number of training epochs."}
      )
    learning_rate: Optional[float] = field(
        default = 1e-4,
        metadata = {"help": "the learning rate"}
      )
    gradient_accumulation_steps: Optional[int] = field(
        default = 8,
        metadata = {"help": "the number of gradient accumulation steps"}
      )
    gradient_checkpointing: Optional[bool] = field(
        default = True,
        metadata = {"help": "Enables gradient checkpointing."}
      )
    per_device_train_batch_size: Optional[int] = field(
        default = 4,
        metadata = {"help": "batch_size of training (per device)"}
      )
    per_device_eval_batch_size: Optional[int] = field(
        default = 1,
        metadata = {"help": "batch_size of validation (per device)"}
      )
    max_seq_length: Optional[int] = field(
        default = 512,
        metadata = {"help": "maximum input sequence length"}
      )
    trust_remote_code: Optional[bool] = field(
        default = True,
        metadata = {"help": '''Enable `trust_remote_code` so that it
        will execute code present on the Hub on your local machine'''}
      )
    split_ratio: Optional[float] = field(
        default = (0.8, 0.2, 0),
        metadata = {"help": "train/test/validation splits"}
      )
    precision: Optional[int] = field(
        default = 16,
        metadata = {"help": "train with 16/32/bf16 precision."}
      )
    num_sanity_val_steps: Optional[float] = field(
        default = 0,
        metadata = {"help": "number of validation batches before the first training epoch"}
      )
    max_new_tokens: Optional[int] = field(
        default = 30,
        metadata = {"help": "the maximum number of new tokens in the generated sequences (test step)"}
      )
    # ##########################################################################
    #                             Lora Configuration
    # ##########################################################################
    use_peft: Optional[bool] = field(
        default = True,
        metadata = {"help": "Wether to use PEFT or not to train adapters"}
      )
    lora_r: Optional[int] = field(
        default = 64,
        metadata = {"help": "the r parameter of the LoRA adapters"}
      )
    lora_alpha: Optional[int] = field(
        default = 64,
        metadata = {"help": "the alpha parameter of the LoRA adapters"}
      )
    lora_dropout: Optional[int] = field(
        default = 0.1,
        metadata = {"help": "the dropout rate of the LoRA adapters"}
      )
    # ##########################################################################
    #                                 BitsAndBytes
    # ##########################################################################
    load_in_8bit: Optional[bool] = field(
        default = False,
        metadata = {"help": "load the model in 8 bits precision"}
      )
    load_in_4bit: Optional[bool] = field(
        default = False,
        metadata = {"help": "load the model in 4 bits precision"}
      )
    use_nested_quant: Optional[bool] = field(
        default = False,
        metadata = {"help": "Activate nested quantization for 4bit base models"}
      )
    bnb_4bit_compute_dtype: Optional[str] = field(
        default = "float16",
        metadata = {"help": "Compute dtype for 4bit base models"}
      )
    bnb_4bit_quant_type: Optional[str] = field(
        default = "nf4",
        metadata = {"help": "Quantization type fp4 or nf4"}
      )

parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)[0]
pl.seed_everything(42)

## 🌻 **5.2. proposed model**

In [None]:
class OverrideEpochStepCallback(Callback):
    def __init__(self) -> None:
        super().__init__()

    def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        self._log_step_as_current_epoch(trainer, pl_module)

    def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        self._log_step_as_current_epoch(trainer, pl_module)

    def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        self._log_step_as_current_epoch(trainer, pl_module)

    def _log_step_as_current_epoch(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        pl_module.log("step", trainer.current_epoch + 1)

checkpoint_callback = ModelCheckpoint(every_n_epochs=script_args.every_n_epochs)

In [None]:
class TGModel(pl.LightningModule):
    def __init__(self, script_args):
        super(TGModel, self).__init__()
        self.save_hyperparameters()
        self.Setup(script_args)
        self.rouge = ROUGEScore()
        self.adapter_name = script_args.adapter_name
        self.epoch_n = 1

    def Setup(self, script_args):
        if script_args.load_in_4bit and script_args.load_in_8bit:
          raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
        elif script_args.load_in_4bit:
          compute_dtype = getattr(torch, script_args.bnb_4bit_compute_dtype)

          bnb_config = BitsAndBytesConfig(
              load_in_4bit = script_args.load_in_4bit,
              bnb_4bit_quant_type = script_args.bnb_4bit_quant_type,
              bnb_4bit_compute_dtype = compute_dtype,
              bnb_4bit_use_double_quant = script_args.use_nested_quant,
          )
          self.model = AutoModelForCausalLM.from_pretrained(
              script_args.model_name,
              quantization_config = bnb_config,
              device_map = {"": 0},
          )
        elif script_args.load_in_8bit:
          self.model = AutoModelForCausalLM.from_pretrained(
              script_args.model_name,
              load_in_8bit = True,
              torch_dtype = torch.float16,
              device_map = {"": 0},
          )
          self.model = prepare_model_for_kbit_training(self.model)
        else:
          self.model = AutoModelForCausalLM.from_pretrained(
              script_args.model_name,
              torch_dtype = torch.bfloat16,
              device_map = {"": 0},
          )

        if script_args.use_peft:
            lora_config = LoraConfig(
                task_type = TaskType.CAUSAL_LM,
                r = script_args.lora_r,
                lora_alpha = script_args.lora_alpha,
                lora_dropout = script_args.lora_dropout,
                bias = "none",
            )
            self.model = get_peft_model(self.model, lora_config)
            self.model.print_trainable_parameters()

        self.model.config.use_cache = False

        self.tokenizer = AutoTokenizer.from_pretrained(
            script_args.model_name,
            padding_side='left'
        )
        self.tokenizer.pad_token_id = 0
        self.model.config.pad_token_id = self.tokenizer.pad_token_id

    def forward(self, input_ids, attention_mask, labels=None):
        output = self.model(
            input_ids,
            attention_mask=attention_mask,
            labels=labels
            )
        return output.loss, output.logits

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        loss, _ = self.forward(input_ids, attention_mask, labels)
        self.log('train_loss', loss.item(), on_epoch=True, on_step=True)
        return loss


    def on_train_epoch_end(self):
      out_dir = f"{Drive_path}/UseRQE/TG/TG-Adapters/"
      self.model.save_pretrained(out_dir + self.adapter_name + str(self.epoch_n))
      self.epoch_n += 1

    def generate(self, *args, **kwargs):
      return self.model.generate(*args, **kwargs)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=script_args.learning_rate)

## 🌻 **5.3. model compile**

In [None]:
MyModel = TGModel(script_args)
logger = TensorBoardLogger(script_args.output_dir + 'logs', name="TG")

print(MyModel)
print("#"*60, "\n\t\t\t Model Configuration\n", "#"*60)
print(MyModel.model.config)

## 🌻 **5.4. data preparation**

In [None]:
MyData = pd.read_pickle(f"{Drive_path}UseRQE/TG/n2v_old_TG_Data_After_HieClustering.pkl")
MyData = MyData[['text', 'newtags']]
MyData.rename(columns = {'newtags': 'tags'}, inplace = True)

if script_args.total_num_samples != 'All':
  MyData = MyData[:int(script_args.total_num_samples)]

print(MyData.shape)
display(MyData[0:10])

In [None]:
class TGDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, max_len, is_eval):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.is_eval = is_eval

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
      row_data = self.data.iloc[index]
      prompt = generate_prompt(row_data, self.is_eval)
      prompt_encoding = self.tokenizer(
          prompt,
          max_length = self.max_len,
          padding = 'max_length',
          truncation = True,
          add_special_tokens = True,
          return_tensors = 'pt',
      )
      input_ids = prompt_encoding['input_ids'].squeeze()
      attention_mask = prompt_encoding['attention_mask'].squeeze()

      if self.is_eval == False:
        response_index = get_response_index(input_ids, 'TG')
        if response_index:
          labels = torch.cat((torch.full((response_index,), -100), input_ids[response_index:])).squeeze()
        else:
          print('response_index not found')
      else:
        labels = self.tokenizer(
            row_data['tags'] + '</s>',
            add_special_tokens = False,
            return_tensors='pt',
        )
        labels = labels['input_ids'].squeeze()
      return {
          'input_ids': input_ids,
          'attention_mask': attention_mask,
          'labels': labels
      }

In [None]:
class TGDataModule(pl.LightningDataModule):
    def __init__(self, data, tokenizer, script_args):
        super().__init__()
        self.data = data
        self.tokenizer = tokenizer
        self.per_device_train_batch_size = script_args.per_device_train_batch_size
        self.per_device_eval_batch_size = script_args.per_device_eval_batch_size
        self.max_len = script_args.max_seq_length
        self.setup()

    def setup(self, stage=None):
        len_tr = int(script_args.split_ratio[0] * self.data.shape[0])
        len_te = int(script_args.split_ratio[1] * self.data.shape[0])
        train_data, test_data = train_test_split(self.data,
                                                 test_size=len_te,
                                                 random_state=42)
        train_data.reset_index(drop=True, inplace=True)
        test_data.reset_index(drop=True, inplace=True)

        self.train_data = TGDataset(train_data, self.tokenizer, self.max_len, is_eval=False)
        self.test_data = TGDataset(test_data, self.tokenizer, self.max_len, is_eval=True)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_data,
            batch_size=self.per_device_train_batch_size,
            shuffle=True,
            num_workers=8,
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_data,
            sampler = torch.utils.data.SequentialSampler(self.test_data,),
            batch_size= self.per_device_eval_batch_size,
            num_workers=8
        )

In [None]:
DataModule = TGDataModule(
    MyData,
    MyModel.tokenizer,
    script_args
)
print("num train batches", len(DataModule.train_dataloader()))
print("num test batches", len(DataModule.test_dataloader()))

## 🌻 **5.5. training**

In [None]:
trainer = pl.Trainer(
    logger = logger,
    log_every_n_steps = script_args.logging_steps,
    max_epochs = script_args.max_epochs,
    accumulate_grad_batches = script_args.gradient_accumulation_steps,
    num_sanity_val_steps = script_args.num_sanity_val_steps,
    callbacks = [OverrideEpochStepCallback(), checkpoint_callback],
    default_root_dir= script_args.output_dir + 'Checkpoints',
    )

In [None]:
%reload_ext tensorboard
%tensorboard --logdir /content/UseRQElogs

trainer.fit(
    MyModel,
    datamodule=DataModule,
)

!cp -r /content/UseRQElogs /content/drive/MyDrive/UseRQE/TG/UseRQElogs_TG

## 🌻 **5.6. save adapters**

save model in:<br>
1.    **local directory** 📁   
 or   <br>
2.   **HuggingFace 🤗 Hub**:



In [None]:
if script_args.save_to == "Both" or script_args.save_to == "Drive":
  MyModel.model.save_pretrained(
      f"{Drive_path}UseRQE/TG/TG-Adapters/{script_args.adapter_name}"
      )
  print(
      "Model successfully saved in ",
      script_args.output_dir + script_args.adapter_name
      )

if script_args.save_to == "Both" or script_args.save_to == "Hub":
  MyModel.model.push_to_hub(script_args.adapter_name)
  print("Model successfully saved in ", script_args.adapter_name)

In [None]:
MyModel.model.save_pretrained(f"/content/drive/MyDrive/UseRQE/TG/TG-Adapters/LLama-TG10")

## 🌻 **5.7. load model**

In [None]:
del tokenizer
del trainer
del MyModel
del fModel
del BaseModel
gc.collect()
torch.cuda.empty_cache()

In [None]:
BaseModel= AutoModelForCausalLM.from_pretrained(
    f"{Drive_path}llama-2-7b-chat-hf",
    device_map={"": 0},
    offload_folder="offload",
    offload_state_dict = True,
    # load_in_8bit = True,
    )

address = f"/content/drive/MyDrive/UseRQE/TG/TG-Adapters/LLama-TG10"
print("\n Loading model from ", address, "\n")
config = PeftConfig.from_pretrained(address)
fModel= PeftModel.from_pretrained(BaseModel,address,device_map={"": 0})
fModel = fModel.merge_and_unload()
print(fModel)
print(fModel.config)
print("\n Model successfully loded from ", address, "\n")


tokenizer = AutoTokenizer.from_pretrained(
    script_args.model_name,
    padding_side='left'
    )

tokenizer.pad_token_id = 0
fModel.config.pad_token_id = tokenizer.pad_token_id

In [None]:
DataModule = TGDataModule(
    MyData,
    tokenizer,
    script_args
)
print("num train batches", len(DataModule.train_dataloader()))
print("num test batches", len(DataModule.test_dataloader()))

## 🌻 **5.8. test**

In [None]:
def test_step(test_dl):
  testOutputs = []

  for batch in test_dl:
    input_ids = batch['input_ids'].cuda()
    attention_mask = batch['attention_mask'].cuda()
    labels = batch['labels'].cuda()

    generated_txts_ids = fModel.generate(
        input_ids = input_ids,
        max_new_tokens = script_args.max_new_tokens,
        do_sample=True,
        temperature=0.97
        ).squeeze()

    generated_txts = tokenizer.decode(
        generated_txts_ids[get_response_index(generated_txts_ids, 'TG'):],
        skip_special_tokens = False,
        clean_up_tokenization_spaces = True
        )

    labels = torch.where(
        labels != -100,
        labels,
        tokenizer.pad_token_id
        ).squeeze()

    target_txts = tokenizer.decode(
        labels,
        skip_special_tokens = False,
        clean_up_tokenization_spaces = True
        )

    testOutputs.append([generated_txts[:-4], target_txts[:-4]])

  return testOutputs

In [None]:
fModel.eval()
testOutputs = test_step(DataModule.test_dataloader())
testOutputs_file_name = f"{Drive_path}UseRQE/TG/TG_test_outputs.pkl"
pd.DataFrame(testOutputs).to_pickle(testOutputs_file_name)

In [None]:
testOutputs_file_name = f"{Drive_path}UseRQE/TG/TG_test_outputs.pkl"
testOutputs = pd.read_pickle(testOutputs_file_name)
testOutputs

In [None]:
testOutputs.columns=['generated_tags', 'target_tags']
testOutputs['generated_tags'] = testOutputs['generated_tags'].str.replace('\n\n', '')
testOutputs

substring = '[/INST]'
testOutputs['generated_tags'] = testOutputs['generated_tags'].apply(
    lambda x: x[x.find(substring)+7:] if substring in x else ""
)
display(testOutputs)

## 🌻 **5.9. test evaluation**
Rouge & BERTScore

In [None]:
testOutputs_file_name = f"{Drive_path}UseRQE/TG/TG_test_outputs.pkl"
testOutputs = pd.read_pickle(testOutputs_file_name)
testOutputs

In [None]:
scorer = BERTScorer(lang="en", device="cuda")
P, R, F1 = scorer.score(testOutputs['generated_tags'].to_list(), testOutputs['target_tags'].to_list(), verbose=False)
print(f"BERTScore Precision: {P}")
print(f"BERTScore Recall: {R}")
print(f"BERTScore F1: {F1}")
print(f"BERTScore Precision: {P.mean():.4f}, Recall: {R.mean():.4f}, F1: {F1.mean():.4f}")

rouge = Rouge()
scores2 = rouge.get_scores(testOutputs['generated_tags'].to_list(), testOutputs['target_tags'].to_list(), avg=True)
print("rouge-1:", scores2['rouge-1'])
print("rouge-2:",scores2['rouge-2'])
print("rouge-l:",scores2['rouge-l'])