# mT5 fine-tuned for generative question answering

In [None]:
import polars as pl
import os
import gc

from transformers import (
    MT5Tokenizer,
    MT5ForQuestionAnswering,
    TrainingArguments,
    Trainer,
)
from datasets import Dataset, load_dataset
import torch

In [47]:
model_checkpoint = "google/mt5-small"
tokenizer = MT5Tokenizer.from_pretrained(model_checkpoint, use_fast=False)
model = MT5ForQuestionAnswering.from_pretrained(model_checkpoint)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'MT5Tokenizer'.
You are using the default legacy behaviour of the <class 'transformers.models.mt5.tokenization_mt5.MT5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Some weights of MT5ForQuestionAnswering were not initialized from the model checkpoint at google/mt5-small and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use

In [48]:
# Load dataset
dataset = load_dataset("coastalcph/tydi_xor_rc")
df_train = dataset["train"].to_polars()
df_val = dataset["validation"].to_polars()

# Arabic, Telegu and Korean
df_te_train = df_train.filter(pl.col("lang") == "te", pl.col("answer_inlang").is_not_null())
df_te_val = df_val.filter(pl.col("lang") == "te", pl.col("answer_inlang").is_not_null())
df_te_train.head()

question,context,lang,answerable,answer_start,answer,answer_inlang
str,str,str,bool,i64,str,str
"""1990 నాటికి ఆఫ్రికాలో అతిపెద్ద…","""various archipelagos. It conta…","""te""",False,-1,"""Nigeria""","""నైజీరియా"""
"""2010 నాటికీ వ్యవసాయ రంగంలో చైన…","""A country with In [[2010]] Chi…","""te""",False,-1,"""the first""","""ప్రధమ"""
"""2011 నాటికి గొరిగపూడి గ్రామ జన…","""Gorigapudi is a village belong…","""te""",True,306,"""2229""","""2229"""
"""2011 నాటికి పెద యాచవరం గ్రామ జ…","""Peda Yachavaram is a village i…","""te""",True,247,"""4610""","""4610"""
"""ఆంధ్రప్రదేశ్ లో మొదటగా ఏ ఇంజనీ…","""Andhra University College of E…","""te""",False,-1,"""Velagapudi Ramakrishna Siddhar…","""వెలగపుడి రామకృష్ణ సిద్ధార్థ ఇం…"


In [49]:
# Prepare your data
def prepare_data(df: pl.DataFrame) -> Dataset:
    # Convert Polars to dict format for HF datasets
    data_dict = {
        "question": df["question"].to_list(),
        "context": df["context"].to_list(),
        "answers": df["answer_inlang"].to_list(),
    }
    return Dataset.from_dict(data_dict)


# Tokenization function
def tokenize_function(examples: Dataset, tokenizer: AutoTokenizer):
    # Tokenize with question and content separated by [SEP]
    # [CLS] is added automatically
    return tokenizer(
        examples["question"],
        examples["context"],
        truncation=True,
        padding="max_length",
        max_length=512,
        return_tensors="pt",
    ) # type: ignore

In [56]:
def train_qa_mt5(
    tokenized_train: Dataset,
    tokenized_val: Dataset,
    model_checkpoint: str = "google/mt5-small",
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
) -> tuple[MT5ForQuestionAnswering, MT5Tokenizer]:
    # Load model
    qa_generator = MT5ForQuestionAnswering.from_pretrained(model_checkpoint).to(device)
    # Load tokenizer (mostly for saving complete model later)
    tokenizer = MT5Tokenizer.from_pretrained(model_checkpoint)

    # Training arguments
    training_args = TrainingArguments(
        output_dir="./results",
        eval_strategy="epoch",
        learning_rate=2e-5,
        num_train_epochs=3,
        # Regularization
        weight_decay=0.01,
        # Memory settings
        per_device_train_batch_size=4,
        gradient_accumulation_steps=2,
        fp16=True,
        # Evaluation
        per_device_eval_batch_size=8,
        save_strategy="epoch",
        load_best_model_at_end=True,
    )

    # Trainer
    trainer = Trainer(
        model=qa_generator,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val,
    )
    # Clear torch cache before training
    gc.collect()
    torch.cuda.empty_cache()
    # Train and save the model
    print("Training mBERT classifier...")
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    print(f"Environment variable set: {os.environ['PYTORCH_CUDA_ALLOC_CONF']}")
    trainer.train()

    return qa_generator, tokenizer

In [57]:
train_dataset = prepare_data(df_te_train)
val_dataset = prepare_data(df_te_val)
tokenized_train = train_dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True)
tokenized_val = val_dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True)

model, tokenizer = train_qa_mt5(tokenized_train, tokenized_val)

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Some weights of MT5ForQuestionAnswering were not initialized from the model checkpoint at google/mt5-small and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'MT5Tokenizer'.


NameError: name 'gc' is not defined