In [None]:
!pip install -q sumeval==0.2.2
!pip install -q nlpaug==1.1.3
!pip install -q simpletransformers==0.60.9
!pip install transformers

# Decription 
### Fine-Tune a T5 model on a summarisation dataset using Pytorch and HugingFace.

In [92]:
import gc
import random
import warnings
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

import nlpaug.augmenter.word as naw
from sumeval.metrics.rouge import RougeCalculator

import torch
from simpletransformers.t5 import T5Model, T5Args
from transformers import AutoTokenizer
import transformers
from transformers import AutoModelForSeq2SeqLM

print('Pytorch version: %s'  % torch.__version__)

Pytorch version: 1.11.0


In [93]:
warnings.simplefilter('ignore')
pd.set_option('display.max_colwidth', 100)
cuda =  torch.cuda.is_available()
device = torch.device("cuda") if cuda else torch.device("cpu")

## Reading Data

In [94]:
df = pd.read_csv('/kaggle/input/newssummary/news_summary_.csv', encoding='ISO-8859-1').dropna().reset_index(drop=True)
more_df = pd.read_csv('/kaggle/input/newssummary/news_summary_more.csv', encoding='ISO-8859-1')

In [95]:
more_df.head()

Unnamed: 0,headlines,text
0,upGrad learner switches to career in ML & Al with 90% salary hike,"Saurav Kant, an alumnus of upGrad and IIIT-B's PG Program in Machine learning and Artificial Int..."
1,Delhi techie wins free food from Swiggy for one year on CRED,"Kunal Shah's credit card bill payment platform, CRED, gave users a chance to win free food from ..."
2,New Zealand end Rohit Sharma-led India's 12-match winning streak,New Zealand defeated India by 8 wickets in the fourth ODI at Hamilton on Thursday to win their f...
3,Aegon life iTerm insurance plan helps customers save tax,"With Aegon Life iTerm Insurance plan, customers can enjoy tax benefits on your premiums paid and..."
4,"Have known Hirani for yrs, what if MeToo claims are not true: Sonam","Speaking about the sexual harassment allegations against Rajkumar Hirani, Sonam Kapoor said, ""I'..."


## Data Pre-processing

In [96]:
df['headlines_length'] = [len(df['headlines'][i]) for i in range(len(df))]
df['text_length'] = [len(df['text'][i]) for i in range(len(df))]
more_df['headlines_length'] = [len(more_df['headlines'][i]) for i in range(len(more_df))]
more_df['text_length'] = [len(more_df['text'][i]) for i in range(len(more_df))]

print('df headlines length:\n', df['headlines_length'].describe())
print('more_df headlines length:\n', more_df['headlines_length'].describe())

df headlines length:
 count    4396.000000
mean       55.976342
std         4.580106
min        31.000000
25%        54.000000
50%        58.000000
75%        59.000000
max        62.000000
Name: headlines_length, dtype: float64
more_df headlines length:
 count    98401.000000
mean        57.643337
std          4.878594
min          9.000000
25%         56.000000
50%         59.000000
75%         60.000000
max         86.000000
Name: headlines_length, dtype: float64


In [97]:
print('df text length:\n', df['text_length'].describe())
print('more_df text length:\n', more_df['text_length'].describe())

df text length:
 count    4396.000000
mean      354.820746
std        23.956240
min       282.000000
25%       339.000000
50%       356.000000
75%       372.000000
max       400.000000
Name: text_length, dtype: float64
more_df text length:
 count    98401.000000
mean       357.544161
std         24.647988
min          4.000000
25%        341.000000
50%        358.000000
75%        376.000000
max        513.000000
Name: text_length, dtype: float64


In [98]:
df = df.drop(['author', 'date', 'read_more', 'ctext',
              'headlines_length', 'text_length'], axis=1)
more_df = more_df.drop(['headlines_length', 'text_length'], axis=1)
df = pd.concat([df, more_df]).reset_index(drop=True)
df = df.rename(columns={'text': 'text', 'headlines': 'summary'}).reindex(columns=['text', 'summary'])
df.head()

Unnamed: 0,text,summary
0,The Administration of Union Territory Daman and Diu has revoked its order that made it compulsor...,Daman & Diu revokes mandatory Rakshabandhan in offices order
1,"Malaika Arora slammed an Instagram user who trolled her for ""divorcing a rich man"" and ""having f...",Malaika slams user who trolled her for 'divorcing rich man'
2,The Indira Gandhi Institute of Medical Sciences (IGIMS) in Patna on Thursday made corrections in...,'Virgin' now corrected to 'Unmarried' in IGIMS' form
3,"Lashkar-e-Taiba's Kashmir commander Abu Dujana, who was killed by security forces, said ""Kabhi h...",Aaj aapne pakad liya: LeT man Dujana before being killed
4,"Hotels in Maharashtra will train their staff to spot signs of sex trafficking, including frequen...",Hotel staff to get training to spot signs of sex trafficking


In [99]:
from datasets import Dataset
dataset = Dataset.from_pandas(df)
split = dataset.train_test_split(test_size=0.2, shuffle=True, seed=20)

train, test = split["train"], split["test"]

split = train.train_test_split(test_size=0.2, shuffle=True, seed=20)
train, valid = split["train"], split["test"]

print(len(train["text"]), len(valid["text"]), len(test["text"]))

65789 16448 20560


## Modeling

In [100]:
# Loading tokenizer of t5 model
tokenizer = AutoTokenizer.from_pretrained("t5-small")

In [101]:
# prompting the model to do summarisation
prefix = "summarize: "


def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)
    model_inputs["input_ids"] = model_inputs["input_ids"]

    labels = tokenizer(text=examples["summary"], max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [102]:
tokenized_train = train.map(preprocess_function, batched=True)
tokenized_valid = valid.map(preprocess_function, batched=True)
tokenized_test = test.map(preprocess_function, batched=True)

  0%|          | 0/66 [00:00<?, ?ba/s]

  0%|          | 0/17 [00:00<?, ?ba/s]

  0%|          | 0/21 [00:00<?, ?ba/s]

In [103]:
tokenized_train = tokenized_train.remove_columns(["text"]).remove_columns(["summary"]).remove_columns(["attention_mask"])
tokenized_valid = tokenized_valid.remove_columns(["text"]).remove_columns(["summary"]).remove_columns(["attention_mask"])
tokenized_test = tokenized_test.remove_columns(["text"]).remove_columns(["summary"]).remove_columns(["attention_mask"])

tokenized_train.set_format("torch")
tokenized_valid.set_format("torch")
tokenized_test.set_format("torch")

## Loading the model

In [104]:
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")

In [105]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.dataloader import default_collate

def pad_collate(batch):
    xx = [x["input_ids"] for x in batch]
    yy = [x["labels"] for x in batch]
    x_lens = [len(x) for x in xx]
    y_lens = [len(y) for y in yy]

    xx_pad = pad_sequence(xx, batch_first=True, padding_value=0)
    yy_pad = pad_sequence(yy, batch_first=True, padding_value=0)
    
    data = [{"input_ids": x, "labels": y} for x,y in zip(xx_pad, yy_pad)]
    
    return default_collate(data)

train_dataloader = DataLoader(tokenized_train, shuffle=True, batch_size=8, collate_fn=pad_collate)
eval_dataloader = DataLoader(tokenized_valid, batch_size=8, collate_fn=pad_collate)

In [106]:
from transformers import get_scheduler
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=5e-5)

num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
print(num_training_steps)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

24672


In [None]:
from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))

model.to(device)
model.train()

for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)