<a href="https://colab.research.google.com/github/SZAftabi/UseRQE/blob/main/5_RecognizingQuestionEntailment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<center> <font size='6'> üíü <b> UseRQE </b> üíü </font> <br> </center>
<center>Recognizing Question Entailment with User Background-knowledge Modeling <br> </center> <center> <font size='4' color='red'> <b> Step (5) </b> User-aware/User-agnostic question entailment recognition </font> </center>


# üòé **Mount the drive**

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
Drive_path = "/content/drive/MyDrive/"

# üòé **1. Libraries**

In [None]:
!pip install -q -U transformers                                                 # ==4.31.0
!pip install -q torchmetrics
!pip install -q pytorch_lightning
!pip install -q bitsandbytes
!pip install -q -U peft                                                         # ==0.4.0
!pip install -q accelerate                                                      # ==0.21.0
!pip install -q trl
!pip install -q tensorboard
!pip install -q datasets
!pip install -q rouge
!pip install -q bert-score

In [None]:
import os
import re
import torch
import warnings
import nltk
import json
import time
import requests
import sklearn
import gc
nltk.download('punkt')

import numpy as np
import pandas as pd
import bitsandbytes as bnb
import pytorch_lightning as pl
import matplotlib.pyplot as plt

In [None]:
# !pip install --upgrade -q huggingface-hub
# !pip install --upgrade -q transformers

In [None]:
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Callback
from pytorch_lightning.callbacks import ModelCheckpoint
from tensorboard import notebook

from torchmetrics import MetricCollection
from torchmetrics.text.bert import BERTScore
from torchmetrics.text.rouge import ROUGEScore
from torchmetrics.classification import (
    BinaryAccuracy,
    BinaryPrecision,
    BinaryRecall,
    BinaryF1Score
    )

from peft import (
    TaskType,
    PeftModel,
    PeftConfig,
    LoraConfig,
    get_peft_model,
    AutoPeftModelForCausalLM,
    prepare_model_for_kbit_training,
    )

from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForCausalLM,
    HfArgumentParser,
    TrainingArguments,
    )

from dataclasses import dataclass, field
from nltk.tokenize import word_tokenize
from typing import Optional
from tqdm import tqdm
from bert_score import BERTScorer
from rouge import Rouge
from statistics import mean
from sklearn.model_selection import train_test_split
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

tqdm.pandas()
warnings.filterwarnings('ignore')
import transformers
print(transformers.__version__)

# üòé **2. Helper Functions**

In [None]:
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
B_INST, E_INST = "[INST]", "[/INST]"

In [None]:
# ========== User-agnostic ==========
# def get_rqe_prompt(_q1, _q2, _entailment=None):
#     system_prompt = "Given two questions, Q1 and Q2, determine if Q1 entails Q2 or not."
#     user_prompt = f'''Entailment means every answer to Q2 must fully or partially answer Q1.
# Respond with "Entailed" or "Not-entailed" only.
# Q1: {_q1}
# Q2: {_q2}
# ### Answer:
# '''
#     prompt = f"{B_INST} {B_SYS}{system_prompt}{E_SYS}{user_prompt}{E_INST}\n\n"
#     if _entailment: prompt += f"{_entailment}</s>"
#     return prompt


# ========== User-aware ==========
def get_rqe_prompt(_q1, _q2, _BN, _entailment=None):
    system_prompt = "Given two questions, Q1 and Q2, determine if Q1 entails Q2 or not."
    user_prompt = f'''Entailment means every answer to Q2 must fully or partially answer Q1.
Note that, Q2 must align with the user's topics of interest: ({_BN}).
Respond with "Entailed" or "Not-entailed" only.
Q1: {_q1}
Q2: {_q2}
### Answer:
'''
    prompt = f"{B_INST} {B_SYS}{system_prompt}{E_SYS}{user_prompt}{E_INST}\n\n"
    if _entailment: prompt += f"{_entailment}</s>"
    return prompt

In [None]:
def get_response_index(_input_ids, _task):
  _index = None
  _skip_tokens = None
  if _task == 'RQE':
    _index = 0
    _skip_tokens = 10
  if _task == 'SUM':
    _index = 1
    _skip_tokens = 11
  if _task == 'TG':
    _index = 1
    _skip_tokens = 10
  hashtags_indexes = [i for i, n in enumerate(_input_ids) if n == 29937]
  if len(hashtags_indexes) > _index:
    return [i for i, n in enumerate(_input_ids) if n == 29937][_index] + _skip_tokens
  elif _task == 'RQE':
    return 0
  else:
    return -1

