# MT5 Fine-tuning for Question Answering

This notebook demonstrates fine-tuning the MT5 model on multilingual question answering data.

In [1]:
import polars as pl
from datasets import load_dataset
from transformers import MT5ForConditionalGeneration, T5Tokenizer
import torch
import evaluate
import numpy as np
import os

from mt5_utils import generate_prompt_with_context, generate_prompt_wo_context, tokenize_to_dataset, trainer_generator, get_extra_id_0_proportion

In [2]:
# Global vars
TRAIN = False
mt5_telugu_w_context_save_path = os.path.join(os.path.join("results", "mt5-telugu-qa-w-context"))
mt5_telugu_wo_context_save_path = os.path.join(os.path.join("results", "mt5-telugu-qa-wo-context"))

In [3]:
# Load dataset
dataset = load_dataset("coastalcph/tydi_xor_rc")
df_train = dataset["train"].to_polars()
df_val = dataset["validation"].to_polars()
df_te_train = df_train.filter(pl.col("lang") == "te", pl.col("answer_inlang").is_not_null())
df_te_val = df_val.filter(pl.col("lang") == "te", pl.col("answer_inlang").is_not_null())

print(f"Train size: {df_te_train.shape}, Val size: {df_te_val.shape}")

Train size: (50, 7), Val size: (100, 7)


load model

In [4]:
# Select device for training
device = torch.device("cpu")
if torch.backends.mps.is_available():
    device = torch.device("mps")
if torch.cuda.is_available():
  device = torch.device("cuda")

print(f'Using device: {device}')

Using device: cuda


In [5]:
if os.path.exists(os.path.join(mt5_telugu_w_context_save_path, "fine_tuned")):
    print("Loading model from disk")
    mt5_w_context_tokenizer = T5Tokenizer.from_pretrained(os.path.join(mt5_telugu_w_context_save_path, "fine_tuned"))
    mt5_w_context_model = MT5ForConditionalGeneration.from_pretrained(os.path.join(mt5_telugu_w_context_save_path, "fine_tuned"))
    mt5_w_context_model.to(device)
else:
    print("Loading model from Huggingface")
    mt5_w_context_tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
    mt5_w_context_model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
    mt5_w_context_model.to(device)

Loading model from disk


prepare dataset fo training (with context)

In [6]:
df_te_train_prompt_w_context = generate_prompt_with_context(df_te_val) # Flip train and val since val is bigger
df_te_val_prompt_w_context = generate_prompt_with_context(df_te_train)
train_dataset_w_context = tokenize_to_dataset(df_te_train_prompt_w_context, mt5_w_context_tokenizer)
val_dataset_w_context = tokenize_to_dataset(df_te_val_prompt_w_context, mt5_w_context_tokenizer)

In [7]:
trainer = trainer_generator(
    model=mt5_w_context_model,
    tokenizer=mt5_w_context_tokenizer,
    train_dataset=train_dataset_w_context,
    eval_dataset=val_dataset_w_context,
    output_dir=mt5_telugu_w_context_save_path,
    epochs=25
)
if TRAIN:
    trainer.train()
    trainer.save_model(os.path.join(mt5_telugu_w_context_save_path, "fine_tuned"))
