<a href="https://colab.research.google.com/github/Karthick47v2/question-generator/blob/main/model_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


### Install 3rd party libraries


In [None]:
!pip3 install pytorch-lightning == 1.7.0
# newwer version not works with FastT5 (for ONNX conversion)
!pip3 install transformers == 4.1.1
!pip3 install tokenizers == 0.9.4
!pip3 install sentencepiece == 0.1.94


### Import libraries

> You **_may_** need to restart runtime after installing python packages. (If importing `pytorch_lightning` throws error)


In [None]:
import os
import pandas as pd
import torch
import pytorch_lightning as pl

from torch.utils.data import Dataset, DataLoader
from transformers import AdamW, T5ForConditionalGeneration, T5Tokenizer
from sklearn.model_selection import train_test_split

pl.seed_everything(42)


### Load and split dataset


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')


In [4]:
dataset = 'sciq'  # squad or sciq

df = pd.read_csv(
    f"gdrive/MyDrive/mcq-gen/{'SQuAD' if dataset == 'squad' else 'SciQ'}-processed.csv")


In [None]:
train_df, validation_df = train_test_split(df, test_size=0.1, shuffle=True)
validation_df, test_df = train_test_split(validation_df, test_size=0.4)
train_df.shape, validation_df.shape, test_df.shape


### Load base model


In [None]:
t5_tokenizer = T5Tokenizer.from_pretrained('t5-base')


### Dataset code


In [7]:
class QADataset(Dataset):
    def __init__(self, tokenizer, data, max_out_len, max_in_len=512):
        self.data = data
        self.max_in_len = max_in_len
        self.max_out_len = max_out_len
        self.tokenizer = tokenizer
        self.inputs = []
        self.targets = []
        self.__tokenize()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        labels = self.targets[index]["input_ids"]
        labels[labels == 0] = -100

        return {'context': self.data.iloc[index]['context'],
                'answer': self.data.iloc[index]['answer'],
                'question': self.data.iloc[index]['question'],
                'input_ids': self.inputs[index]["input_ids"].flatten(),
                'attention_mask': self.inputs[index]["attention_mask"].flatten(),
                'labels': labels.flatten(),
                'labels_attention_mask': self.targets[index]["attention_mask"].flatten()
                }

    def __tokenize(self):
        for _, row in self.data.iterrows():
            context, answer, question = row['context'], row['answer'], row['question']

            source_encoding = self.tokenizer(
                context, answer,
                max_length=self.max_in_len,
                padding='max_length',
                truncation='only_first',
                return_attention_mask=True,
                add_special_tokens=True,
                return_tensors='pt'
            )

            target_encoding = self.tokenizer(
                question,
                max_length=self.max_out_len,
                padding='max_length',
                truncation=True,
                return_attention_mask=True,
                add_special_tokens=True,
                return_tensors='pt'
            )

            self.inputs.append(source_encoding)
            self.targets.append(target_encoding)


In [8]:
class QADataModule(pl.LightningDataModule):
    def __init__(self, train_df, validation_df, test_df, tokenizer, batch_size, max_out_len,
                 max_in_len=512):
        super().__init__()
        self.train_df = train_df
        self.validation_df = validation_df
        self.test_df = test_df
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.max_in_len = max_in_len
        self.max_out_len = max_out_len

    def setup(self, stage=None):
        self.train_dataset = QADataset(self.tokenizer,
                                       self.train_df,
                                       self.max_out_len,
                                       self.max_in_len)
        self.validation_dataset = QADataset(self.tokenizer,
                                            self.validation_df,
                                            self.max_out_len,
                                            self.max_in_len)
        self.test_dataset = QADataset(self.tokenizer,
                                      self.test_df,
                                      self.max_out_len,
                                      self.max_in_len)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size,
                          shuffle=True, num_workers=os.cpu_count())

    def val_dataloader(self):
        return DataLoader(self.validation_dataset, batch_size=2, num_workers=os.cpu_count())

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=2,  num_workers=os.cpu_count())


### Model training code


In [9]:
class QAModel(pl.LightningModule):
    def __init__(self, learning_rate=None):
        super().__init__()
        self.model = T5ForConditionalGeneration.from_pretrained(
            't5-base', return_dict=True)
        self.lr = learning_rate

    def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels,
        )
        return outputs.loss, outputs.logits

    def step(self, batch, step):

        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels_attention_mask = batch['labels_attention_mask']
        labels = batch['labels']

        loss, outputs = self(input_ids, attention_mask, labels_attention_mask,
                             labels)

        self.log(f"{step}_loss", loss, prog_bar=True, logger=True)

        return loss

    def training_step(self, batch, batch_idx):
        return self.step(batch, 'train')

    def validation_step(self, batch, batch_idx):

        return self.step(batch, 'val')

    def test_step(self, batch, batch_idx):

        return self.step(batch, 'test')

    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=self.lr, eps=1e-8)


In [10]:
BATCH_SIZE = 8
N_EPOCHS = 3
MAX_LR = 1e-2


### Find best LR


In [11]:
data_module = QADataModule(train_df, validation_df,
                           test_df, t5_tokenizer, BATCH_SIZE, 72)  # 48 / 72
data_module.setup()


In [None]:
model = QAModel(learning_rate=MAX_LR)
trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=20)


In [None]:
lr_finder = trainer.tuner.lr_find(model, data_module, max_lr=MAX_LR)


In [None]:
fig = lr_finder.plot(suggest=True)
fig.show()


In [None]:
lr = lr_finder.suggestion()
print(lr)


### Train model


In [None]:
import gc
gc.collect()

%load_ext tensorboard
%tensorboard - -logdir ./lightning_logs


In [11]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor="val_loss",
    dirpath="checkpoints",
    filename="model-{epoch:02d}-{val_loss:.2f}",
    save_top_k=1,
    verbose=True,
    mode="min",
)

logger = pl.loggers.TensorBoardLogger('lightning_logs', name='SciQ-T5')


In [None]:
data_module = QADataModule(train_df, validation_df, test_df, t5_tokenizer,
                           BATCH_SIZE, 72)  # 48 -squad / 72 - sciq
data_module.setup()

model = QAModel(learning_rate=lr)

trainer = pl.Trainer(callbacks=[checkpoint_callback],
                     max_epochs=N_EPOCHS,
                     accelerator='gpu',
                     devices=1,
                     enable_progress_bar=True,
                     logger=logger,
                     precision=32)


In [None]:
trainer.fit(model, data_module)
