# LM for QA Tidy_XOR dataset

In [None]:
import polars as pl
from transformers import AutoModel, AutoTokenizer
import torch
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
import os

from data.const import ARB_CACHE, KOR_CACHE, TELU_CACHE
from nlm.models import BiLSTMLanguageModel
from nlm.train_utils import train
from nlm.probs import sentence_log_probability, perplexity

In [None]:
mbert_tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-uncased")
mbert_model = AutoModel.from_pretrained("bert-base-multilingual-uncased")
pretrained_embeddings = mbert_model.get_input_embeddings().weight.data

In [None]:
device = torch.device("cpu")
if torch.cuda.is_available():
  device = torch.device("cuda")
print(f'Using device: {device}')

In [None]:

def dataloader_generator(dataset: list, tokenizer, device, test_split: float = 0.2, batch_size: int = 8) -> tuple[DataLoader, DataLoader]:

    tokens = tokenizer(
        dataset,
        truncation=True,
        max_length=65,
        padding='max_length',
        return_tensors='pt'
    ).to(device)

    input_ids = tokens['input_ids']
    attention_mask = tokens['attention_mask']
    input_lens = attention_mask.sum(dim=1)

    # Shift input_ids for targets
    targets = input_ids.clone()
    targets[:, :-1] = input_ids[:, 1:]
    targets[:, -1] = tokenizer.pad_token_id

    # Split into train and validation sets
    train_idx, val_idx = train_test_split(
        range(input_ids.size(0)), test_size=test_split, random_state=42
    )

    train_dataset = TensorDataset(
        input_ids[train_idx], input_lens[train_idx], targets[train_idx]
    )
    val_dataset = TensorDataset(
        input_ids[val_idx], input_lens[val_idx], targets[val_idx]
    )
    train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dl = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

    return train_dl, val_dl

In [None]:
def ml_model_loader(dataset: list, device, model_cache_path: str, epochs: int, model_lstm_dim: int = 100) -> tuple[BiLSTMLanguageModel, float, float]:
    model = BiLSTMLanguageModel(
        pretrained_embeddings=torch.FloatTensor(pretrained_embeddings),
        lstm_dim=model_lstm_dim
    ).to(device)

    if os.path.exists(model_cache_path):
        model.load_state_dict(torch.load(model_cache_path))
    else:
        print("No cached model found. Training a new model.")
        train_dl, val_dl = dataloader_generator(dataset, mbert_tokenizer, device)
        losses, best_acc = train(model, train_dl, val_dl, torch.optim.Adam(model.parameters(), lr=1e-3), n_epochs=epochs, device=device, save_path=model_cache_path)
        print('Training complete. Best validation accuracy:', best_acc)

    return model


In [None]:
# Arabic dataset
arabic_model_path = "cached_data/bilstm_lm_arabic"
df_ar = pl.read_parquet(ARB_CACHE)
df_arabic = df_ar["question"].to_list()

arabic_model = ml_model_loader(df_arabic, device, arabic_model_path, epochs=20)

In [None]:
# Korean dataset
korean_model_path = "cached_data/bilstm_lm_korean"
df_ko = pl.read_parquet(KOR_CACHE)
df_korean = df_ko["question"].to_list()

korean_model = ml_model_loader(df_korean, device, korean_model_path, epochs=20)

In [None]:
# Telughu dataset
telugu_model_path = "cached_data/bilstm_lm_telugu"
df_telu = pl.read_parquet(TELU_CACHE)
df_telugu = df_telu["question"].to_list()

telugu_model = ml_model_loader(df_telugu, device, telugu_model_path, epochs=20)

In [None]:
# Context dataset
df_arkote = pl.concat([
    df_ar,
    df_ko,
    df_telu
])

context_model_path = "cached_data/bilstm_lm_context"
df_context = df_arkote["context"].to_list()
context_model = ml_model_loader(df_context, device, context_model_path, epochs=3)


In [None]:
tst_stc = "I am Sam"
sentence_log_probability(context_model, device, mbert_tokenizer, tst_stc)

In [None]:
perplex_korean = perplexity(
    korean_model, 
    device,
    mbert_tokenizer,
    df_korean,
)
print(f"Perplexity of the Korean text: {perplex_korean}")

perplex_telugu = perplexity(
    telugu_model, 
    device,
    mbert_tokenizer,
    df_telugu,
)
print(f"Perplexity of the Telugu text: {perplex_telugu}")

perplex_arabic = perplexity(
    arabic_model, 
    device,
    mbert_tokenizer,
    df_arabic,
)
print(f"Perplexity of the Arabic text: {perplex_arabic}")

perplex_context = perplexity(
    context_model, 
    device,
    mbert_tokenizer,
    df_context,
)
print(f"Perplexity of the Context text: {perplex_context}")