In [1]:
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import BartTokenizer, BartForConditionalGeneration


import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn

from tqdm import tqdm
import numpy as np
import os

# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
class CNNDMDataset(Dataset):
    def __init__(self, type:str, max_len:int=1024, data_len:int=1000):
        #input type [article, highlights, id]
        super().__init__()
        self.data = load_dataset('cnn_dailymail', '3.0.0')[type][:data_len]
        self.tok = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
        self.max_len = max_len
        self.data_len = data_len

    def __len__(self):
        return self.data_len
    
    def __getitem__(self, idx):
        
        src = self.tok.encode_plus([self.data['article'][idx]], max_length=self.max_len, return_tensors='pt', truncation=True, pad_to_max_length=False)
        tgt = self.tok.encode_plus([self.data['highlights'][idx]], max_length=self.max_len, return_tensors='pt', truncation=True, pad_to_max_length=False)
        src_input_ids = src['input_ids'].squeeze()
        tgt_input_ids = tgt['input_ids'].squeeze()
        result = {
            'src_input_ids': src_input_ids,
            'tgt_input_ids': tgt_input_ids
        }
        return result

In [3]:
def cnndm_collate_fn(batch):
    def pad(X, max_len=-1):
        #[input_ids, attention_mask]
        if max_len < 0:
            max_len = max(x.size(0) for x in X)
        result = torch.ones(len(X), max_len, dtype=X[0].dtype) * BartTokenizer.from_pretrained('facebook/bart-large-cnn').pad_token_id
        for (i, x) in enumerate(X):
            result[0, :x.size(0)] = x
        return result
    
    src_input_ids = pad([x['src_input_ids'] for x in batch])
    tgt_input_ids = pad([x['tgt_input_ids'] for x in batch])
    result = {
        'src_input_ids': src_input_ids,
        'tgt_input_ids': tgt_input_ids
    }
    return result

In [4]:
train_dataset = CNNDMDataset('train', data_len=1000)
val_dataset = CNNDMDataset('validation', data_len=1000)

In [5]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=cnndm_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=cnndm_collate_fn)

In [6]:
# Load the pre-trained model
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')

In [7]:
"""
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
for i, batch in enumerate(train_loader):
    src_input_ids = batch['src_input_ids'].to(device)
    tgt_input_ids = batch['tgt_input_ids'].to(device)
    outputs = model(input_ids=src_input_ids, decoder_input_ids=tgt_input_ids)
    outputs = outputs[0] #[bsz, seq_len, vocab_size]
    outputs = F.log_softmax(outputs, dim=-1)
    loss = criterion(outputs.view(-1, outputs.size(-1)), tgt_input_ids.view(-1))
    print("Loss: ", loss)
    del src_input_ids, tgt_input_ids, outputs, loss
    break
"""

'\nmodel.to(device)\ncriterion = torch.nn.CrossEntropyLoss()\nfor i, batch in enumerate(train_loader):\n    src_input_ids = batch[\'src_input_ids\'].to(device)\n    tgt_input_ids = batch[\'tgt_input_ids\'].to(device)\n    outputs = model(input_ids=src_input_ids, decoder_input_ids=tgt_input_ids)\n    outputs = outputs[0] #[bsz, seq_len, vocab_size]\n    outputs = F.log_softmax(outputs, dim=-1)\n    loss = criterion(outputs.view(-1, outputs.size(-1)), tgt_input_ids.view(-1))\n    print("Loss: ", loss)\n    del src_input_ids, tgt_input_ids, outputs, loss\n    break\n'

In [8]:
def run(is_accumulate=False, accumulate_step=10):
    epochs = 3
    accumulate_cnt = 0
    
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

    model.to(device)
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        for i, batch in enumerate(train_loader):
            
            input_ids = batch['src_input_ids'].to(device)
            tgt_ids = batch['tgt_input_ids'].to(device)
            outputs = model(input_ids=input_ids, decoder_input_ids=tgt_ids)
            outputs = outputs[0]
            outputs = F.log_softmax(outputs, dim=-1)
            outputs = outputs.view(-1, outputs.size(-1))
            loss = loss_fn(outputs, tgt_ids.view(-1))
            loss.backward()

            if is_accumulate:
                accumulate_cnt += 1
                if accumulate_cnt % accumulate_step == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    accumulate_cnt = 0
            else:
                optimizer.step()
                optimizer.zero_grad()

            if i % 50 == 0:
                print(f'Epoch: {epoch}, Loss:  {loss.item()}')
            del input_ids, tgt_ids, outputs, loss
            torch.cuda.empty_cache()

        model.eval()
        with torch.no_grad():
            LOSS = 0
            for (i, batch) in enumerate(val_loader):
                val_input_ids = batch['src_input_ids'].to(device)
                val_tgt_ids = batch['tgt_input_ids'].to(device)
                outputs = model(val_input_ids, decoder_input_ids=val_tgt_ids)
                outputs = outputs[0]
                outputs = F.log_softmax(outputs, dim=-1)
                outputs = outputs.view(-1, outputs.size(-1))
                loss = loss_fn(outputs, val_tgt_ids.view(-1))
                LOSS += loss.item()
                del val_input_ids, val_tgt_ids, outputs, loss
                torch.cuda.empty_cache()

            print(f'Epoch: {epoch}, Val Loss:  {LOSS/len(val_loader)}')

In [9]:
run()

Epoch: 0, Loss:  12.131403923034668
Epoch: 0, Val Loss:  5.922112941741943
Epoch: 1, Loss:  6.441112995147705
Epoch: 1, Val Loss:  1.1850165128707886
Epoch: 2, Loss:  1.9356474876403809
Epoch: 2, Val Loss:  0.2954336404800415


In [11]:
dataset = load_dataset('cnn_dailymail', '3.0.0')
test_data = dataset['test']['article'][0]
tok = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
test_data = tok(test_data, return_tensors='pt', max_length=1024, truncation=True).to(device)
summary = model.generate(test_data["input_ids"], max_length=100, num_beams=4, early_stopping=True)

In [13]:
summary.to('cpu')
summary = tok.decode(summary[0], skip_special_tokens=True)
summary

" Palestinian Foreign Minister Riad al-Malki says move is move toward greater justice. The ICC opened a preliminary examination into the situation in Palestinian territories in January. The inquiry will include alleged war crimes committed since June. Israel and the United States opposed the Palestinians' efforts to join the body."