<a href="https://colab.research.google.com/github/RoyElkabetz/Text-Summarization-with-Deep-Learning/blob/main/T5_Summarizer_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
## uncomment to mount google drive
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
!nvidia-smi

In [None]:
!pip install --quiet transformers==4.5.0
!pip install --quiet pytorch-lightning==1.2.7

In [None]:
import json
import time
import pandas as pd
import numpy as np
import torch
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.model_selection import train_test_split
from termcolor import colored



# import torch
# from torchtext.datasets import IMDB as the_dataset
from torchtext.datasets import AG_NEWS 
# import torchtext.data as data
# from torchtext.data.utils import get_tokenizer
# from torchtext.vocab import build_vocab_from_iterator
# from torchtext.data.functional import to_map_style_dataset
# from torch.utils.data import DataLoader
# from torch.utils.data.dataset import random_split
# from torch import nn

from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5TokenizerFast as T5Tokenizer
)

from tqdm.auto import tqdm

In [None]:
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc

%matplotlib inline
%config InlineBackend.figure_format='retina'
sns.set(style='whitegrid', palette='muted', font_scale=1.2)
rcParams['figure.figsize'] = 16, 10

pl.seed_everything(216)

In [None]:
DATASET_PATH = '/content/gdrive/MyDrive/Datasets/Text/news_summary.csv'
CHECKPOINTS_PATH = '/content/gdrive/MyDrive/Checkpoints'
MY_MODEL_NAME = 'Text_Summarizer_T5'

In [None]:
MODEL_NAME = 't5-base'

tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)

In [None]:
train_iter = AG_NEWS(split='train')
n_samples = len(train_iter)
random_list = torch.randint(0, n_samples - 1, (4, ))
labels = []
for i, (label, text) in enumerate(train_iter):
    labels.append(label)
    if i in random_list:
        print(f'Label: {label_pipeline(label)}')
        print(f'Text: {text}')
        print(f'Split: {tokenizer(text)}')
        print(f'Tokens: {text_pipeline(text)}\n')
print('Number of classes: {}'.format(len(set(labels))))
print('Number of samples: {}'.format(n_samples))

In [None]:
text_token_counts, summary_token_counts = [], []

for _, row in train_df.iterrows():
  text_token_count = len(tokenizer.encode(row['text']))
  text_token_counts.append(text_token_count)

  summary_token_count = len(tokenizer.encode(row['summary']))
  summary_token_counts.append(summary_token_count)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2)

sns.histplot(text_token_counts, ax=ax1)
ax1.set_title('full text token counts')

sns.histplot(summary_token_counts, ax=ax2)
ax2.set_title('summary text token counts')

In [None]:
class NewsSummaryModel(pl.LightningModule):

  def __init__(self):
    super().__init__()
    self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)

  def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):

    output = self.model(
        input_ids,
        attention_mask=attention_mask,
        labels=labels,
        decoder_attention_mask=decoder_attention_mask
    )

    return output.loss, output.logits

  def training_step(self, batch, batch_idx):
    input_ids = batch['text_input_ids']
    attention_mask = batch['text_attention_mask']
    labels = batch['labels']
    labels_attention_mask = batch['labels_attention_mask']

    loss, outputs = self(
        input_ids=input_ids,
        attention_mask=attention_mask,
        decoder_attention_mask=labels_attention_mask,
        labels=labels
    )

    self.log('train_loss', loss, prog_bar=True, logger=True)
    return loss
  
  def validation_step(self, batch, batch_idx):
    input_ids = batch['text_input_ids']
    attention_mask = batch['text_attention_mask']
    labels = batch['labels']
    labels_attention_mask = batch['labels_attention_mask']

    loss, outputs = self(
        input_ids=input_ids,
        attention_mask=attention_mask,
        decoder_attention_mask=labels_attention_mask,
        labels=labels
    )

    self.log('valid_loss', loss, prog_bar=True, logger=True)
    return loss

  def test_step(self, batch, batch_idx):
    input_ids = batch['text_input_ids']
    attention_mask = batch['text_attention_mask']
    labels = batch['labels']
    labels_attention_mask = batch['labels_attention_mask']

    loss, outputs = self(
        input_ids=input_ids,
        attention_mask=attention_mask,
        decoder_attention_mask=labels_attention_mask,
        labels=labels
    )

    self.log('test_loss', loss, prog_bar=True, logger=True)
    return loss
    
    
  def configure_optimizers(self):
    return AdamW(self.parameters(), lr=0.0001)

In [None]:
model = NewsSummaryModel()

In [None]:
trainer = pl.Trainer(
    logger=logger,
    checkpoint_callback=checkpoint_callback,
    max_epochs=N_EPOCHS,
    gpus=1,
    progress_bar_refresh_rate=30
)

In [None]:
trainer.checkpoint_callback.best_model_path = ''.join([CHECKPOINTS_PATH, '/', MY_MODEL_NAME, '-v1.ckpt'])
trained_model = NewsSummaryModel.load_from_checkpoint(
    trainer.checkpoint_callback.best_model_path
)

trained_model.freeze()

In [None]:
def summarizer(text):
  text_encoding = tokenizer(
      text,
      max_length=512,
      padding='max_length',
      truncation=True,
      return_attention_mask=True,
      add_special_tokens=True,
      return_tensors='pt'
  )

  generated_ids = trained_model.model.generate(
      input_ids=text_encoding['input_ids'],
      attention_mask=text_encoding['attention_mask'],
      max_length=150,
      num_beams=2,
      repetition_penalty=2.5,
      length_penalty=1.0,
      early_stopping=True
  )

  preds = [
   tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
   for gen_id in generated_ids
  ]

  return ''.join(preds)

In [None]:
with torch.no_grad():
  sample_row = test_df.iloc[0]
  text = sample_row['text']
  model_summary = summarizer(text)