In [1]:
import time
import torch
import transformers

import numpy as np
import pandas as pd

from tqdm import tqdm
from torch import cuda
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments, Trainer

In [2]:
device = 'cuda' if cuda.is_available() else 'cpu'

### Check out the Dataset

In [3]:
df = pd.read_csv('news-expand.csv')

In [4]:
df.head()

Unnamed: 0.1,Unnamed: 0,summary,text
0,0,Australian batsman Steve Smith reveals he near...,australia batsman steve smith has revealed tha...
1,1,milind soman recently trekked to the highest p...,milind soman needs no introduction and nor doe...
2,2,singer Aditya Narayan shared a picture of hims...,ahead of his wedding (reportedly on december 1...
3,3,Rubina Dilaik will reveal one of the deepest s...,tv star rubina dilaik will reveal one of the d...
4,4,"the couple, who got married exactly a month ag...",it is a happy day for kajal aggarwal and her h...


### Create CustomDataset class for training

In [4]:
class CustomDataset(Dataset):

    def __init__(self, dataframe, tokenizer, source_len, summ_len):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.source_len = source_len
        self.summ_len = summ_len
        self.text = self.data.summary
        self.ctext = self.data.text

    def __len__(self):
        return len(self.text)

    def __getitem__(self, index):
        ctext = str(self.ctext[index])
        ctext = ' '.join(ctext.split())

        text = str(self.text[index])
        text = ' '.join(text.split())

        source = self.tokenizer.batch_encode_plus([ctext], max_length = self.source_len, padding = 'max_length',\
                                                  return_tensors = 'pt', truncation = True)
        target = self.tokenizer.batch_encode_plus([text], max_length = self.summ_len, padding = 'max_length',\
                                                  return_tensors = 'pt', truncation = True)

        source_ids = source['input_ids'].squeeze().to(dtype = torch.long)
        source_mask = source['attention_mask'].squeeze().to(dtype = torch.long)
        target_ids = target['input_ids'].squeeze().to(dtype = torch.long)
        
        y_ids = target_ids[:-1].contiguous() # make y_ids contiguous 
        lm_labels = target_ids[1:].clone().detach() # make fast copy
        lm_labels[target_ids[1:] == tokenizer.pad_token_id] = -100 # replace pad tokens

        return {
            'input_ids': source_ids, 
            'attention_mask': source_mask, 
            'decoder_input_ids': y_ids,
            'labels': lm_labels
        }

In [7]:
TRAIN_BATCH_SIZE = 2
VALID_BATCH_SIZE = 2 
TRAIN_EPOCHS = 25      
VAL_EPOCHS = 1 
LEARNING_RATE = 1e-4    
SEED = 42               
MAX_LEN = 512
SUMMARY_LEN = 128

torch.manual_seed(SEED) 
np.random.seed(SEED) 
torch.backends.cudnn.deterministic = True

tokenizer = T5Tokenizer.from_pretrained("t5-base")

df = df[['text','summary']]
df.summary = 'summarize: ' + df.summary 

train_size = 0.8
train_dataset = df.sample(frac=train_size,random_state = SEED)
val_dataset = df.drop(train_dataset.index).reset_index(drop = True)
train_dataset = train_dataset.reset_index(drop = True)

print("FULL Dataset: {}".format(df.shape))
print("TRAIN Dataset: {}".format(train_dataset.shape))
print("TEST Dataset: {}".format(val_dataset.shape))

training_set = CustomDataset(train_dataset, tokenizer, MAX_LEN, SUMMARY_LEN)
val_set = CustomDataset(val_dataset, tokenizer, MAX_LEN, SUMMARY_LEN)

train_params = {
    'batch_size': TRAIN_BATCH_SIZE,
    'shuffle': True,
    'num_workers': 0
    }

val_params = {
    'batch_size': VALID_BATCH_SIZE,
    'shuffle': False,
    'num_workers': 0
    }

training_loader = DataLoader(training_set, **train_params)
val_loader = DataLoader(val_set, **val_params)

model = T5ForConditionalGeneration.from_pretrained("t5-tt-trainer/")
model = model.to(device)

optimizer = torch.optim.Adam(params =  model.parameters(), lr = LEARNING_RATE)

FULL Dataset: (8104, 2)
TRAIN Dataset: (6483, 2)
TEST Dataset: (1621, 2)


In [8]:
args = TrainingArguments(output_dir="t5-tt-trainer/",
                         seed=42,
                         num_train_epochs=25,
                         per_device_train_batch_size=2,  
                         # max batch size without OOM exception, because of the large max token length
                         per_device_eval_batch_size=2,
                         logging_steps=2500,
                         save_steps=0,
                        )

In [9]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=training_set,
    eval_dataset=val_set,
)

