# MT5 Fine-tuning for Question Answering

This notebook demonstrates fine-tuning the MT5 model on multilingual question answering data.

In [1]:
%load_ext autoreload
%autoreload 2

In [39]:
import polars as pl
from transformers import MT5ForConditionalGeneration, T5Tokenizer
import torch
import os

from mt5_utils import generate_prompt_wo_context, tokenize_to_dataset, trainer_generator, get_extra_id_0_proportion

In [3]:
mt5_telugu_w_context_save_path = os.path.join(os.path.join("results", "mt5-telugu-qa-wo-context"))

In [4]:
df_test_json = pl.read_json("test.json")
df_test_json.head()

question,context,lang,answerable,answer_start,answer,answer_inlang,translation
str,str,str,bool,i64,str,str,str
"""Which company published the vi…","""Hell's Kitchen: The Game is a …","""en""",True,164,"""Ubisoft""","""Ubisoft""","""Which company published the vi…"
"""八一大楼是什么时候完全建成的？""","""The first PLA main office in B…","""zh""",True,584,"""July 1999""","""1999年7月""","""When was the August 1st buildi…"
"""1650ల చివరలో 'Hareskoven'ను ఏ …","""Hareskov was severely damaged …","""te""",True,37,"""the Swedish""","""స్వీడిష్""","""Which country damaged 'Haresko…"
"""ఎముక పురుగులు దేని నుండి ఆహారం…","""Osedax is a genus of deep-sea …","""te""",True,208,"""whale carcasses""","""తిమింగల కళేబరాలు""","""What do boneworms feed on?"""
"""దోసకాయలు ఏ ఖండం నుండి ఉద్భవించ…","""Considered an annual plant, th…","""te""",True,180,"""Asia""","""ఆసియా""","""Which continent did cucumbers …"


In [5]:
# Select device for training
device = torch.device("cpu")
if torch.backends.mps.is_available():
    device = torch.device("mps")
if torch.cuda.is_available():
  device = torch.device("cuda")

print(f'Using device: {device}')

Using device: mps


In [6]:
print("Loading model from disk")
mt5_w_context_tokenizer = T5Tokenizer.from_pretrained(os.path.join(mt5_telugu_w_context_save_path, "fine_tuned"))
mt5_w_context_model = MT5ForConditionalGeneration.from_pretrained(os.path.join(mt5_telugu_w_context_save_path, "fine_tuned"))
mt5_w_context_model.to(device)

Loading model from disk


