# LM for QA Tidy_XOR dataset

In [None]:
import polars as pl
from transformers import AutoModel, AutoTokenizer
import torch
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
import os
import matplotlib.pyplot as plt

from data.const import ARB_CACHE, KOR_CACHE, TELU_CACHE
from nlm.models import BiLSTMClassifierModel
from nlm.train_utils import train_classifier as train

In [None]:
mbert_tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-uncased")
mbert_model = AutoModel.from_pretrained("bert-base-multilingual-uncased")
pretrained_embeddings = mbert_model.get_input_embeddings().weight.data

In [None]:
device = torch.device("cpu")
if torch.backends.mps.is_available():
    device = torch.device("mps")
if torch.cuda.is_available():
  device = torch.device("cuda")

print(f'Using device: {device}')

In [None]:
def dataloader_generator(dataset: list, labels: list, tokenizer, device, test_split: float = 0.2, batch_size: int = 8) -> tuple[DataLoader, DataLoader]:

    tokens = tokenizer(
        dataset,
        truncation=True,
        max_length=65,
        padding='max_length',
        return_tensors='pt'
    ).to(device)
    labels = torch.tensor(labels).to(device)

    input_ids = tokens['input_ids']
    attention_mask = tokens['attention_mask']
    input_lens = attention_mask.sum(dim=1)

    # Split into train and validation sets
    train_idx, val_idx = train_test_split(
        range(input_ids.size(0)), test_size=test_split, random_state=42
    )

    train_dataset = TensorDataset(
        input_ids[train_idx], input_lens[train_idx], labels[train_idx]
    )
    val_dataset = TensorDataset(
        input_ids[val_idx], input_lens[val_idx], labels[val_idx]
    )
    train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dl = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

    return train_dl, val_dl

In [None]:
def class_model_loader(dataset: list, labels:list, device, model_cache_path: str, epochs: int, n_classes:int = 2, model_lstm_dim: int = 100) -> tuple[BiLSTMClassifierModel, float, float]:
    model = BiLSTMClassifierModel(
        pretrained_embeddings=torch.FloatTensor(pretrained_embeddings).to(device),
        n_classes=n_classes,
        lstm_dim=model_lstm_dim,
    ).to(device)

    if os.path.exists(model_cache_path):
        print("Loading cached model from", model_cache_path)
        model.load_state_dict(torch.load(model_cache_path))
    else:
        print("No cached model found. Training a new model.")
        train_dl, val_dl = dataloader_generator(dataset, labels, mbert_tokenizer, device)
        losses, best_acc = train(model, train_dl, val_dl, torch.optim.Adam(model.parameters(), lr=1e-3), n_epochs=epochs, device=device, save_path=model_cache_path)
        print('Training complete. Best validation accuracy:', best_acc)

    return model


In [None]:
# Arabic dataset
arabic_model_path = "cached_data/bilstm_class_arabic"
df_ar = pl.read_parquet(ARB_CACHE)
df_arabic = df_ar["question"].to_list()
df_arabic_answerable = [int(x) for x in df_ar["answerable"].to_list()]

arabic_model = class_model_loader(df_arabic, df_arabic_answerable, device, arabic_model_path, epochs=10, n_classes=2, model_lstm_dim=100)

In [None]:
# Arabic dataset with context
arabic_model_path = "cached_data/bilstm_class_arabic_w_context"
df_ar = pl.read_parquet(ARB_CACHE)
df_arabic = zip(df_ar["question"].to_list() , df_ar["context"].to_list())
df_arabic = ["[SEP]".join([q, c]) for q, c in df_arabic]
df_arabic_answerable = [int(x) for x in df_ar["answerable"].to_list()]

arabic_model_w_context = class_model_loader(df_arabic, df_arabic_answerable, device, arabic_model_path, epochs=10, n_classes=2, model_lstm_dim=100)

In [None]:
# Korean dataset
korean_model_path = "cached_data/bilstm_class_korean"
df_ko = pl.read_parquet(KOR_CACHE)
df_korean = df_ko["question"].to_list()
df_korean_answerable = [int(x) for x in df_ko["answerable"].to_list()]

korean_model = class_model_loader(df_korean, df_korean_answerable, device, korean_model_path,  epochs=10, n_classes=2, model_lstm_dim=100)

In [None]:
# Korean dataset with context
korean_model_path = "cached_data/bilstm_class_korean_w_context"
df_ko = pl.read_parquet(KOR_CACHE)
df_korean = zip(df_ko["question"].to_list() , df_ko["context"].to_list())
df_korean = ["[SEP]".join([q, c]) for q, c in df_korean]
df_korean_answerable = [int(x) for x in df_ko["answerable"].to_list()]

korean_model_w_context = class_model_loader(df_korean, df_korean_answerable, device, korean_model_path, epochs=10, n_classes=2, model_lstm_dim=100)

