<a href="https://colab.research.google.com/github/ambideXtrous9/Knowledge-Distillation-using-FlanT5-Teacher-Student-Method/blob/main/Knowledge_Distillation_using_T5_Online_Teacher.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --quiet transformers
!pip install --quiet pytorch-lightning
!pip install --quiet lightning

In [2]:
import math
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
from torchmetrics.functional import accuracy
from torchmetrics.classification import Accuracy
from torch.optim import AdamW
from transformers import (AutoTokenizer, AutoModelForSeq2SeqLM)
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pytorch_lightning.callbacks import RichProgressBar,ModelCheckpoint
import torch.nn.functional as F

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
path = '/content/drive/MyDrive/MTP CODE/NewsQA_SPAN.feather'


In [5]:
df = pd.read_feather(path)
df.head(3)

Unnamed: 0,question,answer,ans_pos,paragraph,answer_start,answer_end
0,Who is the managing director of Synergee Capital?,Vikram Dalal,"[133, 145]","""Investors can use a combination of governmen...",133,145
1,What is the yield of 30- and 40-year governmen...,7%,"[565, 567]","""Investors can use a combination of governmen...",565,567
2,What is the name of the ETF 2027 that a conser...,SDL,"[209, 212]","According to financial planners, an example o...",209,212


In [6]:
df = df.iloc[:5000]

In [7]:
MODEL_NAME = 'google/flan-t5-small'

In [8]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [9]:
start_token = tokenizer.cls_token
end_token = tokenizer.sep_token

print("Start Token:", start_token)
print("End Token:", end_token)

Start Token: None
End Token: None


In [10]:
vocab_size = tokenizer.vocab_size
print(vocab_size)

32100


In [11]:
class NQADataset(Dataset):
    def __init__(self,data ,tokenizer ,source_max_token_len : int = 180,target_max_token_len : int = 5):

        self.tokenizer = tokenizer
        self.data = data
        self.source_max_token_len = source_max_token_len
        self.target_max_token_len = target_max_token_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self,index : int):
        data_row = self.data.iloc[index]

        source_encoding = self.tokenizer(
            data_row['question'],
            data_row['paragraph'],
            max_length = self.source_max_token_len,
            padding = "max_length",
            truncation = "only_second",
            return_attention_mask = True,
            add_special_tokens = True,
            return_tensors = "pt")

        target_encoding = self.tokenizer(
            data_row['answer'],
            max_length = self.target_max_token_len,
            padding = "max_length",
            truncation = True,
            return_attention_mask = True,
            add_special_tokens = True,
            return_tensors = "pt")

        labels = target_encoding["input_ids"]
        # labels[labels == 0] = -100

        return dict(
            input_ids = source_encoding['input_ids'].flatten(),
            attention_mask = source_encoding['attention_mask'].flatten(),
            labels = labels.flatten())


In [12]:

class NQADataModule(pl.LightningDataModule):
  def __init__(self,train_df,val_df,tokenizer,batch_size : int = 8,source_max_token_len : int = 180,target_max_token_len : int = 5):
    super().__init__()
    self.batch_size = batch_size
    self.train_df = train_df
    self.val_df = val_df
    self.MODEL_NAME = MODEL_NAME
    self.tokenizer = tokenizer
    self.source_max_token_len = source_max_token_len
    self.target_max_token_len = target_max_token_len

  def setup(self,stage=None):
    self.train_dataset = NQADataset(self.train_df,self.tokenizer,self.source_max_token_len,self.target_max_token_len)
    self.val_dataset = NQADataset(self.val_df,self.tokenizer,self.source_max_token_len,self.target_max_token_len)


  def train_dataloader(self):
    return DataLoader(self.train_dataset,batch_size = self.batch_size,shuffle=True,num_workers=2)

  def val_dataloader(self):
    return DataLoader(self.val_dataset,batch_size = self.batch_size,num_workers=2)


In [13]:
train_df, val_df = train_test_split(df,test_size=0.1)

In [14]:
print(train_df.shape)
print(val_df.shape)