MT5ForConditionalGeneration(
  (shared): Embedding(250112, 512)
  (encoder): MT5Stack(
    (embed_tokens): Embedding(250112, 512)
    (block): ModuleList(
      (0): MT5Block(
        (layer): ModuleList(
          (0): MT5LayerSelfAttention(
            (SelfAttention): MT5Attention(
              (q): Linear(in_features=512, out_features=384, bias=False)
              (k): Linear(in_features=512, out_features=384, bias=False)
              (v): Linear(in_features=512, out_features=384, bias=False)
              (o): Linear(in_features=384, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 6)
            )
            (layer_norm): MT5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): MT5LayerFF(
            (DenseReluDense): MT5DenseGatedActDense(
              (wi_0): Linear(in_features=512, out_features=1024, bias=False)
              (wi_1): Linear(in_features=512, out_features=1024, bias=False)
          

# Evaluate difference between answerable and not answerable questions

In [7]:
df_test_json_only_answerable = df_test_json.filter(pl.col("answerable") == False)
df_test_json_only_unanswerable = df_test_json.filter(pl.col("answerable") == True)
df_test_json_only_telugu = df_test_json.filter(pl.col("lang") == "te")

In [8]:
# with context and answerable
df_test_json_prompt_w_context_answerable = generate_prompt_wo_context(df_test_json_only_answerable)
test_json_dataset_w_context_answerable = tokenize_to_dataset(df_test_json_prompt_w_context_answerable, mt5_w_context_tokenizer)

# with context and unanswerable
df_test_json_prompt_w_context_unanswerable = generate_prompt_wo_context(df_test_json_only_unanswerable)
test_json_dataset_w_context_unanswerable = tokenize_to_dataset(df_test_json_prompt_w_context_unanswerable, mt5_w_context_tokenizer)

# with context and only telugu
df_test_json_prompt_w_context_telugu = generate_prompt_wo_context(df_test_json_only_telugu)
test_json_dataset_w_context_telugu = tokenize_to_dataset(df_test_json_prompt_w_context_telugu, mt5_w_context_tokenizer)

# both answerable and unanswerable
df_test_json_prompt_w_context = generate_prompt_wo_context(df_test_json)
test_json_dataset_w_context = tokenize_to_dataset(df_test_json_prompt_w_context, mt5_w_context_tokenizer)

In [18]:
# with context and answerable
trainer = trainer_generator(
    model=mt5_w_context_model,
    tokenizer=mt5_w_context_tokenizer,
    train_dataset=test_json_dataset_w_context_answerable,
    eval_dataset=test_json_dataset_w_context_answerable,
    output_dir="tmp",
    epochs=1
)
results = trainer.evaluate()
print(results)

PyTorch: setting up devices
average_tokens_across_devices is True but world size is 1. Setting it to False automatically.

***** Running Evaluation *****
  Num examples = 3
  Batch size = 8


Raw Predictions: [[     0 250099    260      1      0      0      0      0      0      0
       0      0      0      0      0      0      0      0      0      0]
 [     0 250099    259  37604  77007   7596      1      0      0      0
       0      0      0      0      0      0      0      0      0      0]
 [     0 250099      1      0      0      0      0      0      0      0
       0      0      0      0      0      0      0      0      0      0]]
Raw Labels: [[28125 89292  3813 22675  5259   259 75958  1660     1  -100  -100  -100
   -100  -100  -100  -100  -100  -100  -100  -100  -100  -100  -100  -100
   -100  -100  -100  -100  -100  -100  -100  -100  -100  -100  -100  -100
   -100  -100  -100  -100  -100  -100  -100  -100  -100  -100  -100  -100
   -100  -100  -100  -100  -100  -100  -100  -100  -100  -100  -100  -100
   -100  -100  -100  -100]
 [ 7174  2479     1  -100  -100  -100  -100  -100  -100  -100  -100  -100
   -100  -100  -100  -100  -100  -100  -100  -100  -100  -100  -

In [10]:
# with context and unanswerable
trainer = trainer_generator(
    model=mt5_w_context_model,
    tokenizer=mt5_w_context_tokenizer,
    train_dataset=test_json_dataset_w_context_unanswerable,
    eval_dataset=test_json_dataset_w_context_unanswerable,
    output_dir="tmp",
    epochs=1
)
results = trainer.evaluate()
print(results)

PyTorch: setting up devices
average_tokens_across_devices is True but world size is 1. Setting it to False automatically.

***** Running Evaluation *****
  Num examples = 16
  Batch size = 8


Predictions: [[     0 250099      1      0      0      0      0      0      0      0
       0      0      0      0      0      0      0      0      0      0]
 [     0 250099    306      1      0      0      0      0      0      0
       0      0      0      0      0      0      0      0      0      0]
 [     0 250099    260      1      0      0      0      0      0      0
       0      0      0      0      0      0      0      0      0      0]
 [     0 250099    259  29794   2742  99260   1660  89811 230175    291
       1      0      0      0      0      0      0      0      0      0]
 [     0 250099    260      1      0      0      0      0      0      0
       0      0      0      0      0      0      0      0      0      0]
 [     0 250099    259   3687   1426      1      0      0      0      0
       0      0      0      0      0      0      0      0      0      0]
 [     0 250099    260      1      0      0      0      0      0      0
       0      0      0      0      0      0  

In [11]:
trainer = trainer_generator(
    model=mt5_w_context_model,
    tokenizer=mt5_w_context_tokenizer,
    train_dataset=test_json_dataset_w_context,
    eval_dataset=test_json_dataset_w_context,
    output_dir="tmp",
    epochs=1
)
results = trainer.evaluate()
print(results)

PyTorch: setting up devices
average_tokens_across_devices is True but world size is 1. Setting it to False automatically.

***** Running Evaluation *****
  Num examples = 19
  Batch size = 8


Predictions: [[     0 250099      1      0      0      0      0      0      0      0
       0      0      0      0      0      0      0      0      0      0]
 [     0 250099    306      1      0      0      0      0      0      0
       0      0      0      0      0      0      0      0      0      0]
 [     0 250099    260      1      0      0      0      0      0      0
       0      0      0      0      0      0      0      0      0      0]
 [     0 250099    259  29794   2742  99260   1660  89811 230175    291
       1      0      0      0      0      0      0      0      0      0]
 [     0 250099    260      1      0      0      0      0      0      0
       0      0      0      0      0      0      0      0      0      0]
 [     0 250099    259   3687   1426      1      0      0      0      0
       0      0      0      0      0      0      0      0      0      0]
 [     0 250099    260      1      0      0      0      0      0      0
       0      0      0      0      0      0  

In [21]:
trainer = trainer_generator(
    model=mt5_w_context_model,
    tokenizer=mt5_w_context_tokenizer,
    train_dataset=test_json_dataset_w_context_telugu,
    eval_dataset=test_json_dataset_w_context_telugu,
    output_dir="tmp",
    epochs=1
)
results = trainer.evaluate()
print(results)

PyTorch: setting up devices
average_tokens_across_devices is True but world size is 1. Setting it to False automatically.

***** Running Evaluation *****
  Num examples = 4
  Batch size = 8


Raw Predictions: [[     0 250099    260      1      0      0      0      0      0      0
       0      0      0      0      0      0      0      0      0      0]
 [     0 250099    259  29794   2742  99260   1660  89811 230175    291
       1      0      0      0      0      0      0      0      0      0]
 [     0 250099    260      1      0      0      0      0      0      0
       0      0      0      0      0      0      0      0      0      0]
 [     0 250099    260      1      0      0      0      0      0      0
       0      0      0      0      0      0      0      0      0      0]]
Raw Labels: [[ 52809   8219  12135  46786      1   -100   -100   -100   -100   -100
    -100   -100   -100   -100   -100   -100   -100   -100   -100   -100
    -100   -100   -100   -100   -100   -100   -100   -100   -100   -100
    -100   -100   -100   -100   -100   -100   -100   -100   -100   -100
    -100   -100   -100   -100   -100   -100   -100   -100   -100   -100
    -100   -100   -100   -100 

In [31]:
question = df_test_json_prompt_w_context_telugu["prompt"][0]
answer = df_test_json_prompt_w_context_telugu["answer_inlang"][0]

inputs = mt5_w_context_tokenizer(
    question, 
    return_tensors="pt", 
    truncation=True, 
    max_length=512
    ).to(device)    

outputs = mt5_w_context_model.generate(**inputs)
gen_answer = mt5_w_context_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Question: {question}")
print(f"Answer: {answer}")
print(f"Generated Answer: {gen_answer}")

Question: 1650ల చివరలో 'Hareskoven'ను ఏ దేశం దెబ్బతీసింది?
Answer: స్వీడిష్
Generated Answer: <extra_id_0>.


In [40]:
get_extra_id_0_proportion(df_test_json_prompt_w_context_telugu, mt5_w_context_tokenizer, mt5_w_context_model, device)

0.75