In [44]:
!pip install transformers -q

In [45]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import BertTokenizer, BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel

In [46]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [47]:
df = pd.read_csv("news_summary.csv", encoding="latin-1")
df = df[["text", "ctext"]]
df.ctext = "summarize: " + df.ctext
df.head(3)

Unnamed: 0,text,ctext
0,The Administration of Union Territory Daman an...,summarize: The Daman and Diu administration on...
1,Malaika Arora slammed an Instagram user who tr...,summarize: From her special numbers to TV?appe...
2,The Indira Gandhi Institute of Medical Science...,summarize: The Indira Gandhi Institute of Medi...


In [48]:
MAX_LEN = 512
SUMMARY_LEN = 150 
TRAIN_BATCH_SIZE = 2
VALID_BATCH_SIZE = 2
TRAIN_EPOCHS = 2
VAL_EPOCHS = 1 
LEARNING_RATE = 1e-4
SEED = 42

torch.manual_seed(SEED)

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")  

In [50]:
class CustomDataset(Dataset):
    def __init__(
        self, 
        dataframe, 
        tokenizer, 
        source_len, 
        summ_len
    ):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.source_len = source_len
        self.summ_len = summ_len
        self.text = self.data.text
        self.ctext = self.data.ctext
        
    def __len__(self):
        return len(self.text)
        
    def __getitem__(self, index):
        ctext = str(self.ctext[index])
        ctext = ' '.join(ctext.split())

        text = str(self.text[index])
        text = ' '.join(text.split())

        source = self.tokenizer.batch_encode_plus(
            [ctext], 
            max_length=self.source_len, 
            pad_to_max_length=True,
            return_tensors="pt",
            truncation=True,
        )
        target = self.tokenizer.batch_encode_plus(
            [text], 
            max_length=self.summ_len, 
            pad_to_max_length=True,
            return_tensors="pt",
            truncation=True,
        )

        source_ids = source["input_ids"].squeeze()
        source_mask = source["attention_mask"].squeeze()
        target_ids = target["input_ids"].squeeze()
        target_mask = target["attention_mask"].squeeze()

        return {
            "source_ids": source_ids.to(dtype=torch.long), 
            "source_mask": source_mask.to(dtype=torch.long), 
            "target_ids": target_ids.to(dtype=torch.long),
            "target_ids_y": target_ids.to(dtype=torch.long)
        }

In [51]:
train_size = 0.8
train_dataset = df.sample(frac=train_size, random_state=SEED)
val_dataset = df.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)

print(f"FULL Dataset: {df.shape}")
print(f"TRAIN Dataset: {train_dataset.shape}")
print(f"TEST Dataset: {val_dataset.shape}")

training_set = CustomDataset(
    train_dataset, 
    tokenizer, 
    MAX_LEN, 
    SUMMARY_LEN
)
    
val_set = CustomDataset(
    val_dataset, 
    tokenizer, 
    MAX_LEN, 
    SUMMARY_LEN
)

FULL Dataset: (4514, 2)
TRAIN Dataset: (3611, 2)
TEST Dataset: (903, 2)


In [52]:
train_params = {
    "batch_size": TRAIN_BATCH_SIZE,
    "shuffle": True,
    "num_workers": 0
}

val_params = {
        "batch_size": VALID_BATCH_SIZE,
        "shuffle": False,
        "num_workers": 0
}

training_loader = DataLoader(
    training_set, 
    **train_params
)
    
val_loader = DataLoader(
    val_set, 
    **val_params
)

In [53]:
encoder = BertGenerationEncoder.from_pretrained("bert-base-uncased")
decoder = BertGenerationDecoder.from_pretrained("bert-base-uncased")
model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
model = model.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertGenerationEncoder: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'bert.embeddings.token_type_embeddings.weight']
- This IS expected if you are initializing BertGenerationEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertGenerationEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
If you want to use `BertGenerationDecoder` as a standalone, add `is_decoder=True.`
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertGenerationDecoder: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls

In [54]:
optimizer = torch.optim.Adam(
    params=model.parameters(), 
    lr=LEARNING_RATE,
)

In [55]:
for epoch in range(TRAIN_EPOCHS):
    model.train()
    for _,data in enumerate(training_loader, 0):
        y = data["target_ids"].to(device, dtype = torch.long)
        y_ids = y[:, :-1].contiguous()
        lm_labels = y[:, 1:].clone().detach()
        lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100
        ids = data["source_ids"].to(device, dtype=torch.long)
        mask = data["source_mask"].to(device, dtype=torch.long)

        outputs = model(
            input_ids=ids, 
            attention_mask=mask, 
            decoder_input_ids=y_ids, 
            labels=lm_labels,
        )
        loss = outputs[0]
        
        if _%500==0:
            print(f'Epoch: {epoch}, Loss:  {loss.item()}')
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()



Epoch: 0, Loss:  9.649497985839844
Epoch: 0, Loss:  7.181732177734375
Epoch: 0, Loss:  7.296773433685303
Epoch: 0, Loss:  7.359935760498047
Epoch: 1, Loss:  6.571008205413818
Epoch: 1, Loss:  6.389883518218994
Epoch: 1, Loss:  6.820432186126709
Epoch: 1, Loss:  6.172169208526611


In [57]:
for epoch in range(VAL_EPOCHS):
    model.eval()
    predictions = []
    actuals = []
    with torch.no_grad():
        for _, data in enumerate(val_loader, 0):
            y = data["target_ids"].to(device, dtype=torch.long)
            ids = data["source_ids"].to(device, dtype=torch.long)
            mask = data["source_mask"].to(device, dtype=torch.long)

            generated_ids = model.generate(
                input_ids=ids,
                attention_mask=mask, 
                max_length=150, 
                num_beams=2,
                repetition_penalty=2.5, 
                length_penalty=1.0, 
                early_stopping=True
                )
            preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
            target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)for t in y]
            
            if _%100==0:
                print(f"Completed {_}")

            predictions.extend(preds)
            actuals.extend(target)



Completed 0
Completed 100
Completed 200
Completed 300
Completed 400


In [58]:
final_df = pd.DataFrame(
    {
        "Generated Text": predictions,
        "Actual Text": actuals,
    }
)
final_df.to_csv("predictions.csv")
final_df.head()

Unnamed: 0,Generated Text,Actual Text
0,[unused1] supreme on said the court s has a of...,hotels in maharashtra will train their staff t...
1,[unused1] supreme on said the court s has a of...,the congress party has opened a bank called'st...
2,[unused1] supreme on said the court s has a of...,"tanveer hussain, a 24 - year - old indian athl..."
3,[unused1] supreme on said the court s has a of...,"the remains of a german hiker, who disappeared..."
4,[unused1] supreme on said the court s has a of...,"a uk - based doctor, manish shah, has been cha..."