results = trainer.evaluate()
print(results)

  trainer = Seq2SeqTrainer(

***** Running Evaluation *****
  Num examples = 50
  Batch size = 8


{'eval_loss': 11.063244819641113, 'eval_model_preparation_time': 0.003, 'eval_bleu': 0.1785, 'eval_rouge1': 0.0, 'eval_rouge2': 0.0, 'eval_rougeL': 0.0, 'eval_rougeLsum': 0.0, 'eval_runtime': 3.746, 'eval_samples_per_second': 13.348, 'eval_steps_per_second': 1.869}


In [8]:
if os.path.exists(os.path.join(mt5_telugu_wo_context_save_path, "fine_tuned")):
    print("Loading model from disk")
    mt5_wo_context_tokenizer = T5Tokenizer.from_pretrained(os.path.join(mt5_telugu_wo_context_save_path, "fine_tuned"))
    mt5_wo_context_model = MT5ForConditionalGeneration.from_pretrained(os.path.join(mt5_telugu_wo_context_save_path, "fine_tuned"))
    mt5_wo_context_model.to(device)
else:
    print("Loading model from Huggingface")
    mt5_wo_context_tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
    mt5_wo_context_model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
    mt5_wo_context_model.to(device)

loading file spiece.model
loading file added_tokens.json
loading file special_tokens_map.json
loading file tokenizer_config.json
loading file tokenizer.json
loading file chat_template.jinja


Loading model from disk


loading configuration file results\mt5-telugu-qa-wo-context\fine_tuned\config.json
Model config MT5Config {
  "architectures": [
    "MT5ForConditionalGeneration"
  ],
  "classifier_dropout": 0.0,
  "d_ff": 1024,
  "d_kv": 64,
  "d_model": 512,
  "decoder_start_token_id": 0,
  "dense_act_fn": "gelu_new",
  "dropout_rate": 0.1,
  "dtype": "float32",
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "mt5",
  "num_decoder_layers": 8,
  "num_heads": 6,
  "num_layers": 8,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "tie_word_embeddings": false,
  "tokenizer_class": "T5Tokenizer",
  "transformers_version": "4.57.0",
  "use_cache": true,
  "vocab_size": 250112
}

loading weights file results\mt5-telugu-qa-wo-context\fine_tuned\model.safetensors
Generate config GenerationConfig {
  "decoder_start_t

prepare dataset fo training (without context)

In [9]:
df_te_train_prompt_wo_context = generate_prompt_wo_context(df_te_val)
df_te_val_prompt_wo_context = generate_prompt_wo_context(df_te_train)
train_dataset_wo_context = tokenize_to_dataset(df_te_train_prompt_wo_context, mt5_wo_context_tokenizer)
val_dataset_wo_context = tokenize_to_dataset(df_te_val_prompt_wo_context, mt5_wo_context_tokenizer)

In [10]:
trainer = trainer_generator(
    model=mt5_wo_context_model,
    tokenizer=mt5_wo_context_tokenizer,
    train_dataset=train_dataset_wo_context,
    eval_dataset=val_dataset_wo_context,
    output_dir=mt5_telugu_wo_context_save_path,
    epochs=25
)
if TRAIN:
    trainer.train()
    trainer.save_model(os.path.join(mt5_telugu_wo_context_save_path, "fine_tuned"))
results = trainer.evaluate()
print(results)

PyTorch: setting up devices

***** Running Evaluation *****
  Num examples = 50
  Batch size = 8


{'eval_loss': 11.946054458618164, 'eval_model_preparation_time': 0.003, 'eval_bleu': 0.1277, 'eval_rouge1': 0.0, 'eval_rouge2': 0.0, 'eval_rougeL': 0.0, 'eval_rougeLsum': 0.0, 'eval_runtime': 9.3689, 'eval_samples_per_second': 5.337, 'eval_steps_per_second': 0.747}


In [11]:
#import gc
#gc.collect()
#with torch.no_grad():
#    torch.cuda.empty_cache()

Generate a random answer

In [12]:
random_num = np.random.randint(0, len(df_te_val_prompt_w_context))

question = df_te_val_prompt_w_context["prompt"][random_num]
answer = df_te_val_prompt_w_context["answer_inlang"][random_num]

inputs = mt5_w_context_tokenizer(
    question, 
    return_tensors="pt", 
    truncation=True, 
    max_length=512
    ).to(device)    

outputs = mt5_w_context_model.generate(**inputs)
gen_answer = mt5_w_context_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Question: {question}")
print(f"Answer: {answer}")
print(f"Generated Answer: {gen_answer}")

Question: ఆంధ్రప్రదేశ్ లో మొదటగా ఏ ఇంజనీరింగ్ కళాశాలలో సివిల్ ఇంజనీరింగ్ విభాగాన్ని ప్రారంభించారు?</s>The Indian Railways Institute of Mechanical and Electrical Engineering (IRIMEE), was founded in 1888 as a technical school and commenced training Mechanical Officers for Indian Railways in 1927. It is the oldest of the five Centralised Training Institutes (CTIs) for training officers for Indian Railways. IRIMEE is located at Jamalpur in the Munger district of Bihar, on the Patna-Bhagalpur rail route. IRIMEE provides theoretical and practical training for a four-year undergraduate degree in mechanical engineering as well as professional courses to officers and supervisors of Indian Railways. There are also courses for non-railway organizations and foreign railways. Traditionally a center
Answer: వెలగపుడి రామకృష్ణ సిద్ధార్థ ఇంజనీరింగ్ కళాశాల
Generated Answer: <extra_id_0>.


# Evaluate difference between answerable and not answerable questions

In [13]:
df_te_train_only_answerable = df_te_train.filter(pl.col("answerable") == False)
df_te_val_only_answerable = df_te_val.filter(pl.col("answerable") == False)
df_te_train_only_unanswerable = df_te_train.filter(pl.col("answerable") == True)
df_te_val_only_unanswerable = df_te_val.filter(pl.col("answerable") == True)

print(f"Train answerable size: {df_te_train_only_answerable.shape}, Val answerable size: {df_te_val_only_answerable.shape}")
print(f"Train unanswerable size: {df_te_train_only_unanswerable.shape}, Val unanswerable size: {df_te_val_only_unanswerable.shape}")

Train answerable size: (45, 7), Val answerable size: (93, 7)
Train unanswerable size: (5, 7), Val unanswerable size: (7, 7)


In [14]:
# with context and answerable
df_te_val_prompt_w_context_answerable = generate_prompt_with_context(df_te_train_only_answerable)
val_dataset_w_context_answerable = tokenize_to_dataset(df_te_val_prompt_w_context_answerable, mt5_w_context_tokenizer)

# without context and answerable
df_te_val_prompt_wo_context_answerable = generate_prompt_wo_context(df_te_train_only_answerable)
val_dataset_wo_context_answerable = tokenize_to_dataset(df_te_val_prompt_wo_context_answerable, mt5_wo_context_tokenizer)

# with context and unanswerable
df_te_val_prompt_w_context_unanswerable = generate_prompt_with_context(df_te_train_only_unanswerable)
val_dataset_w_context_unanswerable = tokenize_to_dataset(df_te_val_prompt_w_context_unanswerable, mt5_w_context_tokenizer)

# without context and unanswerable
df_te_val_prompt_wo_context_unanswerable = generate_prompt_wo_context(df_te_train_only_unanswerable)
val_dataset_wo_context_unanswerable = tokenize_to_dataset(df_te_val_prompt_wo_context_unanswerable, mt5_wo_context_tokenizer)

In [15]:
# with context and answerable
trainer = trainer_generator(
    model=mt5_w_context_model,
    tokenizer=mt5_w_context_tokenizer,
    train_dataset=val_dataset_w_context_answerable,
    eval_dataset=val_dataset_w_context_answerable,
    output_dir="tmp",
    epochs=1
)
results = trainer.evaluate()
print(results)

print(get_extra_id_0_proportion(df_te_val_prompt_w_context_answerable, mt5_w_context_tokenizer, mt5_w_context_model, device))

PyTorch: setting up devices

***** Running Evaluation *****
  Num examples = 45
  Batch size = 8


{'eval_loss': 10.996822357177734, 'eval_model_preparation_time': 0.003, 'eval_bleu': 0.1982, 'eval_rouge1': 0.0, 'eval_rouge2': 0.0, 'eval_rougeL': 0.0, 'eval_rougeLsum': 0.0, 'eval_runtime': 6.8359, 'eval_samples_per_second': 6.583, 'eval_steps_per_second': 0.878}
0.8666666666666667


In [16]:
# without context and answerable
trainer = trainer_generator(
    model=mt5_wo_context_model,
    tokenizer=mt5_wo_context_tokenizer,
    train_dataset=val_dataset_wo_context_answerable,
    eval_dataset=val_dataset_wo_context_answerable,
    output_dir="tmp",
    epochs=1
)
results = trainer.evaluate()
print(results)

print(get_extra_id_0_proportion(df_te_val_prompt_wo_context_answerable, mt5_wo_context_tokenizer, mt5_wo_context_model, device))

PyTorch: setting up devices

***** Running Evaluation *****
  Num examples = 45
  Batch size = 8


{'eval_loss': 11.54616928100586, 'eval_model_preparation_time': 0.003, 'eval_bleu': 0.1424, 'eval_rouge1': 0.0, 'eval_rouge2': 0.0, 'eval_rougeL': 0.0, 'eval_rougeLsum': 0.0, 'eval_runtime': 7.4229, 'eval_samples_per_second': 6.062, 'eval_steps_per_second': 0.808}
0.4


In [17]:
# with context and unanswerable
trainer = trainer_generator(
    model=mt5_w_context_model,
    tokenizer=mt5_w_context_tokenizer,
    train_dataset=val_dataset_w_context_unanswerable,
    eval_dataset=val_dataset_w_context_unanswerable,
    output_dir="tmp",
    epochs=1
)
results = trainer.evaluate()
print(results)

print(get_extra_id_0_proportion(df_te_val_prompt_w_context_unanswerable, mt5_w_context_tokenizer, mt5_w_context_model, device))

PyTorch: setting up devices

***** Running Evaluation *****
  Num examples = 5
  Batch size = 8


{'eval_loss': 14.556205749511719, 'eval_model_preparation_time': 0.003, 'eval_bleu': 0.0, 'eval_rouge1': 0.0, 'eval_rouge2': 0.0, 'eval_rougeL': 0.0, 'eval_rougeLsum': 0.0, 'eval_runtime': 0.5474, 'eval_samples_per_second': 9.134, 'eval_steps_per_second': 1.827}
1.0


In [18]:
# without context and unanswerable
trainer = trainer_generator(
    model=mt5_wo_context_model,
    tokenizer=mt5_wo_context_tokenizer,
    train_dataset=val_dataset_wo_context_unanswerable,
    eval_dataset=val_dataset_wo_context_unanswerable,
    output_dir="tmp",
    epochs=1
)
results = trainer.evaluate()
print(results)

print(get_extra_id_0_proportion(df_te_val_prompt_wo_context_unanswerable, mt5_wo_context_tokenizer, mt5_wo_context_model, device))

PyTorch: setting up devices

***** Running Evaluation *****
  Num examples = 5
  Batch size = 8


{'eval_loss': 20.817989349365234, 'eval_model_preparation_time': 0.003, 'eval_bleu': 0.0, 'eval_rouge1': 0.0, 'eval_rouge2': 0.0, 'eval_rougeL': 0.0, 'eval_rougeLsum': 0.0, 'eval_runtime': 0.6666, 'eval_samples_per_second': 7.501, 'eval_steps_per_second': 1.5}
0.2
