# MT5 Fine-tuning for Question Answering

This notebook demonstrates fine-tuning the MT5 model on multilingual question answering data.

In [1]:
import polars as pl
from datasets import Dataset, load_dataset
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, Seq2SeqTrainer
import torch
import evaluate
import numpy as np
import os

In [2]:
# Load dataset
dataset = load_dataset("coastalcph/tydi_xor_rc")
df_train = dataset["train"].to_polars()
df_val = dataset["validation"].to_polars()

df_te_train = df_train.filter(pl.col("lang") == "te", pl.col("answer_inlang").is_not_null())
df_te_val = df_val.filter(pl.col("lang") == "te", pl.col("answer_inlang").is_not_null())
df_te_train.head()

question,context,lang,answerable,answer_start,answer,answer_inlang
str,str,str,bool,i64,str,str
"""1990 నాటికి ఆఫ్రికాలో అతిపెద్ద…","""various archipelagos. It conta…","""te""",False,-1,"""Nigeria""","""నైజీరియా"""
"""2010 నాటికీ వ్యవసాయ రంగంలో చైన…","""A country with In [[2010]] Chi…","""te""",False,-1,"""the first""","""ప్రధమ"""
"""2011 నాటికి గొరిగపూడి గ్రామ జన…","""Gorigapudi is a village belong…","""te""",True,306,"""2229""","""2229"""
"""2011 నాటికి పెద యాచవరం గ్రామ జ…","""Peda Yachavaram is a village i…","""te""",True,247,"""4610""","""4610"""
"""ఆంధ్రప్రదేశ్ లో మొదటగా ఏ ఇంజనీ…","""Andhra University College of E…","""te""",False,-1,"""Velagapudi Ramakrishna Siddhar…","""వెలగపుడి రామకృష్ణ సిద్ధార్థ ఇం…"


In [3]:
# Select device for training
device = torch.device("cpu")
if torch.backends.mps.is_available():
    device = torch.device("mps")
if torch.cuda.is_available():
  device = torch.device("cuda")

print(f'Using device: {device}')

Using device: cuda


In [5]:
from transformers import MT5ForConditionalGeneration, T5Tokenizer
mt5_telugu_path = os.path.join(os.path.join("results", "mt5-telugu-qa"))

if os.path.exists(os.path.join(mt5_telugu_path, "fine_tuned")):
    print("Loading model from disk")
    mt5_tokenizer = T5Tokenizer.from_pretrained(os.path.join(mt5_telugu_path, "fine_tuned"))
    mt5_model = MT5ForConditionalGeneration.from_pretrained(os.path.join(mt5_telugu_path, "fine_tuned"))
    mt5_model.to(device)
else:
    print("Loading model from Huggingface")
    mt5_tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
    mt5_model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
    mt5_model.to(device)

Loading model from disk


In [6]:
#from transformers import MT5ForConditionalGeneration, T5Tokenizer
#mt5_tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
#mt5_model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
#mt5_model.to(device)

In [7]:
#from transformers import T5ForConditionalGeneration, T5Tokenizer
#t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
#t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")

In [8]:
#from transformers import MBartForConditionalGeneration, AutoTokenizer
#mbart_tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
#mbart_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")

In [9]:
#from transformers import AutoTokenizer, AutoModelForCausalLM
#gpt2_tokenizer = AutoTokenizer.from_pretrained('distilgpt2')
#gpt2_model = AutoModelForCausalLM.from_pretrained('distilgpt2')

In [10]:
def answer_question(promt: str, model, tokenizer, **kwargs) -> str:
    input_tokens = tokenizer(promt, return_tensors="pt", truncation=True, max_length=512).to(device)
    generated_tokens = model.generate(
        **input_tokens,
        max_new_tokens=64,
        **kwargs
    )
    answer = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
    return answer

In [11]:
question = "What is the best way to fall asleep?"
context = "Do not drink a white monster right before going to bed"

In [12]:
answer_question(f"question: {question}  context: {context}", mt5_model, mt5_tokenizer,  early_stopping=False)

'<extra_id_0>?'

In [13]:
#answer_question(f"question: {question}  context: {context}", t5_model, t5_tokenizer, early_stopping=True)

In [14]:
#question = "What is the best way to fall asleep?"
#context = "Do not drink a white monster right before going to bed"
#answer_question(f"Answer the following question based on the given context. \n Context: {context} \n Question: {question}", mbart_model, mbart_tokenizer, 
#    forced_bos_token_id=mbart_tokenizer.lang_code_to_id["en_XX"],
#    early_stopping=True,
#    num_beams=4,
#    length_penalty=1.2,
#    temperature=0.7,
#    top_p=0.92,
#    top_k=50,
#    repetition_penalty=1.3,
#    )

