<a href="https://colab.research.google.com/github/RoyElkabetz/Text-Summarization-with-Deep-Learning/blob/main/T5_Summarizer_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
## uncomment to mount google drive
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [2]:
!nvidia-smi

Sat Jun 26 13:18:47 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   31C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
!pip install --quiet transformers==4.5.0
!pip install --quiet pytorch-lightning==1.2.7

[K     |████████████████████████████████| 2.2MB 10.4MB/s 
[K     |████████████████████████████████| 901kB 35.6MB/s 
[K     |████████████████████████████████| 3.3MB 37.0MB/s 
[K     |████████████████████████████████| 839kB 6.6MB/s 
[K     |████████████████████████████████| 276kB 18.5MB/s 
[K     |████████████████████████████████| 276kB 21.3MB/s 
[K     |████████████████████████████████| 829kB 21.1MB/s 
[K     |████████████████████████████████| 122kB 25.3MB/s 
[K     |████████████████████████████████| 1.3MB 28.0MB/s 
[K     |████████████████████████████████| 296kB 44.9MB/s 
[K     |████████████████████████████████| 143kB 43.1MB/s 
[?25h  Building wheel for PyYAML (setup.py) ... [?25l[?25hdone
  Building wheel for future (setup.py) ... [?25l[?25hdone


In [4]:
import json
import time
import pandas as pd
import numpy as np
import torch
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.model_selection import train_test_split
from termcolor import colored



# import torch
# from torchtext.datasets import IMDB as the_dataset
from torchtext.datasets import AG_NEWS, IMDB 
# import torchtext.data as data
# from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
# from torchtext.data.functional import to_map_style_dataset
# from torch.utils.data import DataLoader
# from torch.utils.data.dataset import random_split
# from torch import nn

from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5TokenizerFast as T5Tokenizer
)

from tqdm.auto import tqdm

In [5]:
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc

%matplotlib inline
%config InlineBackend.figure_format='retina'
sns.set(style='whitegrid', palette='muted', font_scale=1.2)
rcParams['figure.figsize'] = 16, 10

pl.seed_everything(216)

Global seed set to 216


216

In [6]:
DATASET_PATH = '/content/gdrive/MyDrive/Datasets/Text/news_summary.csv'
CHECKPOINTS_PATH = '/content/gdrive/MyDrive/Checkpoints'
MY_MODEL_NAME = 'Text_Summarizer_T5'
MODEL_NAME = 't5-base'

In [7]:
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=791656.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1389353.0, style=ProgressStyle(descript…




In [8]:
class NewsSummaryModel(pl.LightningModule):

  def __init__(self):
    super().__init__()
    self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)

  def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):

    output = self.model(
        input_ids,
        attention_mask=attention_mask,
        labels=labels,
        decoder_attention_mask=decoder_attention_mask
    )

    return output.loss, output.logits

  def training_step(self, batch, batch_idx):
    input_ids = batch['text_input_ids']
    attention_mask = batch['text_attention_mask']
    labels = batch['labels']
    labels_attention_mask = batch['labels_attention_mask']

    loss, outputs = self(
        input_ids=input_ids,
        attention_mask=attention_mask,
        decoder_attention_mask=labels_attention_mask,
        labels=labels
    )

    self.log('train_loss', loss, prog_bar=True, logger=True)
    return loss
  
  def validation_step(self, batch, batch_idx):
    input_ids = batch['text_input_ids']
    attention_mask = batch['text_attention_mask']
    labels = batch['labels']
    labels_attention_mask = batch['labels_attention_mask']

    loss, outputs = self(
        input_ids=input_ids,
        attention_mask=attention_mask,
        decoder_attention_mask=labels_attention_mask,
        labels=labels
    )

    self.log('valid_loss', loss, prog_bar=True, logger=True)
    return loss

  def test_step(self, batch, batch_idx):
    input_ids = batch['text_input_ids']
    attention_mask = batch['text_attention_mask']
    labels = batch['labels']
    labels_attention_mask = batch['labels_attention_mask']

    loss, outputs = self(
        input_ids=input_ids,
        attention_mask=attention_mask,
        decoder_attention_mask=labels_attention_mask,
        labels=labels
    )

    self.log('test_loss', loss, prog_bar=True, logger=True)
    return loss
    
    
  def configure_optimizers(self):
    return AdamW(self.parameters(), lr=0.0001)

In [9]:
model = NewsSummaryModel()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1199.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=891691430.0, style=ProgressStyle(descri…




In [11]:
checkpoint_callback = ModelCheckpoint(
    dirpath=CHECKPOINTS_PATH,
    filename=MY_MODEL_NAME,
    save_top_k=1,
    verbose=True,
    monitor='valid_loss',
    mode='min'
)

trainer = pl.Trainer(
    checkpoint_callback=checkpoint_callback,
    max_epochs=1,
    gpus=1,
    progress_bar_refresh_rate=30
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [12]:
trainer.checkpoint_callback.best_model_path = ''.join([CHECKPOINTS_PATH, '/', MY_MODEL_NAME, '-v1.ckpt'])
trained_model = NewsSummaryModel.load_from_checkpoint(
    trainer.checkpoint_callback.best_model_path
)

trained_model.freeze()

In [36]:
def summarizer(text):
  text_encoding = tokenizer(
      text,
      max_length=512,
      padding='max_length',
      truncation=True,
      return_attention_mask=True,
      add_special_tokens=True,
      return_tensors='pt'
  )

  generated_ids = trained_model.model.generate(
      input_ids=text_encoding['input_ids'],
      attention_mask=text_encoding['attention_mask'],
      max_length=50,
      num_beams=2,
      repetition_penalty=2.5,
      length_penalty=1.0,
      early_stopping=True
  )

  preds = [
   tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
   for gen_id in generated_ids
  ]

  return ''.join(preds)

In [30]:
train_iter = AG_NEWS(split='train')
for i, sample in enumerate(train_iter):
  print(len(sample))
  if i == 5:
    break

2
2
2
2
2
2


In [21]:
train_iter = AG_NEWS(split='train')
labels = []
texts = []
for i, (label, text) in enumerate(train_iter):
  labels.append(label)
  texts.append(text)


train.csv: 29.5MB [00:00, 78.8MB/s]


In [22]:
train_df = pd.DataFrame.from_dict({'label': labels, 'text': texts})
train_df.head()

Unnamed: 0,label,text
0,3,Wall St. Bears Claw Back Into the Black (Reute...
1,3,Carlyle Looks Toward Commercial Aerospace (Reu...
2,3,Oil and Economy Cloud Stocks' Outlook (Reuters...
3,3,Iraq Halts Oil Exports from Main Southern Pipe...
4,3,"Oil prices soar to all-time record, posing new..."


In [40]:
with torch.no_grad():
 
  sample_row = train_df.iloc[4]
  text = sample_row['text']
  model_summary = summarizer(text)

In [41]:
text

'Oil prices soar to all-time record, posing new menace to US economy (AFP) AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections.'

In [42]:
model_summary

'Oil prices soar to all-time record, posing new menace to US economy, reports said. The price of crude oil is set to hit its highest level in the last three years, and it has fallen to an all-time'