<a href="https://colab.research.google.com/github/SZAftabi/UseRQE/blob/main/(Step3)UserModeling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<center> <font size='6'> 💟 <b> UseRQE </b> 💟 </font> <br> </center>
<center>Recognizing Question Entailment with User Background-knowledge Modeling <br> </center> <center> <font size='4' color='red'> <b> Step (3) </b> User background-knowledge modeling </font> </center>


# 😎 **Mount the drive**

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
Drive_path = "/content/drive/MyDrive/"

# 😎 **1. Libraries**

In [None]:
!pip install -q -U transformers                                                 # ==4.31.0
!pip install -q torchmetrics
!pip install -q pytorch_lightning
!pip install -q bitsandbytes
!pip install -q -U peft                                                         # ==0.4.0
!pip install -q accelerate                                                      # ==0.21.0
!pip install -q trl
!pip install -q tensorboard
!pip install -q datasets
!pip install -q rouge
!pip install -q bert-score

In [None]:
import os
import re
import torch
import warnings
import nltk
import json
import time
import requests
nltk.download('punkt')

import numpy as np
import pandas as pd
import bitsandbytes as bnb
import pytorch_lightning as pl
import matplotlib.pyplot as plt

In [None]:
# !pip install --upgrade huggingface-hub
# !pip install --upgrade transformers

In [None]:
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Callback
from tensorboard import notebook

from torchmetrics import MetricCollection
from torchmetrics.text.bert import BERTScore
from torchmetrics.text.rouge import ROUGEScore
from torchmetrics.classification import (
    BinaryAccuracy,
    BinaryPrecision,
    BinaryRecall,
    BinaryF1Score
    )

from peft import (
    TaskType,
    PeftModel,
    PeftConfig,
    LoraConfig,
    get_peft_model,
    AutoPeftModelForCausalLM,
    prepare_model_for_kbit_training,
    )

from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForCausalLM,
    HfArgumentParser,
    TrainingArguments,
    )

from dataclasses import dataclass, field
from nltk.tokenize import word_tokenize
from typing import Optional
from tqdm import tqdm
from bert_score import BERTScorer
from rouge import Rouge
from statistics import mean
from sklearn.model_selection import train_test_split
from collections import Counter

tqdm.pandas()
warnings.filterwarnings('ignore')
import transformers
print(transformers.__version__)

# 😎 **2. Helper Functions**

In [None]:
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
B_INST, E_INST = "[INST]", "[/INST]"

In [None]:
def get_tg_prompt(_question, _tags = None):
  system_prompt = 'You are a Tag Generator. Respond only with a list of tags; do not include any additional text or explanations.'
  user_prompt = f'''Please generate at least 5 tags for the provided question. Tags can include multi-word phrases if appropriate and should help hierarchically categorize the question's topics.
### Question:
{_question}
### Tags:
'''
  prompt = f"{B_INST} {B_SYS}{system_prompt}{E_SYS}{user_prompt} {E_INST}\n\n"
  if _tags: prompt += f'{_tags}</s>'
  return prompt

In [None]:
def get_response_index(_input_ids, _task):
  _index = None
  _skip_tokens = None
  if _task == 'RQE':
    _index = 2
    _skip_tokens = 10
  if _task == 'SUM':
    _index = 1
    _skip_tokens = 11
  if _task == 'TG':
    _index = 1
    _skip_tokens = 10
  hashtags_indexes = [i for i, n in enumerate(_input_ids) if n == 29937]
  if len(hashtags_indexes) > _index:
    return [i for i, n in enumerate(_input_ids) if n == 29937][_index] + _skip_tokens
  elif _task == 'RQE':
    return 0
  else:
    return -1

In [None]:
def generate_prompt(data, is_eval):
  promp = None
  if is_eval: prompt = get_tg_prompt(data['text'])
  else: prompt = get_tg_prompt(data['text'], data['tags'])
  return prompt

# 😎 **3. LLama2-TG**

## 🌻 **3.1. hyper-parameters**

In [None]:
@dataclass
class ScriptArguments:
    max_seq_length: Optional[int] = field(
        default = 512,
        metadata = {"help": "maximum input sequence length"}
      )
    max_new_tokens: Optional[int] = field(
        default = 30,
        metadata = {"help": "the maximum number of new tokens in the generated sequences (test step)"}
      )

parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)[0]
pl.seed_everything(42)

## 🌻 **3.2. data preparation**