In [None]:
# ========== User-agnostic ==========
# def generate_prompt_rqe(data, is_eval):
#   promp = None
#   if is_eval: prompt = get_rqe_prompt(data['q1'], data['q2'])
#   else: prompt = get_rqe_prompt(data['q1'], data['q2'], data['entailment'])
#   return prompt


# ========== User-aware ==========
def generate_prompt_rqe(data, is_eval):
  promp = None
  if is_eval: prompt = get_rqe_prompt(data['q1'], data['q2'], data['U_Background_kn'])
  else: prompt = get_rqe_prompt(data['q1'], data['q2'], data['U_Background_kn'], data['entailment'])
  return prompt

# üòé **3. LLama2-RQE**

## üåª **3.1. hyper-parameters**

In [None]:
@dataclass
class ScriptArguments:
    # ##########################################################################
    #                             Configuration
    # ##########################################################################
    model_name: Optional[str] = field(
        default = f"{Drive_path}llama-2-7b-chat-hf",
        metadata = {"help": "The model that you want to train from the Hugging Face hub."}
      )
    adapter_name: Optional[str] = field(
        default = "LLama-RQE",
        metadata = {"help": "The adapter name saved in the HuggingFace hub."}
      )
    save_to: Optional[str] = field(
        default = "Drive",                                                       # Save to "Hub", or "Drive", or "Both"
        metadata = {"help": "Determine where to save Adapters"}
      )
    # ##########################################################################
    #                         Logs and Checkpoints
    # ##########################################################################
    logging_steps: Optional[int] = field(
        default = 1,
        metadata = {"help": "log every X update steps"}
      )
    output_dir: Optional[str] = field(
        default = "/content/LLama",
        metadata = {"help": "the output directory"}
      )
    every_n_epochs : Optional[int] = field(
        default = 1,
        metadata = {"help": "Save checkpoints every X epochs"}
      )
    save_on_train_epoch_end: Optional[bool] = field(
        default = None,
        metadata = {"help": "Whether to run checkpointing at the end of training epochs or validation"}
      )
    total_num_samples: Optional[str] = field(
        default = 'All',
        metadata = {"help": "Number of samples to be selected from the whole dataset"}
      )
    # ##########################################################################
    #                             Hyper-parameters
    # ##########################################################################
    max_epochs: Optional[int] = field(
        default = 10,
        metadata = {"help": "maximum number of training epochs."}
      )
    learning_rate: Optional[float] = field(
        default = 1e-4,
        metadata = {"help": "the learning rate"}
      )
    gradient_accumulation_steps: Optional[int] = field(
        default = 2,
        metadata = {"help": "the number of gradient accumulation steps"}
      )
    gradient_checkpointing: Optional[bool] = field(
        default = True,
        metadata = {"help": "Enables gradient checkpointing."}
      )
    per_device_train_batch_size: Optional[int] = field(
        default = 8,
        metadata = {"help": "batch_size of training (per device)"}
      )
    per_device_eval_batch_size: Optional[int] = field(
        default = 8,
        metadata = {"help": "batch_size of validation (per device)"}
      )
    max_seq_length: Optional[int] = field(
        default = 512,
        metadata = {"help": "maximum input sequence length"}
      )
    trust_remote_code: Optional[bool] = field(
        default = True,
        metadata = {"help": '''Enable `trust_remote_code` so that it
        will execute code present on the Hub on your local machine'''}
      )
    split_ratio: Optional[float] = field(
        default = (0.8, 0.2, 0),
        metadata = {"help": "train/test/validation splits"}
      )
    precision: Optional[int] = field(
        default = 16,
        metadata = {"help": "train with 16/32/bf16 precision."}
      )
    num_sanity_val_steps: Optional[float] = field(
        default = 0,
        metadata = {"help": "number of validation batches before the first training epoch"}
      )
    max_new_tokens: Optional[int] = field(
        default = 5,
        metadata = {"help": "the maximum number of new tokens in the generated sequences (test step)"}
      )
    # ##########################################################################
    #                             Lora Configuration
    # ##########################################################################
    use_peft: Optional[bool] = field(
        default = True,
        metadata = {"help": "Wether to use PEFT or not to train adapters"}
      )
    lora_r: Optional[int] = field(
        default = 64,
        metadata = {"help": "the r parameter of the LoRA adapters"}
      )
    lora_alpha: Optional[int] = field(
        default = 64,
        metadata = {"help": "the alpha parameter of the LoRA adapters"}
      )
    lora_dropout: Optional[int] = field(
        default = 0.1,
        metadata = {"help": "the dropout rate of the LoRA adapters"}
      )
    # ##########################################################################
    #                                 BitsAndBytes
    # ##########################################################################
    load_in_8bit: Optional[bool] = field(
        default = False,
        metadata = {"help": "load the model in 8 bits precision"}
      )
    load_in_4bit: Optional[bool] = field(
        default = False,
        metadata = {"help": "load the model in 4 bits precision"}
      )
    use_nested_quant: Optional[bool] = field(
        default = False,
        metadata = {"help": "Activate nested quantization for 4bit base models"}
      )
    bnb_4bit_compute_dtype: Optional[str] = field(
        default = "float16",
        metadata = {"help": "Compute dtype for 4bit base models"}
      )
    bnb_4bit_quant_type: Optional[str] = field(
        default = "nf4",
        metadata = {"help": "Quantization type fp4 or nf4"}
      )

parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)[0]
pl.seed_everything(42)

## üåª **3.2. proposed model**

In [None]:
class OverrideEpochStepCallback(Callback):
    def __init__(self) -> None:
        super().__init__()

    def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        self._log_step_as_current_epoch(trainer, pl_module)

    def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        self._log_step_as_current_epoch(trainer, pl_module)

    def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        self._log_step_as_current_epoch(trainer, pl_module)

    def _log_step_as_current_epoch(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        pl_module.log("step", trainer.current_epoch + 1)

checkpoint_callback = ModelCheckpoint(every_n_epochs = script_args.every_n_epochs,)

In [None]:
class RQEModel(pl.LightningModule):
    def __init__(self, script_args):
        super(RQEModel, self).__init__()
        self.save_hyperparameters()
        self.Setup(script_args)
        self.rouge = ROUGEScore()
        self.adapter_name = script_args.adapter_name
        self.epoch_n = 1

    def Setup(self, script_args):
        if script_args.load_in_4bit and script_args.load_in_8bit:
          raise ValueError(
              "You can't load the model in 8 bits and 4 bits at the same time"
              )
        elif script_args.load_in_4bit:
          compute_dtype = getattr(torch, script_args.bnb_4bit_compute_dtype)

          bnb_config = BitsAndBytesConfig(
              load_in_4bit = script_args.load_in_4bit,
              bnb_4bit_quant_type = script_args.bnb_4bit_quant_type,
              bnb_4bit_compute_dtype = compute_dtype,
              bnb_4bit_use_double_quant = script_args.use_nested_quant,
          )
          self.model = AutoModelForCausalLM.from_pretrained(
              script_args.model_name,
              quantization_config = bnb_config,
              device_map = {"": 0},
          )
        elif script_args.load_in_8bit:
          self.model = AutoModelForCausalLM.from_pretrained(
              script_args.model_name,
              load_in_8bit = True,
              torch_dtype = torch.float16,
              device_map = {"": 0},
          )
          self.model = prepare_model_for_kbit_training(self.model)

        else:
          self.model = AutoModelForCausalLM.from_pretrained(
              script_args.model_name,
              torch_dtype = torch.bfloat16,
              device_map = {"": 0},
          )

        if script_args.use_peft:
            lora_config = LoraConfig(
                task_type = TaskType.CAUSAL_LM,
                r = script_args.lora_r,
                lora_alpha = script_args.lora_alpha,
                lora_dropout = script_args.lora_dropout,
                bias = "none",
            )
            self.model = get_peft_model(self.model, lora_config)
            self.model.print_trainable_parameters()

        self.model.config.use_cache = False

        self.tokenizer = AutoTokenizer.from_pretrained(
            script_args.model_name,
            padding_side='left'
        )
        self.tokenizer.pad_token_id = 0
        self.model.config.pad_token_id = self.tokenizer.pad_token_id

    def forward(self, input_ids, attention_mask, labels=None):
        output = self.model(input_ids,
                            attention_mask=attention_mask,
                            labels=labels
                            )
        return output.loss, output.logits

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        loss, _ = self.forward(input_ids, attention_mask, labels)
        self.log('train_loss', loss.item(), on_epoch=True, on_step=True)
        return loss

    def on_train_epoch_end(self):

      # ========== User-aware ==========
      out_dir = f"{Drive_path}/LLama/LLAMA-RQE-UM/"

      # ========== User-agnostic ==========
      # out_dir = f"{Drive_path}/LLama/LLAMA-RQE-WoUM/"

      self.model.save_pretrained(
          out_dir + self.adapter_name + str(self.epoch_n)
          )
      self.epoch_n += 1

    def generate(self, *args, **kwargs):
      return self.model.generate(*args, **kwargs)

    def configure_optimizers(self):
        return torch.optim.AdamW(
            self.model.parameters(),
            lr=script_args.learning_rate
            )

## üåª **3.3. model compile**

In [None]:
MyModel = RQEModel(script_args)
logger = TensorBoardLogger(script_args.output_dir + 'logs', name="RQE")

print(MyModel)
print("#"*60, "\n\t\t\t Model Configuration\n", "#"*60)
print(MyModel.model.config)

## üåª **3.4. data preparation**

In [None]:
data_path_LLama = f"/content/drive/MyDrive/RQE_Data.pkl"
MyData_LLama = pd.read_pickle(data_path_LLama)

MyData2 = pd.read_pickle(f"/content/drive/MyDrive/RQE_Data_T20_UK.pkl")

MyData_LLama.loc[(MyData_LLama['body_Q2']=='') & (MyData_LLama['userid_Q2']=='65001'), 'body_Q2']='So when I launch Minecraft, before it finishes loading, it crashes. I do not understand what is going on. Could someone help me? Here is my crash report:'
MyData_LLama.loc[(MyData_LLama['body_Q2']=='') & (MyData_LLama['userid_Q2']=='36896'), 'body_Q2']='How do I type the infinity symbol in MacTex'
MyData_LLama.loc[(MyData_LLama['body_Q2']=='') & (MyData_LLama['userid_Q2']=='3031'), 'body_Q2']='Run time error for GP objects'
MyData_LLama.loc[(MyData_LLama['body_Q1']=='') & (MyData_LLama['userid_Q2']=='65001'), 'body_Q1']='Misplaced allignment tab character line 53'
MyData_LLama.loc[(MyData_LLama['body_Q1']=='') & (MyData_LLama['userid_Q2']=='16188'), 'body_Q1']='How to Export this animation as a gif file for powerpoint presentation'
MyData_LLama.loc[(MyData_LLama['body_Q1']=='') & (MyData_LLama['userid_Q2']=='24829'), 'body_Q1']='why does rotation style work on actual coordinates and not variables in tikz 3d plot'

MyData_LLama.loc[(MyData_LLama['body_Q2']=='') & (MyData_LLama['userid_Q2']=='50615'), 'body_Q2']='How set a table in margin'
MyData_LLama.loc[(MyData_LLama['body_Q2']=='') & (MyData_LLama['userid_Q2']=='23835'), 'body_Q2']='Latex equation positioning problem'
MyData_LLama.loc[(MyData_LLama['body_Q2']=='') & (MyData_LLama['userid_Q2']=='14524'), 'body_Q2']='Chapter comment with regulation'
MyData_LLama.loc[(MyData_LLama['body_Q2']=='') & (MyData_LLama['userid_Q2']=='50823'), 'body_Q2']='minipage goes beyond right margin'

# # ========== User-aware ==========
MyData = pd.concat([
    MyData_LLama[['body_Q1', 'body_Q2', 'entailment']],
    MyData2['U_Background_kn']
    ], axis=1)

# ========== User-agnostic ==========
# MyData = MyData_LLama[['body_Q1', 'body_Q2', 'entailment']]

MyData = MyData.rename(columns={'body_Q1': 'q1', 'body_Q2': 'q2'})
display(MyData)

In [None]:
class RQEDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, max_len, is_eval):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.is_eval = is_eval

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
      row_data = self.data.iloc[index]
      prompt = generate_prompt_rqe(row_data, self.is_eval)
      prompt_encoding = self.tokenizer(prompt,
                                       max_length = self.max_len,
                                       padding = 'max_length',
                                       truncation = True,
                                       add_special_tokens = True,
                                       return_tensors = 'pt',
                                       )
      input_ids = prompt_encoding['input_ids'].squeeze()
      attention_mask = prompt_encoding['attention_mask'].squeeze()

      if self.is_eval == False:
        response_index = get_response_index(input_ids, 'RQE')
        if response_index:
          start_indexes = [i for i, n in enumerate(input_ids) if n == 1]
          labels = torch.cat(
              (torch.full((start_indexes[0],), -100),
               input_ids[start_indexes[0]:])
              ).squeeze()
        else:
          print('response_index not found')
      else:
        labels = self.tokenizer(row_data['entailment'] + '</s>',
                                add_special_tokens = False,
                                truncation = True,
                                max_length = 5,
                                padding = 'max_length',
                                return_tensors='pt',
                                )
        labels = labels['input_ids'].squeeze()
      return {
          'input_ids': input_ids,
          'attention_mask': attention_mask,
          'labels': labels
          }

