In [1]:
from metrics import WordLength, WordExact

import os

from datasets import Dataset, DatasetDict
import pandas as pd
import torch
from torchmetrics import MeanMetric
from torchmetrics.text import EditDistance
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import T5ForConditionalGeneration, T5Tokenizer, DataCollatorWithPadding
from transformers import get_linear_schedule_with_warmup
from tqdm.auto import tqdm

In [2]:
batch_size = 32
num_epochs = 3
eval_steps = 200
learning_rate = 1e-4
checkpoint = "t5-small"
run_name = "test"
run_dir = os.path.join("runs", run_name)

In [3]:
tokenizer = T5Tokenizer.from_pretrained(checkpoint)
model = T5ForConditionalGeneration.from_pretrained(checkpoint)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
writer = SummaryWriter(log_dir=os.path.join(run_dir, "logs"))
os.makedirs(os.path.join(run_dir, "checkpoints"), exist_ok=True)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [4]:
def tokenize_function(example):
    inputs = tokenizer(example["clue"], truncation=True, max_length=96, padding="max_length")
    targets = tokenizer(example["answer"], truncation=True, max_length=32, padding="max_length")

    return {
        "input_ids": inputs.input_ids,
        "attention_mask": inputs.attention_mask,
        "labels": targets.input_ids
    }

In [5]:
raw_datasets = DatasetDict(
    {
        "train": Dataset.from_pandas(pd.read_csv("data/train.csv")),
        "eval": Dataset.from_pandas(pd.read_csv("data/eval.csv").iloc[:1000]),
    }
)

tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)

tokenized_datasets = tokenized_datasets.remove_columns(["clue", "answer"])
tokenized_datasets.set_format("torch")

Map:   0%|          | 0/433821 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [6]:
train_dataloader = DataLoader(tokenized_datasets["train"], batch_size=batch_size, shuffle=True)
eval_dataloader = DataLoader(tokenized_datasets["eval"], batch_size=batch_size)

In [7]:
def save_checkpoint(step, model, optimizer, scheduler, best_val_loss, checkpoint_dir="checkpoints", filename="checkpoint.pth"):
    checkpoint_path = os.path.join(run_dir, checkpoint_dir, filename)
    torch.save({
        'epoch': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),  # Save scheduler state
        'best_val_loss': best_val_loss
    }, checkpoint_path)
    print(f"Checkpoint saved: {checkpoint_path}")

In [8]:
def decode_output(logits, labels):
    predictions = torch.argmax(logits, dim=-1)
    predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    return predictions, labels

In [9]:
train_loss_metric = MeanMetric()
eval_loss_metric = MeanMetric()
train_exact_metric = WordExact()
eval_exact_metric = WordExact()
train_length_metric = WordLength()
eval_length_metric = WordLength()
train_edit_metric = EditDistance()
eval_edit_metric = EditDistance()

In [10]:
def evaluate():
    eval_loss_metric.reset()
    eval_length_metric.reset()
    eval_exact_metric.reset()
    eval_edit_metric.reset()
    
    model.eval()
    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}

        with torch.cuda.amp.autocast():
            outputs = model(**batch)
            loss = outputs.loss

        predictions, labels = decode_output(outputs[1], batch["labels"])

        eval_loss_metric.update(loss.item(), batch["input_ids"].size(0))
        eval_length_metric.update(predictions, labels)
        eval_exact_metric.update(predictions, labels)
        eval_edit_metric.update(predictions, labels)

    metrics = {
        "eval/loss": eval_loss_metric.compute(),
        "eval/length": eval_length_metric.compute(),
        "eval/exact": eval_exact_metric.compute(),
        "eval/edit": eval_edit_metric.compute(),
    }

    return metrics

In [11]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)

num_training_steps = num_epochs * len(train_dataloader)

scaler = torch.cuda.amp.GradScaler()
optimizer = AdamW(model.parameters(), lr=learning_rate, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(num_training_steps * 0.1),
    num_training_steps=num_training_steps,
)



In [12]:
step = 0
best_val_loss = float("inf")
for epoch in tqdm(range(num_epochs)):
    for batch in tqdm(train_dataloader, leave=False):
        batch = {k: v.to(device) for k, v in batch.items()}
    
        if (step + 1) % eval_steps == 0:
            eval_metrics = evaluate()
    
            for key, value in eval_metrics.items():
                writer.add_scalar(key, value, step)
                
            if eval_metrics["eval/loss"] < best_val_loss:
                best_val_loss = eval_metrics["eval/loss"]
                save_checkpoint(step, model, optimizer, scheduler, best_val_loss)
    
        model.train()
        with torch.cuda.amp.autocast():
            outputs = model(**batch)
            loss = outputs.loss
    
        train_loss_metric.reset()
        train_length_metric.reset()
        train_edit_metric.reset()
        train_exact_metric.reset()
    
        predictions, labels = decode_output(outputs[1], batch["labels"])
        train_loss_metric.update(loss.item(), batch["input_ids"].size(0))
        train_length_metric.update(predictions, labels)
        train_edit_metric.update(predictions, labels)
        train_exact_metric.update(predictions, labels)
    
        writer.add_scalar("train/loss", loss.item(), step)
        writer.add_scalar("train/length", train_length_metric.compute(), step)
        writer.add_scalar("train/exact", train_exact_metric.compute(), step)
        writer.add_scalar("train/edit", train_edit_metric.compute(), step)
    
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    
        current_lr = optimizer.param_groups[0]['lr']
        writer.add_scalar("learning_rate", current_lr, step)
    
        scheduler.step()
        step += 1

writer.close()

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/13557 [00:00<?, ?it/s]

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


KeyboardInterrupt: 