In [None]:
!pip install --quiet transformers
!pip install --quiet pytorch-lightning
!pip install rouge-score

In [None]:
import json
import pandas as pd
import numpy as np 
import torch 
from pathlib import Path
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.model_selection import train_test_split
from termcolor import colored
import textwrap

from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5TokenizerFast as T5Tokenizer
)
from tqdm.auto import tqdm
from rouge_score import rouge_scorer

In [None]:
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
sns.set(style='whitegrid', palette='muted', font_scale=1.2)
rcParams['figure.figsize'] = 16, 10

In [None]:
pl.seed_everything(42)

In [None]:
def clean_data(df):
    df['text'] = df['text'].str.replace(r'http\S+', '', regex=True).replace(r'www\S+', '', regex=True) # Remove urls
    df = df.astype(str).apply(lambda x: x.str.encode('ascii', 'ignore').str.decode('ascii')) # Remove emojis and smileys.
    df['text'] = df['text'].str.replace('#', '').replace('@', '') # Remove hashtag and mention symbols only.
    return df

In [None]:
%cd drive/My Drive/PROJECT/T5

In [None]:
df = pd.read_csv('dataset.csv')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
train_df, test_df = train_test_split(df, test_size=0.1)
train_df = clean_data(train_df) # cleaning the training dataset
(train_df.shape, test_df.shape)

In [None]:
class TweetSummaryDataset(Dataset):

    def __init__(
        self,
        data: pd.DataFrame,
        tokenizer: T5Tokenizer,
        text_max_token_len: int = 512,
        summary_max_token_len: int = 128):

        self.tokenizer = tokenizer
        self.data = data
        self.text_max_token_len = text_max_token_len
        self.summary_max_token_len = summary_max_token_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index: int):
        data_row = self.data.iloc[index]

        text = data_row["text"]

        text_encoding = tokenizer(
            text,
            max_length = self.text_max_token_len,
            padding = "max_length",
            truncation = True,
            return_attention_mask = True,
            add_special_tokens = True,
            return_tensors = "pt"
        )
        
        summary_encoding = tokenizer(
            data_row["summary"],
            max_length = self.summary_max_token_len,
            padding = "max_length",
            truncation = True,
            return_attention_mask = True,
            add_special_tokens = True,
            return_tensors = "pt"
        )

        labels = summary_encoding["input_ids"]
        labels[labels == 0] = -100

        return dict(
            text = text,
            summary = data_row["summary"],
            text_input_ids = text_encoding["input_ids"].flatten(),
            text_attention_mask = text_encoding["attention_mask"].flatten(),
            labels = labels.flatten(),
            labels_attention_mask = summary_encoding["attention_mask"].flatten()
        )
    

In [None]:
class TweetSummaryDataModule(pl.LightningDataModule):

    def __init__(self,
                 train_df: pd.DataFrame,
                 test_df: pd.DataFrame,
                 tokenizer: T5Tokenizer,
                 batch_size: int = 8,
                 text_max_token_len: int = 512,
                 summary_max_token_len: int = 128):
        
        super().__init__()

        self.train_df = train_df
        self.test_df = test_df
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.text_max_token_len = text_max_token_len
        self.summary_max_token_len = summary_max_token_len

    def setup(self, stage= None):

        self.train_dataset = TweetSummaryDataset(
            self.train_df,
            self.tokenizer,
            self.text_max_token_len,
            self.summary_max_token_len
        )

        self.test_dataset = TweetSummaryDataset(
            self.test_df,
            self.tokenizer,
            self.text_max_token_len,
            self.summary_max_token_len
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size= self.batch_size,
            shuffle= True,
            num_workers= 2
        )

    def val_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size= self.batch_size,
            num_workers= 2
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size= self.batch_size,
            num_workers= 2
        )

In [None]:
MODEL_NAME = "t5-base"

tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)

In [None]:
N_EPOCHS = 4
BATCH_SIZE = 4

data_module = TweetSummaryDataModule(train_df, test_df, tokenizer, batch_size= BATCH_SIZE)

