<a href="https://colab.research.google.com/github/ambideXtrous9/Seq2Seq-Attention-QA/blob/main/Seq2SeqAttentionQATeacher-Forcing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --quiet transformers
!pip install --quiet pytorch-lightning
!pip install --quiet tokenizers

In [2]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
torch.set_float32_matmul_precision('high')
import random
import numpy as np
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer as Tokenizer,AutoModelForSeq2SeqLM
from torch.optim import AdamW

In [3]:
pl.seed_everything (42)

INFO:lightning_fabric.utilities.seed:Global seed set to 42


42

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
MODEL_NAME = 'google/flan-t5-base'

In [6]:
tokenizer = Tokenizer.from_pretrained(MODEL_NAME)

In [7]:
path = '/content/drive/MyDrive/MTP CODE/NewsQA_SPAN.feather'

In [8]:
df = pd.read_feather(path)

In [9]:
df = df.iloc[:5000]

In [10]:
train_df, val_df = train_test_split(df,test_size=0.2)
val_df, test_df = train_test_split(val_df,test_size=0.5)

In [11]:
class QADataset(Dataset):
  def __init__(self,data : pd.DataFrame,tokenizer : Tokenizer,source_max_token_len : int = 200,
               target_max_token_len : int = 20):

    self.tokenizer = tokenizer
    self.data = data
    self.source_max_token_len = source_max_token_len
    self.target_max_token_len = target_max_token_len

  def __len__(self):
    return len(self.data)
  
  def __getitem__(self,index : int):
    data_row = self.data.iloc[index]

    source_encoding = tokenizer(
        data_row['question'],
        data_row['paragraph'],
        max_length = self.source_max_token_len,
        padding = "max_length",
        truncation = "only_second",
        return_attention_mask = True,
        add_special_tokens = True,
        return_tensors = "pt")
    
    target_encoding = tokenizer(
        data_row['answer'],
        max_length = self.target_max_token_len,
        padding = "max_length",
        truncation = True,
        return_attention_mask = True,
        add_special_tokens = True,
        return_tensors = "pt")
    
    labels = target_encoding["input_ids"]

    return dict(
        answer = data_row['answer'],
        input_ids = source_encoding['input_ids'].flatten(),
        attention_mask = source_encoding['attention_mask'].flatten(),
        labels = labels.flatten())


In [12]:
sample_dataset = QADataset(df,tokenizer)

In [13]:
tokenizer.vocab_size

32100

In [14]:
for data in sample_dataset:
  print(type(data['input_ids']))
  print(data['input_ids'])
  print(data['labels'])
  break