In [None]:
class TGDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, max_len, is_eval):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.is_eval = is_eval

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
      row_data = self.data.iloc[index]
      prompt = generate_prompt(row_data, self.is_eval)
      prompt_encoding = self.tokenizer(
          prompt,
          max_length = self.max_len,
          padding = 'max_length',
          truncation = True,
          add_special_tokens = True,
          return_tensors = 'pt',
      )
      input_ids = prompt_encoding['input_ids'].squeeze()
      attention_mask = prompt_encoding['attention_mask'].squeeze()
      if self.is_eval == False:
        response_index = get_response_index(input_ids, 'TG')
        if response_index:
          labels = torch.cat(
              (torch.full((response_index,), -100),
               input_ids[response_index:])
              ).squeeze()
        else:
          print('response_index not found')
      else:
        labels = self.tokenizer(
            row_data['tags'] + '</s>',
            add_special_tokens = False,
            return_tensors='pt',
        )
        labels = labels['input_ids'].squeeze()
      return {
          'input_ids': input_ids,
          'attention_mask': attention_mask,
          'labels': labels
      }

## 🌻 **3.3. load model**

In [None]:
BaseModel= AutoModelForCausalLM.from_pretrained(
    f"{Drive_path}llama-2-7b-chat-hf",
    device_map={"": 0},
    offload_folder="offload",
    offload_state_dict = True,
    # load_in_8bit = True,
    )

address = f"/content/drive/MyDrive/UseRQE/TG/TG-Adapters/LLama-TG10"
print("\n Loading model from ", address, "\n")
config = PeftConfig.from_pretrained(address)
fModel= PeftModel.from_pretrained(BaseModel,address,device_map={"": 0})
fModel = fModel.merge_and_unload()

print("\n Model successfully loded from ", address, "\n")
print(fModel)
print(fModel.config)

tokenizer = AutoTokenizer.from_pretrained(
    script_args.model_name,
    padding_side='left'
    )

tokenizer.pad_token_id = 0
fModel.config.pad_token_id = tokenizer.pad_token_id

## 🌻 **3.4. test**

In [None]:
def test_step(test_dl):
  testOutputs = []

  for batch in test_dl:
    input_ids = batch['input_ids'].cuda()
    attention_mask = batch['attention_mask'].cuda()

    generated_txts_ids = fModel.generate(
        input_ids = input_ids,
        max_new_tokens = script_args.max_new_tokens,
        do_sample=True,
        temperature=0.97
        ).squeeze()

    generated_txts = tokenizer.decode(
        generated_txts_ids[get_response_index(generated_txts_ids, 'TG'):],
        skip_special_tokens = False,
        clean_up_tokenization_spaces = True
        )

    testOutputs.append(generated_txts[:-4])

  return testOutputs

# 😎 **4. Modeling user knowledge**

## 🌻 **4.1. Find users' history**


In [None]:
data_path_LLama = f"{Drive_path}RQE_Data_With_Both_uesrid.pkl"
MyData_LLama = pd.read_pickle(data_path_LLama)

In [None]:
MyData_LLama.loc[(MyData_LLama['body_Q2']=='') & (MyData_LLama['userid_Q2']=='65001'), 'body_Q2']='So when I launch Minecraft, before it finishes loading, it crashes. I do not understand what is going on. Could someone help me? Here is my crash report:'
MyData_LLama.loc[(MyData_LLama['body_Q2']=='') & (MyData_LLama['userid_Q2']=='36896'), 'body_Q2']='How do I type the infinity symbol in MacTex'
MyData_LLama.loc[(MyData_LLama['body_Q2']=='') & (MyData_LLama['userid_Q2']=='3031'), 'body_Q2']='Run time error for GP objects'
MyData_LLama.loc[(MyData_LLama['body_Q1']=='') & (MyData_LLama['userid_Q2']=='65001'), 'body_Q1']='Misplaced allignment tab character line 53'
MyData_LLama.loc[(MyData_LLama['body_Q1']=='') & (MyData_LLama['userid_Q2']=='16188'), 'body_Q1']='How to Export this animation as a gif file for powerpoint presentation'
MyData_LLama.loc[(MyData_LLama['body_Q1']=='') & (MyData_LLama['userid_Q2']=='24829'), 'body_Q1']='why does rotation style work on actual coordinates and not variables in tikz 3d plot'

MyData_LLama.loc[(MyData_LLama['body_Q2']=='') & (MyData_LLama['userid_Q2']=='50615'), 'body_Q2']='How set a table in margin'
MyData_LLama.loc[(MyData_LLama['body_Q2']=='') & (MyData_LLama['userid_Q2']=='23835'), 'body_Q2']='Latex equation positioning problem'
MyData_LLama.loc[(MyData_LLama['body_Q2']=='') & (MyData_LLama['userid_Q2']=='14524'), 'body_Q2']='Chapter comment with regulation'
MyData_LLama.loc[(MyData_LLama['body_Q2']=='') & (MyData_LLama['userid_Q2']=='50823'), 'body_Q2']='minipage goes beyond right margin'

