In [1]:
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler

from transformers import T5Tokenizer, T5ForConditionalGeneration


In [77]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [78]:
df = pd.read_csv("../input/news_summary.csv", encoding="latin-1")
df.head()

Unnamed: 0,author,date,headlines,read_more,text,ctext
0,Chhavi Tyagi,"03 Aug 2017,Thursday",Daman & Diu revokes mandatory Rakshabandhan in...,http://www.hindustantimes.com/india-news/raksh...,The Administration of Union Territory Daman an...,The Daman and Diu administration on Wednesday ...
1,Daisy Mowke,"03 Aug 2017,Thursday",Malaika slams user who trolled her for 'divorc...,http://www.hindustantimes.com/bollywood/malaik...,Malaika Arora slammed an Instagram user who tr...,"From her special numbers to TV?appearances, Bo..."
2,Arshiya Chopra,"03 Aug 2017,Thursday",'Virgin' now corrected to 'Unmarried' in IGIMS...,http://www.hindustantimes.com/patna/bihar-igim...,The Indira Gandhi Institute of Medical Science...,The Indira Gandhi Institute of Medical Science...
3,Sumedha Sehra,"03 Aug 2017,Thursday",Aaj aapne pakad liya: LeT man Dujana before be...,http://indiatoday.intoday.in/story/abu-dujana-...,Lashkar-e-Taiba's Kashmir commander Abu Dujana...,Lashkar-e-Taiba's Kashmir commander Abu Dujana...
4,Aarushi Maheshwari,"03 Aug 2017,Thursday",Hotel staff to get training to spot signs of s...,http://indiatoday.intoday.in/story/sex-traffic...,Hotels in Maharashtra will train their staff t...,Hotels in Mumbai and other Indian cities are t...


In [79]:
df = df[["text", "ctext"]]
df.ctext = "summarize: " + df.ctext
df.head()

Unnamed: 0,text,ctext
0,The Administration of Union Territory Daman an...,summarize: The Daman and Diu administration on...
1,Malaika Arora slammed an Instagram user who tr...,summarize: From her special numbers to TV?appe...
2,The Indira Gandhi Institute of Medical Science...,summarize: The Indira Gandhi Institute of Medi...
3,Lashkar-e-Taiba's Kashmir commander Abu Dujana...,summarize: Lashkar-e-Taiba's Kashmir commander...
4,Hotels in Maharashtra will train their staff t...,summarize: Hotels in Mumbai and other Indian c...


In [80]:
df["text"][0]

'The Administration of Union Territory Daman and Diu has revoked its order that made it compulsory for women to tie rakhis to their male colleagues on the occasion of Rakshabandhan on August 7. The administration was forced to withdraw the decision within 24 hours of issuing the circular after it received flak from employees and was slammed on social media.'

In [81]:
df["ctext"][0]

'summarize: The Daman and Diu administration on Wednesday withdrew a circular that asked women staff to tie rakhis on male colleagues after the order triggered a backlash from employees and was ripped apart on social media.The union territory?s administration was forced to retreat within 24 hours of issuing the circular that made it compulsory for its staff to celebrate Rakshabandhan at workplace.?It has been decided to celebrate the festival of Rakshabandhan on August 7. In this connection, all offices/ departments shall remain open and celebrate the festival collectively at a suitable time wherein all the lady staff shall tie rakhis to their colleagues,? the order, issued on August 1 by Gurpreet Singh, deputy secretary (personnel), had said.To ensure that no one skipped office, an attendance report was to be sent to the government the next evening.The two notifications ? one mandating the celebration of Rakshabandhan (left) and the other withdrawing the mandate (right) ? were issued 

In [82]:
class CustomDataset(Dataset):
    def __init__(
        self, 
        dataframe, 
        tokenizer, 
        source_len, 
        summ_len
    ):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.source_len = source_len
        self.summ_len = summ_len
        self.text = self.data.text
        self.ctext = self.data.ctext
        
    def __len__(self):
        return len(self.text)
        
    def __getitem__(self, index):
        ctext = str(self.ctext[index])
        ctext = ' '.join(ctext.split())

        text = str(self.text[index])
        text = ' '.join(text.split())

        source = self.tokenizer.batch_encode_plus(
            [ctext], 
            max_length=self.source_len, 
            pad_to_max_length=True,
            return_tensors="pt",
            truncation=True,
        )
        target = self.tokenizer.batch_encode_plus(
            [text], 
            max_length=self.summ_len, 
            pad_to_max_length=True,
            return_tensors="pt",
            truncation=True,
        )

        source_ids = source["input_ids"].squeeze()
        source_mask = source["attention_mask"].squeeze()
        target_ids = target["input_ids"].squeeze()
        target_mask = target["attention_mask"].squeeze()

        return {
            "source_ids": source_ids.to(dtype=torch.long), 
            "source_mask": source_mask.to(dtype=torch.long), 
            "target_ids": target_ids.to(dtype=torch.long),
            "target_ids_y": target_ids.to(dtype=torch.long)
        }
        

In [83]:
def train(epoch, tokenizer, model, device, loader, optimizer):
    model.train()
    for _,data in enumerate(loader, 0):
        y = data["target_ids"].to(device, dtype = torch.long)
        y_ids = y[:, :-1].contiguous()
        lm_labels = y[:, 1:].clone().detach()
        lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100
        ids = data["source_ids"].to(device, dtype=torch.long)
        mask = data["source_mask"].to(device, dtype=torch.long)

        outputs = model(
            input_ids=ids, 
            attention_mask=mask, 
            decoder_input_ids=y_ids, 
            lm_labels=lm_labels
        )
        loss = outputs[0]
        
        if _%500==0:
            print(f'Epoch: {epoch}, Loss:  {loss.item()}')
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        

