<a href="https://colab.research.google.com/github/SZAftabi/ReQuEST/blob/main/ReQuEST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<p align="center"><font size='5'><b>ReQuEST: </b> </font> <font size='5'><b>Re</b></font>cognizing <font size='5'><b>Qu</b></font>estion <font size='5'><b>E</b></font>ntailment, tag-focused question <font size='5'><b>S</b></font>ummarization, and <font size='5'><b>T</b></font>ag generation
</p>

📄 <b>Paper:</b> <br>
>S. Z. Aftabi, S. M. Seyyedi, M. Maleki and S. Farzi, "<i><b>ReQuEST: A Small-Scale Multi-Task Model for Community Question-Answering Systems</b></i>," in IEEE Access, vol. 12, pp. 17137-17151, 2024, [doi: 10.1109/ACCESS.2024.3358287](https://ieeexplore.ieee.org/abstract/document/10413543).

## 🌞 **Access to Google Drive**

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

## 🌞 **Prerequisites**

In [6]:
!pip install -q transformers torchsummary torchviz torchmetrics rouge pytorch_lightning torchvision tensorboard

In [7]:
import nltk
import re
import json
import time
import sklearn
import sys
import datetime
import copy
import string
import transformers
import torch
import warnings
import logging
import math

import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [8]:
from sklearn import metrics
from typing import Tuple
from datetime import timedelta
from statistics import mean
from IPython.display import clear_output

from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

from transformers import AutoTokenizer, BartForConditionalGeneration
from transformers import BartModel, BartPretrainedModel, BartConfig
from transformers import get_linear_schedule_with_warmup
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers import logging

from torch import autograd
from torch.utils.data import DataLoader, random_split
from torch.utils.data import Dataset, SequentialSampler, RandomSampler
from torchsummary import summary
from torchviz import make_dot

from torchmetrics import MetricCollection
from torchmetrics.text.rouge import ROUGEScore
from torchmetrics.classification import Accuracy, F1Score, Precision, Recall
from torchmetrics.text import BERTScore
from rouge import Rouge

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import Callback
from tensorboard import notebook

from matplotlib.rcsetup import validate_backend
from os import truncate

In [None]:
nltk.download('punkt')
nltk.download('stopwords')
logging.set_verbosity_error()

# ========================= Torch AND GPU =========================
if torch.cuda.is_available():
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

warnings.filterwarnings('ignore')

# 🌞 **Proposed Model**

>## 🔧 **RQE head**
A Neural Network Head for Recognizing Question Entailment

In [10]:
class NN_Model_RQE(nn.Module):
    def __init__(self, embed_size, dimensions, do_r):
        super(NN_Model_RQE, self).__init__()
        layers = []
        prev_size = embed_size
        for size in dimensions:
            layers.extend([
                nn.Linear(prev_size, size),
                nn.LayerNorm(size),
                nn.GLU(),
                nn.Dropout1d(do_r)
            ])
            prev_size = size // 2
        layers.append(nn.Linear(prev_size, 1))
        self.network = nn.Sequential(*layers)
        self.network.apply(self.init_weights)

    def init_weights(self, m):
      if isinstance(m, nn.Linear):
          torch.nn.init.normal_(m.weight, mean=0, std=0.1, generator=torch.Generator().manual_seed(42))
          m.bias.data.fill_(0.01)

    def forward(self, decoder_last_embd):
        return self.network(decoder_last_embd)

>## 🔧 **Main Framework**
1. **One shared BART encoder +**
2. **Two partially shared BART decoders +**
2. **Three Neural Network Head** (One for Recognizing Question Entailment, the other one for Query-focused Question Summarization, and the third one for tag generation)

In [11]:
class Main_Architecture(BartPretrainedModel):
  def __init__(self, config: BartConfig, hparams):
    super(Main_Architecture, self).__init__(config)

    self.learning_rate_Encoder = hparams['lr_Encoder']
    self.learning_rate_Decoder = hparams['lr_Decoder']
    self.learning_rate_RQE = hparams['lr_RQE']
    self.learning_rate_SUM = hparams['lr_SUM']
    self.learning_rate_TG = hparams['lr_TG']


    self.SUM = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
    self.SUM.lm_head = BartForConditionalGeneration.from_pretrained("facebook/bart-base").lm_head

    self.TG = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
    self.TG.lm_head = BartForConditionalGeneration.from_pretrained("facebook/bart-base").lm_head
    self.TG.model.encoder = self.SUM.model.encoder
    self.TG.model.shared = self.SUM.model.shared
    self.TG.model.decoder.embed_positions = self.SUM.model.decoder.embed_positions
    self.TG.model.decoder.layernorm_embedding = self.SUM.model.decoder.layernorm_embedding
    self.TG.model.decoder.embed_tokens = self.SUM.model.decoder.embed_tokens
    self.TG.model.decoder.layers = torch.nn.ModuleList(
                        [self.SUM.model.decoder.layers[i] if i < 3 else copy.deepcopy(self.SUM.model.decoder.layers[i])
                        for i in range(len(self.SUM.model.decoder.layers))])

    self.RQE = torch.nn.ModuleDict({
            "encoder": self.SUM.model.encoder,
            "lm_head": NN_Model_RQE(hparams['embed_size'], hparams['Dimensions'], hparams['DO_r'],)
            })

    self.register_buffer("final_logits_bias",
                         torch.zeros((1, self.SUM.model.shared.num_embeddings)))

  def forward(self,
              EncoderRQE_input_ids = None,
              EncoderRQE_attention = None,
              EncoderSUM_input_ids = None,
              EncoderSUM_attention_mask = None,
              EncoderTG_input_ids = None,
              EncoderTG_attention_mask = None,
              DecoderSUM_input_ids = None,
              DecoderSUM_attention_mask = None,
              DecoderTG_input_ids = None,
              DecoderTG_attention_mask = None,
              output_attentions = None,
              output_hidden_states = None,
              encoder_outputs = None,
              SUM_labels = None,
              TG_labels = None,
              past_key_values = None,
              head_mask = None,
              decoder_head_mask = None,
              cross_attn_head_mask = None,
              inputs_embeds = None,
              decoder_inputs_embeds = None,
              use_cache = None,
              return_dict = None,
              decoder_task = None
              ):

    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # ****************** For Summarization *****************
    if EncoderSUM_input_ids is not None or\
        DecoderSUM_input_ids is not None or\
          EncoderSUM_attention_mask is not None or\
            DecoderSUM_attention_mask is not None:


      BARTOutputs_SUM = self.SUM(input_ids = EncoderSUM_input_ids,
                                 attention_mask = EncoderSUM_attention_mask,
                                 decoder_input_ids = DecoderSUM_input_ids,
                                 decoder_attention_mask = DecoderSUM_attention_mask,
                                 encoder_outputs = encoder_outputs,
                                 head_mask = head_mask,
                                 decoder_head_mask = decoder_head_mask,
                                 cross_attn_head_mask = cross_attn_head_mask,
                                 past_key_values = past_key_values,
                                 inputs_embeds = inputs_embeds,
                                 decoder_inputs_embeds = decoder_inputs_embeds,
                                 labels = SUM_labels,
                                 use_cache = use_cache,
                                 output_attentions = output_attentions,
                                 output_hidden_states = True,
                                 return_dict = True)
    # ****************** End Summarization *****************


    # ***************** For Tag Genertaion *****************
    if EncoderTG_input_ids is not None or\
        DecoderTG_input_ids is not None or\
          EncoderTG_attention_mask is not None or\
            DecoderTG_attention_mask is not None:

      BARTOutputs_TG = self.TG(input_ids = EncoderTG_input_ids,
                               attention_mask = EncoderTG_attention_mask,
                               decoder_input_ids = DecoderTG_input_ids,
                               decoder_attention_mask = DecoderTG_attention_mask,
                               encoder_outputs = encoder_outputs,
                               head_mask = head_mask,
                               decoder_head_mask = decoder_head_mask,
                               cross_attn_head_mask = cross_attn_head_mask,
                               past_key_values = past_key_values,
                               inputs_embeds = inputs_embeds,
                               decoder_inputs_embeds = decoder_inputs_embeds,
                               labels = TG_labels,
                               use_cache = use_cache,
                               output_attentions = output_attentions,
                               output_hidden_states = True,
                               return_dict = True)
    # ***************** End Tag Genertaion *****************


    # ********************** For RQE ***********************
    if (EncoderRQE_input_ids != None):

      BARTOutputs_RQE = self.RQE.encoder(input_ids = EncoderRQE_input_ids,
                                       attention_mask = EncoderRQE_attention,
                                       output_hidden_states = True,
                                       return_dict = True)
      Last_HS_RQE = BARTOutputs_RQE.last_hidden_state
      att_mask_expanded = EncoderRQE_attention.unsqueeze(-1).expand(Last_HS_RQE.size())
      sum_embeddings = torch.sum(Last_HS_RQE * att_mask_expanded, 1)
      sum_mask = att_mask_expanded.sum(1)
      sum_mask = torch.clamp(sum_mask, min=1e-9)
      decoder_last_embd = sum_embeddings / sum_mask
      Predicted_label = self.RQE.lm_head(decoder_last_embd)
    # ********************** End RQE ***********************


    if (EncoderRQE_input_ids == None and EncoderTG_input_ids == None):
      return BARTOutputs_SUM.logits
    elif (EncoderSUM_input_ids == None and EncoderTG_input_ids == None):
      return Predicted_label
    elif (EncoderSUM_input_ids == None and EncoderRQE_input_ids == None):
      return BARTOutputs_TG.logits
    else:
      return BARTOutputs_SUM.logits, BARTOutputs_TG.logits, Predicted_label,\
       BARTOutputs_SUM.loss, BARTOutputs_TG.loss


  def SUM_generate2(self, tokenizer, input_ids, _min_length=5, _max_length=10, _num_beams=3, _no_repeat_ngram_size=3):
    self.SUM.eval()
    Summary_input_ids = self.SUM.generate(input_ids,
                                          max_length = _max_length,
                                          min_length = _min_length,
                                          num_beams = _num_beams,
                                          no_repeat_ngram_size = _no_repeat_ngram_size
                                          )
    Summary = tokenizer.batch_decode(Summary_input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    return Summary


  def TG_generate2(self, tokenizer, input_ids, _min_length=5, _max_length=10, _num_beams=3, _no_repeat_ngram_size=3):
    self.TG.eval()
    Tag_input_ids = self.TG.generate(input_ids,
                                     max_length = _max_length,
                                     min_length = _min_length,
                                     num_beams = _num_beams,
                                     no_repeat_ngram_size = _no_repeat_ngram_size,
                                     )
    Tags = tokenizer.batch_decode(Tag_input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    return Tags


  def RQE_predict(self, input_ids, attention_masks):
    self.RQE.eval()
    sig = nn.Sigmoid()
    out = self.RQE.encoder(input_ids = input_ids,  attention_mask = attention_masks, return_dict = True)
    LH = out.last_hidden_state
    att_mask_expanded = attention_masks.unsqueeze(-1).expand(LH.size())
    sum_embeddings = torch.sum(LH * att_mask_expanded, 1)
    sum_mask = att_mask_expanded.sum(1)
    sum_mask = torch.clamp(sum_mask, min=1e-9)
    dle = sum_embeddings / sum_mask
    logits = self.RQE.lm_head(dle)
    Pr_labels = torch.round(sig(logits))
    return Pr_labels


  def Freeze_Parameters(self, FoN_Dec, FoN_Enc):

    for param in self.RQE.lm_head.parameters():
          param.requires_grad = True

    # =========== BART Decoder Layers ===========
    for i, FON in enumerate(FoN_Dec):
      if FON == 0:
        for param in self.TG.model.decoder.layers[i].parameters():
          param.requires_grad = False
        for param in self.SUM.model.decoder.layers[i].parameters():
          param.requires_grad = False

    # =========== BART Encoder Layers ===========
    for param in self.SUM.model.encoder.embed_positions.parameters():
        param.requires_grad = False
    for param in self.SUM.model.encoder.layernorm_embedding.parameters():
        param.requires_grad = False
    for i, FON in enumerate(FoN_Enc):
      if FON == 0:
        for param in self.SUM.model.encoder.layers[i].parameters():
          param.requires_grad = False
        for param in self.TG.model.encoder.layers[i].parameters():
          param.requires_grad = False
        for param in self.RQE.encoder.layers[i].parameters():
          param.requires_grad = False


  def shift_tokens_right(self, input_ids, pad_token_id, decoder_start_token_id):
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id
    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
    return shifted_input_ids

>## 🔧 **LitReQuEST**
ReQuEST with lightning, including training, validation, and test steps

In [13]:
class OverrideEpochStepCallback(Callback):
    def __init__(self) -> None:
        super().__init__()

    def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        self._log_step_as_current_epoch(trainer, pl_module)

    def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        self._log_step_as_current_epoch(trainer, pl_module)

    def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        self._log_step_as_current_epoch(trainer, pl_module)

    def _log_step_as_current_epoch(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        pl_module.log("step", trainer.current_epoch + 1)


checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath='/content/drive/MyDrive/ReQuEST/Results/Checkpoints/',
        filename='{epoch:03d}',
        every_n_epochs = 5,
        save_top_k = -1,
    )

In [14]:
class LitReQuEST(pl.LightningModule):
    def __init__(self, hparams):
      super(LitReQuEST, self).__init__()
      self.save_hyperparameters()
      self.Model = hparams['Model']
      self.tokenizer = hparams['tokenizer']
      self.RQE_Loss_Func = nn.BCEWithLogitsLoss()
      self.FreezeEnc = hparams['FreezeLayers'][0]
      self.FreezeDec = hparams['FreezeLayers'][1]
      self.Coefficient = hparams['Coefficient']
      self.max_epochs = hparams['max_epochs']
      self.warmup = hparams['warmup']
      self.weight_decay = hparams['weight_decay']
      self.num_train_batches = hparams['num_train_batches']
      self.opt_epsilon = hparams['epsilon']
      self.Micro_iter= 3                                                        # to have 3 micro iterations set this parameter to 3
      self.max_norm = 0.1
      self.generated_summaries = []
      self.generated_tags = []
      self.predicted_labels = []
      self.sig = nn.Sigmoid()
      self.automatic_optimization = False
      self.opt1, self.opt2, self.opt3, self.opt4,\
       self.opt5 = self.configure_optimizers()                                  # opt1: Optimizer_RQE, opt2: Optimizer_Encoder
                                                                                # opt3: Optimizer_Decoder, opt4: Optimizer_SUM,
                                                                                # opt5: Optimizer_TG
      self.scheduler_Encoder = self.configure_scheduler(self.opt2)
      self.scheduler_Decoder = self.configure_scheduler(self.opt3)

      rqe_metrics = MetricCollection([
            Accuracy(task="binary", num_classes=2),
            F1Score(task="binary", num_classes=2),
            Precision(task="binary", num_classes=2),
            Recall(task="binary", num_classes=2),
        ])
      rouge_keys = ("rouge1", "rouge2", "rougeL")
      sum_metrics = MetricCollection([
            ROUGEScore(rouge_keys=rouge_keys),
            BERTScore(device="cuda")
        ])
      tg_metrics = MetricCollection([
            ROUGEScore(rouge_keys=rouge_keys),
            BERTScore(device="cuda")
        ])

      self.test_rqe_metrics = rqe_metrics.clone(prefix='test_')
      self.test_sum_metrics = sum_metrics.clone(prefix='test_')
      self.test_tg_metrics = tg_metrics.clone(prefix='test_')
      self.train_rqe_metrics = rqe_metrics.clone(prefix='train_')
      self.train_sum_metrics = sum_metrics.clone(prefix='train_')
      self.train_tg_metrics = tg_metrics.clone(prefix='train_')
      self.val_rqe_metrics = rqe_metrics.clone(prefix='val_')
      self.val_sum_metrics = sum_metrics.clone(prefix='val_')
      self.val_tg_metrics = tg_metrics.clone(prefix='val_')


    def forward(self, EncoderRQE_input_ids = None,
                EncoderRQE_attention = None,
                EncoderSUM_input_ids = None,
                EncoderSUM_attention_mask = None,
                EncoderTG_input_ids = None,
                EncoderTG_attention_mask = None,
                DecoderSUM_input_ids = None,
                DecoderSUM_attention_mask = None,
                DecoderTG_input_ids = None,
                DecoderTG_attention_mask = None,
                output_attentions = None,
                output_hidden_states = None,
                encoder_outputs = None,
                SUM_labels = None,
                TG_labels = None,
                past_key_values = None,
                head_mask = None,
                decoder_head_mask = None,
                cross_attn_head_mask = None,
                inputs_embeds = None,
                decoder_inputs_embeds = None,
                use_cache = None,
                return_dict = None,
                decoder_task = None):
        return self.Model(EncoderRQE_input_ids, EncoderRQE_attention,
                          EncoderSUM_input_ids, EncoderSUM_attention_mask,
                          EncoderTG_input_ids, EncoderTG_attention_mask,
                          DecoderSUM_input_ids, DecoderSUM_attention_mask,
                          DecoderTG_input_ids, DecoderTG_attention_mask,
                          output_attentions, output_hidden_states,
                          encoder_outputs, SUM_labels, TG_labels,
                          past_key_values, head_mask, decoder_head_mask,
                          cross_attn_head_mask, inputs_embeds,
                          decoder_inputs_embeds, use_cache, return_dict,
                          decoder_task)


    def training_step(self, batch, batch_idx):
      Q1_input_ids, Q1_attention, Q1Q2_input_ids, Q1Q2_attention, Q1Tags_input_ids, Q1Tags_attention,\
        TG_decoder_input_ids, TG_decoder_attention_mask, SUM_decoder_input_ids, SUM_decoder_attention_mask,\
          GoldSummary_input_ids, GoldTags_input_ids, Pair_Labels = batch

      self.Model.train()
      self.Model.Freeze_Parameters(self.FreezeDec, self.FreezeEnc)

      for iter in range(0, self.Micro_iter):
          self.Model.zero_grad()
          for opt in [self.opt1, self.opt2, self.opt3, self.opt4, self.opt5]:
              opt.zero_grad()

          Model_Output = self.Model(EncoderRQE_input_ids = Q1Q2_input_ids,
                                    EncoderRQE_attention = Q1Q2_attention,
                                    EncoderSUM_input_ids = Q1Tags_input_ids,
                                    EncoderSUM_attention_mask = Q1Tags_attention,
                                    EncoderTG_input_ids = Q1_input_ids,
                                    EncoderTG_attention_mask = Q1_attention,
                                    DecoderSUM_attention_mask = SUM_decoder_attention_mask,
                                    DecoderTG_attention_mask = TG_decoder_attention_mask,
                                    TG_labels = GoldTags_input_ids,
                                    SUM_labels = GoldSummary_input_ids,
                                    return_dict = False)

          Summary, Tags, PLabel, SUM_Loss, TG_Loss = Model_Output
          RQE_Loss = self.RQE_Loss_Func((PLabel.view(-1, PLabel.shape[-1])).squeeze(1), Pair_Labels.float())

          if self.Micro_iter == 1:
              Encoder_Loss =SUM_Loss + RQE_Loss + TG_Loss
              self.manual_backward(Encoder_Loss, retain_graph = True)
              torch.nn.utils.clip_grad_norm_(self.Model.parameters(), self.max_norm)
              self.opt2.step()
              self.opt3.step()
              self.opt1.step()
              self.opt4.step()
              self.opt5.step()
          else:
              if iter==0:
                  RQE_Loss = self.Coefficient['RQE3'] * RQE_Loss
                  self.manual_backward(RQE_Loss, retain_graph = True)
                  torch.nn.utils.clip_grad_norm_(self.Model.parameters(), self.max_norm)
                  self.opt1.step()

                  SUM_Loss = self.Coefficient['SUM3'] * SUM_Loss
                  self.manual_backward(SUM_Loss, retain_graph = True)
                  torch.nn.utils.clip_grad_norm_(self.Model.parameters(), self.max_norm)
                  self.opt4.step()

                  TG_Loss = self.Coefficient['TG3'] * TG_Loss
                  self.manual_backward(TG_Loss, retain_graph = True)
                  torch.nn.utils.clip_grad_norm_(self.Model.parameters(), self.max_norm)
                  self.opt5.step()

              if iter==1:
                  Decoder_Loss = self.Coefficient['SUM2'] * SUM_Loss +\
                      self.Coefficient['TG2'] * TG_Loss
                  self.manual_backward(Decoder_Loss, retain_graph = True)
                  torch.nn.utils.clip_grad_norm_(self.Model.parameters(), self.max_norm)
                  self.opt3.step()
                  # self.scheduler_Decoder.step()

              if iter==2:
                  Encoder_Loss = self.Coefficient['SUM'] * SUM_Loss +\
                      self.Coefficient['RQE'] * RQE_Loss +\
                      self.Coefficient['TG'] * TG_Loss
                  self.manual_backward(Encoder_Loss, retain_graph = True)
                  torch.nn.utils.clip_grad_norm_(self.Model.parameters(), self.max_norm)
                  self.opt2.step()
                  # self.scheduler_Encoder.step()

              if iter==0:
                  self.log('train/rqe_loss', RQE_Loss.item(), on_step=True, on_epoch=True)
                  self.log('train/sum_loss', SUM_Loss.item(), on_step=True, on_epoch=True)
                  self.log('train/tg_loss', TG_Loss.item(), on_step=True, on_epoch=True)

                  Q1_input_ids = torch.where(
                      Q1_input_ids != -100,
                      Q1_input_ids,
                      self.Model.SUM.config.pad_token_id
                      )
                  Q1s = self.tokenizer.batch_decode(
                      sequences = Q1_input_ids,
                      skip_special_tokens = True,
                      clean_up_tokenization_spaces = False
                      )
                  GoldSummary_input_ids = torch.where(
                      GoldSummary_input_ids != -100,
                      GoldSummary_input_ids,
                      self.Model.SUM.config.pad_token_id
                      )
                  GoldSummaries = self.tokenizer.batch_decode(
                      sequences = GoldSummary_input_ids,
                      skip_special_tokens = True,
                      clean_up_tokenization_spaces = False
                      )
                  GoldTags_input_ids = torch.where(
                      GoldTags_input_ids != -100,
                      GoldTags_input_ids,
                      self.Model.SUM.config.pad_token_id
                      )
                  GoldTags = self.tokenizer.batch_decode(
                      sequences = GoldTags_input_ids,
                      skip_special_tokens = True,
                      clean_up_tokenization_spaces = False
                      )
                  labels_predicted = torch.round(self.sig(PLabel))
                  summary_decoded = self.tokenizer.batch_decode(
                      sequences = torch.argmax(Summary, dim=-1),
                      skip_special_tokens = True,
                      clean_up_tokenization_spaces = False
                      )
                  tg_decoded = self.tokenizer.batch_decode(
                      sequences = torch.argmax(Tags, dim=-1),
                      skip_special_tokens = True,
                      clean_up_tokenization_spaces = False
                      )

                  self.train_sum_metrics.update(summary_decoded, GoldSummaries)
                  self.train_tg_metrics.update(tg_decoded, GoldTags)
                  self.train_rqe_metrics.update(labels_predicted.squeeze(1), Pair_Labels)

      output =  {
          'Model_Output': Model_Output,
          'RQE_Loss': RQE_Loss,
          'SUM_Loss': SUM_Loss,
          'TG_Loss': TG_Loss
          }
      return output


    def test_step(self, batch, batch_idx):
      Q1_input_ids, Q1_attention, Q1Q2_input_ids, Q1Q2_attention, Q1Tags_input_ids, Q1Tags_attention,\
        TG_decoder_input_ids, TG_decoder_attention_mask, SUM_decoder_input_ids, SUM_decoder_attention_mask,\
          GoldSummary_input_ids, GoldTags_input_ids, Pair_Labels = batch

      summary_decoded = self.Model.SUM_generate2(
        tokenizer = self.tokenizer,
        input_ids = Q1Tags_input_ids,
        _max_length = 140,
        _min_length = 20,
        _num_beams = 4,
        _no_repeat_ngram_size = 3
        )
      tg_decoded = self.Model.TG_generate2(
        tokenizer = self.tokenizer,
        input_ids = Q1_input_ids,
        _max_length = 30, #20,
        _min_length = 7, #3,
        _num_beams = 4,
        _no_repeat_ngram_size = 3
        )
      labels_predicted = self.Model.RQE_predict(
          input_ids = Q1Q2_input_ids,
          attention_masks = Q1Q2_attention,
          )
      Q1_input_ids = torch.where(
          Q1_input_ids != -100,
          Q1_input_ids,
          self.Model.SUM.config.pad_token_id
          )
      Q1s = self.tokenizer.batch_decode(
          sequences = Q1_input_ids,
          skip_special_tokens = True,
          clean_up_tokenization_spaces = False
          )
      GoldSummary_input_ids = torch.where(
          GoldSummary_input_ids != -100,
          GoldSummary_input_ids,
          self.Model.SUM.config.pad_token_id
          )
      GoldSummaries = self.tokenizer.batch_decode(
          sequences = GoldSummary_input_ids,
          skip_special_tokens = True,
          clean_up_tokenization_spaces = False
          )
      GoldTags_input_ids = torch.where(
          GoldTags_input_ids != -100,
          GoldTags_input_ids,
          self.Model.SUM.config.pad_token_id
          )
      GoldTags = self.tokenizer.batch_decode(
          sequences = GoldTags_input_ids,
          skip_special_tokens = True,
          clean_up_tokenization_spaces = False
          )

      self.test_sum_metrics.update(summary_decoded, GoldSummaries)
      self.test_tg_metrics.update(tg_decoded, GoldTags)
      self.test_rqe_metrics.update(labels_predicted.squeeze(1), Pair_Labels)

      self.generated_summaries.append(summary_decoded)
      self.generated_tags.append(tg_decoded)
      self.predicted_labels.append(labels_predicted.squeeze(1).tolist())

    def validation_step(self, batch, batch_idx):
      Q1_input_ids, Q1_attention, Q1Q2_input_ids, Q1Q2_attention, Q1Tags_input_ids, Q1Tags_attention,\
        TG_decoder_input_ids, TG_decoder_attention_mask, SUM_decoder_input_ids, SUM_decoder_attention_mask,\
          GoldSummary_input_ids, GoldTags_input_ids, Pair_Labels = batch

      self.Model.eval()
      with torch.no_grad():
        Model_Output = self.Model(EncoderRQE_input_ids = Q1Q2_input_ids,
                          EncoderRQE_attention = Q1Q2_attention,
                          EncoderSUM_input_ids = Q1Tags_input_ids,
                          EncoderSUM_attention_mask = Q1Tags_attention,
                          EncoderTG_input_ids = Q1_input_ids,
                          EncoderTG_attention_mask = Q1_attention,
                          DecoderSUM_attention_mask = SUM_decoder_attention_mask,
                          DecoderTG_attention_mask = TG_decoder_attention_mask,
                          TG_labels = GoldTags_input_ids,
                          SUM_labels = GoldSummary_input_ids,
                          return_dict = False)

        Summary, Tags, PLabel, SUM_Loss, TG_Loss = Model_Output
        RQE_Loss = self.RQE_Loss_Func((PLabel.view(-1, PLabel.shape[-1])).squeeze(1), Pair_Labels.float())


      self.log('val/rqe_loss', RQE_Loss.item(), on_step=False, on_epoch=True)
      self.log('val/sum_loss', SUM_Loss.item(), on_step=False, on_epoch=True)
      self.log('val/tg_loss', TG_Loss.item(), on_step=False, on_epoch=True)

      Q1_input_ids = torch.where(
          Q1_input_ids != -100,
          Q1_input_ids,
          self.Model.SUM.config.pad_token_id
          )
      Q1s = self.tokenizer.batch_decode(
          sequences = Q1_input_ids,
          skip_special_tokens = True,
          clean_up_tokenization_spaces = False
          )
      GoldSummary_input_ids = torch.where(
          GoldSummary_input_ids != -100,
          GoldSummary_input_ids,
          self.Model.SUM.config.pad_token_id
          )
      GoldSummaries = self.tokenizer.batch_decode(
          sequences = GoldSummary_input_ids,
          skip_special_tokens = True,
          clean_up_tokenization_spaces = False
          )
      GoldTags_input_ids = torch.where(
          GoldTags_input_ids != -100,
          GoldTags_input_ids,
          self.Model.SUM.config.pad_token_id
          )
      GoldTags = self.tokenizer.batch_decode(
          sequences = GoldTags_input_ids,
          skip_special_tokens = True,
          clean_up_tokenization_spaces = False
          )
      labels_predicted = torch.round(self.sig(PLabel))
      summary_decoded = self.tokenizer.batch_decode(
          sequences = torch.argmax(Summary, dim=-1),
          skip_special_tokens = True,
          clean_up_tokenization_spaces = False
          )
      tg_decoded = self.tokenizer.batch_decode(
          sequences = torch.argmax(Tags, dim=-1),
          skip_special_tokens = True,
          clean_up_tokenization_spaces = False
          )

      self.val_sum_metrics.update(summary_decoded, GoldSummaries)
      self.val_tg_metrics.update(tg_decoded, GoldTags)
      self.val_rqe_metrics.update(labels_predicted.squeeze(1), Pair_Labels)

      output =  {
          'Model_Output': Model_Output,
          'RQE_Loss': RQE_Loss,
          'SUM_Loss': SUM_Loss,
          'TG_Loss': TG_Loss
          }
      return output


    def on_validation_epoch_end(self):
      val_rqe_metrics = self.val_rqe_metrics.compute()
      val_sum_metrics = self.val_sum_metrics.compute()
      val_tg_metrics = self.val_tg_metrics.compute()

      self.log_dict(val_rqe_metrics, prog_bar=False)
      self.log('val/val_rouge1_sum(f)', val_sum_metrics['val_rouge1_fmeasure'])
      self.log('val/val_rouge2_sum(f)', val_sum_metrics['val_rouge2_fmeasure'])
      self.log('val/val_rougel_sum(f)', val_sum_metrics['val_rougeL_fmeasure'])
      self.log('val/val_rouge1_sum(r)', val_sum_metrics['val_rouge1_recall'])
      self.log('val/val_rouge2_sum(r)', val_sum_metrics['val_rouge2_recall'])
      self.log('val/val_rougel_sum(r)', val_sum_metrics['val_rougeL_recall'])
      self.log('val/val_rouge1_sum(p)', val_sum_metrics['val_rouge1_precision'])
      self.log('val/val_rouge2_sum(p)', val_sum_metrics['val_rouge2_precision'])
      self.log('val/val_rougel_sum(p)', val_sum_metrics['val_rougeL_precision'])
      self.log('val/val_rouge1_tg(f)', val_tg_metrics['val_rouge1_fmeasure'])
      self.log('val/val_rouge2_tg(f)', val_tg_metrics['val_rouge2_fmeasure'])
      self.log('val/val_rougel_tg(f)', val_tg_metrics['val_rougeL_fmeasure'])
      self.log('val/val_rouge1_tg(r)', val_tg_metrics['val_rouge1_recall'])
      self.log('val/val_rouge2_tg(r)', val_tg_metrics['val_rouge2_recall'])
      self.log('val/val_rougel_tg(r)', val_tg_metrics['val_rougeL_recall'])
      self.log('val/val_rouge1_tg(p)', val_tg_metrics['val_rouge1_precision'])
      self.log('val/val_rouge2_tg(p)', val_tg_metrics['val_rouge2_precision'])
      self.log('val/val_rougel_tg(p)', val_tg_metrics['val_rougeL_precision'])
      self.log('val/val_BS_score_sum(r)', val_sum_metrics['val_recall'].mean())
      self.log('val/val_BS_score_sum(f)', val_sum_metrics['val_f1'].mean())
      self.log('val/val_BS_score_sum(p)', val_sum_metrics['val_precision'].mean())
      self.log('val/val_BS_score_tg(r)', val_tg_metrics['val_recall'].mean())
      self.log('val/val_BS_score_tg(f)', val_tg_metrics['val_f1'].mean())
      self.log('val/val_BS_score_tg(p)', val_tg_metrics['val_precision'].mean())

      self.val_rqe_metrics.reset()
      self.val_sum_metrics.reset()
      self.val_tg_metrics.reset()


    def on_train_epoch_end(self):
      train_rqe_metrics = self.train_rqe_metrics.compute()
      train_sum_metrics = self.train_sum_metrics.compute()
      train_tg_metrics = self.train_tg_metrics.compute()

      self.log_dict(train_rqe_metrics, prog_bar=False)
      self.log('train/train_rouge1_sum(f)', train_sum_metrics['train_rouge1_fmeasure'])
      self.log('train/train_rouge2_sum(f)', train_sum_metrics['train_rouge2_fmeasure'])
      self.log('train/train_rougel_sum(f)', train_sum_metrics['train_rougeL_fmeasure'])
      self.log('train/train_rouge1_sum(r)', train_sum_metrics['train_rouge1_recall'])
      self.log('train/train_rouge2_sum(r)', train_sum_metrics['train_rouge2_recall'])
      self.log('train/train_rougel_sum(r)', train_sum_metrics['train_rougeL_recall'])
      self.log('train/train_rouge1_sum(p)', train_sum_metrics['train_rouge1_precision'])
      self.log('train/train_rouge2_sum(p)', train_sum_metrics['train_rouge2_precision'])
      self.log('train/train_rougel_sum(p)', train_sum_metrics['train_rougeL_precision'])
      self.log('train/train_rouge1_tg(f)', train_tg_metrics['train_rouge1_fmeasure'])
      self.log('train/train_rouge2_tg(f)', train_tg_metrics['train_rouge2_fmeasure'])
      self.log('train/train_rougel_tg(f)', train_tg_metrics['train_rougeL_fmeasure'])
      self.log('train/train_rouge1_tg(r)', train_tg_metrics['train_rouge1_recall'])
      self.log('train/train_rouge2_tg(r)', train_tg_metrics['train_rouge2_recall'])
      self.log('train/train_rougel_tg(r)', train_tg_metrics['train_rougeL_recall'])
      self.log('train/train_rouge1_tg(p)', train_tg_metrics['train_rouge1_precision'])
      self.log('train/train_rouge2_tg(p)', train_tg_metrics['train_rouge2_precision'])
      self.log('train/train_rougel_tg(p)', train_tg_metrics['train_rougeL_precision'])
      self.log('train/train_BS_score_sum(r)', train_sum_metrics['train_recall'].mean())
      self.log('train/train_BS_score_sum(f)', train_sum_metrics['train_f1'].mean())
      self.log('train/train_BS_score_sum(p)', train_sum_metrics['train_precision'].mean())
      self.log('train/train_BS_score_tg(r)', train_tg_metrics['train_recall'].mean())
      self.log('train/train_BS_score_tg(f)', train_tg_metrics['train_f1'].mean())
      self.log('train/train_BS_score_tg(p)', train_tg_metrics['train_precision'].mean())

      self.train_rqe_metrics.reset()
      self.train_sum_metrics.reset()
      self.train_tg_metrics.reset()


    def on_test_epoch_end(self):
      te_rqe_metrics = self.test_rqe_metrics.compute()
      te_sum_metrics = self.test_sum_metrics.compute()
      te_tg_metrics = self.test_tg_metrics.compute()

      self.log_dict(te_rqe_metrics, prog_bar=False)
      self.log('test/test_rouge1_sum(f)', te_sum_metrics['test_rouge1_fmeasure'])
      self.log('test/test_rouge2_sum(f)', te_sum_metrics['test_rouge2_fmeasure'])
      self.log('test/test_rougel_sum(f)', te_sum_metrics['test_rougeL_fmeasure'])
      self.log('test/test_rouge1_sum(r)', te_sum_metrics['test_rouge1_recall'])
      self.log('test/test_rouge2_sum(r)', te_sum_metrics['test_rouge2_recall'])
      self.log('test/test_rougel_sum(r)', te_sum_metrics['test_rougeL_recall'])
      self.log('test/test_rouge1_sum(p)', te_sum_metrics['test_rouge1_precision'])
      self.log('test/test_rouge2_sum(p)', te_sum_metrics['test_rouge2_precision'])
      self.log('test/test_rougel_sum(p)', te_sum_metrics['test_rougeL_precision'])
      self.log('test/test_rouge1_tg(f)', te_tg_metrics['test_rouge1_fmeasure'])
      self.log('test/test_rouge2_tg(f)', te_tg_metrics['test_rouge2_fmeasure'])
      self.log('test/test_rougel_tg(f)', te_tg_metrics['test_rougeL_fmeasure'])
      self.log('test/test_rouge1_tg(r)', te_tg_metrics['test_rouge1_recall'])
      self.log('test/test_rouge2_tg(r)', te_tg_metrics['test_rouge2_recall'])
      self.log('test/test_rougel_tg(r)', te_tg_metrics['test_rougeL_recall'])
      self.log('test/test_rouge1_tg(p)', te_tg_metrics['test_rouge1_precision'])
      self.log('test/test_rouge2_tg(p)', te_tg_metrics['test_rouge2_precision'])
      self.log('test/test_rougel_tg(p)', te_tg_metrics['test_rougeL_precision'])
      self.log('test/test_BS_score_sum(r)', te_sum_metrics['test_recall'].mean())
      self.log('test/test_BS_score_sum(f)', te_sum_metrics['test_f1'].mean())
      self.log('test/test_BS_score_sum(p)', te_sum_metrics['test_precision'].mean())
      self.log('test/test_BS_score_tg(r)', te_tg_metrics['test_recall'].mean())
      self.log('test/test_BS_score_tg(f)', te_tg_metrics['test_f1'].mean())
      self.log('test/test_BS_score_tg(p)', te_tg_metrics['test_precision'].mean())

      self.test_rqe_metrics.reset()
      self.test_sum_metrics.reset()
      self.test_tg_metrics.reset()


    def configure_optimizers(self):
      params_Enc = list(self.Model.SUM.model.encoder.layers[3:].parameters())
      opt_BART_Encoder = torch.optim.AdamW(                                     #  1.  Optimizer for BART Encoder module only
          params_Enc,
          lr = self.Model.learning_rate_Encoder,
          eps = self.opt_epsilon,
          weight_decay=self.weight_decay,
          )


      params_Dec = list(self.Model.SUM.model.decoder.layers[0:3].parameters()) +\
                   list(self.Model.SUM.model.decoder.embed_positions.parameters()) +\
                   list(self.Model.SUM.model.decoder.layernorm_embedding.parameters()) +\
                   list(self.Model.SUM.model.decoder.embed_tokens.parameters())
      opt_BART_Decoder = torch.optim.AdamW(                                     #  2.  Optimizer for BART shared Decoder layers only
          params_Dec,
          lr = self.Model.learning_rate_Decoder,
          eps = self.opt_epsilon,
          weight_decay=self.weight_decay,
          )


      params_RQE = list(self.Model.RQE.lm_head.parameters())
      opt_RQE = torch.optim.AdamW(                                              #  3.  Optimizer for RQE module only
          params_RQE,
          lr = self.Model.learning_rate_RQE,
          eps = self.opt_epsilon,
          weight_decay=self.weight_decay,
          )


      paramsSUM = list(self.Model.SUM.lm_head.parameters()) +\
                  list(self.Model.SUM.model.decoder.layers[3:].parameters())
      opt_SUM = torch.optim.AdamW(                                              #  4.  Optimizer for Summarization module only
          paramsSUM,
          lr = self.Model.learning_rate_SUM,
          eps = self.opt_epsilon,
          weight_decay=self.weight_decay,
          )


      paramsTG = list(self.Model.TG.lm_head.parameters()) +\
                 list(self.Model.TG.model.decoder.layers[3:].parameters())
      opt_TG = torch.optim.AdamW(                                               #  5.  Optimizer for Tag generation module only
          paramsTG,
          lr = self.Model.learning_rate_TG,
          eps = self.opt_epsilon,
          weight_decay=self.weight_decay,
          )

      self.Model.Freeze_Parameters(self.FreezeDec, self.FreezeEnc)
      return opt_RQE, opt_BART_Encoder, opt_BART_Decoder, opt_SUM, opt_TG


    def configure_scheduler(self, optimizer):
      total_steps = self.num_train_batches * self.max_epochs
      scheduler = get_linear_schedule_with_warmup(
          optimizer,
          num_warmup_steps = self.warmup,
          num_training_steps = total_steps
          )
      return scheduler

# 🌞 **Hyper-Parameters**

In [None]:
BARTtokenizer = AutoTokenizer.from_pretrained("facebook/bart-base", lowercase = False)
BARTmodel = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
hparams_Main = {
    'embed_size' : BARTmodel.config.d_model,
    'DO_r' : 0.1,
    'lr_RQE' :2e-5,
    'lr_SUM' : 2e-5,
    'lr_TG' : 2e-5,
    'lr_Encoder' : 2e-5,
    'lr_Decoder' : 2e-5,
    'Dimensions' : [48]                                                          # [n1, n2, n3, ....]
    }
Main_Arch_Obj = Main_Architecture(BARTmodel.config,hparams_Main)
max_epochs = 10
batch_size = 16

hparams_MyModel = {
    'Model': Main_Arch_Obj,
    'tokenizer': BARTtokenizer,
    'label_smoothing':  0.1,
    'FreezeLayers':[[0, 0, 0, 1, 1, 1],                                          # (Encoder) 0: Freeze, 1: Unfreeze
                    [1, 1, 1, 1, 1, 1]],                                         # (Decoder) 0: Freeze, 1: Unfreeze
    'Coefficient': {'RQE': 0.7, 'SUM':0.2, 'TG':0.1,
                    'SUM2':1, 'TG2':1,
                    'RQE3':1, 'SUM3':1, 'TG3':1},
    'max_epochs': max_epochs,
    'warmup': 0.0,
    'epsilon': 1e-8,
    'weight_decay': 0.01,
    }

print("Maximum position embeddings: ", BARTmodel.config.max_position_embeddings)
print("Size of embeddings: " , BARTmodel.config.d_model)

# 🌞 **Data preparation**

>## ✨ **Preprocess Data**

In [16]:
def Preprocess_Text (txt, txttype='S'):
    if (txttype == 'T'):
        txt = ", ".join(txt)
        filtered_sentence = txt.replace('-', ' ')
    else:
        txt = txt.replace('-', ' ')
        txt = txt.replace('\n', ' ')
        txt = txt.replace('“', ' ')
        txt = txt.replace('”', ' ')
        txt = re.sub(r'http\S+', '', txt, flags=re.MULTILINE)
        filtered_sentence = txt.translate(str.maketrans('', '', string.punctuation))
    return filtered_sentence

def Preprocess_Data(Data):
    Check_list = ['long_text', 'short_text',
                  'long_text_title', 'long_text_tags', 'short_text_tags'
                  ]
    for index,row in Data.iterrows():
        for i, C in enumerate(Check_list):
            txttype = 'T' if i>2 else 'S'
            filtered_sentence = Preprocess_Text(row[C], txttype)
            Data.at[index, C] = filtered_sentence
    return Data

In [None]:
data_tr = f"/content/drive/MyDrive/ReQuEST/Data/TrainData.pkl"
MyData_tr = pd.read_pickle(data_tr)
MyData_tr['is_duplicate'] = [1 if row == 'Entailed' else 0 for row in MyData_tr['is_duplicate']]

data_te = f"/content/drive/MyDrive/ReQuEST/Data/TestData.pkl"
MyData_te = pd.read_pickle(data_te)
MyData_te['is_duplicate'] = [1 if row == 'Entailed' else 0 for row in MyData_te['is_duplicate']]

MyData_tr = Preprocess_Data(MyData_tr)
MyData_te = Preprocess_Data(MyData_te)

display(MyData_tr)
display(MyData_te)

>## ✨ **Custom Dataset**

In [18]:
class CustomDataset(Dataset):
    def __init__(self, data, tokenizer, pad_token_id):
        self.data = data
        self.tokenizer = tokenizer
        self.pad_tID = pad_token_id
        self.labels = [0, 1]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.loc[idx]
        label = self.labels.index(row[8])

        # ================ Tokenization for RQE ==============
        RQE_input_text = f"Question1: {row[0]} </s></s> Question2: {row[1]}"
        Q1Q2_Tokenized = self.tokenizer(RQE_input_text, max_length = 256, padding = 'max_length')

        # =========== Tokenization for Tag Generation ========
        Q1_Tokenized = self.tokenizer(row[0], max_length = 140, padding='max_length')
        Q1_Tags_Tokenized = self.tokenizer(row[2], max_length = 140, padding='max_length')
        Q1_Tags_input_ids = [-100 if t == self.pad_tID else t for t in Q1_Tags_Tokenized["input_ids"]]

        # =========== Tokenization for Summarization =========
        if (label == 0):
          SUM_input_text = f"Question: {row[0]} </s></s> Query: {row[2]}"
          Q1Tags_Tokenized = self.tokenizer(SUM_input_text, max_length = 140, padding='max_length')
          Q1Title_Tokenized = self.tokenizer(row[4], max_length = 140, padding='max_length')
          Summary_input_ids = [-100 if t == self.pad_tID else t for t in Q1Title_Tokenized["input_ids"]]
          SUM_decoder_attention_mask =  Q1Title_Tokenized["attention_mask"]
          SUM_decoder_input_ids = Q1Title_Tokenized["input_ids"]
        else:
          SUM_input_text = f"Question: {row[0]} </s></s> Query: {row[3]}"
          Q1Tags_Tokenized = self.tokenizer(SUM_input_text, max_length = 140, padding='max_length')
          Q2_Tokenized = self.tokenizer(row[1], max_length = 140, padding='max_length')
          Summary_input_ids = [-100 if t == self.pad_tID else t for t in Q2_Tokenized["input_ids"]]
          SUM_decoder_attention_mask =  Q2_Tokenized["attention_mask"]
          SUM_decoder_input_ids = Q2_Tokenized["input_ids"]


        return torch.tensor(Q1_Tokenized['input_ids']),\
            torch.tensor(Q1_Tokenized['attention_mask']),\
            torch.tensor(Q1Q2_Tokenized['input_ids']),\
            torch.tensor(Q1Q2_Tokenized['attention_mask']),\
            torch.tensor(Q1Tags_Tokenized['input_ids']),\
            torch.tensor(Q1Tags_Tokenized['attention_mask']),\
            torch.tensor(Q1_Tags_Tokenized["input_ids"]),\
            torch.tensor(Q1_Tags_Tokenized['attention_mask']),\
            torch.tensor(SUM_decoder_input_ids),\
            torch.tensor(SUM_decoder_attention_mask),\
            torch.tensor(Summary_input_ids),\
            torch.tensor(Q1_Tags_input_ids),\
            torch.tensor(label)

>## ✨ **Custom Data Module**

In [19]:
class CustomDataModule(pl.LightningDataModule):
    def __init__(self,
                 MyData_tr,
                 MyData_te,
                 tokenizer,
                 pad_token_id,
                 batch_size = 16,
                 split_ratio = [0.8, 0.1, 0.1],
                 stage=None):
        super().__init__()
        self.batch_size = batch_size
        self.train_dataset = None
        self.test_dataset = None
        self.val_dataset = None
        self.MyData_tr = MyData_tr
        self.MyData_te = MyData_te
        self.tokenizer = tokenizer
        self.pad_id = pad_token_id

    def setup(self, stage=None):
        self.train_dataset = CustomDataset(self.MyData_tr, self.tokenizer, self.pad_id)
        self.test_dataset = CustomDataset(self.MyData_te, self.tokenizer, self.pad_id)
        self.val_dataset = CustomDataset(self.MyData_te, self.tokenizer, self.pad_id)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            sampler = RandomSampler(self.train_dataset, generator= torch.Generator().manual_seed(42)),
            batch_size = self.batch_size,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            sampler = SequentialSampler(self.val_dataset),
            batch_size = self.batch_size,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            sampler = SequentialSampler(self.test_dataset),
            batch_size = self.batch_size,
        )

In [20]:
DataModule = CustomDataModule(
    MyData_tr, MyData_te,
    BARTtokenizer,
    BARTmodel.config.pad_token_id,
    batch_size=batch_size
)

In [21]:
DataModule.setup()
num_train_batches = len(DataModule.train_dataloader())
hparams_MyModel['num_train_batches'] = num_train_batches

# 🌞 **Model Compile**


In [None]:
MyModel = LitReQuEST(hparams_MyModel)
print(MyModel)

logger = TensorBoardLogger("/content/drive/MyDrive/ReQuEST/Results/logs", name="ReQuEST_Logs")

# 🌞 **Training Phase**

In [None]:
trainer = pl.Trainer(
    logger = logger,
    max_epochs = hparams_MyModel['max_epochs'],
    log_every_n_steps = 1,
    num_sanity_val_steps = 0,
    callbacks = [OverrideEpochStepCallback(), checkpoint_callback],
    default_root_dir="/content/drive/MyDrive/ReQuEST/Results/Checkpoints/",
    )

%reload_ext tensorboard
%tensorboard --logdir /content/drive/MyDrive/ReQuEST/

trainer.fit(
    MyModel,
    datamodule=DataModule,
    # ckpt_path="/content/drive/MyDrive/ReQuEST/Results/Checkpoints/MyModel_checkpoints10.ckpt",
    )

trainer.save_checkpoint(
    f"/content/drive/MyDrive/ReQuEST/Results/Checkpoints/MyModel_checkpoints{hparams_MyModel['max_epochs']}.ckpt"
    )

# 🌞 **Test Phase**

In [None]:
result2 = trainer.test(MyModel, datamodule=DataModule)

print(MyModel.generated_summaries)
print(MyModel.generated_tags)
print(MyModel.predicted_labels)