In [None]:
user_history_df = pd.DataFrame(columns=['userid', 'historyCount', 'historyIDs', 'history'])
for index, row in MyData_LLama.iterrows():
    user_id = row['userid_Q1']
    forum_name = row['forum_x']

    AA = (MyData_LLama[
        (MyData_LLama['userid_Q1'] == user_id) & (MyData_LLama['forum_x'] == forum_name)
        ][['id_Q1', 'body_Q1']])
    AA.rename(
        columns={'id_Q1': 'QuestionID', 'body_Q1': 'QuestionBody'},
        inplace=True
        )
    BB = (MyData_LLama[
        (MyData_LLama['userid_Q2'] == user_id) & (MyData_LLama['forum_y'] == forum_name)
        ][['id_Q2', 'body_Q2']])
    BB.rename(
        columns={'id_Q2': 'QuestionID', 'body_Q2': 'QuestionBody'},
        inplace=True
        )
    user_history = pd.concat([AA, BB])
    user_history = user_history.drop_duplicates(subset=['QuestionID'])

    user_history_str = ', '.join(user_history['QuestionID'].astype(str))
    user_history_body_str = ', '.join(user_history['QuestionBody'].astype(str))

    user_history_df = pd.concat(
        [user_history_df, pd.DataFrame({'userid': [user_id],
                                        'forum': [forum_name],
                                        'historyCount': len(user_history),
                                        'historyIDs': [user_history_str],
                                        'history': [user_history_body_str]})
        ], ignore_index=True)

user_history_df_filepath = f"{Drive_path}UseRQE/TG/user_history_df.pkl"
user_history_df.to_pickle(user_history_df_filepath)
display(user_history_df)

## 🌻 **4.2. Modeling User knowledge**
(using LLama-TG) Generate Tags for user's questions in his/her history

In [None]:
user_history_df_filepath = f"{Drive_path}UseRQE/TG/user_history_df.pkl"
user_history_df = pd.read_pickle(user_history_df_filepath)
user_history_df2 = user_history_df.drop_duplicates(subset=["userid", 'forum'], keep='first')
display(user_history_df2)

In [None]:
start_time = time.time()
fModel.eval()
user_history_df2['generated_tags'] = None

for index, row in user_history_df2.iterrows():
    history = row['history']
    history_questions = pd.DataFrame(history.split(', '), columns=['text'])
    history_questions['tags'] = ""
    historytags = []

    data = TGDataset(history_questions, tokenizer, 512, is_eval=True)
    DL = torch.utils.data.DataLoader(
            data, sampler = torch.utils.data.SequentialSampler(data),
            batch_size= 1, num_workers=8
        )
    historytags = test_step(DL)
    user_history_df2.at[index, 'generated_tags']= ' -- '.join(historytags)

print("--- %s seconds ---" % (time.time() - start_time))
user_history_df2_filepath = f"{Drive_path}UseRQE/TG/user_history_gen_tags.pkl"
user_history_df2.to_pickle(user_history_df2_filepath)
display(user_history_df2)

In [None]:
user_history_df2['generated_tags'] = user_history_df2['generated_tags'].str.replace('/', '')

In [None]:
user_history_df['generated_tags'] = None
user_history_df

for index, row in user_history_df.iterrows():
  ui = row['userid']
  fr = row['forum']
  row2 = user_history_df2[(user_history_df2['userid']==ui) & (user_history_df2['forum']==fr)]
  user_history_df.at[index, 'generated_tags'] = (', '.join((row2['generated_tags'].item()).split(' -- ')))

display(user_history_df)

clean redundant tags and sort them based on their frequency

In [None]:
def process_row(row):
    tags_series = pd.Series(row['generated_tags'].split(', ')).explode()
    tag_counts = tags_series.value_counts()
    sorted_tags = tag_counts.index.tolist()
    top_20_tags = sorted_tags[:20]
    result = ', '.join(top_20_tags)
    return result

user_history_df['generated_tags'] = user_history_df.apply(process_row, axis=1)
user_history_df_filepath = f"{Drive_path}UseRQE/TG/user_history_T20_gen_tags.pkl"
user_history_df.to_pickle(user_history_df_filepath)
display(user_history_df)

In [None]:
data_path_LLama = f"{Drive_path}RQE_Data_With_Both_uesrid.pkl"
MyData_LLama = pd.read_pickle(data_path_LLama)
MyData_LLama['U_Background_kn'] = user_history_df['generated_tags']
MyData_LLama_filepath = f"{Drive_path}UseRQE/TG/RQE_Data_With_Both_uesrid_T20_UK.pkl"
MyData_LLama.to_pickle(MyData_LLama_filepath)
MyData_LLama