In [2]:
!pip install -q sumeval==0.2.2
!pip install transformers

[0m

# Decription 
### Fine-Tune a T5 model on a summarisation dataset using Pytorch and HugingFace.

In [37]:
import gc
import random
import warnings
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

#import nlpaug.augmenter.word as naw
from sumeval.metrics.rouge import RougeCalculator

import torch
from transformers import AutoTokenizer
import transformers
from transformers import AutoModelForSeq2SeqLM

print('Pytorch version: %s'  % torch.__version__)

Pytorch version: 1.11.0


In [38]:
warnings.simplefilter('ignore')
pd.set_option('display.max_colwidth', 100)
cuda =  torch.cuda.is_available()
device = torch.device("cuda") if cuda else torch.device("cpu")

## Reading Data

In [39]:
df = pd.read_csv('../input/newssummery/news_summary.csv', encoding='ISO-8859-1').dropna().reset_index(drop=True)
more_df = pd.read_csv('../input/newssummery/news_summary_more.csv', encoding='ISO-8859-1')

In [40]:
more_df = more_df
df = df

## Data Pre-processing

In [41]:
df['headlines_length'] = [len(df['headlines'][i]) for i in range(len(df))]
df['text_length'] = [len(df['text'][i]) for i in range(len(df))]
more_df['headlines_length'] = [len(more_df['headlines'][i]) for i in range(len(more_df))]
more_df['text_length'] = [len(more_df['text'][i]) for i in range(len(more_df))]

print('df headlines length:\n', df['headlines_length'].describe())
print('more_df headlines length:\n', more_df['headlines_length'].describe())

df headlines length:
 count    4396.000000
mean       55.976342
std         4.580106
min        31.000000
25%        54.000000
50%        58.000000
75%        59.000000
max        62.000000
Name: headlines_length, dtype: float64
more_df headlines length:
 count    98401.000000
mean        57.643337
std          4.878594
min          9.000000
25%         56.000000
50%         59.000000
75%         60.000000
max         86.000000
Name: headlines_length, dtype: float64


In [42]:
print('df text length:\n', df['text_length'].describe())
print('more_df text length:\n', more_df['text_length'].describe())

df text length:
 count    4396.000000
mean      354.820746
std        23.956240
min       282.000000
25%       339.000000
50%       356.000000
75%       372.000000
max       400.000000
Name: text_length, dtype: float64
more_df text length:
 count    98401.000000
mean       357.544161
std         24.647988
min          4.000000
25%        341.000000
50%        358.000000
75%        376.000000
max        513.000000
Name: text_length, dtype: float64


In [43]:
df = df.drop(['author', 'date', 'read_more', 'ctext',
              'headlines_length', 'text_length'], axis=1)
more_df = more_df.drop(['headlines_length', 'text_length'], axis=1)
df = pd.concat([df, more_df]).reset_index(drop=True)
df = df.rename(columns={'text': 'text', 'headlines': 'summary'}).reindex(columns=['text', 'summary'])
df.head()

Unnamed: 0,text,summary
0,The Administration of Union Territory Daman and Diu has revoked its order that made it compulsor...,Daman & Diu revokes mandatory Rakshabandhan in offices order
1,"Malaika Arora slammed an Instagram user who trolled her for ""divorcing a rich man"" and ""having f...",Malaika slams user who trolled her for 'divorcing rich man'
2,The Indira Gandhi Institute of Medical Sciences (IGIMS) in Patna on Thursday made corrections in...,'Virgin' now corrected to 'Unmarried' in IGIMS' form
3,"Lashkar-e-Taiba's Kashmir commander Abu Dujana, who was killed by security forces, said ""Kabhi h...",Aaj aapne pakad liya: LeT man Dujana before being killed
4,"Hotels in Maharashtra will train their staff to spot signs of sex trafficking, including frequen...",Hotel staff to get training to spot signs of sex trafficking


In [44]:
from datasets import Dataset
dataset = Dataset.from_pandas(df)
split = dataset.train_test_split(test_size=0.2, shuffle=True, seed=20)

train, test = split["train"], split["test"]

split = train.train_test_split(test_size=0.2, shuffle=True, seed=20)
train, valid = split["train"], split["test"]

print(len(train["text"]), len(valid["text"]), len(test["text"]))

65789 16448 20560


## Modeling

In [45]:
# Loading tokenizer of t5 model
tokenizer = AutoTokenizer.from_pretrained("t5-small")

In [46]:
# prompting the model to do summarisation
prefix = "summarize: "


def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)
    model_inputs["input_ids"] = model_inputs["input_ids"]

    labels = tokenizer(text=examples["summary"], max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [47]:
tokenized_train = train.map(preprocess_function, batched=True)
tokenized_valid = valid.map(preprocess_function, batched=True)
tokenized_test = test.map(preprocess_function, batched=True)

  0%|          | 0/66 [00:00<?, ?ba/s]

  0%|          | 0/17 [00:00<?, ?ba/s]

  0%|          | 0/21 [00:00<?, ?ba/s]

In [48]:
tokenized_train = tokenized_train.remove_columns(["text"]).remove_columns(["summary"]).remove_columns(["attention_mask"])
tokenized_valid = tokenized_valid.remove_columns(["text"]).remove_columns(["summary"]).remove_columns(["attention_mask"])
tokenized_test = tokenized_test.remove_columns(["text"]).remove_columns(["summary"]).remove_columns(["attention_mask"])

tokenized_train.set_format("torch")
tokenized_valid.set_format("torch")
tokenized_test.set_format("torch")

In [None]:
batch

## Loading the model

In [49]:
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")

In [81]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.dataloader import default_collate

def pad_collate(batch):
    xx = [x["input_ids"] for x in batch]
    yy = [x["labels"] for x in batch]
    x_lens = [len(x) for x in xx]
    y_lens = [len(y) for y in yy]

    xx_pad = pad_sequence(xx, batch_first=True, padding_value=0)
    yy_pad = pad_sequence(yy, batch_first=True, padding_value=0)
    
    data = [{"input_ids": x, "labels": y} for x,y in zip(xx_pad, yy_pad)]
    
    return default_collate(data)

train_dataloader = DataLoader(tokenized_train, shuffle=True, batch_size=8, collate_fn=pad_collate)
eval_dataloader = DataLoader(tokenized_valid, batch_size=8, collate_fn=pad_collate)

In [71]:
test_dataloader = DataLoader(tokenized_test, batch_size=8, collate_fn=pad_collate)

In [51]:
from transformers import get_scheduler
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=5e-5)