In [None]:
class TweetSummaryModel(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict= True)

    def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):

        output = self.model(
            input_ids,
            attention_mask= attention_mask,
            labels= labels,
            decoder_attention_mask= decoder_attention_mask
        )

        return output.loss, output.logits

    def training_step(self, batch, batch_idx):

        input_ids = batch["text_input_ids"]
        attention_mask = batch["text_attention_mask"]
        labels = batch["labels"]
        labels_attention_mask = batch["labels_attention_mask"]

        loss, outputs = self(
            input_ids= input_ids,
            attention_mask = attention_mask,
            decoder_attention_mask = labels_attention_mask,
            labels=labels
        )

        self.log("train_loss", loss, prog_bar= True, logger= True)
        return loss

    def validation_step(self, batch, batch_idx):

        input_ids = batch["text_input_ids"]
        attention_mask = batch["text_attention_mask"]
        labels = batch["labels"]
        labels_attention_mask = batch["labels_attention_mask"]

        loss, outputs = self(
            input_ids= input_ids,
            attention_mask = attention_mask,
            decoder_attention_mask = labels_attention_mask,
            labels=labels
        )

        self.log("val_loss", loss, prog_bar= True, logger= True)
        return loss

    def testing_step(self, batch, batch_idx):

        input_ids = batch["text_input_ids"]
        attention_mask = batch["text_attention_mask"]
        labels = batch["labels"]
        labels_attention_mask = batch["labels_attention_mask"]

        loss, outputs = self(
            input_ids= input_ids,
            attention_mask = attention_mask,
            decoder_attention_mask = labels_attention_mask,
            labels=labels
        )

        self.log("test_loss", loss, prog_bar= True, logger= True)
        return loss

    def configure_optimizers(self):
        return AdamW(self.parameters(), lr= 0.0001)

In [None]:
load_model = TweetSummaryModel.load_from_checkpoint(checkpoint_path="checkpoint/epoch=40.ckpt")
#load_model.train()

In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath = "checkpoints",
    filename = "best-checkpoint",
    save_top_k = 1,
    verbose = True,
    monitor = "val_loss",
    mode = "min"
)

logger = TensorBoardLogger("lightning_logs", name= "tweet-summary")

trainer = pl.Trainer(
    logger = logger,
    checkpoint_callback = checkpoint_callback,
    max_epochs = N_EPOCHS,
    gpus = 1
) 

In [None]:
%%timeit -r 1 -n 1
trainer.fit(load_model, data_module)

In [None]:
trained_model = TweetSummaryModel.load_from_checkpoint(
    trainer.checkpoint_callback.best_model_path
)

trained_model.freeze()

In [None]:
def summarize(text):
    
    text_encoding = tokenizer(
        text,
        max_length= 1024,
        padding= "max_length",
        truncation= True,
        return_attention_mask= True,
        add_special_tokens = True,
        return_tensors = "pt"
    )

    load_model.eval()

    generated_ids = load_model.model.generate(
        input_ids = text_encoding["input_ids"],
        attention_mask= text_encoding["attention_mask"],
        max_length = 200,
        num_beams = 2,
        no_repeat_ngram_size = 3,
        repetition_penalty = 2.5,
        length_penalty = 1.0
    )

    preds = [
     tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
     for gen_id in generated_ids
    ]

    return "".join(preds)

In [None]:
sample_row = test_df.iloc[1]
text = sample_row["text"]
model_summary = summarize(text)

In [None]:
sample_row['text']

In [None]:
model_summary

In [None]:
sample_row["summary"]

In [None]:
%%timeit -r 1 -n 1
### Rouge-Metrics

scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

precision_1 = 0
precision_2 = 0
precision_L = 0
recall_1 = 0
recall_2 = 0
recall_L = 0
fmeasure_1 = 0
fmeasure_2 = 0
fmeasure_L = 0

test_len = len(test_df['text'])

for i in range(test_len):

    sample_row = test_df.iloc[i]
    text = sample_row["text"]
    model_summary = summarize(text)

    scores = scorer.score(model_summary,
                      sample_row["summary"])
    
    precision_1 += scores['rouge1'][0]
    precision_2 += scores['rouge2'][0]
    precision_L += scores['rougeL'][0]
    recall_1 += scores['rouge1'][1]
    recall_2 += scores['rouge2'][1]
    recall_L += scores['rougeL'][1]
    fmeasure_1 += scores['rouge1'][2]
    fmeasure_2 += scores['rouge2'][2]
    fmeasure_L += scores['rougeL'][2]

    if i%15 == 0:
        print(i//15, end=" ")

print("\n\t\tPrecision\tRecall\t\tF-Measure")
print(f"Rouge-1: \t {precision_1/test_len:.3f}\t\t {recall_1/test_len:.3f}\t\t {fmeasure_1/test_len:.3f}")
print(f"Rouge-2: \t {precision_2/test_len:.3f}\t\t {recall_2/test_len:.3f}\t\t {fmeasure_2/test_len:.3f}")
print(f"Rouge-L: \t {precision_L/test_len:.3f}\t\t {recall_L/test_len:.3f}\t\t {fmeasure_L/test_len:.3f}")