# MT5 Fine-tuning for Question Answering

This notebook demonstrates fine-tuning the MT5 model on multilingual question answering data.

In [None]:
import polars as pl
from datasets import Dataset, load_dataset
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, Seq2SeqTrainer
import torch
import evaluate
import numpy as np
import os

In [None]:
# Global vars
TRAIN = False

In [None]:
# Load dataset
dataset = load_dataset("coastalcph/tydi_xor_rc")
df_train = dataset["train"].to_polars()
df_val = dataset["validation"].to_polars()

df_te_train = df_train.filter(pl.col("lang") == "te", pl.col("answer_inlang").is_not_null())
df_te_val = df_val.filter(pl.col("lang") == "te", pl.col("answer_inlang").is_not_null())
df_te_train.head()

In [None]:
# Select device for training
device = torch.device("cpu")
if torch.backends.mps.is_available():
    device = torch.device("mps")
if torch.cuda.is_available():
  device = torch.device("cuda")

print(f'Using device: {device}')

In [None]:
from transformers import MT5ForConditionalGeneration, T5Tokenizer
mt5_telugu_path = os.path.join(os.path.join("results", "mt5-telugu-qa"))


In [None]:


if os.path.exists(os.path.join(mt5_telugu_path, "fine_tuned")):
    print("Loading model from disk")
    mt5_tokenizer = T5Tokenizer.from_pretrained(os.path.join(mt5_telugu_path, "fine_tuned"))
    mt5_model = MT5ForConditionalGeneration.from_pretrained(os.path.join(mt5_telugu_path, "fine_tuned"))
    mt5_model.to(device)
else:
    print("Loading model from Huggingface")
    mt5_tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
    mt5_model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
    mt5_model.to(device)

In [None]:
#from transformers import MT5ForConditionalGeneration, T5Tokenizer
#mt5_tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
#mt5_model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
#mt5_model.to(device)

In [None]:
#from transformers import T5ForConditionalGeneration, T5Tokenizer
#t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
#t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")

In [None]:
#from transformers import MBartForConditionalGeneration, AutoTokenizer
#mbart_tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
#mbart_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")

In [None]:
#from transformers import AutoTokenizer, AutoModelForCausalLM
#gpt2_tokenizer = AutoTokenizer.from_pretrained('distilgpt2')
#gpt2_model = AutoModelForCausalLM.from_pretrained('distilgpt2')

In [None]:
def answer_question(promt: str, model, tokenizer, **kwargs) -> str:
    input_tokens = tokenizer(promt, return_tensors="pt", truncation=True, max_length=512).to(device)
    generated_tokens = model.generate(
        **input_tokens,
        max_new_tokens=64,
        **kwargs
    )
    answer = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
    return answer

In [None]:
question = "What is the best way to fall asleep?"
context = "Do not drink a white monster right before going to bed"

In [None]:
answer_question(f"question: {question}  context: {context}", mt5_model, mt5_tokenizer,  early_stopping=False)

In [None]:
#answer_question(f"question: {question}  context: {context}", t5_model, t5_tokenizer, early_stopping=True)

In [None]:
#question = "What is the best way to fall asleep?"
#context = "Do not drink a white monster right before going to bed"
#answer_question(f"Answer the following question based on the given context. \n Context: {context} \n Question: {question}", mbart_model, mbart_tokenizer, 
#    forced_bos_token_id=mbart_tokenizer.lang_code_to_id["en_XX"],
#    early_stopping=True,
#    num_beams=4,
#    length_penalty=1.2,
#    temperature=0.7,
#    top_p=0.92,
#    top_k=50,
#    repetition_penalty=1.3,
#    )

In [None]:
#answer_question(f"Answer the following question based on the given context. \n Context: {context} \n Question: {question}", gpt2_model, gpt2_tokenizer, 
#    early_stopping=True,
#    temperature=0.7,
#    top_p=0.92,  
#    top_k=50,
#    repetition_penalty=1.2,
#    no_repeat_ngram_size=2
#    )

In [None]:
def generate_prompts(df: pl.DataFrame):
    #df = df.with_columns([
    #    (pl.lit("Question: ") + pl.col("question") + pl.lit("\n Context: ") + pl.col("context")).alias("prompt")
    #])
    df = df.with_columns([pl.col("question").alias("prompt")])
    return df