<class 'torch.Tensor'>
tensor([ 2645,    19,     8,  5037,  2090,    13,  8951,    49,   397,    15,
         5826,    58,     1,    96, 13898,   127,     7,    54,   169,     3,
            9,  2711,    13,   789, 13237,    11,   731,  8225, 14410,  2814,
         2731,    12,   918,     3,     9, 15812,    21,    70,  4833,   976,
          845,  1813,   157,  2375, 10729,   138,     6,  5037,  2090,     6,
         8951,    49,   397,    15,  5826,     5,  5421, 13015,     7,   243,
         4367,   228,  2153,   710,    18,  1987,     6,  2768,    18,  1987,
           11,   307,    18,  1987,   789, 13237,    28,   128,  8543,  3069,
          494,     5,     3, 17229,     6,     3,     9,   386,    18,  7393,
        20792,  3259,   133,   428,  4367,     3,     9,  6339,    13,  1877,
         2712,     6,     3,     9,    80,    18,  1201,   332,  2876,   133,
         6339,     3, 19708,  4704,    11,     3,     9,  9445,  1201,   789,
         1034,  6339,     7,     3, 27865

In [15]:
class QADataModule(pl.LightningDataModule):
  def __init__(self,train_df , val_df, test_df,tokenizer : Tokenizer,batch_size : int = 8,
               source_max_token_len : int = 200,target_max_token_len : int = 20):
    super().__init__()
    self.batch_size = batch_size
    self.train_df = train_df
    self.test_df = test_df
    self.val_df = val_df
    self.tokenizer = tokenizer
    self.source_max_token_len = source_max_token_len
    self.target_max_token_len = target_max_token_len

  def setup(self,stage=None):
    self.train_dataset = QADataset(self.train_df,self.tokenizer,self.source_max_token_len,self.target_max_token_len)
    self.val_dataset = QADataset(self.val_df,self.tokenizer,self.source_max_token_len,self.target_max_token_len)
    self.test_dataset = QADataset(self.test_df,self.tokenizer,self.source_max_token_len,self.target_max_token_len)
    

  def train_dataloader(self):
    return DataLoader(self.train_dataset,batch_size = self.batch_size,shuffle=True,num_workers=2)

  def val_dataloader(self):
    return DataLoader(self.val_dataset,batch_size = self.batch_size,num_workers=2)

  def test_dataloader(self):
    return DataLoader(self.test_dataset,batch_size = self.batch_size,num_workers=2)   

In [16]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = nn.Dropout(dropout)
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, num_layers, dropout=dropout, batch_first=True)

    def forward(self, x):
        embedded = self.dropout(self.embedding(x))
        outputs, hidden = self.gru(embedded)
        return outputs, hidden

In [17]:
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Linear(hidden_size, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        # hidden shape: (batch_size, hidden_size)
        # encoder_outputs shape: (batch_size, seq_len, hidden_size)
        
        # Calculate attention energies
        seq_len = encoder_outputs.size(1)
        hidden_expanded = hidden.unsqueeze(1).repeat(1, seq_len, 1)
        attn_inputs = torch.cat((hidden_expanded, encoder_outputs), dim=2)
        attn_energies = self.attn(attn_inputs)
        attn_energies = torch.tanh(attn_energies)
        
        # Calculate attention weights and context vector
        attn_weights = torch.softmax(self.v(attn_energies), dim=1)
        context = torch.bmm(attn_weights.transpose(1, 2), encoder_outputs)
        
        # Return context vector
        return context


In [18]:
class Decoder(nn.Module):
    def __init__(self, output_size, hidden_size, num_layers, dropout):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = nn.Dropout(dropout)
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size*2, hidden_size, num_layers, dropout=dropout, batch_first=True)
        self.fc_out = nn.Linear(hidden_size, output_size)
        self.attention = Attention(hidden_size)

    def forward(self, x, hidden, encoder_outputs):
        embedded = self.dropout(self.embedding(x))
        context = self.attention(hidden[-1], encoder_outputs)
        rnn_input = torch.cat((embedded, context), dim=2)
        output, hidden = self.gru(rnn_input, hidden)
        output = self.fc_out(output)
        return output, hidden

In [19]:
class Seq2Seq(pl.LightningModule):
    def __init__(self, encoder, decoder, pad_idx):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=pad_idx)
        

    def forward(self, src, attn,trg,teacher_forcing_ratio = 0.5):
        batch_size = src.size(0)
        max_len = trg.size(1)
        trg_vocab_size = self.decoder.fc_out.out_features

        encoder_outputs, hidden = self.encoder(src)


        teacher_forcing = True if random.random() < teacher_forcing_ratio else False



        outputs = torch.zeros(batch_size, max_len, trg_vocab_size).to(self.device)
        output = trg[:, 0]

        if teacher_forcing : 
           for t in range(1, max_len):
                
                output, hidden = self.decoder(output.unsqueeze(1), hidden, encoder_outputs)
                outputs[:, t, :] = output.squeeze(1)
                output = trg[:, t]
        
        else:
            
            for t in range(1, max_len):
                output, hidden = self.decoder(output.unsqueeze(1), hidden, encoder_outputs)
                outputs[:, t, :] = output.squeeze(1)
                top1 = output.argmax(2)
                output = top1.squeeze(1)
          
        
        return outputs, hidden

    def training_step(self, batch, batch_idx):
        
        src = batch['input_ids']
        attn = batch['attention_mask']
        trg = batch['labels']

        trg_input = trg
        trg_output = trg

        output, hidden = self(src, attn,trg_input)


        output = output.reshape(-1, output.shape[-1])
        trg_output = trg_output.reshape(-1)

        train_loss = self.loss_fn(output, trg_output)

        self.log_dict({"train_loss" : train_loss,
                       },prog_bar=True,logger=True)

        return train_loss

    def validation_step(self, batch, batch_idx):
        src = batch['input_ids']
        attn = batch['attention_mask']
        trg = batch['labels']

        trg_input = trg
        trg_output = trg

        output, hidden = self(src, attn,trg_input)


        

        output = output.reshape(-1, output.shape[-1])
        trg_output = trg_output.reshape(-1)

        val_loss = self.loss_fn(output, trg_output)

        self.log_dict({"val_loss" : val_loss
                       },prog_bar=True,logger=True)


        return val_loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer


In [20]:
BATCH_SIZE = 64
N_EPOCHS = 50

data_module = QADataModule(train_df,val_df,test_df,tokenizer,batch_size = BATCH_SIZE)
data_module.setup()

In [21]:
checkpoint_callback = ModelCheckpoint(
    dirpath = 'checkpoints',
    filename = 'Seq2Seq',
    save_top_k = 1,
    verbose = True,
    monitor = 'val_loss',
    mode = 'min'
)

In [22]:
encoder = Encoder(input_size=tokenizer.vocab_size, hidden_size=512, num_layers=2, dropout=0.2)

In [23]:
decoder = Decoder(output_size=tokenizer.vocab_size, hidden_size=512, num_layers=2, dropout=0.2)

In [24]:
model = Seq2Seq(encoder,decoder,pad_idx=0)

