# MT5 Fine-tuning for Question Answering

This notebook demonstrates fine-tuning the MT5 model on multilingual question answering data.

In [None]:
import polars as pl
from datasets import load_dataset
from transformers import MT5ForConditionalGeneration, T5Tokenizer
import torch
import evaluate
import numpy as np
import os

from mt5_utils import generate_prompt_with_context, generate_prompt_wo_context, tokenize_to_dataset, trainer_generator

In [None]:
# Global vars
TRAIN = False
mt5_telugu_w_context_save_path = os.path.join(os.path.join("results", "mt5-telugu-qa-w-context"))
mt5_telugu_wo_context_save_path = os.path.join(os.path.join("results", "mt5-telugu-qa-wo-context"))

In [None]:
# Load dataset
dataset = load_dataset("coastalcph/tydi_xor_rc")
df_train = dataset["train"].to_polars()
df_val = dataset["validation"].to_polars()
df_te_train = df_train.filter(pl.col("lang") == "te", pl.col("answer_inlang").is_not_null())
df_te_val = df_val.filter(pl.col("lang") == "te", pl.col("answer_inlang").is_not_null())

print(f"Train size: {df_te_train.shape}, Val size: {df_te_val.shape}")

In [None]:
# Select device for training
device = torch.device("cpu")
if torch.backends.mps.is_available():
    device = torch.device("mps")
if torch.cuda.is_available():
  device = torch.device("cuda")

print(f'Using device: {device}')

In [None]:
if os.path.exists(os.path.join(mt5_telugu_w_context_save_path, "fine_tuned")):
    print("Loading model from disk")
    mt5_w_context_tokenizer = T5Tokenizer.from_pretrained(os.path.join(mt5_telugu_w_context_save_path, "fine_tuned"))
    mt5_w_context_model = MT5ForConditionalGeneration.from_pretrained(os.path.join(mt5_telugu_w_context_save_path, "fine_tuned"))
    mt5_w_context_model.to(device)
else:
    print("Loading model from Huggingface")
    mt5_w_context_tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
    mt5_w_context_model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
    mt5_w_context_model.to(device)

In [None]:
df_te_train_prompt_w_context = generate_prompt_with_context(df_te_val) # Flip train and val since val is bigger
df_te_val_prompt_w_context = generate_prompt_with_context(df_te_train)
train_dataset_w_context = tokenize_to_dataset(df_te_train_prompt_w_context, mt5_w_context_tokenizer)
val_dataset_w_context = tokenize_to_dataset(df_te_val_prompt_w_context, mt5_w_context_tokenizer)

In [None]:
trainer = trainer_generator(
    model=mt5_w_context_model,
    tokenizer=mt5_w_context_tokenizer,
    train_dataset=train_dataset_w_context,
    eval_dataset=val_dataset_w_context,
    output_dir=mt5_telugu_w_context_save_path,
    epochs=25
)
if TRAIN:
    trainer.train()
    trainer.save_model(os.path.join(mt5_telugu_w_context_save_path, "fine_tuned"))
results = trainer.evaluate()
print(results)

In [None]:
if os.path.exists(os.path.join(mt5_telugu_wo_context_save_path, "fine_tuned")):
    print("Loading model from disk")
    mt5_wo_context_tokenizer = T5Tokenizer.from_pretrained(os.path.join(mt5_telugu_wo_context_save_path, "fine_tuned"))
    mt5_wo_context_model = MT5ForConditionalGeneration.from_pretrained(os.path.join(mt5_telugu_wo_context_save_path, "fine_tuned"))
    mt5_wo_context_model.to(device)
else:
    print("Loading model from Huggingface")
    mt5_wo_context_tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
    mt5_wo_context_model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
    mt5_wo_context_model.to(device)

In [None]:
df_te_train_prompt_wo_context = generate_prompt_wo_context(df_te_val)
df_te_val_prompt_wo_context = generate_prompt_wo_context(df_te_train)
train_dataset_wo_context = tokenize_to_dataset(df_te_train_prompt_wo_context, mt5_wo_context_tokenizer)
val_dataset_wo_context = tokenize_to_dataset(df_te_val_prompt_wo_context, mt5_wo_context_tokenizer)