def tokenize_to_dataset(df: pl.DataFrame, tokenizer, question_col: str = "prompt", answer_col: str = "answer_inlang"):

    inputs = tokenizer(
        df[question_col].to_list(),
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=512
    )
    labels = tokenizer(
        df[answer_col].to_list(),
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=64
    )

    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100

    dataset_dict = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "labels": labels
    }

    # Convert to Hugging Face Dataset
    dataset = Dataset.from_dict({k: v.numpy() for k, v in dataset_dict.items()})
    return dataset

In [None]:
df_te_train_prompt = generate_prompts(df_te_train)
df_te_val_prompt = generate_prompts(df_te_val)

train_dataset = tokenize_to_dataset(df_te_train_prompt, mt5_tokenizer)
val_dataset = tokenize_to_dataset(df_te_val_prompt, mt5_tokenizer)

print("Training dataset:")
print(train_dataset)
print("Validation dataset:")
print(val_dataset)

In [None]:
# https://huggingface.co/learn/llm-course/chapter7/4?fw=pt

metric = evaluate.load("sacrebleu")
epochs = 500

training_args = Seq2SeqTrainingArguments(
    output_dir=mt5_telugu_path,
    overwrite_output_dir = True,
    learning_rate=2e-5,
    auto_find_batch_size=True,
    weight_decay=0.01,
    num_train_epochs=epochs,
    predict_with_generate=True,
    generation_max_length=128,
    fp16=False,

    save_total_limit=5,
    save_strategy = "best",
    load_best_model_at_end = True,

    logging_strategy="epoch",
    eval_strategy = "epoch",
    log_level="info",
    report_to=[],
    logging_dir=None
)

data_collator = DataCollatorForSeq2Seq(mt5_tokenizer, model=mt5_model)

def compute_metrics(eval_preds, tokenizer=mt5_tokenizer):
    preds, labels = eval_preds

    # Handle tuple case
    if isinstance(preds, tuple):
        preds = preds[0]

    # Convert logits to token IDs if necessary
    if preds.ndim == 3:  # (batch, seq_len, vocab_size)
        preds = np.argmax(preds, axis=-1)

    # Replace -100 in labels with pad_token_id for decoding
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    # Get <unk> token ID
    unk_id = tokenizer.unk_token_id
    if unk_id is None:
        raise ValueError("Tokenizer does not define an <unk> token.")

    # Replace out-of-range token IDs with <unk> ID
    invalid_mask = (preds < 0) | (preds >= tokenizer.vocab_size)
    if np.any(invalid_mask):
        num_invalid = np.sum(invalid_mask)
        print(f"Replacing {num_invalid} invalid token IDs with <unk> ({unk_id}).")
        preds = np.where(invalid_mask, unk_id, preds)

    # Decode predictions and labels
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Strip whitespace and prepare for BLEU
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [[label.strip()] for label in decoded_labels]

    # Compute BLEU
    result = metric.compute(predictions=decoded_preds, references=decoded_labels)

    return {"bleu": result["score"]}



trainer = Seq2SeqTrainer(
    model=mt5_model,
    args=training_args,
    train_dataset=val_dataset,
    eval_dataset=train_dataset,
    tokenizer=mt5_tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [None]:
if TRAIN:
    trainer.train()
    trainer.save_model(os.path.join(mt5_telugu_path, "fine_tuned"))

In [None]:
import gc
gc.collect()
with torch.no_grad():
    torch.cuda.empty_cache()
tokenizer = T5Tokenizer.from_pretrained(os.path.join(mt5_telugu_path, "fine_tuned"))
model = MT5ForConditionalGeneration.from_pretrained(os.path.join(mt5_telugu_path, "fine_tuned"))


In [None]:
model.to(device)

In [None]:
random_num = np.random.randint(0, len(df_te_train_prompt))

question = df_te_train_prompt["prompt"][random_num]
answer = df_te_train_prompt["answer_inlang"][random_num]

inputs = tokenizer(
    question, 
    return_tensors="pt", 
    truncation=True, 
    max_length=512
    ).to(device)

outputs = model.generate(**inputs)
gen_answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Question: {question}")
print(f"Answer: {answer}")
print(f"Generated Answer: {gen_answer}")