# LM for QA Tidy_XOR dataset

In [None]:
import os
import polars as pl
import torch

from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt

from bert_utils import (
    predict_binary,
    prepare_data,
    tokenize_function,
    train_mbert,
)

# Huggingface imports
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
)
from datasets import load_dataset

In [None]:
# Select device for training
device = torch.device("cpu")
if torch.backends.mps.is_available():
    device = torch.device("mps")
if torch.cuda.is_available():
  device = torch.device("cuda")

print(f'Using device: {device}')

In [None]:
# Load dataset
dataset = load_dataset("coastalcph/tydi_xor_rc")
df_train = dataset["train"].to_polars()
df_val = dataset["validation"].to_polars()

# Arabic, Telegu and Korean
df_ar_train = df_train.filter(pl.col("lang") == "ar")
df_ar_val = df_val.filter(pl.col("lang") == "ar")
df_te_train = df_train.filter(pl.col("lang") == "te")
df_te_val = df_val.filter(pl.col("lang") == "te")
df_ko_train = df_train.filter(pl.col("lang") == "ko")
df_ko_val = df_val.filter(pl.col("lang") == "ko")

# Make a dict
data = {
    "arabic": {"train": df_ar_train, "val": df_ar_val},
    "telugu": {"train": df_te_train, "val": df_te_val},
    "korean": {"train": df_ko_train, "val": df_ko_val},
}

In [None]:
# Check Arabic distribution train
print(f"Arabic TRAINING set size: {len(df_ar_train)} with a total of {df_ar_train['answerable'].sum()} answerable questions.")
print(f"This gives a distribution of {df_ar_train['answerable'].sum() / len(df_ar_train) * 100:.2f}% answerable questions.")    
# Check Arabic distribution val
print(f"Arabic VALIDATION set size: {len(df_ar_val)} with a total of {df_ar_val['answerable'].sum()} answerable questions.")
print(f"This gives a distribution of {df_ar_val['answerable'].sum() / len(df_ar_val) * 100:.2f}% answerable questions.")

In [None]:
# Check Telegu distribution train
print(f"Telegu TRAINING set size: {len(df_te_train)} with a total of {df_te_train['answerable'].sum()} answerable questions.")
print(f"This gives a distribution of {df_te_train['answerable'].sum() / len(df_te_train) * 100:.2f}% answerable questions.")    
# Check Telegu distribution val
print(f"Telegu VALIDATION set size: {len(df_te_val)} with a total of {df_te_val['answerable'].sum()} answerable questions.")
print(f"This gives a distribution of {df_te_val['answerable'].sum() / len(df_te_val) * 100:.2f}% answerable questions.")


In [None]:
# Check Korean distribution train
print(f"Korean TRAINING set size: {len(df_ko_train)} with a total of {df_ko_train['answerable'].sum()} answerable questions.")
print(f"This gives a distribution of {df_ko_train['answerable'].sum() / len(df_ko_train) * 100:.2f}% answerable questions.")    
# Check Korean distribution val
print(f"Korean VALIDATION set size: {len(df_ko_val)} with a total of {df_ko_val['answerable'].sum()} answerable questions.")
print(f"This gives a distribution of {df_ko_val['answerable'].sum() / len(df_ko_val) * 100:.2f}% answerable questions.")

## Finetune the multilingual BERT for binary classification

In [None]:
mbert_checkpoint = "bert-base-multilingual-uncased"
mbert_tokenizer = AutoTokenizer.from_pretrained(mbert_checkpoint)

In [None]:
### EXAMPLE OF DATA PROCESS PIPELINE
# Prepare datasets
# train_dataset = prepare_data(data["telugu"]["train"])
# # Tokenize datasets - fix the function call
# tokenized_train = train_dataset.map(lambda examples: tokenize_function(examples, mbert_tokenizer), batched=True)
# tokenized_train.features

In [None]:
all_classifiers = {}
all_tokenizers = {} # they're all the same
mbert_checkpoint = "bert-base-multilingual-uncased"
for lang in ["arabic", "telugu", "korean"]:
    cap_lang = lang.capitalize()
    print(f"\n--- Processing language: {cap_lang} ---")
    trained = False
    classifiers_dir = "./mbert_classifiers"
    save_path = f"{lang}_mbert_answerable_classifier"
    full_save_path = os.path.join(classifiers_dir, save_path)
    # Check if model exists
    if not os.path.exists(classifiers_dir):
        print(f"No classifiers folder found, creating {classifiers_dir}...")
        os.makedirs(classifiers_dir)
    if os.path.exists(full_save_path):
        print(f"Found existing model for {cap_lang}, loading...")
        all_classifiers[lang] = AutoModelForSequenceClassification.from_pretrained(full_save_path)
        all_tokenizers[lang] = AutoTokenizer.from_pretrained(full_save_path) # all the same, we don't train tokenizer
        trained = True
        print(f"Model for {cap_lang} loaded.")

    # If model doesn't exist, train it
    if not trained:
        print("Model not found, training new mBERT model...")
        mbert_tokenizer = AutoTokenizer.from_pretrained(mbert_checkpoint)
        all_tokenizers[lang] = mbert_tokenizer # all the same, we don't train tokenizer
        mbert_classifier = AutoModelForSequenceClassification.from_pretrained(
            mbert_checkpoint,
            num_labels=2,
        )
        # Prepare datasets
        train_dataset = prepare_data(data[lang]["train"])
        val_dataset = prepare_data(data[lang]["val"])

        # Tokenize datasets - fix the function call
        tokenized_train = train_dataset.map(lambda examples: tokenize_function(examples, mbert_tokenizer), batched=True)
        tokenized_val = val_dataset.map(lambda examples: tokenize_function(examples, mbert_tokenizer), batched=True)
        # Train
        classifier, tokenizer = train_mbert(
            tokenized_train,
            tokenized_val,
            model_checkpoint = mbert_checkpoint,
            device=device,
        ) # type: ignore
        print("Saving model...")
        classifier.save_pretrained(full_save_path) # type: ignore
        tokenizer.save_pretrained(full_save_path) # type: ignore
        print(f"Model trained and saved to {full_save_path}.")
        # Store the trained model in notebook variable
        all_classifiers[lang] = AutoModelForSequenceClassification.from_pretrained(full_save_path)

