# mT5 fine-tuned for generative question answering

In [None]:
import polars as pl
import os

from transformers import (
    MT5Tokenizer,
    MT5ForQuestionAnswering,
    TrainingArguments,
    Trainer,
)
from datasets import Dataset, load_dataset
import torch

In [None]:
model_checkpoint = "google/mt5-small"
tokenizer = MT5Tokenizer.from_pretrained(model_checkpoint, use_fast=False)
model = MT5ForQuestionAnswering.from_pretrained(model_checkpoint)

In [None]:
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"

inputs = tokenizer(question, text, return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs)

answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()

predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)

# target is "nice puppet"
target_start_index = torch.tensor([14])
target_end_index = torch.tensor([15])

outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
loss = outputs.loss
round(loss.item(), 2)

In [None]:
# Load dataset
dataset = load_dataset("coastalcph/tydi_xor_rc")
df_train = dataset["train"].to_polars()
df_val = dataset["validation"].to_polars()

# Arabic, Telegu and Korean
df_te_train = df_train.filter(pl.col("lang") == "te", pl.col("answer_inlang").is_not_null())
df_te_val = df_val.filter(pl.col("lang") == "te", pl.col("answer_inlang").is_not_null())
df_te_train.head()

In [None]:
# Prepare your data
def prepare_data(df: pl.DataFrame) -> Dataset:
    # Convert Polars to dict format for HF datasets
    data_dict = {
        "question": df["question"].to_list(),
        "context": df["context"].to_list(),
        "answers": df["answer_inlang"].to_list(),
    }
    return Dataset.from_dict(data_dict)


# Tokenization function
def tokenize_function(examples: Dataset, tokenizer: AutoTokenizer):
    # Tokenize with question and content separated by [SEP]
    # [CLS] is added automatically
    return tokenizer(
        examples["question"],
        examples["context"],
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors="pt",
    ) # type: ignore

In [None]:
def train_qa_mt5(
    tokenized_train: Dataset,
    tokenized_val: Dataset,
    model_checkpoint: str = "google/mt5-small",
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
) -> tuple[MT5ForQuestionAnswering, MT5Tokenizer]:
    # Load model
    qa_generator = MT5ForQuestionAnswering.from_pretrained(model_checkpoint).to(device)
    # Load tokenizer (mostly for saving complete model later)
    tokenizer = MT5Tokenizer.from_pretrained(model_checkpoint)

    # Training arguments
    training_args = TrainingArguments(
        output_dir="./results",
        eval_strategy="epoch",
        learning_rate=2e-5,
        num_train_epochs=3,
        # Regularization
        weight_decay=0.01,
        # Memory settings
        per_device_train_batch_size=4,
        gradient_accumulation_steps=2,
        fp16=True,
        # Evaluation
        per_device_eval_batch_size=8,
        save_strategy="epoch",
        load_best_model_at_end=True,
    )

    # Trainer
    trainer = Trainer(
        model=qa_generator,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val,
    )
    # Clear torch cache before training
    torch.cuda.empty_cache()
    # Train and save the model
    print("Training mBERT classifier...")
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    print(f"Environment variable set: {os.environ['PYTORCH_CUDA_ALLOC_CONF']}")
    trainer.train()

    return qa_generator, tokenizer

In [None]:
train_dataset = prepare_data(df_te_train)
val_dataset = prepare_data(df_te_val)
tokenized_train = train_dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True)
tokenized_val = val_dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True)

model, tokenizer = train_qa_mt5(tokenized_train, tokenized_val)