# MT5 Fine-tuning for Question Answering

This notebook demonstrates fine-tuning the MT5 model on multilingual question answering data.

In [None]:
import polars as pl
from datasets import load_dataset
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, Seq2SeqTrainer
from transformers import MT5ForConditionalGeneration, T5Tokenizer
import torch
import evaluate
import numpy as np
import os

from mt5_utils import generate_prompt_with_context, generate_prompt_wo_context, tokenize_to_dataset, compute_metrics

In [None]:
# Global vars
TRAIN = False
mt5_telugu_w_context_save_path = os.path.join(os.path.join("results", "mt5-telugu-qa-w-context"))
mt5_telugu_wo_context_save_path = os.path.join(os.path.join("results", "mt5-telugu-qa-wo-context"))

In [None]:
# Load dataset
dataset = load_dataset("coastalcph/tydi_xor_rc")
df_train = dataset["train"].to_polars()
df_val = dataset["validation"].to_polars()
df_te_train = df_train.filter(pl.col("lang") == "te", pl.col("answer_inlang").is_not_null())
df_te_val = df_val.filter(pl.col("lang") == "te", pl.col("answer_inlang").is_not_null())

In [None]:
# Select device for training
device = torch.device("cpu")
if torch.backends.mps.is_available():
    device = torch.device("mps")
if torch.cuda.is_available():
  device = torch.device("cuda")

print(f'Using device: {device}')

In [None]:
if os.path.exists(os.path.join(mt5_telugu_w_context_save_path, "fine_tuned")):
    print("Loading model from disk")
    mt5_w_context_tokenizer = T5Tokenizer.from_pretrained(os.path.join(mt5_telugu_w_context_save_path, "fine_tuned"))
    mt5_w_context_model = MT5ForConditionalGeneration.from_pretrained(os.path.join(mt5_telugu_w_context_save_path, "fine_tuned"))
    mt5_w_context_model.to(device)
else:
    print("Loading model from Huggingface")
    mt5_w_context_tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
    mt5_w_context_model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
    mt5_w_context_model.to(device)

In [None]:
df_te_train_prompt_w_context = generate_prompt_with_context(df_te_val) # Flip train and val since val is bigger
df_te_val_prompt_w_context = generate_prompt_with_context(df_te_train)
train_dataset_w_context = tokenize_to_dataset(df_te_train_prompt_w_context, mt5_w_context_tokenizer)
val_dataset_w_context = tokenize_to_dataset(df_te_val_prompt_w_context, mt5_w_context_tokenizer)

In [None]:
# https://huggingface.co/learn/llm-course/chapter7/4?fw=pt
epochs = 25

training_args = Seq2SeqTrainingArguments(
    fp16=False,
    auto_find_batch_size=True,

    output_dir=mt5_telugu_w_context_save_path,
    overwrite_output_dir = True,
    learning_rate=2e-5,
    
    predict_with_generate=True,
    num_train_epochs=epochs,
    weight_decay=0.01,
    generation_max_length=128,

    save_total_limit=3,
    save_strategy = "best",
    load_best_model_at_end = True,

    logging_strategy="epoch",
    eval_strategy = "epoch",
    log_level="info",
    report_to=[],
    logging_dir=None
)
#data_collator = DataCollatorForSeq2Seq(mt5_tokenizer, model=mt5_model)

trainer = Seq2SeqTrainer(
    model=mt5_w_context_model,
    args=training_args,
    train_dataset=train_dataset_w_context,
    eval_dataset=val_dataset_w_context,
    tokenizer=mt5_w_context_tokenizer,
#    data_collator=data_collator,
    compute_metrics=lambda eval_preds: compute_metrics(eval_preds, mt5_w_context_tokenizer)
)

In [None]:
if TRAIN:
    trainer.train()
    trainer.save_model(os.path.join(mt5_telugu_w_context_save_path, "fine_tuned"))
results = trainer.evaluate()
print(results)

In [None]:
if os.path.exists(os.path.join(mt5_telugu_wo_context_save_path, "fine_tuned")):
    print("Loading model from disk")
    mt5_wo_context_tokenizer = T5Tokenizer.from_pretrained(os.path.join(mt5_telugu_wo_context_save_path, "fine_tuned"))
    mt5_wo_context_model = MT5ForConditionalGeneration.from_pretrained(os.path.join(mt5_telugu_wo_context_save_path, "fine_tuned"))
    mt5_wo_context_model.to(device)
else:
    print("Loading model from Huggingface")
    mt5_wo_context_tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
    mt5_wo_context_model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
    mt5_wo_context_model.to(device)

In [None]:
df_te_train_prompt_wo_context = generate_prompt_wo_context(df_te_val)
df_te_val_prompt_wo_context = generate_prompt_wo_context(df_te_train)
train_dataset_wo_context = tokenize_to_dataset(df_te_train_prompt_wo_context, mt5_wo_context_tokenizer)
val_dataset_wo_context = tokenize_to_dataset(df_te_val_prompt_wo_context, mt5_wo_context_tokenizer)

In [None]:
# https://huggingface.co/learn/llm-course/chapter7/4?fw=pt
epochs = 25

training_args = Seq2SeqTrainingArguments(
    fp16=False,
    auto_find_batch_size=True,

    output_dir=mt5_telugu_wo_context_save_path,
    overwrite_output_dir = True,
    learning_rate=2e-5,
    
    predict_with_generate=True,
    num_train_epochs=epochs,
    weight_decay=0.01,
    generation_max_length=128,

    save_total_limit=3,
    save_strategy = "best",
    load_best_model_at_end = True,

    logging_strategy="epoch",
    eval_strategy = "epoch",
    log_level="info",
    report_to=[],
    logging_dir=None
)
#data_collator = DataCollatorForSeq2Seq(mt5_tokenizer, model=mt5_model)

trainer = Seq2SeqTrainer(
    model=mt5_wo_context_model,
    args=training_args,
    train_dataset=train_dataset_wo_context,
    eval_dataset=val_dataset_wo_context,
    tokenizer=mt5_wo_context_tokenizer,
#    data_collator=data_collator,
    compute_metrics=lambda eval_preds: compute_metrics(eval_preds, mt5_wo_context_tokenizer),
)

In [None]:
if TRAIN:
    trainer.train()
    trainer.save_model(os.path.join(mt5_telugu_wo_context_save_path, "fine_tuned"))
results = trainer.evaluate()
print(results)

In [None]:
#import gc
#gc.collect()
#with torch.no_grad():
#    torch.cuda.empty_cache()

In [None]:
random_num = np.random.randint(0, len(df_te_val_prompt_w_context))

question = df_te_val_prompt_w_context["prompt"][random_num]
answer = df_te_val_prompt_w_context["answer_inlang"][random_num]

inputs = mt5_w_context_tokenizer(
    question, 
    return_tensors="pt", 
    truncation=True, 
    max_length=512
    ).to(device)    

outputs = mt5_w_context_model.generate(**inputs)
gen_answer = mt5_w_context_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Question: {question}")
print(f"Answer: {answer}")
print(f"Generated Answer: {gen_answer}")