In [None]:
class RQEDataModule(pl.LightningDataModule):
    def __init__(self, data, tokenizer, script_args):
        super().__init__()
        self.data = data
        self.tokenizer = tokenizer
        self.per_device_train_batch_size = script_args.per_device_train_batch_size
        self.per_device_eval_batch_size = script_args.per_device_eval_batch_size
        self.max_len = script_args.max_seq_length
        self.setup()

    def setup(self, stage=None):
        len_tr = int(script_args.split_ratio[0] * self.data.shape[0])
        len_te = int(script_args.split_ratio[1] * self.data.shape[0])
        train_data, test_data = train_test_split(self.data,
                                                 test_size=len_te,
                                                 random_state=42)

        train_data.reset_index(drop=True, inplace=True)
        test_data.reset_index(drop=True, inplace=True)

        self.train_data = RQEDataset(train_data,
                                     self.tokenizer,
                                     self.max_len,
                                     is_eval=False)
        self.test_data = RQEDataset(test_data,
                                    self.tokenizer,
                                    self.max_len,
                                    is_eval=True)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_data,
            batch_size=self.per_device_train_batch_size,
            shuffle=True,
            num_workers=4,
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_data,
            sampler = torch.utils.data.SequentialSampler(self.test_data,),
            batch_size= self.per_device_eval_batch_size,
            num_workers=2
        )

