## Text-To-Text Transfer Transformer (T5)

In [None]:
import json
import pandas as pd
import numpy as np
import torch
from pathlib import Path
from torch. utils.data import Dataset, DataLoader
# PyTorch Lightning is built on top of ordinary (vanilla) PyTorch. The purpose of Lightning is 
#to provide a research framework that allows for fast experimentation and scalability, which it
#achieves via an OOP approach that removes boilerplate and hardware-reference code.
from pytorch_lightning import Trainer
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.model_selection import train_test_split
from termcolor import colored
import textwrap

from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5TokenizerFast as T5Tokenizer
)
from tqdm.auto import tqdm

In [None]:
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc

%matplotlib inline
%config InlineBackend.figure_format='retina'
sns.set(style='whitegrid', palette='muted', font_scale=1.2)
rcParams['figure.figsize'] = 16, 10

In [None]:
pl.seed_everything(42)

In [None]:
# Data from Kaggle news_summary.csv
df = pd.read_csv("data/news_summary.csv", encoding="latin-1")
df.head()

In [None]:
df = df[['text', 'ctext']]
df.head()

In [None]:
df.columns = ['summary', 'text']
df = df.dropna()
df.head()

In [None]:
df.shape

In [None]:
train_df, test_df = train_test_split(df, test_size=0.1)
print(f"Shape of the Train Set: {train_df.shape}\nShape of the Test Set: {test_df.shape}")

In [None]:
class NewsDataset(Dataset):
    def __init__(self, data, tokenizer, text_max_token_len=512, summary_max_token_len=128):
        """
        A dataset that represents news articles and their respective summaries.

        Args:
        - data (pd.DataFrame): The data that contains the news articles and their summaries.
        - tokenizer (transformers.tokenization_*) : The tokenizer used to tokenize the text and summary.
        - text_max_token_len (int, optional): The maximum length of the text in terms of tokens. Defaults to 512.
        - summary_max_token_len (int, optional): The maximum length of the summary in terms of tokens. Defaults to 128.
        """
        self.tokenizer = tokenizer
        self.data = data
        self.text_max_token_len = text_max_token_len
        self.summary_max_token_len = summary_max_token_len
        
    def __len__(self):
        """
        Returns:
        - The number of samples in the dataset.
        """
        return len(self.data)
    
    def __getitem__(self, index):
        """
        Get a sample from the dataset.

        Args:
        - index (int): The index of the sample to get.

        Returns:
        - A dictionary that contains the following:
            - text (str): The original text of the news article.
            - summary (str): The summary of the news article.
            - text_input_ids (torch.Tensor): The input IDs of the text after tokenization.
            - text_attention_mask (torch.Tensor): The attention mask of the text after tokenization.
            - labels (torch.Tensor): The input IDs of the summary after tokenization.
            - labels_attention_mask (torch.Tensor): The attention mask of the summary after tokenization.
        """
        data_row = self.data.iloc[index]
        text = data_row["text"]

        # Encode the text
        text_encoding = self.tokenizer(
            text, 
            max_length=self.text_max_token_len,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors="pt"
        )

        # Encode the summary
        summary_encoding = self.tokenizer(
            data_row["summary"], 
            max_length=self.summary_max_token_len,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors="pt"
        )

        # Modify the labels so that the model knows which tokens to predict
        labels = summary_encoding['input_ids']
        labels[labels == 0] = -100
        
        return {
            'text': text,
            'summary': data_row['summary'],
            'text_input_ids': text_encoding['input_ids'].flatten(),
            'text_attention_mask': text_encoding['attention_mask'].flatten(),
            'labels': labels.flatten(),
            'labels_attention_mask': summary_encoding["attention_mask"].flatten()
        }

In [None]:
class NewsDataModule(pl.LightningDataModule):
    def __init__(self,
                 train_df,
                 test_df,
                 tokenizer,
                 batch_size=16,
                 text_max_token_len=152,
                 summary_max_token_len=128):
        """
        Initializes the NewsDataModule.
        
        Args:
        - train_df (pandas.DataFrame): the training dataset
        - test_df (pandas.DataFrame): the testing dataset
        - tokenizer (transformers.PreTrainedTokenizer): the tokenizer to be used
        - batch_size (int): the batch size
        - text_max_token_len (int): the maximum number of tokens for the text
        - summary_max_token_len (int): the maximum number of tokens for the summary
        """
        super().__init__()
        
        self.train_df = train_df
        self.test_df = test_df
        
        self.batch_size = batch_size
        self.tokenizer = tokenizer
        self.text_max_token_len = text_max_token_len
        self.summary_max_token_len = summary_max_token_len
    
    def setup(self, stage=None):
        """
        Sets up the dataset.
        """
        self.train_dataset = NewsDataset(
            self.train_df,
            self.tokenizer,
            self.text_max_token_len,
            self.summary_max_token_len)
        
        self.test_dataset = NewsDataset(
            self.test_df,
            self.tokenizer,
            self.text_max_token_len,
            self.summary_max_token_len)
    
    def train_dataloader(self):
        """
        Returns the DataLoader for the training set.
        """
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True
        )
    
    def test_dataloader(self):
        """
        Returns the DataLoader for the testing set.
        """
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False
        )
    
    def val_dataloader(self):
        """
        Returns the DataLoader for the validation set, which is the same as the testing set.
        """
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False
        )

In [None]:
MODEL_NAME = "t5-base"
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)

In [None]:
text_token_counts = [len(tokenizer.encode(row["text"])) for _, row in train_df.iterrows()]
summary_token_counts = [len(tokenizer.encode(row["summary"])) for _, row in train_df.iterrows()]

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
sns.histplot(text_token_counts, ax=ax1, color='blue', alpha=0.7)
ax1.set_title("Distribution of Text Token Counts", fontsize=14, fontweight='bold')
ax1.set_xlabel("Number of Tokens", fontsize=12)
ax1.set_ylabel("Frequency", fontsize=12)
ax1.grid(axis='y', alpha=0.5)

