<a href="https://colab.research.google.com/github/ambideXtrous9/600-DSA-Problems/blob/main/Teacher_Student_Flant5_Seq2Seq_with_AttentionQA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --quiet transformers
!pip install --quiet pytorch-lightning
!pip install --quiet tokenizers

In [215]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
torch.set_float32_matmul_precision('high')
import random
import numpy as np
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer as Tokenizer,AutoModelForSeq2SeqLM
from torch.optim import AdamW

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
MODEL_NAME = 'google/flan-t5-small'

In [5]:
tokenizer = Tokenizer.from_pretrained(MODEL_NAME)

In [6]:
path = '/content/drive/MyDrive/MTP CODE/NewsQA_SPAN.feather'

In [7]:
df = pd.read_feather(path)

In [8]:
df = df.iloc[:5000]

In [9]:
train_df, val_df = train_test_split(df,test_size=0.2)
val_df, test_df = train_test_split(val_df,test_size=0.5)

In [216]:
class QADataset(Dataset):
  def __init__(self,data : pd.DataFrame,tokenizer : Tokenizer,source_max_token_len : int = 200,target_max_token_len : int = 15):

    self.tokenizer = tokenizer
    self.data = data
    self.source_max_token_len = source_max_token_len
    self.target_max_token_len = target_max_token_len

  def __len__(self):
    return len(self.data)
  
  def __getitem__(self,index : int):
    data_row = self.data.iloc[index]

    source_encoding = tokenizer(
        data_row['question'],
        data_row['paragraph'],
        max_length = self.source_max_token_len,
        padding = "max_length",
        truncation = "only_second",
        return_attention_mask = True,
        add_special_tokens = True,
        return_tensors = "pt")
    
    target_encoding = tokenizer(
        data_row['answer'],
        max_length = self.target_max_token_len,
        padding = "max_length",
        truncation = True,
        return_attention_mask = True,
        add_special_tokens = True,
        return_tensors = "pt")
    
    labels = target_encoding["input_ids"]
    #labels[labels == 0] = -100

    return dict(
        answer = data_row['answer'],
        input_ids = source_encoding['input_ids'].flatten(),
        attention_mask = source_encoding['attention_mask'].flatten(),
        labels = labels.flatten())


In [217]:
sample_dataset = QADataset(df,tokenizer)

In [218]:
tokenizer.vocab_size

32100

In [219]:
for data in sample_dataset:
  print(type(data['input_ids']))
  print(data['input_ids'])
  break