In [None]:
# Telughu dataset
telugu_model_path = "cached_data/bilstm_class_telugu"
df_telu = pl.read_parquet(TELU_CACHE)
df_telugu = df_telu["question"].to_list()
df_telugu_answerable = [int(x) for x in df_telu["answerable"].to_list()]

telugu_model = class_model_loader(df_telugu, df_telugu_answerable, device, telugu_model_path, epochs=10, n_classes=2, model_lstm_dim=100)

In [None]:
# Telugu dataset with context
telugu_model_path = "cached_data/bilstm_class_telugu_w_context"
df_telu = pl.read_parquet(TELU_CACHE)
df_telugu = zip(df_telu["question"].to_list() , df_telu["context"].to_list())
df_telugu = ["[SEP]".join([q, c]) for q, c in df_telugu]
df_telugu_answerable = [int(x) for x in df_telu["answerable"].to_list()]

telugu_model_w_context = class_model_loader(df_telugu, df_telugu_answerable, device, telugu_model_path, epochs=10, n_classes=2, model_lstm_dim=100)

In [None]:
# Context dataset
df_arkote = pl.concat([
    df_ar,
    df_ko,
    df_telu
])
df_arkote_answerable = [int(x) for x in df_arkote["answerable"].to_list()]

context_model_path = "cached_data/bilstm_class_context"
df_context = df_arkote["context"].to_list()
context_model = class_model_loader(df_context, df_arkote_answerable, device, context_model_path, epochs=10, n_classes=2, model_lstm_dim=100)


In [None]:
def predict_answerable(model: BiLSTMClassifierModel, texts: list[str], tokenizer, device) -> list[int]:
    model.eval()
    tokens = tokenizer(
        texts,
        truncation=True,
        max_length=65,
        padding='max_length',
        return_tensors='pt'
    ).to(device)

    input_ids = tokens['input_ids']
    attention_mask = tokens['attention_mask']
    input_lens = attention_mask.sum(dim=1)

    with torch.no_grad():
        logits = model(input_ids, input_lens)
        probs = torch.softmax(logits, dim=1)
        preds = torch.argmax(probs, dim=1).cpu().tolist()

    return preds

In [None]:
def show_performance(true: list, pred: list):
    # Proportion of answerable predictions
    print(f"Answerable proportion: {true.count(1) / len(true):.2f}")
    print(f"Predicted answerable proportion: {pred.count(1) / len(pred):.2f}")

    # Evaluate the rule-based classification
    true_positives = sum(1 for t, p in zip(true, pred) if t == 1 and p == 1)
    false_positives = sum(1 for t, p in zip(true, pred) if t == 0 and p == 1)
    true_negatives = sum(1 for t, p in zip(true, pred) if t == 0 and p == 0)
    false_negatives = sum(1 for t, p in zip(true, pred) if t == 1 and p == 0)
    print(f"True Positives: {true_positives} , False Positives: {false_positives}")
    print(f"True Negatives: {true_negatives} , False Negatives: {false_negatives}")

    # Accuracy, Precision, Recall, F1 Score
    accuracy = (true_positives + true_negatives) / len(true)
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    print(f"Accuracy: {accuracy:.2f}")
    print(f"Precision: {precision:.2f}")
    print(f"Recall: {recall:.2f}")
    print(f"F1 Score: {f1_score:.2f}")

    # Plot confusion matrix
    confusion_matrix = [[true_positives, false_negatives],
                        [false_positives, true_negatives]]
    plt.imshow(confusion_matrix)
    plt.colorbar()
    plt.xticks([0, 1], ['P', 'N'])
    plt.yticks([0, 1], ['P', 'N'])
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.show()

In [None]:
print("Arabic model without context:")
show_performance(df_arabic_answerable, predict_answerable(arabic_model, df_arabic, mbert_tokenizer, device))
print("Arabic model with context:")
show_performance(df_arabic_answerable, predict_answerable(arabic_model_w_context, df_arabic, mbert_tokenizer, device))
print("Korean model without context:")
show_performance(df_korean_answerable, predict_answerable(korean_model, df_korean, mbert_tokenizer, device))
print("Korean model with context:")
show_performance(df_korean_answerable, predict_answerable(korean_model_w_context, df_korean, mbert_tokenizer, device))
print("Telugu model without context:")
show_performance(df_telugu_answerable, predict_answerable(telugu_model, df_telugu, mbert_tokenizer, device))
print("Telugu model with context:")
show_performance(df_telugu_answerable, predict_answerable(telugu_model_w_context, df_telugu, mbert_tokenizer, device))
print("Only context model:")
show_performance(df_arkote_answerable, predict_answerable(context_model, df_context, mbert_tokenizer, device))