sns.histplot(summary_token_counts, ax=ax2, color='green', alpha=0.7)
ax2.set_title("Distribution of Summary Token Counts", fontsize=14, fontweight='bold')
ax2.set_xlabel("Number of Tokens", fontsize=12)
ax2.set_ylabel("Frequency", fontsize=12)
ax2.grid(axis='y', alpha=0.5)

plt.suptitle("Token Count Distributions", fontsize=16, fontweight='bold')
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

In [None]:
N_EPOCHS = 3
BATCH_SIZE= 16

data_module = NewsDataModule(
    train_df, 
    test_df,
    tokenizer,
    batch_size=BATCH_SIZE
    
)

In [None]:
class SummaryModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)

    def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
        output = self.model(
            input_ids,
            attention_mask=attention_mask,
            labels=labels,
            decoder_attention_mask=decoder_attention_mask
        )
        return output.loss, output.logits

    def shared_step(self, batch, batch_idx, stage):
        input_ids = batch['text_input_ids']
        attention_mask = batch["text_attention_mask"]
        labels = batch["labels"]
        labels_attention_mask = batch["labels_attention_mask"]

        loss, _ = self(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=labels_attention_mask,
            labels=labels
        )

        self.log(f"{stage}_loss", loss, prog_bar=True, logger=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, batch_idx, 'train')

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, batch_idx, 'val')

    def test_step(self, batch, batch_idx):
        return self.shared_step(batch, batch_idx, 'test')

    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=0.0001)

In [None]:
model_1 = SummaryModel()

In [None]:
%load_ext tensorboard
%tensorboard --logdir lighting_logs/

In [None]:
callbacks = ModelCheckpoint(
    dirpath="checkpoints",
    filename="base-checkpoint",
    save_top_k=1,
    verbose=True,
    monitor="val_loss",
    mode='min'
)

logger = TensorBoardLogger("lightning_logs", name="news_summary")

trainer= Trainer(
    logger=logger,
    callbacks=callbacks,
    max_epochs=N_EPOCHS,
    accelerator = 'cpu'
    #gpus=1
)

In [None]:
trainer.fit(model_1, data_module)

In [None]:
best_model = SummaryModel.load_from_checkpoint(
    trainer.checkpoint_callback.best_model_path
)
best_model.freeze()

In [None]:
import pickle
filename = open('text_summarization_model.pkl', 'wb')
pickle.dump(best_model.model, filename)
model = pickle.load(open('text_summarization_model.pkl', 'rb'))

In [None]:
def encode_text(text):
    # Encode the text using the tokenizer
    encoding = tokenizer.encode_plus(
        text,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )
    return encoding["input_ids"], encoding["attention_mask"]

def generate_summary(input_ids, attention_mask):
    # Generate a summary using the best model
    generated_ids = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=150,
        num_beams=2,
        repetition_penalty=2.5,
        length_penalty=1.0,
        early_stopping=True
    )
    return generated_ids

def decode_summary(generated_ids):
    # Decode the generated summary
    summary = [tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
               for gen_id in generated_ids]
    return "".join(summary)

def summarize(text):
    input_ids, attention_mask = encode_text(text)
    generated_ids = generate_summary(input_ids, attention_mask)
    summary = decode_summary(generated_ids)
    return summary

In [None]:
model_summary = summarize(text)

In [None]:
text = """Delhi Capitals’ head coach Ricky Ponting during a press conference in Delhi on Friday. | Photo Credit: PTI

Ricky Ponting knows a thing or two about cricket and spotlight and how together, the two can either be a recipe for unprecedented success or unmitigated disaster, depending on how one handles them.

In India, in particular, the pressure to manage both is a lot more than anywhere else and the IPL is at the pinnacle of fan attention. “Well it is a lot different in our country than it is here. The big thing about the IPL is seeing so many younger players getting an opportunity that they are not ready for. And I don’t mean the sport per se. They are ready for the cricket side of it but there are a lot of guys not ready, yet, for the many other things that come with cricket. There wasn’t as much spotlight on me back as a young player as on some of the young Indian players today,” Ponting admitted.

And, being the coach and legend that he is, he accepts the responsibility to try and guide them. “For me, it’s letting players understand how big what they are doing actually is, in the public’s eyes. As a player you want to play cricket, you want to represent your team and franchise and country, but sometimes you can’t see the bigger picture behind it than just you playing cricket. It’s also about how everyone sees you in the real world and the IPL, for a lot of these youngsters, is not the real world. There’s a lot of other stuff happening out there,” he cautioned.


Ponting predicts this year’s IPL will see the real Prithvi Shaw
His advice? Get your act together outside the field so you can perform inside. “My job is to make them better players but, at the end of the day, I want them to be better people. I think the better you are as a person, the easier it is to be a better player and if you haven’t got your life in order off the field, it’s really difficult to be a disciplined performer on it. That’s one of the things I try to teach because I have been there, done that,” the 48-year-old World Cup-winning captain explained.

And in a World Cup year, who better than a two-time winning captain to ask about the constant hype around Indian performers in IPL? “Ideally we would want them all to have that drive and passion to be the best they can be but the one thing I always stress with young guys is, not to start looking too far ahead and thinking about the World Cup. They need to stay in the present and think about the here and now and play their role in the team. My job is to train and get these guys ready to win games for us but after that, their selections for any other event or format are not in my hands,” he shrugged."""

In [None]:
sample_row["summary"]

In [None]:
model_summary