In [1]:
import evaluate
import accelerate
import numpy as np
import pandas as pd
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)

MODEL_REPO = "google/mT5-small"
PREFIX = "translate English to Spanish: "

In [2]:
bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")
meteor = evaluate.load("meteor")
ter = evaluate.load("ter")
METRICS = [
    ("BLEU", bleu),
    ("ROUGE", rouge),
    ("METEOR", meteor),
    ("TER", ter),
]

[nltk_data] Downloading package wordnet to /home/midge/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/midge/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/midge/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [3]:
data = pd.read_csv("./data/combined.data")
train = data.loc[data["split"] != "test"]
test = data.loc[data["split"] == "test"]

In [4]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [5]:
def preprocess_text(sample):
    input = PREFIX + str(sample[0])
    target = str(sample[1])
    return tokenizer(input, text_target=target, max_length=128, truncation=True)


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

In [6]:
tokenized_train_data = [
    preprocess_text((row["en"], row["es"])) for _, row in train.iterrows()
]
tokenized_test_data = [
    preprocess_text((row["en"], row["es"])) for _, row in test.iterrows()
]

In [7]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer, model=MODEL_REPO, return_tensors="pt"
)

In [8]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = {}
    for name, metric in METRICS:
        result = metric.compute(predictions=decoded_preds, references=decoded_labels)
        result = {name: result["score"]}

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}

    return result

In [9]:
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_REPO)

training_args = Seq2SeqTrainingArguments(
    output_dir="mt5",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=False,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_data,
    eval_dataset=tokenized_test_data,
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [10]:
trainer.train()

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss


ZeroDivisionError: float division by zero

In [11]:
trainer.save_model("mt5")

In [None]:
# import torch
# import gc

# gc.collect()
# torch.cuda.empty_cache()