# MT5 Fine-tuning for Question Answering

This notebook demonstrates fine-tuning the MT5 model on multilingual question answering data.

In [None]:
import polars as pl
from datasets import Dataset, load_dataset
from transformers import Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, Seq2SeqTrainer
import torch
import evaluate
import numpy as np
import os

In [None]:
# Load dataset
dataset = load_dataset("coastalcph/tydi_xor_rc")
df_train = dataset["train"].to_polars()
df_val = dataset["validation"].to_polars()

df_te_train = df_train.filter(pl.col("lang") == "te", pl.col("answer_inlang").is_not_null())
df_te_val = df_val.filter(pl.col("lang") == "te", pl.col("answer_inlang").is_not_null())
df_te_train.head()

In [None]:
# Select device for training
device = torch.device("cpu")
if torch.backends.mps.is_available():
    device = torch.device("mps")
if torch.cuda.is_available():
  device = torch.device("cuda")

print(f'Using device: {device}')

In [None]:
from transformers import MT5ForConditionalGeneration, T5Tokenizer
mt5_tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
mt5_model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
#mt5_model.to(device)

In [None]:
#from transformers import T5ForConditionalGeneration, T5Tokenizer
#t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
#t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")

In [None]:
#from transformers import MBartForConditionalGeneration, AutoTokenizer
#mbart_tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
#mbart_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")

In [None]:
#from transformers import AutoTokenizer, AutoModelForCausalLM
#gpt2_tokenizer = AutoTokenizer.from_pretrained('distilgpt2')
#gpt2_model = AutoModelForCausalLM.from_pretrained('distilgpt2')

In [None]:
def answer_question(promt: str, model, tokenizer, **kwargs) -> str:
    input_tokens = tokenizer(promt, return_tensors="pt", truncation=True, max_length=512)
    generated_tokens = model.generate(
        **input_tokens,
        max_new_tokens=64,
        **kwargs
    )
    answer = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
    return answer

In [None]:
question = "What is the best way to fall asleep?"
context = "Do not drink a white monster right before going to bed"

In [None]:
#answer_question(f"question: {question}  context: {context}", t5_model, t5_tokenizer, early_stopping=True)

In [None]:
answer_question(f"question: {question}  context: {context}", mt5_model, mt5_tokenizer,  early_stopping=True)

In [None]:
#question = "What is the best way to fall asleep?"
#context = "Do not drink a white monster right before going to bed"
#answer_question(f"Answer the following question based on the given context. \n Context: {context} \n Question: {question}", mbart_model, mbart_tokenizer, 
#    forced_bos_token_id=mbart_tokenizer.lang_code_to_id["en_XX"],
#    early_stopping=True,
#    num_beams=4,
#    length_penalty=1.2,
#    temperature=0.7,
#    top_p=0.92,
#    top_k=50,
#    repetition_penalty=1.3,
#    )

In [None]:
#answer_question(f"Answer the following question based on the given context. \n Context: {context} \n Question: {question}", gpt2_model, gpt2_tokenizer, 
#    early_stopping=True,
#    temperature=0.7,
#    top_p=0.92,  
#    top_k=50,
#    repetition_penalty=1.2,
#    no_repeat_ngram_size=2
#    )

In [None]:
def generate_prompts(df: pl.DataFrame):
    df = df.with_columns([
        (pl.lit("Question: ") + pl.col("question") + pl.lit("\n Context: ") + pl.col("context")).alias("prompt")
    ])
    return df

def tokenize_to_dataset(df: pl.DataFrame, tokenizer, question_col: str = "prompt", answer_col: str = "answer_inlang"):

    inputs = tokenizer(
        df[question_col].to_list(),
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=512
    )
    labels = tokenizer(
        df[answer_col].to_list(),
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=64
    )

    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100

    dataset_dict = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "labels": labels
    }

    # Convert to Hugging Face Dataset
    dataset = Dataset.from_dict({k: v.numpy() for k, v in dataset_dict.items()})
    return dataset

In [None]:
df_te_train_prompt = generate_prompts(df_te_train)
df_te_val_prompt = generate_prompts(df_te_val)

train_dataset = tokenize_to_dataset(df_te_train_prompt, mt5_tokenizer)
val_dataset = tokenize_to_dataset(df_te_val_prompt, mt5_tokenizer)

In [None]:
# https://huggingface.co/learn/llm-course/chapter7/4?fw=pt

model_save_dir = os.path.join("results", "mt5-telugu-qa")

metric = evaluate.load("sacrebleu")

training_args = Seq2SeqTrainingArguments(
    output_dir=model_save_dir,
    overwrite_output_dir = True,
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=6,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=16,
    predict_with_generate=True,
    fp16=False,

    logging_strategy="epoch",
    eval_strategy = "epoch",
    log_level="info",
    report_to=[],
    logging_dir=None
)

data_collator = DataCollatorForSeq2Seq(mt5_tokenizer, model=mt5_model)

def compute_metrics(eval_preds, tokenizer=mt5_tokenizer):
    preds, labels = eval_preds
    # In case the model returns more than the prediction logits
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100s in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [[label.strip()] for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    return {"bleu": result["score"]}



trainer = Seq2SeqTrainer(
    model=mt5_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=mt5_tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()
trainer.save_model("results/mt5-telugu-qa/fine_tuned")

In [None]:
tokenizer = T5Tokenizer.from_pretrained("results/mt5-telugu-qa/fine_tuned")
model = MT5ForConditionalGeneration.from_pretrained("results/mt5-telugu-qa/fine_tuned")
model.to(device)

In [None]:
question = df_te_val_prompt["prompt"][0]
answer = df_te_val_prompt["answer_inlang"][0]

inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512).to(device)
outputs = model.generate(**inputs, max_new_tokens=64, early_stopping=True)
gen_answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Question: {question}")
print(f"Answer: {answer}")
print(f"Generated Answer: {gen_answer}")