In [25]:
trainer = pl.Trainer(devices=-1, accelerator="gpu",
    callbacks=[checkpoint_callback],
    max_epochs = N_EPOCHS
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [26]:
trainer.fit(model,data_module)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type             | Params
---------------------------------------------
0 | encoder | Encoder          | 19.6 M
1 | decoder | Decoder          | 37.4 M
2 | loss_fn | CrossEntropyLoss | 0     
---------------------------------------------
57.0 M    Trainable params
0         Non-trainable params
57.0 M    Total params
227.813   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 0, global step 63: 'val_loss' reached 6.48547 (best 6.48547), saving model to '/content/checkpoints/Seq2Seq.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 1, global step 126: 'val_loss' reached 6.22582 (best 6.22582), saving model to '/content/checkpoints/Seq2Seq.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 2, global step 189: 'val_loss' reached 5.89201 (best 5.89201), saving model to '/content/checkpoints/Seq2Seq.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 3, global step 252: 'val_loss' reached 5.54968 (best 5.54968), saving model to '/content/checkpoints/Seq2Seq.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 4, global step 315: 'val_loss' reached 5.13121 (best 5.13121), saving model to '/content/checkpoints/Seq2Seq.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 5, global step 378: 'val_loss' reached 4.70981 (best 4.70981), saving model to '/content/checkpoints/Seq2Seq.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 6, global step 441: 'val_loss' reached 4.35354 (best 4.35354), saving model to '/content/checkpoints/Seq2Seq.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 7, global step 504: 'val_loss' reached 4.14572 (best 4.14572), saving model to '/content/checkpoints/Seq2Seq.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 8, global step 567: 'val_loss' reached 3.92973 (best 3.92973), saving model to '/content/checkpoints/Seq2Seq.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 9, global step 630: 'val_loss' reached 3.87446 (best 3.87446), saving model to '/content/checkpoints/Seq2Seq.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 10, global step 693: 'val_loss' reached 3.86925 (best 3.86925), saving model to '/content/checkpoints/Seq2Seq.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 11, global step 756: 'val_loss' reached 3.80586 (best 3.80586), saving model to '/content/checkpoints/Seq2Seq.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 12, global step 819: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 13, global step 882: 'val_loss' reached 3.68031 (best 3.68031), saving model to '/content/checkpoints/Seq2Seq.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 14, global step 945: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 15, global step 1008: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 16, global step 1071: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 17, global step 1134: 'val_loss' reached 3.66159 (best 3.66159), saving model to '/content/checkpoints/Seq2Seq.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 18, global step 1197: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 19, global step 1260: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 20, global step 1323: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 21, global step 1386: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 22, global step 1449: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 23, global step 1512: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 24, global step 1575: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 25, global step 1638: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 26, global step 1701: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 27, global step 1764: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 28, global step 1827: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 29, global step 1890: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 30, global step 1953: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [27]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [28]:
cppath = '/content/checkpoints/Seq2Seq.ckpt'

In [29]:
trained_model = Seq2Seq.load_from_checkpoint(cppath,encoder=encoder,decoder=decoder,pad_idx=0)
trained_model.freeze()

In [30]:
def predict(question,model):
    
    model.eval()


    ques = question['question'],
    ans = question['answer']

    print("QUESTION : ",ques)
    print("Actual Ans : ",ans)



    # Tokenize the source text
    source_tokens = tokenizer(
        question['question'],
        question['paragraph'],
        max_length=200,
        padding="max_length",
        truncation="only_second",
        add_special_tokens=True,
        return_tensors="pt")['input_ids'].flatten().to(device)

    # Reshape the source tokens to match the expected input shape of the encoder
    source_tokens = source_tokens.unsqueeze(0).to(device)

    model.to(device)

    # Encode the source text
    with torch.no_grad():
        encoder_outputs, hidden = model.encoder(source_tokens)

    max_length = 20
    # Initialize the predicted sentence
    outputs = [0]

    # Generate the output sequence token by token
    for _ in range(max_length):
        previous_word = torch.LongTensor([outputs[-1]]).unsqueeze(0).to(device)

        # Decode the next token
        with torch.no_grad():
            output, hidden = model.decoder(previous_word, hidden, encoder_outputs)
            best_guess = output.argmax(2).item()

        # Add the predicted token to the predicted sentence
        outputs.append(best_guess)

        # If the predicted token is the end-of-sequence token, stop generating further tokens
        if best_guess == tokenizer.sep_token_id:
            break

    # Convert the predicted sentence back to text
    
    print(outputs)
    predicted_text = tokenizer.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)

    return predicted_text


In [41]:
sample_question = test_df.iloc[11]
predict(sample_question,trained_model)

QUESTION :  ('How much did Mudrick Capital Management buy new shares of AMC in June?',)
Actual Ans :  about $230 million
[0, 3771, 1458, 770, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


'$230 million'