# Week 38 mBERT Binary 'Answerability' Classifier

In [None]:
import os
import polars as pl
import torch

from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt

from bert_utils import (
    predict_binary,
    prepare_data,
    tokenize_function,
    train_mbert,
)

# Huggingface imports
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
)
from datasets import load_dataset

In [None]:
# Select device for training
device = torch.device("cpu")
if torch.backends.mps.is_available():
    device = torch.device("mps")
if torch.cuda.is_available():
  device = torch.device("cuda")

print(f'Using device: {device}')

In [None]:
# Load personal dataset
test_data = pl.read_json("test.json")

In [None]:
# Check Korean distribution train
print(f"Custom test set size: {len(test_data)} with a total of {test_data['answerable'].sum()} answerable questions.")
print(f"This gives a distribution of {test_data['answerable'].sum() / len(test_data) * 100:.2f}% answerable questions.")

### Load in the fine-tuned model

In [None]:
path = os.path.join("mbert_classifiers", "telugu_mbert_answerable_classifier")
if os.path.exists(path):
    model = AutoModelForSequenceClassification.from_pretrained(path)
    tokenizer = AutoTokenizer.from_pretrained(path)
else:
    raise ValueError(f"Model not trained. Go to bert-class.ipynb notebook to train the model first.")

## Results

In [None]:
    # Test AFTER training on the same examples
    print("\n" + "=" * 50)
    print(f"AFTER FINE-TUNING")
    print("=" * 50)
    for i in range(3):
        example = test_data.row(i, named=True)

        result = predict_binary(example['question'], example['context'], model, tokenizer)

        print(f"\nExample {i+1}:")
        print(f"Question: {example['question'][:100]}...")
        print(f"Ground Truth: {'Answerable' if example['answerable'] else 'Not Answerable'}")
        print(f"Prediction: {'Answerable' if result['prediction'] == 1 else 'Not Answerable'}")
        print(f"Confidence: {result['confidence']:.3f}")

In [None]:
# Get the global accuracy for each language on validation set  
correct = 0
total = len(test_data)
print(f"\nCalculating accuracy for test set of size {total}...")
for i in range(total):
    example = test_data.row(i, named=True)
    result = predict_binary(example['question'], example['context'], model, tokenizer)
    if result['prediction'] == example['answerable']:
        correct += 1
accuracy = correct / total * 100
print(f"Accuracy: {accuracy:.2f}% ({correct}/{total})")

In [None]:
# Make confusion matrices for each language
y_true = []
y_pred = []
total = len(test_data)
print(f"\nCalculating confusion matrix for on of size {total}...")
for i in range(total):
    example = test_data.row(i, named=True)
    result = predict_binary(example['question'], example['context'], model, tokenizer)
    y_true.append(example['answerable'])
    y_pred.append(result['prediction'])
cm = confusion_matrix(y_true, y_pred, normalize='true')
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title(f"Confusion Matrix for Test Set")

plt.colorbar()
tick_marks = range(len(['Not Answerable', 'Answerable']))
plt.xticks(tick_marks, ['Not Answerable', 'Answerable'])
plt.yticks(tick_marks, ['Not Answerable', 'Answerable'])
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
# Include numbers as text in the plot
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(j, i, f"{cm[i, j]:.2f}",
                    horizontalalignment="center",
                    color="white" if cm[i, j] > thresh else "black")
plt.show()
print(f"Classification Report for test set:\n{classification_report(y_true, y_pred, target_names=['Not Answerable', 'Answerable'])}") 