## Compare pre-trained vs fine-tuned results

In [None]:
# Get pretrained model and tokenizer
pt_mbert = AutoModelForSequenceClassification.from_pretrained(
    mbert_checkpoint,
    num_labels=2,
)
mbert_tokenizer = AutoTokenizer.from_pretrained(mbert_checkpoint)


for lang in ["arabic", "telugu", "korean"]:
    # Test on a few examples BEFORE training
    print("=" * 50)
    print(f"{lang.upper()} BEFORE FINE-TUNING (Random Classification Head)")
    print("=" * 50)
    # Get a few examples from your validation set
    for i in range(3):
        example = data[lang]["val"].row(i, named=True)
        
        result = predict_binary(example["question"], example["context"], pt_mbert, mbert_tokenizer)
        
        print(f"\nExample {i+1}:")
        print(f"Question: {example['question'][:100]}...")
        print(f"Ground Truth: {'Answerable' if example['answerable'] else 'Not Answerable'}")
        print(f"Prediction: {'Answerable' if result['prediction'] == 1 else 'Not Answerable'}")
        print(f"Confidence: {result['confidence']:.3f}")
        print(f"Probs: [Not Answerable: {result['prob_class_0']:.3f}, Answerable: {result['prob_class_1']:.3f}]")

In [None]:
for lang in ["arabic", "telugu", "korean"]:
    # Test AFTER training on the same examples
    print("\n" + "=" * 50)
    print(f"{lang.upper()} AFTER FINE-TUNING")
    print("=" * 50)
    for i in range(3):
        example = data[lang]["val"].row(i, named=True)

        result = predict_binary(example['question'], example['context'], all_classifiers[lang], mbert_tokenizer)

        print(f"\nExample {i+1}:")
        print(f"Question: {example['question'][:100]}...")
        print(f"Ground Truth: {'Answerable' if example['answerable'] else 'Not Answerable'}")
        print(f"Prediction: {'Answerable' if result['prediction'] == 1 else 'Not Answerable'}")
        print(f"Confidence: {result['confidence']:.3f}")

In [None]:
# Get the global accuracy for each language on validation set  
for lang in ["arabic", "telugu", "korean"]:
    correct = 0
    total = len(data[lang]["val"])
    print(f"\nCalculating accuracy for {lang} on validation set of size {total}...")
    for i in range(total):
        example = data[lang]["val"].row(i, named=True)
        result = predict_binary(example['question'], example['context'], all_classifiers[lang], mbert_tokenizer)
        if result['prediction'] == example['answerable']:
            correct += 1
    accuracy = correct / total * 100
    print(f"Accuracy for {lang} on validation set: {accuracy:.2f}% ({correct}/{total})")

In [None]:
# Make confusion matrices for each language
for lang in ["arabic", "telugu", "korean"]:
    y_true = []
    y_pred = []
    total = len(data[lang]["val"])
    print(f"\nCalculating confusion matrix for {lang} on validation set of size {total}...")
    for i in range(total):
        example = data[lang]["val"].row(i, named=True)
        result = predict_binary(example['question'], example['context'], all_classifiers[lang], mbert_tokenizer)
        y_true.append(example['answerable'])
        y_pred.append(result['prediction'])
    cm = confusion_matrix(y_true, y_pred, normalize='true')
    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(f"Confusion Matrix for {lang.capitalize()}")

    plt.colorbar()
    tick_marks = range(len(['Not Answerable', 'Answerable']))
    plt.xticks(tick_marks, ['Not Answerable', 'Answerable'])
    plt.yticks(tick_marks, ['Not Answerable', 'Answerable'])
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    # Include numbers as text in the plot
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, f"{cm[i, j]:.2f}",
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
    plt.show()
    print(f"Classification Report for {lang.capitalize()}:\n{classification_report(y_true, y_pred, target_names=['Not Answerable', 'Answerable'])}") 