# LM for QA Tidy_XOR dataset

In [None]:
import gc
import polars as pl
import torch

# Huggingface imports
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments
)
from datasets import load_dataset, Dataset

In [None]:
# Select device for training
device = torch.device("cpu")
if torch.backends.mps.is_available():
    device = torch.device("mps")
if torch.cuda.is_available():
  device = torch.device("cuda")

print(f'Using device: {device}')

In [None]:
# Load dataset
dataset = load_dataset("coastalcph/tydi_xor_rc")
df_train = dataset["train"].to_polars()
df_val = dataset["validation"].to_polars()

# Get only Arabic for now
df_ar_train = df_train.filter(pl.col("lang") == "ar")
df_ar_val = df_val.filter(pl.col("lang") == "ar")

In [None]:
# Load mBERT tokenizer
model_checkpoint = "bert-base-multilingual-uncased"
mbert_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
mbert_classifier = AutoModelForSequenceClassification.from_pretrained(
    model_checkpoint,
    num_labels=2,
)
df_ar_val.head(5)

In [None]:
# Load mBERT with classification head
model_checkpoint = "bert-base-multilingual-uncased"
mbert_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
mbert_model = AutoModelForSequenceClassification.from_pretrained(
    model_checkpoint, 
    num_labels=2  # binary classification (answerable or not)
)

# Prepare your data
def prepare_data(df):
    # Convert Polars to dict format for HF datasets
    data_dict = {
        "question": df["question"].to_list(),
        "context": df["context"].to_list(),
        "label": df["answerable"].cast(int).to_list()  # Convert bool to int
    }
    return Dataset.from_dict(data_dict)

train_dataset = prepare_data(df_ar_train)
val_dataset = prepare_data(df_ar_val)

# Tokenization function
def tokenize_function(examples):
    # Tokenize with question and content separated by [SEP]
    # [CLS] is added automatically
    return mbert_tokenizer(
        examples["question"],
        examples["context"],
        truncation=True,
        padding="max_length",
        max_length=512
    )

# Tokenize datasets
tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_val = val_dataset.map(tokenize_function, batched=True)

# Print example tokenized input
print(tokenized_train[0])

In [None]:
# Function to get predictions
def predict(question, context, model, tokenizer):
    """Get model prediction for a single example"""
    inputs = tokenizer(
        question, 
        context, 
        truncation=True, 
        max_length=512,
        return_tensors="pt"
    )
    
    # Move to GPU if available
    if torch.cuda.is_available():
        inputs = {k: v.cuda() for k, v in inputs.items()}
        model = model.cuda()
    
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probs = torch.softmax(logits, dim=1)
        prediction = torch.argmax(logits, dim=1).item()
    
    return {
        'prediction': prediction,  # 0 or 1
        'confidence': probs[0][prediction].item(),
        'prob_class_0': probs[0][0].item(),
        'prob_class_1': probs[0][1].item()
    }

# Test on a few examples BEFORE training
print("=" * 50)
print("BEFORE FINE-TUNING (Random Classification Head)")
print("=" * 50)

# Get a few examples from your validation set
for i in range(3):
    example = df_ar_val.row(i, named=True)
    
    result = predict(example['question'], example['context'], mbert_model, mbert_tokenizer)
    
    print(f"\nExample {i+1}:")
    print(f"Question: {example['question'][:100]}...")
    print(f"Ground Truth: {'Answerable' if example['answerable'] else 'Not Answerable'}")
    print(f"Prediction: {'Answerable' if result['prediction'] == 1 else 'Not Answerable'}")
    print(f"Confidence: {result['confidence']:.3f}")
    print(f"Probs: [Not Answerable: {result['prob_class_0']:.3f}, Answerable: {result['prob_class_1']:.3f}]")

In [None]:
# Clear torch cache before training
gc.collect()
torch.cuda.empty_cache()

In [None]:


# Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    # Regularization
    weight_decay=0.01,
    # Memory settings
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    fp16=True,
    # Evaluation
    per_device_eval_batch_size=8,
    save_strategy="epoch",
    load_best_model_at_end=True,
)

# Trainer
trainer = Trainer(
    model=mbert_classifier,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
)

In [None]:
# Train
trainer.train()

In [None]:
# Test AFTER training on the same examples
print("\n" + "=" * 50)
print("AFTER FINE-TUNING")
print("=" * 50)

for i in range(3):
    example = df_ar_val.row(i, named=True)

    result = predict(example['question'], example['context'], mbert_classifier, mbert_tokenizer)

    print(f"\nExample {i+1}:")
    print(f"Question: {example['question'][:100]}...")
    print(f"Ground Truth: {'Answerable' if example['answerable'] else 'Not Answerable'}")
    print(f"Prediction: {'Answerable' if result['prediction'] == 1 else 'Not Answerable'}")
    print(f"Confidence: {result['confidence']:.3f}")

In [None]:
# Save model
mbert_model.save_pretrained("./mbert_arabic_answerable_classifier")
mbert_tokenizer.save_pretrained("./mbert_arabic_answerable_classifier")