# LM for QA Tidy_XOR dataset

In [None]:
import polars as pl
from transformers import AutoModel, AutoTokenizer
import torch
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
import os
from datasets import load_dataset
from nlm.models import BiLSTMLanguageModel
from nlm.train_utils import train_lm as train
from nlm.probs import sentence_log_probability, perplexity

In [None]:
# Load dataset
dataset = load_dataset("coastalcph/tydi_xor_rc")
df_train = pl.from_pandas(dataset["train"].to_pandas())
df_val = pl.from_pandas(dataset["validation"].to_pandas())

df_ar_train = df_train.filter(pl.col("lang") == "ar")
df_ko_train = df_train.filter(pl.col("lang") == "ko")
df_te_train = df_train.filter(pl.col("lang") == "te")
df_arkote_train = df_train.filter(pl.col("lang").is_in(["ar", "ko", "te"]))

df_ar_val = df_val.filter(pl.col("lang") == "ar")
df_ko_val = df_val.filter(pl.col("lang") == "ko")
df_te_val = df_val.filter(pl.col("lang") == "te")
df_arkote_val = df_val.filter(pl.col("lang").is_in(["ar", "ko", "te"]))

In [None]:
# Load mBERT tokenizer
mbert_tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-uncased")
mbert_model = AutoModel.from_pretrained("bert-base-multilingual-uncased")
pretrained_embeddings = mbert_model.get_input_embeddings().weight.data

In [None]:
# Select device for training
device = torch.device("cpu")
if torch.backends.mps.is_available():
    device = torch.device("mps")
if torch.cuda.is_available():
  device = torch.device("cuda")

print(f'Using device: {device}')

In [None]:
def dataloader_generator(train_dataset: list, val_dataset: list, tokenizer, device, batch_size: int = 8) -> tuple[DataLoader, DataLoader]:
    """
    Generate DataLoader objects for training and validation datasets for use in PyTorch models.
    """
    train_tokens = tokenizer(
        train_dataset,
        truncation=True,
        max_length=65,
        padding='max_length',
        return_tensors='pt'
    ).to(device)
    val_tokens = tokenizer(
        val_dataset,
        truncation=True,
        max_length=65,
        padding='max_length',
        return_tensors='pt'
    ).to(device)

    train_input_ids = train_tokens['input_ids']
    train_input_lens = train_tokens['attention_mask'].sum(dim=1)
    val_input_ids = val_tokens['input_ids']
    val_input_lens = val_tokens['attention_mask'].sum(dim=1)

    # Shift input_ids for targets
    train_targets = train_input_ids.clone()
    train_targets[:, :-1] = train_input_ids[:, 1:]
    train_targets[:, -1] = tokenizer.pad_token_id
    val_targets = val_input_ids.clone()
    val_targets[:, :-1] = val_input_ids[:, 1:]
    val_targets[:, -1] = tokenizer.pad_token_id

    train_dataset = TensorDataset(
        train_input_ids, train_input_lens, train_targets
    )
    val_dataset = TensorDataset(
        val_input_ids, val_input_lens, val_targets
    )
    train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dl = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

    return train_dl, val_dl


In [None]:
def ml_model_loader(train_dataset: list, val_dataset: list, device, model_cache_path: str, epochs: int, model_lstm_dim: int = 100) -> tuple[BiLSTMLanguageModel, float, float]:
    model = BiLSTMLanguageModel(
        pretrained_embeddings=torch.FloatTensor(pretrained_embeddings).to(device),
        lstm_dim=model_lstm_dim
    ).to(device)

    if os.path.exists(model_cache_path):
        print("Loading cached model from", model_cache_path)
        model.load_state_dict(torch.load(model_cache_path))
    else:
        print("No cached model found. Training a new model.")
        train_dl, val_dl = dataloader_generator(train_dataset, val_dataset, mbert_tokenizer, device)
        losses, best_acc = train(model, train_dl, val_dl, torch.optim.Adam(model.parameters(), lr=1e-3), n_epochs=epochs, device=device, save_path=model_cache_path)
        print('Training complete. Best validation accuracy:', best_acc)

    return model

In [None]:
# Arabic dataset
arabic_model_path = "cached_data/bilstm_lm_arabic"
df_arabic_train_questions = df_ar_train["question"].to_list()
df_arabic_val_questions = df_ar_val["question"].to_list()

arabic_model = ml_model_loader(df_arabic_train_questions, df_arabic_val_questions, device, arabic_model_path, epochs=3)

In [None]:
# Korean dataset
korean_model_path = "cached_data/bilstm_lm_korean"
df_korean_train_questions = df_ko_train["question"].to_list()
df_korean_val_questions = df_ko_val["question"].to_list()

korean_model = ml_model_loader(df_korean_train_questions, df_korean_val_questions, device, korean_model_path, epochs=3)

In [None]:
# Telughu dataset
telugu_model_path = "cached_data/bilstm_lm_telugu"
df_telugu_train_questions = df_te_train["question"].to_list()
df_telugu_val_questions = df_te_val["question"].to_list()

telugu_model = ml_model_loader(df_telugu_train_questions, df_telugu_val_questions, device, telugu_model_path, epochs=3)

In [None]:
# Context dataset
context_model_path = "cached_data/bilstm_lm_context"
df_context_train = df_arkote_train["context"].to_list()
df_context_val = df_arkote_val["context"].to_list()

context_model = ml_model_loader(df_context_train, df_context_val, device, context_model_path, epochs=3)

In [None]:
tst_stc = "I am Sam"
sentence_log_probability(context_model, device, mbert_tokenizer, tst_stc)

In [None]:
perplex_korean_train = perplexity(
    korean_model, 
    device,
    mbert_tokenizer,
    df_korean_train_questions
)
perplex_korean_val = perplexity(
    korean_model, 
    device,
    mbert_tokenizer,
    df_korean_val_questions
)
print(f"Perplexity of the Korean training text: {perplex_korean_train} and validation text: {perplex_korean_val}")

perplex_telugu_train = perplexity(
    telugu_model, 
    device,
    mbert_tokenizer,
    df_telugu_train_questions
)
perplex_telugu_val = perplexity(
    telugu_model, 
    device,
    mbert_tokenizer,
    df_telugu_val_questions
)
print(f"Perplexity of the Telugu training text: {perplex_telugu_train} and validation text: {perplex_telugu_val}")

perplex_arabic_train = perplexity(
    arabic_model, 
    device,
    mbert_tokenizer,
    df_arabic_train_questions
)
perplex_arabic_val = perplexity(
    arabic_model, 
    device,
    mbert_tokenizer,
    df_arabic_val_questions
)
print(f"Perplexity of the Arabic training text: {perplex_arabic_train} and validation text: {perplex_arabic_val}")

perplex_context_train = perplexity(
    context_model, 
    device,
    mbert_tokenizer,
    df_context_train
)
perplex_context_val = perplexity(
    context_model, 
    device,
    mbert_tokenizer,
    df_context_val
)
print(f"Perplexity of the Context training text: {perplex_context_train} and validation text: {perplex_context_val}")