In [None]:
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from tqdm.auto import tqdm


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
import wandb
wandb.login()

In [None]:
from datasets import load_dataset, load_metric

raw_datasets = load_dataset("xsum")
metric = load_metric("rouge")

In [None]:
from transformers import T5TokenizerFast, T5ForConditionalGeneration, T5Config

tokenizer = T5TokenizerFast.from_pretrained("t5-small")

In [None]:
max_input_length = 512
max_target_length = 128

def preprocess_function(examples):
    inputs = examples["document"]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

In [None]:
class Summarization(nn.Module):
    def __init__(self, model):
        # model: T5 with encoder and decoder
        super().__init__()
        
        self.model = model
        

    def summarize(self, batch):
        outputs = self.model(
            batch['input_ids'],
            labels=batch['labels'],
            attention_mask=batch['attention_mask']
        )

        # outputs = model(input_ids=input_ids, labels=labels)
        
        loss = outputs.loss

        return loss

    def train_one_epoch(self, dataloader, optimizer):
        self.train()
        
        for batch in tqdm(dataloader):
            for k, v in batch.items():
                batch[k] = v.to(device)

            with torch.autocast(device_type='cuda', dtype=torch.float16):
                loss = self.summarize(batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            wandb.log({
                'loss': loss.item()
            })

In [None]:
from functools import partial


def collate_batch(pad_id, batch):
    input_ids = []
    labels = []
    for sample in batch:
        input_ids.append(torch.tensor(sample['input_ids'], dtype=torch.long))
        labels.append(torch.tensor(sample['labels'], dtype=torch.long))

    batch = {
        'input_ids': pad_sequence(input_ids, padding_value=pad_id, batch_first=True),
        'labels': pad_sequence(labels, padding_value=-100, batch_first=True)
    }
    batch['attention_mask'] = (batch['input_ids'] != pad_id).clone()
    
    return batch


sum_train_loader = torch.utils.data.DataLoader(
    tokenized_datasets['train'],
    collate_fn=partial(collate_batch, tokenizer.pad_token_id),
    batch_size=16
)

sum_val_loader = torch.utils.data.DataLoader(
    tokenized_datasets['validation'],
    collate_fn=partial(collate_batch, tokenizer.pad_token_id),
    batch_size=16
)

In [None]:
model = T5ForConditionalGeneration(T5Config.from_pretrained('t5-small'))
model.load_state_dict(torch.load('your/pretrained/model.pt'))

In [None]:
summarization = Summarization(model).to(device)
optimizer = torch.optim.AdamW(summarization.parameters(), lr=2e-5, weight_decay=0.01)

In [None]:
wandb.init(project='project', name='name')

In [None]:
for epoch in range(2):
    summarization.train_one_epoch(sum_train_loader, optimizer)

In [None]:
import nltk
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}


@torch.no_grad()
def evaluate(model, dataloader):
    model.eval()

    rouges = []
    for batch in tqdm(dataloader):
        for k, v in batch.items():
            batch[k] = v.to(device)

        with torch.autocast(device_type='cuda', dtype=torch.float16):
            output_sequences = model.generate(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                do_sample=False,  # disable sampling to test if batching affects output
            )

        rouge = compute_metrics((output_sequences.cpu(), batch['labels'].cpu()))['rouge1']
        rouges.append(rouge)

    return np.mean(rouges)

In [None]:
evaluate(summarization.model, sum_val_loader)