In [None]:
trainer = trainer_generator(
    model=mt5_wo_context_model,
    tokenizer=mt5_wo_context_tokenizer,
    train_dataset=train_dataset_wo_context,
    eval_dataset=val_dataset_wo_context,
    output_dir=mt5_telugu_wo_context_save_path,
    epochs=25
)
if TRAIN:
    trainer.train()
    trainer.save_model(os.path.join(mt5_telugu_wo_context_save_path, "fine_tuned"))
results = trainer.evaluate()
print(results)

In [None]:
#import gc
#gc.collect()
#with torch.no_grad():
#    torch.cuda.empty_cache()

In [None]:
random_num = np.random.randint(0, len(df_te_val_prompt_w_context))

question = df_te_val_prompt_w_context["prompt"][random_num]
answer = df_te_val_prompt_w_context["answer_inlang"][random_num]

inputs = mt5_w_context_tokenizer(
    question, 
    return_tensors="pt", 
    truncation=True, 
    max_length=512
    ).to(device)    

outputs = mt5_w_context_model.generate(**inputs)
gen_answer = mt5_w_context_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Question: {question}")
print(f"Answer: {answer}")
print(f"Generated Answer: {gen_answer}")

# Evaluate difference between answerable and not answerable questions

In [None]:
df_te_train_only_answerable = df_te_train.filter(pl.col("answerable") == False)
df_te_val_only_answerable = df_te_val.filter(pl.col("answerable") == False)
df_te_train_only_unanswerable = df_te_train.filter(pl.col("answerable") == True)
df_te_val_only_unanswerable = df_te_val.filter(pl.col("answerable") == True)

print(f"Train answerable size: {df_te_train_only_answerable.shape}, Val answerable size: {df_te_val_only_answerable.shape}")
print(f"Train unanswerable size: {df_te_train_only_unanswerable.shape}, Val unanswerable size: {df_te_val_only_unanswerable.shape}")

In [None]:
# with context and answerable
df_te_val_prompt_w_context_answerable = generate_prompt_with_context(df_te_train_only_answerable)
val_dataset_w_context_answerable = tokenize_to_dataset(df_te_val_prompt_w_context_answerable, mt5_w_context_tokenizer)

# without context and answerable
df_te_val_prompt_wo_context_answerable = generate_prompt_wo_context(df_te_train_only_answerable)
val_dataset_wo_context_answerable = tokenize_to_dataset(df_te_val_prompt_wo_context_answerable, mt5_wo_context_tokenizer)

# with context and unanswerable
df_te_val_prompt_w_context_unanswerable = generate_prompt_with_context(df_te_train_only_unanswerable)
val_dataset_w_context_unanswerable = tokenize_to_dataset(df_te_val_prompt_w_context_unanswerable, mt5_w_context_tokenizer)

# without context and unanswerable
df_te_val_prompt_wo_context_unanswerable = generate_prompt_wo_context(df_te_train_only_unanswerable)
val_dataset_wo_context_unanswerable = tokenize_to_dataset(df_te_val_prompt_wo_context_unanswerable, mt5_wo_context_tokenizer)

In [None]:
# with context and answerable
trainer = trainer_generator(
    model=mt5_w_context_model,
    tokenizer=mt5_w_context_tokenizer,
    train_dataset=val_dataset_w_context_answerable,
    eval_dataset=val_dataset_w_context_answerable,
    output_dir="tmp",
    epochs=1
)
results = trainer.evaluate()
print(results)

In [None]:
# without context and answerable
trainer = trainer_generator(
    model=mt5_wo_context_model,
    tokenizer=mt5_wo_context_tokenizer,
    train_dataset=val_dataset_wo_context_answerable,
    eval_dataset=val_dataset_wo_context_answerable,
    output_dir="tmp",
    epochs=1
)
results = trainer.evaluate()
print(results)

In [None]:
# with context and unanswerable
trainer = trainer_generator(
    model=mt5_w_context_model,
    tokenizer=mt5_w_context_tokenizer,
    train_dataset=val_dataset_w_context_unanswerable,
    eval_dataset=val_dataset_w_context_unanswerable,
    output_dir="tmp",
    epochs=1
)
results = trainer.evaluate()
print(results)

In [None]:
# without context and unanswerable
trainer = trainer_generator(
    model=mt5_wo_context_model,
    tokenizer=mt5_wo_context_tokenizer,
    train_dataset=val_dataset_wo_context_unanswerable,
    eval_dataset=val_dataset_wo_context_unanswerable,
    output_dir="tmp",
    epochs=1
)
results = trainer.evaluate()
print(results)