<class 'torch.Tensor'>
tensor([ 2645,    19,     8,  5037,  2090,    13,  8951,    49,   397,    15,
         5826,    58,     1,    96, 13898,   127,     7,    54,   169,     3,
            9,  2711,    13,   789, 13237,    11,   731,  8225, 14410,  2814,
         2731,    12,   918,     3,     9, 15812,    21,    70,  4833,   976,
          845,  1813,   157,  2375, 10729,   138,     6,  5037,  2090,     6,
         8951,    49,   397,    15,  5826,     5,  5421, 13015,     7,   243,
         4367,   228,  2153,   710,    18,  1987,     6,  2768,    18,  1987,
           11,   307,    18,  1987,   789, 13237,    28,   128,  8543,  3069,
          494,     5,     3, 17229,     6,     3,     9,   386,    18,  7393,
        20792,  3259,   133,   428,  4367,     3,     9,  6339,    13,  1877,
         2712,     6,     3,     9,    80,    18,  1201,   332,  2876,   133,
         6339,     3, 19708,  4704,    11,     3,     9,  9445,  1201,   789,
         1034,  6339,     7,     3, 27865

In [220]:
class QADataModule(pl.LightningDataModule):
  def __init__(self,train_df , val_df, test_df,tokenizer : Tokenizer,batch_size : int = 8,source_max_token_len : int = 200,target_max_token_len : int = 15):
    super().__init__()
    self.batch_size = batch_size
    self.train_df = train_df
    self.test_df = test_df
    self.val_df = val_df
    self.tokenizer = tokenizer
    self.source_max_token_len = source_max_token_len
    self.target_max_token_len = target_max_token_len

  def setup(self,stage=None):
    self.train_dataset = QADataset(self.train_df,self.tokenizer,self.source_max_token_len,self.target_max_token_len)
    self.val_dataset = QADataset(self.val_df,self.tokenizer,self.source_max_token_len,self.target_max_token_len)
    self.test_dataset = QADataset(self.test_df,self.tokenizer,self.source_max_token_len,self.target_max_token_len)
    

  def train_dataloader(self):
    return DataLoader(self.train_dataset,batch_size = self.batch_size,shuffle=True,num_workers=2)

  def val_dataloader(self):
    return DataLoader(self.val_dataset,batch_size = self.batch_size,num_workers=2)

  def test_dataloader(self):
    return DataLoader(self.test_dataset,batch_size = self.batch_size,num_workers=2)   

In [238]:
class NQAModel(pl.LightningModule):
  def __init__(self):
    super().__init__()

    self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME,return_dict=True)


  def forward(self,input_ids,attention_mask,labels=None):
    output = self.model(
        input_ids = input_ids,
        attention_mask = attention_mask,
        labels = labels)
    
    return output.loss, output.encoder_last_hidden_state, output.logits

  def training_step(self,batch,batch_idx):
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    labels = batch['labels']
    loss, hidden ,logits = self(input_ids,attention_mask,labels)
    self.log("train_loss",loss,prog_bar=True,logger=True)
    return loss

  def validation_step(self,batch,batch_idx):
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    labels = batch['labels']
    answer = batch['answer']
    loss, hidden ,logits = self(input_ids,attention_mask,labels)

    
    self.log_dict({"val_loss" : loss,},prog_bar=True,logger=True)

    return loss

  def test_step(self,batch,batch_idx):
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    labels = batch['labels']
    answer = batch['answer']
    loss, hidden ,logits = self(input_ids,attention_mask,labels)

   
    self.log_dict({"test_loss" : loss},prog_bar=True,logger=True)
    return loss

  def configure_optimizers(self):
    return AdamW(self.parameters(),lr = 0.0001)

In [239]:
BATCH_SIZE = 64
N_EPOCHS = 1

data_module = QADataModule(train_df,val_df,test_df,tokenizer,batch_size = BATCH_SIZE)
data_module.setup()

In [240]:
model = NQAModel()


In [241]:
checkpoint_callback = ModelCheckpoint(
    dirpath = 'checkpoints',
    filename = 'FlanT5',
    save_top_k = 1,
    verbose = True,
    monitor = 'val_loss',
    mode = 'min'
)

In [242]:
trainer = pl.Trainer(devices=-1, accelerator="gpu",
    callbacks=[checkpoint_callback],
    max_epochs = N_EPOCHS
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [243]:
trainer.fit(model,data_module)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 77.0 M
-----------------------------------------------------
77.0 M    Trainable params
0         Non-trainable params
77.0 M    Total params
307.845   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 0, global step 63: 'val_loss' reached 3.29579 (best 3.29579), saving model to '/content/checkpoints/FlanT5.ckpt' as top 1
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.


In [244]:
cppath = '/content/checkpoints/FlanT5.ckpt'
trained_model = NQAModel.load_from_checkpoint(cppath)
trained_model.freeze()

In [246]:
def generate_ans(question):
    trained_model.to(device)
    source_encoding = tokenizer(
        question['question'],
        question['paragraph'],
        max_length = 200,
        padding = "max_length",
        truncation = "only_second",
        return_attention_mask = True,
        add_special_tokens = True,
        return_tensors = "pt").to(device)
    
    target_encoding = tokenizer(
        question['answer'],
        max_length = 15,
        padding = "max_length",
        truncation = True,
        return_attention_mask = True,
        add_special_tokens = True,
        return_tensors = "pt")
    
    labels = target_encoding["input_ids"].to(device)
    
    loss, hidden, logits = trained_model(
        input_ids = source_encoding['input_ids'],
        attention_mask = source_encoding['attention_mask'],
        labels = labels)
    
    print(logits.shape)

    print(hidden.shape)
    
    return hidden
    

In [247]:
sample_question = test_df.iloc[1]


In [248]:
generate_ans(sample_question)


torch.Size([1, 15, 32128])
torch.Size([1, 200, 512])


tensor([[[-0.2213, -0.0071,  0.1169,  ...,  0.0460, -0.0676, -0.0669],
         [-0.3960,  0.1469,  0.2473,  ..., -0.0052,  0.1503, -0.0187],
         [-0.1478, -0.1415,  0.1866,  ..., -0.0546, -0.0712, -0.1204],
         ...,
         [ 0.0972, -0.3107,  0.2445,  ..., -0.3174, -0.0139,  0.0035],
         [ 0.0902, -0.3170,  0.2631,  ..., -0.2967, -0.0116,  0.0048],
         [ 0.0892, -0.3240,  0.2671,  ..., -0.2863, -0.0034,  0.0088]]],
       device='cuda:0')

In [281]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = nn.Dropout(dropout)
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, num_layers, dropout=dropout, batch_first=True)

    def forward(self, x):
        embedded = self.dropout(self.embedding(x))
        outputs, hidden = self.gru(embedded)
        return outputs, hidden




In [282]:
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.hidden_size = hidden_size
        self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
        self.v = nn.Linear(hidden_size, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        max_len = encoder_outputs.size(1)
        h = hidden.repeat(max_len, 1, 1).transpose(0, 1)
        encoder_outputs = encoder_outputs.transpose(1, 2)
        hidden_linear = hidden[-1].unsqueeze(0).unsqueeze(2)  # add extra dim for bmm
        hidden_linear = self.attn(hidden_linear.transpose(1, 2)).transpose(1, 2)  # transpose for correct shape
        attn_energies = self.v(torch.tanh(hidden_linear + encoder_outputs))
        attn_weights = F.softmax(attn_energies, dim=1)
        context = torch.bmm(attn_weights.transpose(1,2), encoder_outputs)
        return context


In [283]:
class AttentionDecoder(nn.Module):
    def __init__(self, output_size, hidden_size, num_layers, dropout):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = nn.Dropout(dropout)
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, num_layers, dropout=dropout, batch_first=True)
        self.fc_out = nn.Linear(hidden_size, output_size)
        self.attention = Attention(hidden_size)

    def forward(self, x, hidden, encoder_outputs):
        embedded = self.dropout(self.embedding(x))
        context, _ = self.attention(hidden, encoder_outputs)
        rnn_input = torch.cat((embedded, context), dim=2)
        output, hidden = self.gru(rnn_input, hidden)
        output = self.fc_out(output)
        return output, hidden


In [284]:
class Decoder(nn.Module):
    def __init__(self, output_size, hidden_size, num_layers, dropout):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = nn.Dropout(dropout)
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, num_layers, dropout=dropout, batch_first=True)
        self.fc_out = nn.Linear(hidden_size, output_size)
        self.attention = Attention(hidden_size)

    def forward(self, x, hidden, encoder_outputs):
        embedded = self.dropout(self.embedding(x))
        #context = self.attention(hidden[-1], encoder_outputs)
        #rnn_input = torch.cat((embedded, context), dim=2)
        output, hidden = self.gru(embedded, hidden)
        output = self.fc_out(output)
        return output, hidden