(4500, 6)
(500, 6)


In [15]:
data_module = NQADataModule(train_df,val_df,tokenizer,batch_size = 40)
data_module.setup()

In [16]:
sample_batch = next(iter(data_module.val_dataloader()))
for key, value in sample_batch.items():
    print(f"{key}: {value.shape}")

input_ids: torch.Size([40, 180])
attention_mask: torch.Size([40, 180])
labels: torch.Size([40, 5])


## Model

![](https://media.arxiv-vanity.com/render-output/3715543/Figures/ModalNet-21.png)

In [17]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [18]:
class Transformer(nn.Module):
    """
    Classic Transformer that both encodes and decodes.
    Prediction-time inference is done greedily.
    NOTE: start token is hard-coded to be 0, end token to be 1. If changing, update predict() accordingly.
    """

    def __init__(self, num_classes: int, max_output_length: int, dim: int = 512):
        super().__init__()

        # Parameters
        self.dim = dim
        self.start_token = '[CLS]'
        self.max_output_length = max_output_length
        self.log_softmax = nn.LogSoftmax()
        nhead = 8
        num_layers = 8
        dim_feedforward = dim

        # Encoder part
        self.embedding = nn.Embedding(num_classes, dim)
        self.pos_encoder = PositionalEncoding(d_model=self.dim)
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(d_model=self.dim,
                                                     nhead=nhead,
                                                     dim_feedforward=dim_feedforward),
            num_layers=num_layers
        )

        # Decoder part
        self.y_mask = self.generate_square_subsequent_mask(self.max_output_length)
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer=nn.TransformerDecoderLayer(d_model=self.dim,
                                                     nhead=nhead,
                                                     dim_feedforward=dim_feedforward),
            num_layers=num_layers
        )
        self.fc = nn.Linear(self.dim, num_classes)

        # It is empirically important to initialize weights properly
        self.init_weights()

    def generate_square_subsequent_mask(self,size: int):
        """Generate a triangular (size, size) mask. From PyTorch docs."""
        mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        self.fc.weight.data.uniform_(-initrange, initrange)

    def forward(self, x: torch.Tensor, y: torch.Tensor) :
        """
        Input
            x: (B, Sx) with elements in (0, C) where C is num_classes
            y: (B, Sy) with elements in (0, C) where C is num_classes
        Output
            (B, C, Sy) logits
        """
        encoded_x = self.encode(x)  # (Sx, B, E)
        output = self.decode(y, encoded_x)  # (Sy, B, C)
        return output.permute(1, 2, 0)  # (B, C, Sy)

    def encode(self, x: torch.Tensor) :
        """
        Input => x: (B, Sx) with elements in (0, C) where C is num_classes
        Output => (Sx, B, E) embedding
        """
        x = x.permute(1, 0)  # (Sx, B, E)
        x = self.embedding(x) * math.sqrt(self.dim)  # (Sx, B, E)
        x = self.pos_encoder(x)  # (Sx, B, E)
        x = self.transformer_encoder(x)  # (Sx, B, E)
        return x

    def decode(self, y: torch.Tensor, encoded_x: torch.Tensor):
        """
        Input
            encoded_x: (Sx, B, E)
            y: (B, Sy) with elements in (0, C) where C is num_classes
        Output
            (Sy, B, C) logits
        """
        y = y.permute(1, 0)  # (Sy, B)
        y = self.embedding(y) * math.sqrt(self.dim)  # (Sy, B, E)
        y = self.pos_encoder(y)  # (Sy, B, E)
        Sy = y.shape[0]
        y_mask = self.y_mask[:Sy, :Sy].type_as(encoded_x)  # (Sy, Sy)
        output = self.transformer_decoder(y, encoded_x, y_mask)  # (Sy, B, E)
        output = self.fc(output)  # (Sy, B, C)
        output = self.log_softmax(output)
        return output

    def predict(self, x: torch.Tensor) :
        """
        Method to use at inference time. Predict y from x one token at a time. This method is greedy
        decoding. Beam search can be used instead for a potential accuracy boost.

        Input
            x: (B, Sx) with elements in (0, C) where C is num_classes
        Output
            (B, C, Sy) logits
        """
        encoded_x = self.encode(x)

        output_tokens = (torch.ones((x.shape[0], self.max_output_length))).type_as(x).long() # (B, max_length)
        output_tokens[:, 0] = 0  # Set start token
        for Sy in range(1, self.max_output_length):
            y = output_tokens[:, :Sy]  # (B, Sy)
            output = self.decode(y, encoded_x)  # (Sy, B, C)
            output = torch.argmax(output, dim=-1)  # (Sy, B)
            output_tokens[:, Sy] = output[-1:]  # Set the last output token
        return output_tokens


