In [None]:
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from tqdm.auto import tqdm


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
import wandb
wandb.login()

In [None]:
from datasets import load_dataset, load_metric

raw_datasets = load_dataset("wmt16", "de-en")
raw_datasets['train'] = raw_datasets['train'].select(range(int(len(raw_datasets['train']) * 0.01)))
metric = load_metric("sacrebleu")

In [None]:
from transformers import T5TokenizerFast, T5ForConditionalGeneration, T5Config

tokenizer = T5TokenizerFast.from_pretrained("t5-small")

In [None]:
max_input_length = 128
max_target_length = 128
source_lang = "en"
target_lang = "de"

def preprocess_function(examples):
    inputs = [ex[source_lang] for ex in examples["translation"]]
    targets = [ex[target_lang] for ex in examples["translation"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

In [None]:
class Translation(nn.Module):
    def __init__(self, model):
        # model: T5 with encoder and decoder
        super().__init__()

        self.model = model

    def translate(self, batch):
        outputs = self.model(
            batch['input_ids'],
            labels=batch['labels'],
            attention_mask=batch['attention_mask']
        )

        loss = outputs.loss

        return loss

    def train_one_epoch(self, dataloader, optimizer):
        self.train()
        
        for batch in tqdm(dataloader):
            for k, v in batch.items():
                batch[k] = v.to(device)

            with torch.autocast(device_type='cuda', dtype=torch.float16):
                loss = self.translate(batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            wandb.log({
                'loss': loss.item()
            })

In [None]:
from functools import partial


def collate_batch(pad_id, batch):
    input_ids = []
    labels = []
    for sample in batch:
        input_ids.append(torch.tensor(sample['input_ids'], dtype=torch.long))
        labels.append(torch.tensor(sample['labels'], dtype=torch.long))

    batch = {
        'input_ids': pad_sequence(input_ids, padding_value=pad_id, batch_first=True),
        'labels': pad_sequence(labels, padding_value=-100, batch_first=True)
    }
    batch['attention_mask'] = (batch['input_ids'] != pad_id).clone()

    return batch


translation_train_loader = torch.utils.data.DataLoader(
    tokenized_datasets['train'],
    collate_fn=partial(collate_batch, tokenizer.pad_token_id),
    batch_size=8
)

translation_val_loader = torch.utils.data.DataLoader(
    tokenized_datasets['validation'],
    collate_fn=partial(collate_batch, tokenizer.pad_token_id),
    batch_size=8
)

In [None]:
model = T5ForConditionalGeneration(T5Config.from_pretrained('t5-small'))
model.load_state_dict(torch.load('your/pretrained/model.pt'))

In [None]:
translation = Translation(model).to(device)
optimizer = torch.optim.AdamW(translation.parameters(), lr=2e-5, weight_decay=0.01)

In [None]:
wandb.init(project='project', name='name')

In [None]:
for epoch in range(2):
    translation.train_one_epoch(translation_train_loader, optimizer)

In [None]:
import numpy as np


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result


@torch.no_grad()
def evaluate(model, dataloader):
    model.eval()

    bleus = []
    for batch in tqdm(dataloader):
        for k, v in batch.items():
            batch[k] = v.to(device)

        with torch.autocast(device_type='cuda', dtype=torch.float16):
            output_sequences = model.generate(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                do_sample=False,  # disable sampling to test if batching affects output
            )

        bleu = compute_metrics((output_sequences.cpu(), batch['labels'].cpu()))['bleu']
        bleus.append(bleu)
        
    return np.mean(bleus)

In [None]:
evaluate(translation.model, translation_val_loader)