num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
print(num_training_steps)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

24672


In [52]:
from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))

model.to(device)
model.train()

for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

  0%|          | 0/24672 [00:00<?, ?it/s]

In [82]:
for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    break

In [89]:
for batch_test in test_dataloader:
    batch_test = {k: v.to(device) for k, v in batch.items()}
    break

In [111]:
outputs = model.generate(batch['input_ids'])
#type(outputs)
print(tokenizer.decode(outputs[5], skip_special_tokens=True))
print(tokenizer.decode(batch['labels'][5], skip_special_tokens=True))

Boycott names Dream XI cricket team which didn't feature greats
No Indians in Geoffrey Boycott's Dream XI


In [118]:
outputs[5]

tensor([    0,  7508, 10405,  3056,  7099,     3,     4,   196, 18096,   372,
           84,   737,    31,    17,  1451,   248,     7,     1,     0],
       device='cuda:0')

In [98]:
outputs_test = model.generate(batch['input_ids'])
#type(outputs)
print(tokenizer.decode(outputs_test[0], skip_special_tokens=True))

57-yr-old man who married 8 women over 8 years arrested in Tamil Na


In [112]:
print(type(batch['input_ids'][0]))
tokenizer.decode(batch['input_ids'][5], skip_special_tokens=True)

<class 'torch.Tensor'>


'summarize: Ex-England cricketer Geoffrey Boycott named his Dream XI cricket team which did not feature any of the Indian greats. Boycott said he did not pick Gavaskar in the side as batsmen from yesteryears like WG Grace and Jack Hobbs played under higher degree of challenges. "The Dream XI selected by the ICC\'s online readers insults...achievements of the greats," added Boycott.'

In [99]:
rouge = RougeCalculator(stopwords=True, lang="en")

def rouge_calc(preds, targets):
    rouge_1 = [rouge.rouge_n(summary=preds[i],references=targets[i],n=1) for i in range(len(preds))]
    rouge_2 = [rouge.rouge_n(summary=preds[i],references=targets[i],n=2) for i in range(len(preds))]
    rouge_l = [rouge.rouge_l(summary=preds[i],references=targets[i]) for i in range(len(preds))]

    return {"Rouge_1": np.array(rouge_1).mean(),
            "Rouge_2": np.array(rouge_2).mean(),
            "Rouge_L": np.array(rouge_l).mean()}

In [104]:
prediction = []
ground_truth = []
for i in range(0,8):
    prediction.append(tokenizer.decode(outputs[i], skip_special_tokens=True))
    ground_truth.append(tokenizer.decode(batch['labels'][i], skip_special_tokens=True))

In [105]:
rouge_calc(prediction , ground_truth)

{'Rouge_1': 0.4218406593406593,
 'Rouge_2': 0.17316017316017315,
 'Rouge_L': 0.4026098901098901}

In [119]:
prediction

['Man who allegedly duped 8 women of Rs of rupees arrested',
 'RBI to start printing 1100 notes around April: Reports',
 'Pakistan forced to conduct nuclear tests in self-defence: PM',
 'Nakorean missile flew over hotel in Pyongyang: Team',
 "violin bearing 'Made for the Worlds' Profesior's",
 "Boycott names his team which didn't feature Indian greats",
 'Woman turns 110-yr-old tree into a library in US',
 "UK Defence Secretary Gavin Williamson calls Russia 'out of jail'"]