In [27]:
!pip install transformers -q

In [28]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import BertTokenizer, BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [31]:
df = pd.read_csv("news_summary.csv", encoding="latin-1")
df = df[["text", "ctext"]]
df.ctext = "summarize: " + df.ctext
df.head(3)

Unnamed: 0,text,ctext
0,The Administration of Union Territory Daman an...,summarize: The Daman and Diu administration on...
1,Malaika Arora slammed an Instagram user who tr...,summarize: From her special numbers to TV?appe...
2,The Indira Gandhi Institute of Medical Science...,summarize: The Indira Gandhi Institute of Medi...


In [32]:
MAX_LEN = 512
SUMMARY_LEN = 150 
TRAIN_BATCH_SIZE = 2
VALID_BATCH_SIZE = 2
TRAIN_EPOCHS = 1
VAL_EPOCHS = 1 
LEARNING_RATE = 1e-4
SEED = 42
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")  

In [33]:
class CustomDataset(Dataset):
    def __init__(
        self, 
        dataframe, 
        tokenizer, 
        source_len, 
        summ_len
    ):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.source_len = source_len
        self.summ_len = summ_len
        self.text = self.data.text
        self.ctext = self.data.ctext
        
    def __len__(self):
        return len(self.text)
        
    def __getitem__(self, index):
        ctext = str(self.ctext[index])
        ctext = ' '.join(ctext.split())

        text = str(self.text[index])
        text = ' '.join(text.split())

        source = self.tokenizer.batch_encode_plus(
            [ctext], 
            max_length=self.source_len, 
            pad_to_max_length=True,
            return_tensors="pt",
            truncation=True,
        )
        target = self.tokenizer.batch_encode_plus(
            [text], 
            max_length=self.summ_len, 
            pad_to_max_length=True,
            return_tensors="pt",
            truncation=True,
        )

        source_ids = source["input_ids"].squeeze()
        source_mask = source["attention_mask"].squeeze()
        target_ids = target["input_ids"].squeeze()
        target_mask = target["attention_mask"].squeeze()

        return {
            "source_ids": source_ids.to(dtype=torch.long), 
            "source_mask": source_mask.to(dtype=torch.long), 
            "target_ids": target_ids.to(dtype=torch.long),
            "target_ids_y": target_ids.to(dtype=torch.long)
        }

In [35]:
encoder = BertGenerationEncoder.from_pretrained("bert-base-uncased")
decoder = BertGenerationDecoder.from_pretrained("bert-base-uncased")
model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
model = model.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertGenerationEncoder: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'bert.embeddings.token_type_embeddings.weight']
- This IS expected if you are initializing BertGenerationEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertGenerationEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
If you want to use `BertGenerationDecoder` as a standalone, add `is_decoder=True.`
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertGenerationDecoder: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls

In [36]:
optimizer = torch.optim.Adam(
    params=model.parameters(), 
    lr=LEARNING_RATE,
)

In [37]:
for epoch in range(TRAIN_EPOCHS):
    model.train()
    for _,data in enumerate(training_loader, 0):
        y = data["target_ids"].to(device, dtype = torch.long)
        y_ids = y[:, :-1].contiguous()
        lm_labels = y[:, 1:].clone().detach()
        lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100
        ids = data["source_ids"].to(device, dtype=torch.long)
        mask = data["source_mask"].to(device, dtype=torch.long)

        outputs = model(
            input_ids=ids, 
            attention_mask=mask, 
            decoder_input_ids=y_ids, 
            labels=lm_labels,
        )
        loss = outputs[0]
        
        if _%500==0:
            print(f'Epoch: {epoch}, Loss:  {loss.item()}')
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()



Epoch: 0, Loss:  9.863866806030273
Epoch: 0, Loss:  6.91909122467041
Epoch: 0, Loss:  6.715019226074219
Epoch: 0, Loss:  6.2675981521606445