In [19]:
class LitModel(pl.LightningModule):
    """Simple PyTorch-Lightning model to train our Transformer."""

    def __init__(self, model):
        super().__init__()
        self.model = model
        self.t5model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME,return_dict=True)
        self.loss = nn.CrossEntropyLoss(ignore_index=0)
        self.criterion = nn.NLLLoss(ignore_index=0)
        self.KLD = nn.KLDivLoss(reduction='batchmean')
        self.softmax = nn.Softmax(dim=-1)

        self.valloss = []
        self.valacc = []

        # for param in self.t5model.parameters():
        #     param.requires_grad = False


    def training_step(self, batch, batch_ind):
        x = batch['input_ids']
        y = batch['labels']
        attention_mask = batch['attention_mask']


        t5encoder = self.t5model(x,attention_mask,decoder_input_ids=y).encoder_last_hidden_state

        encod = self.model.encode(x).permute(1,0,2)

        encod_probs = self.softmax(encod).to("cuda")
        logits_prob = self.softmax(t5encoder).to("cuda")

        kld_loss = self.KLD(encod_probs.log(), logits_prob)

        # Teacher forcing: model gets input up to the last character,
        # while ground truth (loss is calculated) is from the second character onward.
        logits = self.model(x, y[:, :-1])
        ce_loss = self.criterion(logits, y[:, 1:])

        loss = kld_loss + ce_loss

        self.log_dict({"KLD_loss" : kld_loss,
                      "DEC_loss" : ce_loss,
                      "total_loss" : loss},prog_bar=True,logger=True)
        return loss

    def validation_step(self, batch, batch_ind):
        x = batch['input_ids']
        y = batch['labels']
        # Teacher forcing: model gets input up to the last character,
        # while ground truth (loss is calculated) is from the second character onward.
        logits = self.model(x, y[:, :-1])
        loss = self.criterion(logits, y[:, 1:])
        pred = self.model.predict(x)

        correct_predictions = (y == pred).float()
        row_accuracies = torch.mean(correct_predictions, dim=1)
        accuracy = torch.mean(row_accuracies).item()

        self.log_dict({"val_acc" : accuracy,
                       "val_loss" : loss,
                       },on_step=False, on_epoch=True, prog_bar=True)

        self.valloss.append(loss.item())
        self.valacc.append(accuracy)

    def on_validation_epoch_end(self):

        avg_val_loss = torch.mean(torch.tensor(self.valloss))
        avg_val_acc = torch.mean(torch.tensor(self.valacc))

        self.valacc.clear()
        self.valloss.clear()

        self.log_dict({"avg_val_acc" : avg_val_acc,
                       "avg_val_loss" : avg_val_loss,
                       },on_step=False, on_epoch=True, prog_bar=True)

        print('-' * 90)
        print(f'|  Epoch = {self.current_epoch + 1:3d} | '
            f'Average Validation Loss: {avg_val_loss.item():5.2f} | '
            f'Average Validation Accuracy: {avg_val_acc.item():5.2f} |')
        print('-' * 90)



    def configure_optimizers(self):
        return AdamW(self.parameters(),lr = 0.0001)

In [20]:
from lightning.pytorch.loggers import TensorBoardLogger

logger = TensorBoardLogger("tb_logs", name="my_model")