In [10]:
trainer.train()

Step,Training Loss
2500,1.078
5000,0.8709
7500,0.7875
10000,0.7386
12500,0.6693
15000,0.6267
17500,0.5929
20000,0.5593
22500,0.5243
25000,0.4852


TrainOutput(global_step=81050, training_loss=0.4333953317104483, metrics={'train_runtime': 11297.955, 'train_samples_per_second': 7.174, 'total_flos': 138511275291993600, 'epoch': 25.0})

In [6]:
def validate(epoch, tokenizer, model, device, loader):
    model.eval()
    predictions = []
    actuals = []
    texts = []
    with torch.no_grad():
        for _, data in tqdm(enumerate(loader, 0)):
            y = data['decoder_input_ids'].to(device, dtype = torch.long)
            ids = data['input_ids'].to(device, dtype = torch.long)
            mask = data['attention_mask'].to(device, dtype = torch.long)

            generated_ids = model.generate(
                input_ids = ids,
                attention_mask = mask, 
                max_length = 100, 
                num_beams = 2,
                repetition_penalty = 2.5, 
                length_penalty = 1.0, 
                early_stopping = True
                )
            preds = [tokenizer.decode(g, skip_special_tokens = True, clean_up_tokenization_spaces = True)\
                     for g in generated_ids]
            target = [tokenizer.decode(t, skip_special_tokens = True, clean_up_tokenization_spaces = True)\
                      for t in y]
            text = [tokenizer.decode(i, skip_special_tokens = True, clean_up_tokenization_spaces = True)\
                      for i in ids]
            if _%2500==0:
                print(f'Completed {_}')

            predictions.extend(preds)
            actuals.extend(target)
            texts.extend(text)
    return predictions, actuals, texts

In [11]:
trainer.save_model("t5-tt-trainer/")

In [8]:
start_time = time.time()
for epoch in range(VAL_EPOCHS):
    predictions, actuals, text = validate(epoch, tokenizer, model, device, val_loader)
    final_df = pd.DataFrame({'Generated Text': predictions,'Actual Text': actuals, 'Text': text})
print("Validation took " + str(time.time() - start_time) + " seconds")

1it [00:01,  1.81s/it]

Completed 0


811it [23:47,  1.76s/it]

Validation took 1427.3927965164185 seconds





### Check out Generated vs Actual Text

In [9]:
final_df

Unnamed: 0,Generated Text,Actual Text,Text
0,: bigg boss 14's star rubina dilaik will revea...,summarize: Rubina Dilaik will reveal one of th...,tv star rubina dilaik will reveal one of the d...
1,Kajal Aggarwal and her husband Gautam Kitchlu ...,"summarize: the couple, who got married exactly...",it is a happy day for kajal aggarwal and her h...
2,amitabh Bachchan's Instagram account is a shee...,summarize: Amitabh Bachchan's latest post is a...,there is absolutely no denying the fact that a...
3,"a twitter user addressed Karan Johar as ""the f...",summarize: a twitter user addressed Karan Joha...,"karan johar, who is quite used to be being tro..."
4,actress ankita Lokhande is all set to perform ...,summarize: ankita Lokhande is all set to perfo...,"tv star ankita lokhande, who is all set to per..."
...,...,...,...
1616,": one day later de telegraaf, a daily amsterda...",summarize: but there are other positive uses o...,mobile picture power in your pocket how many t...
1617,: a delta spokesperson confirmed on wednesday ...,summarize: a delta spokesperson confirmed on w...,us blogger fired by her airline a us airline a...
1618,: clearswift said that fact that no viral code...,summarize: a windows virus called bofra is tur...,toxic web links help virus spread virus writer...
1619,: the system is not available commercially yet...,summarize: the system is not available commerc...,mobile networks seek turbo boost third-generat...


In [10]:
final_df.to_csv('predsvsactual.csv')

### Save Model in tensorflow to upload to huggingface model repository

In [None]:
import transformers

In [None]:
tf_model = transformers.TFT5ForConditionalGeneration.from_pretrained("t5-tt-trainer/", from_pt=True)

In [None]:
tf_model.save_pretrained("t5-tt-trainer/")

In [None]:
tokenizer = transformers.T5Tokenizer.from_pretrained("t5-base")

In [None]:
tokenizer.save_pretrained('./')