In [15]:
#answer_question(f"Answer the following question based on the given context. \n Context: {context} \n Question: {question}", gpt2_model, gpt2_tokenizer, 
#    early_stopping=True,
#    temperature=0.7,
#    top_p=0.92,  
#    top_k=50,
#    repetition_penalty=1.2,
#    no_repeat_ngram_size=2
#    )

In [16]:
def generate_prompts(df: pl.DataFrame):
    #df = df.with_columns([
    #    (pl.lit("Question: ") + pl.col("question") + pl.lit("\n Context: ") + pl.col("context")).alias("prompt")
    #])
    df = df.with_columns([pl.col("question").alias("prompt")])
    return df

def tokenize_to_dataset(df: pl.DataFrame, tokenizer, question_col: str = "prompt", answer_col: str = "answer_inlang"):

    inputs = tokenizer(
        df[question_col].to_list(),
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=512
    )
    labels = tokenizer(
        df[answer_col].to_list(),
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=64
    )

    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100

    dataset_dict = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "labels": labels
    }

    # Convert to Hugging Face Dataset
    dataset = Dataset.from_dict({k: v.numpy() for k, v in dataset_dict.items()})
    return dataset

In [17]:
df_te_train_prompt = generate_prompts(df_te_train)
df_te_val_prompt = generate_prompts(df_te_val)

train_dataset = tokenize_to_dataset(df_te_train_prompt, mt5_tokenizer)
val_dataset = tokenize_to_dataset(df_te_val_prompt, mt5_tokenizer)

In [None]:
# https://huggingface.co/learn/llm-course/chapter7/4?fw=pt

metric = evaluate.load("sacrebleu")
epochs = 16

training_args = Seq2SeqTrainingArguments(
    output_dir=mt5_telugu_path,
    overwrite_output_dir = True,
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=6,
    weight_decay=0.01,
    num_train_epochs=epochs,
    predict_with_generate=True,
    fp16=False,

    save_total_limit=5,
    save_strategy = "best",
    load_best_model_at_end = True,

    logging_strategy="epoch",
    eval_strategy = "steps",
    eval_steps = epochs * 5,
    log_level="info",
    report_to=[],
    logging_dir=None
)

data_collator = DataCollatorForSeq2Seq(mt5_tokenizer, model=mt5_model)

def compute_metrics(eval_preds, tokenizer=mt5_tokenizer):
    preds, labels = eval_preds
    # In case the model returns more than the prediction logits
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100s in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [[label.strip()] for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    return {"bleu": result["score"]}



trainer = Seq2SeqTrainer(
    model=mt5_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=mt5_tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

  trainer = Seq2SeqTrainer(


In [19]:
trainer.train()
trainer.save_model(os.path.join(mt5_telugu_path, "fine_tuned"))

***** Running training *****
  Num examples = 50
  Num Epochs = 16
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 208
  Number of trainable parameters = 300,176,768


Step,Training Loss,Validation Loss,Bleu
80,17.4574,14.886859,0.0
160,17.3459,14.460106,0.0


Saving model checkpoint to results\mt5-telugu-qa\checkpoint-13
Configuration saved in results\mt5-telugu-qa\checkpoint-13\config.json
Configuration saved in results\mt5-telugu-qa\checkpoint-13\generation_config.json
Model weights saved in results\mt5-telugu-qa\checkpoint-13\model.safetensors
tokenizer config file saved in results\mt5-telugu-qa\checkpoint-13\tokenizer_config.json
Special tokens file saved in results\mt5-telugu-qa\checkpoint-13\special_tokens_map.json
Deleting older checkpoint [results\mt5-telugu-qa\checkpoint-195] due to args.save_total_limit
Saving model checkpoint to results\mt5-telugu-qa\checkpoint-26
Configuration saved in results\mt5-telugu-qa\checkpoint-26\config.json
Configuration saved in results\mt5-telugu-qa\checkpoint-26\generation_config.json
Model weights saved in results\mt5-telugu-qa\checkpoint-26\model.safetensors
tokenizer config file saved in results\mt5-telugu-qa\checkpoint-26\tokenizer_config.json
Special tokens file saved in results\mt5-telugu-qa\ch

In [None]:
tokenizer = T5Tokenizer.from_pretrained(os.path.join(mt5_telugu_path, "fine_tuned"))
model = MT5ForConditionalGeneration.from_pretrained(os.path.join(mt5_telugu_path, "fine_tuned"))
model.to(device)

question = df_te_val_prompt["prompt"][0]
answer = df_te_val_prompt["answer_inlang"][0]

inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512).to(device)
outputs = model.generate(**inputs, max_new_tokens=64, early_stopping=True)
gen_answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Question: {question}")
print(f"Answer: {answer}")
print(f"Generated Answer: {gen_answer}")

Question: మలేరియా వ్యాధి కి మందు కనిపెట్టిన శాస్త్రవేత్త ఎవరు?
Answer: హన్స్ ఆండర్సాగ్
Generated Answer: <extra_id_0>.