In [21]:
checkpoint_callback = ModelCheckpoint(
    dirpath = 'checkpoints',
    filename = 'Transformer_best_cp',
    save_top_k = 1,
    verbose = True,
    monitor = 'avg_val_loss',
    mode = 'min')

In [22]:
rpb = RichProgressBar()

In [23]:
trainer = pl.Trainer(accelerator="gpu",
                     devices=-1,
                     logger=logger,
                     callbacks=[checkpoint_callback,rpb],
                     max_epochs = 20)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [24]:
model = Transformer(num_classes=vocab_size, max_output_length=5)
lit_model = LitModel(model)



In [25]:
trainer.fit(lit_model, data_module)

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory /content/checkpoints exists and is not empty.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

INFO: Epoch 0, global step 113: 'avg_val_loss' reached 0.39626 (best 0.39626), saving model to '/content/checkpoints/Transformer_best_cp-v1.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 0, global step 113: 'avg_val_loss' reached 0.39626 (best 0.39626), saving model to '/content/checkpoints/Transformer_best_cp-v1.ckpt' as top 1


INFO: Epoch 1, global step 226: 'avg_val_loss' reached 0.17291 (best 0.17291), saving model to '/content/checkpoints/Transformer_best_cp-v1.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 1, global step 226: 'avg_val_loss' reached 0.17291 (best 0.17291), saving model to '/content/checkpoints/Transformer_best_cp-v1.ckpt' as top 1


INFO: Epoch 2, global step 339: 'avg_val_loss' reached 0.09859 (best 0.09859), saving model to '/content/checkpoints/Transformer_best_cp-v1.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 2, global step 339: 'avg_val_loss' reached 0.09859 (best 0.09859), saving model to '/content/checkpoints/Transformer_best_cp-v1.ckpt' as top 1


INFO: Epoch 3, global step 452: 'avg_val_loss' reached 0.07047 (best 0.07047), saving model to '/content/checkpoints/Transformer_best_cp-v1.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 3, global step 452: 'avg_val_loss' reached 0.07047 (best 0.07047), saving model to '/content/checkpoints/Transformer_best_cp-v1.ckpt' as top 1


INFO: Epoch 4, global step 565: 'avg_val_loss' reached 0.06312 (best 0.06312), saving model to '/content/checkpoints/Transformer_best_cp-v1.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 4, global step 565: 'avg_val_loss' reached 0.06312 (best 0.06312), saving model to '/content/checkpoints/Transformer_best_cp-v1.ckpt' as top 1


INFO: Epoch 5, global step 678: 'avg_val_loss' reached 0.06180 (best 0.06180), saving model to '/content/checkpoints/Transformer_best_cp-v1.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 5, global step 678: 'avg_val_loss' reached 0.06180 (best 0.06180), saving model to '/content/checkpoints/Transformer_best_cp-v1.ckpt' as top 1


INFO: Epoch 6, global step 791: 'avg_val_loss' reached 0.06174 (best 0.06174), saving model to '/content/checkpoints/Transformer_best_cp-v1.ckpt' as top 1
INFO:lightning.pytorch.utilities.rank_zero:Epoch 6, global step 791: 'avg_val_loss' reached 0.06174 (best 0.06174), saving model to '/content/checkpoints/Transformer_best_cp-v1.ckpt' as top 1


In [30]:
cppath = 'checkpoints/Transformer_best_cp-v1.ckpt'
trained_model = LitModel.load_from_checkpoint(cppath,
                                              model=model)
trained_model.freeze()

In [None]:
val_df.iloc[12]

In [40]:
source_encoding = tokenizer(
            val_df.iloc[12]['question'],
            val_df.iloc[12]['paragraph'],
            max_length = 180,
            padding = "max_length",
            truncation = "only_second",
            return_attention_mask = True,
            add_special_tokens = True,
            return_tensors = "pt")

input_ids = source_encoding['input_ids'].flatten().to("cuda")

In [None]:
# We can see that the decoding works correctly
pred = trained_model.model.predict(input_ids.unsqueeze(0))
print('Pred:')
tk = tokenizer.decode(pred[0],skip_special_tokens=True)
print(tk)