In [None]:
RQE_DataModule = RQEDataModule(
    MyData,
    MyModel.tokenizer,
    script_args
)
print("num train batches", len(RQE_DataModule.train_dataloader()))
print("num test batches", len(RQE_DataModule.test_dataloader()))

## üåª **3.5. training**

In [None]:
trainer = pl.Trainer(
    logger = logger,
    log_every_n_steps = script_args.logging_steps,
    max_epochs = script_args.max_epochs,
    accumulate_grad_batches = script_args.gradient_accumulation_steps,
    num_sanity_val_steps = script_args.num_sanity_val_steps,
    callbacks = [OverrideEpochStepCallback(), checkpoint_callback],
    default_root_dir= script_args.output_dir + 'Checkpoints',
    )

In [None]:
%reload_ext tensorboard
%tensorboard --logdir /content/LLamalogs/RQE --samples_per_plugin scalars=6000

trainer.fit(
    MyModel,
    datamodule=RQE_DataModule,
)

# ========== User-aware ==========
!cp -r /content/LLamalogs/RQE /content/drive/MyDrive/LLama/LLama_UM

# ========== User-agnostic ==========
# !cp -r /content/LLamalogs/RQE /content/drive/MyDrive/LLama/LLama_WoUM

## üåª **3.6. save adapters**

save model in:<br>
1.    **local directory** üìÅ   
 or   <br>
2.   **HuggingFace ü§ó Hub**:



In [None]:
if script_args.save_to == "Both" or script_args.save_to == "Drive":
  MyModel.model.save_pretrained(f"{Drive_path}LLama/{script_args.adapter_name}")
  print("Model successfully saved in ", script_args.output_dir + script_args.adapter_name)