In [334]:
class Seq2Seq(pl.LightningModule):
    def __init__(self, encoder, decoder, pad_idx):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx
        self.trained_model = NQAModel.load_from_checkpoint(cppath)
        self.trained_model.freeze()

    def forward(self, src, attn,trg):
        batch_size = src.size(0)
        max_len = trg.size(1)
        trg_vocab_size = self.decoder.fc_out.out_features

        encoder_outputs, hidden = self.encoder(src)


        with torch.no_grad():
             loss, hidden_t, logits_t = trained_model(input_ids = src,
                                                  attention_mask = attn,
                                                  labels = trg)
                            

        logits_t = logits_t[:, :, :trg_vocab_size]
        hidden_t = hidden_t[:,0,:]



        outputs = torch.zeros(batch_size, max_len, trg_vocab_size).to(self.device)
        output = trg[:, 0]
        for t in range(1, max_len):
            output, hidden = self.decoder(output.unsqueeze(1), hidden, encoder_outputs)
            outputs[:, t, :] = output.squeeze(1)
            top1 = output.argmax(2)
            output = top1.squeeze(1)
        
        return outputs, hidden, hidden_t, logits_t

    def training_step(self, batch, batch_idx):
        
        src = batch['input_ids']
        attn = batch['attention_mask']
        trg = batch['labels']

        trg_input = trg
        trg_output = trg

        output, hidden, hidden_t, logits_t = self(src, attn,trg_input)

        hidden = hidden.squeeze(0)

        
        output_log_probs = F.log_softmax(output, dim=2)
        target_log_probs = F.log_softmax(logits_t, dim=2)

        # Calculate KL divergence loss
        decoder_kl_loss = F.kl_div(output_log_probs, target_log_probs, reduction='batchmean')

        hidden_t_probs = F.log_softmax(hidden_t, dim=1)
        hidden_probs = F.log_softmax(hidden, dim=1)

        # Calculate KL divergence loss
        encoder_kl_loss = F.kl_div(hidden_t_probs, hidden_probs, reduction='batchmean')



        output = output.reshape(-1, output.shape[-1])
        trg_output = trg_output.reshape(-1)

        train_loss = F.cross_entropy(output, trg_output, ignore_index=self.pad_idx)

        self.log_dict({"train_loss" : train_loss,
                       "train_encoder_kl_loss" : encoder_kl_loss,
                       "train_decoder_kl_loss" : decoder_kl_loss,
                       },prog_bar=True,logger=True)

        return train_loss

    def validation_step(self, batch, batch_idx):
        src = batch['input_ids']
        attn = batch['attention_mask']
        trg = batch['labels']

        trg_input = trg
        trg_output = trg

        output, hidden, hidden_t, logits_t = self(src, attn,trg_input)


        hidden = hidden.squeeze(0)

        
        output_log_probs = F.log_softmax(output, dim=2)
        target_log_probs = F.log_softmax(logits_t, dim=2)

        # Calculate KL divergence loss
        decoder_kl_loss = F.kl_div(output_log_probs, target_log_probs, reduction='batchmean')

        hidden_t_probs = F.log_softmax(hidden_t, dim=1)
        hidden_probs = F.log_softmax(hidden, dim=1)

        # Calculate KL divergence loss
        encoder_kl_loss = F.kl_div(hidden_t_probs, hidden_probs, reduction='batchmean')

        output = output.reshape(-1, output.shape[-1])
        trg_output = trg_output.reshape(-1)

        val_loss = F.cross_entropy(output, trg_output, ignore_index=self.pad_idx)

        self.log_dict({"val_loss" : val_loss,
                       "val_encoder_kl_loss" : encoder_kl_loss,
                       "val_decoder_kl_loss" : decoder_kl_loss,
                       },prog_bar=True,logger=True)


        return val_loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer


In [335]:
checkpoint_callback = ModelCheckpoint(
    dirpath = 'checkpoints',
    filename = 'Seq2Seq',
    save_top_k = 1,
    verbose = True,
    monitor = 'val_loss',
    mode = 'min'
)

In [336]:
encoder = Encoder(input_size=tokenizer.vocab_size, hidden_size=512, num_layers=1, dropout=0.2)

In [337]:
decoder = Decoder(output_size=tokenizer.vocab_size, hidden_size=512, num_layers=1, dropout=0.2)

In [338]:
model = Seq2Seq(encoder,decoder,pad_idx=0)

In [341]:
trainer = pl.Trainer(devices=-1, accelerator="gpu",
    callbacks=[checkpoint_callback],
    max_epochs = 20
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [342]:
trainer.fit(model,data_module)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type     | Params
-------------------------------------------
0 | encoder       | Encoder  | 18.0 M
1 | decoder       | Decoder  | 35.0 M
2 | trained_model | NQAModel | 77.0 M
-------------------------------------------
53.0 M    Trainable params
77.0 M    Non-trainable params
129 M     Total params
519.904   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 1, global step 126: 'val_loss' reached 5.28144 (best 5.28144), saving model to '/content/checkpoints/Seq2Seq-v1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 2, global step 189: 'val_loss' reached 4.91932 (best 4.91932), saving model to '/content/checkpoints/Seq2Seq-v1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 3, global step 252: 'val_loss' reached 4.66983 (best 4.66983), saving model to '/content/checkpoints/Seq2Seq-v1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 4, global step 315: 'val_loss' reached 4.44454 (best 4.44454), saving model to '/content/checkpoints/Seq2Seq-v1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 5, global step 378: 'val_loss' reached 4.31738 (best 4.31738), saving model to '/content/checkpoints/Seq2Seq-v1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 6, global step 441: 'val_loss' reached 4.18945 (best 4.18945), saving model to '/content/checkpoints/Seq2Seq-v1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 7, global step 504: 'val_loss' reached 4.09658 (best 4.09658), saving model to '/content/checkpoints/Seq2Seq-v1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 8, global step 567: 'val_loss' reached 4.06471 (best 4.06471), saving model to '/content/checkpoints/Seq2Seq-v1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 9, global step 630: 'val_loss' reached 3.99800 (best 3.99800), saving model to '/content/checkpoints/Seq2Seq-v1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 10, global step 693: 'val_loss' reached 3.95112 (best 3.95112), saving model to '/content/checkpoints/Seq2Seq-v1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 11, global step 756: 'val_loss' reached 3.91752 (best 3.91752), saving model to '/content/checkpoints/Seq2Seq-v1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 12, global step 819: 'val_loss' reached 3.89153 (best 3.89153), saving model to '/content/checkpoints/Seq2Seq-v1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 13, global step 882: 'val_loss' reached 3.86978 (best 3.86978), saving model to '/content/checkpoints/Seq2Seq-v1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 14, global step 945: 'val_loss' reached 3.85437 (best 3.85437), saving model to '/content/checkpoints/Seq2Seq-v1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 15, global step 1008: 'val_loss' reached 3.84908 (best 3.84908), saving model to '/content/checkpoints/Seq2Seq-v1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 16, global step 1071: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 17, global step 1134: 'val_loss' reached 3.84771 (best 3.84771), saving model to '/content/checkpoints/Seq2Seq-v1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 18, global step 1197: 'val_loss' reached 3.84756 (best 3.84756), saving model to '/content/checkpoints/Seq2Seq-v1.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 19, global step 1260: 'val_loss' reached 3.83673 (best 3.83673), saving model to '/content/checkpoints/Seq2Seq-v1.ckpt' as top 1
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


In [343]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [361]:
def predict(question):
    model.eval()

    # Tokenize the source text
    source_tokens = tokenizer(
        question['question'],
        question['paragraph'],
        max_length=200,
        padding="max_length",
        truncation="only_second",
        add_special_tokens=True,
        return_tensors="pt")['input_ids'].flatten().to(device)

    # Reshape the source tokens to match the expected input shape of the encoder
    source_tokens = source_tokens.unsqueeze(0).to(device)

    model.to(device)

    # Encode the source text
    with torch.no_grad():
        encoder_outputs, hidden = model.encoder(source_tokens)

    max_length = 15
    # Initialize the predicted sentence
    outputs = [0]

    # Generate the output sequence token by token
    for _ in range(max_length):
        previous_word = torch.LongTensor([outputs[-1]]).unsqueeze(0).to(device)

        # Decode the next token
        with torch.no_grad():
            output, hidden = model.decoder(previous_word, hidden, encoder_outputs)
            best_guess = output.argmax(2).item()

        # Add the predicted token to the predicted sentence
        outputs.append(best_guess)

        # If the predicted token is the end-of-sequence token, stop generating further tokens
        if best_guess == tokenizer.sep_token_id:
            break

    # Convert the predicted sentence back to text
    
    print(outputs)
    predicted_text = tokenizer.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)

    return predicted_text


In [382]:
sample_question = test_df.iloc[10]


In [383]:
sample_question['question']


'What is another name for Makkal Needhi Maiam?'

In [384]:
sample_question['answer']


'MNM'

In [385]:
predict(sample_question)

[0, 16568, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


'NM'