In [84]:
def validate(epoch, tokenizer, model, device, loader):
    model.eval()
    predictions = []
    actuals = []
    with torch.no_grad():
        for _, data in enumerate(loader, 0):
            y = data["target_ids"].to(device, dtype=torch.long)
            ids = data["source_ids"].to(device, dtype=torch.long)
            mask = data["source_mask"].to(device, dtype=torch.long)

            generated_ids = model.generate(
                input_ids=ids,
                attention_mask=mask, 
                max_length=150, 
                num_beams=2,
                repetition_penalty=2.5, 
                length_penalty=1.0, 
                early_stopping=True
                )
            preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
            target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)for t in y]
            
            if _%100==0:
                print(f"Completed {_}")

            predictions.extend(preds)
            actuals.extend(target)
            
    return predictions, actuals


In [85]:
def main(): 
    TRAIN_BATCH_SIZE = 2
    VALID_BATCH_SIZE = 2
    TRAIN_EPOCHS = 2
    VAL_EPOCHS = 1 
    LEARNING_RATE = 1e-4
    SEED = 42
    MAX_LEN = 512
    SUMMARY_LEN = 150 

    torch.manual_seed(SEED)
    np.random.seed(SEED)

    tokenizer = T5Tokenizer.from_pretrained("t5-base")
    
    df = pd.read_csv("../input/news_summary.csv", encoding="latin-1")
    df = df[["text", "ctext"]]
    df.ctext = "summarize: " + df.ctext
    print(df.head(3))

    train_size = 0.8
    train_dataset = df.sample(frac=train_size,random_state = SEED)
    val_dataset = df.drop(train_dataset.index).reset_index(drop=True)
    train_dataset = train_dataset.reset_index(drop=True)

    print(f"FULL Dataset: {df.shape}")
    print(f"TRAIN Dataset: {train_dataset.shape}")
    print(f"TEST Dataset: {val_dataset.shape}")

    training_set = CustomDataset(
        train_dataset, 
        tokenizer, 
        MAX_LEN, 
        SUMMARY_LEN
    )
    
    val_set = CustomDataset(
        val_dataset, 
        tokenizer, 
        MAX_LEN, 
        SUMMARY_LEN
    )

    train_params = {
        "batch_size": TRAIN_BATCH_SIZE,
        "shuffle": True,
        "num_workers": 0
        }

    val_params = {
        "batch_size": VALID_BATCH_SIZE,
        "shuffle": False,
        "num_workers": 0
        }

    training_loader = DataLoader(
        training_set, 
        **train_params
    )
    
    val_loader = DataLoader(
        val_set, 
        **val_params
    )

    model = T5ForConditionalGeneration.from_pretrained("t5-base")
    model = model.to(device)

    optimizer = torch.optim.Adam(
        params=model.parameters(), 
        lr=LEARNING_RATE
    )

    print("Initiating Fine-Tuning for the model on our dataset")

    for epoch in range(TRAIN_EPOCHS):
        train(
            epoch, 
            tokenizer, 
            model, 
            device, 
            training_loader, 
            optimizer
        )

    print("Now generating summaries on our fine tuned model for the validation dataset and saving it in a dataframe")
    for epoch in range(VAL_EPOCHS):
        predictions, actuals = validate(
            epoch, 
            tokenizer, 
            model, 
            device, 
            val_loader
        )
        
        final_df = pd.DataFrame(
            {
                "Generated Text": predictions,
                "Actual Text": actuals
            }
        )
        final_df.to_csv("../input/predictions.csv")
        
        print("Output Files generated for review")

if __name__ == "__main__":
    main()
    

                                                text  \
0  The Administration of Union Territory Daman an...   
1  Malaika Arora slammed an Instagram user who tr...   
2  The Indira Gandhi Institute of Medical Science...   

                                               ctext  
0  summarize: The Daman and Diu administration on...  
1  summarize: From her special numbers to TV?appe...  
2  summarize: The Indira Gandhi Institute of Medi...  
FULL Dataset: (4514, 2)
TRAIN Dataset: (3611, 2)
TEST Dataset: (903, 2)


Some weights of T5ForConditionalGeneration were not initialized from the model checkpoint at t5-base and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Initiating Fine-Tuning for the model on our dataset
Epoch: 0, Loss:  6.617685317993164
Epoch: 0, Loss:  1.8290600776672363
Epoch: 0, Loss:  1.5071122646331787
Epoch: 0, Loss:  1.9304471015930176
Epoch: 1, Loss:  2.048556089401245
Epoch: 1, Loss:  1.262363314628601
Epoch: 1, Loss:  0.8555755615234375
Epoch: 1, Loss:  1.6932355165481567
Now generating summaries on our fine tuned model for the validation dataset and saving it in a dataframe
Completed 0
Completed 100
Completed 200
Completed 300
Completed 400
Output Files generated for review


In [89]:
preds = pd.read_csv("../input/predictions.csv")
preds.head()


Unnamed: 0.1,Unnamed: 0,Generated Text,Actual Text
0,0,hotels in Mumbai and other Indian cities are t...,Hotels in Maharashtra will train their staff t...
1,1,UP Congress Party has opened a 'State Bank of ...,The Congress party has opened a bank called 'S...
2,2,a 24-year-old Indian athlete has been indicted...,"Tanveer Hussain, a 24-year-old Indian athlete ..."
3,3,the remains of a German hiker who disappeared ...,"The remains of a German hiker, who disappeared..."
4,4,"GP Manish Shah, who practised in east London, ...","A UK-based doctor, Manish Shah, has been charg..."