if script_args.save_to == "Both" or script_args.save_to == "Hub":
  MyModel.model.push_to_hub(script_args.adapter_name)
  print("Model successfully saved in ", script_args.adapter_name)

## üåª **3.7. Test**

In [None]:
tokenizer=None
trainer=None
MyModel = None
MyModel2 = None
fModel = None
BaseModel = None
gc.collect()
torch.cuda.empty_cache()

In [None]:
BaseModel= AutoModelForCausalLM.from_pretrained(
    f"{Drive_path}llama-2-7b-chat-hf",
    device_map={"": 0},
    offload_folder="offload",
    offload_state_dict = True,
    # load_in_8bit = True
    )

# ========== User-agnostic ==========
# address = f"/content/drive/MyDrive/LLama/LLAMA-RQE-WoUM/LLama-RQE10"

# ========== User-aware ==========
address = f"/content/drive/MyDrive/LLama/LLAMA-RQE-UM/LLama-RQE10"

print("\n Loading model from ", address, "\n")
config = PeftConfig.from_pretrained(address)
fModel= PeftModel.from_pretrained(BaseModel, address, device_map={"": 0})
fModel = fModel.merge_and_unload()

tokenizer = AutoTokenizer.from_pretrained(
    f'{Drive_path}llama-2-7b-chat-hf',
    padding_side='left'
    )
tokenizer.pad_token_id = 0


fModel.config.pad_token_id = tokenizer.pad_token_id
fModel.config.mask_token_id = tokenizer.mask_token_id
print(fModel)
print(fModel.config)
print("\n Model successfully loded from ", address, "\n")

In [None]:
def test_step(test_dl):
    results = []

    for batch in test_dl:
        input_ids = batch['input_ids'].cuda()
        attention_mask = batch['attention_mask'].cuda()
        labels = batch['labels'].cuda()

        generated_txts_ids = fModel.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=script_args.max_new_tokens,
            do_sample=True,
            temperature=0.97
        )

        for i in range(input_ids.size(0)):
            single_generated_ids = generated_txts_ids[i]

            response_start_idx = get_response_index(
                single_generated_ids, 'RQE'
                )
            single_generated_txt = tokenizer.decode(
                single_generated_ids[response_start_idx:],
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True
            )

            single_label_ids = labels[i]
            single_label_ids = torch.where(
                single_label_ids != -100,
                single_label_ids,
                tokenizer.pad_token_id
            )
            single_target_txt = tokenizer.decode(
                single_label_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True
            )
            results.append([single_generated_txt, single_target_txt])

    return results


In [None]:
Data_RQE = RQEDataModule(
    MyData,
    tokenizer,
    script_args
    )
print("num train batches", len(Data_RQE.train_dataloader()))
print("num test batches", len(Data_RQE.test_dataloader()))

In [None]:
fModel.eval()
start_time = time.time()
test_results = test_step(Data_RQE.test_dataloader())
print("--- %s seconds ---" % (time.time() - start_time))

test_results_df = pd.DataFrame(
    test_results,
    columns = ['predicted_label', 'real_label']
    )

# ========== User-agnostic ==========
# test_results_df.to_pickle(f"{Drive_path}LLama/Llama_RQE_WoUM_results.pkl")

# ========== User-aware ==========
test_results_df.to_pickle(f"{Drive_path}LLama/Llama_RQE_UM_results.pkl")

test_results_df

In [None]:
test_results_df.columns=['generated_label', 'real_label']
predicted_labels = test_results_df['generated_label'].apply(lambda x: 0 if x=='Not-entailed' else 1)
real_labels = test_results_df['real_label'].apply(lambda x: 0 if x=='Not-entailed' else 1)

display(predicted_labels)
display(real_labels)

In [None]:
print(
    "F1-score: ",
    sklearn.metrics.f1_score(
        real_labels,
        predicted_labels)
    )
print(
    "Precision: ",
    sklearn.metrics.precision_score(
        real_labels,
        predicted_labels,
        average='binary')
    )
print(
    "Recall: ",
    sklearn.metrics.recall_score(
        real_labels,
        predicted_labels,
        average='binary')
    )
print(
    "Accuracy: ",
    sklearn.metrics.accuracy_score(
        real_labels,
        predicted_labels)
    )