# LM for QA Tidy_XOR dataset

In [None]:
import polars as pl
from transformers import AutoModel, AutoTokenizer
import torch
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
import os

from data.const import ARB_CACHE, KOR_CACHE, TELU_CACHE
from nlm.models import BiLSTMLanguageModel
from nlm.train_utils import train
from nlm.probs import sentence_log_probability, perplexity

In [None]:
df_arkote = pl.concat([
    pl.read_parquet(ARB_CACHE),
    pl.read_parquet(KOR_CACHE),
    pl.read_parquet(TELU_CACHE)
])

In [None]:
mbert_tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-uncased")
mbert_model = AutoModel.from_pretrained("bert-base-multilingual-uncased")
pretrained_embeddings = mbert_model.get_input_embeddings().weight.data
print(pretrained_embeddings.shape)

In [None]:
print(torch.cuda.is_available())
device = torch.device("cpu")
if torch.cuda.is_available():
  device = torch.device("cuda")
print(f'Using device: {device}')

In [None]:
lstm_dim = 100

model = BiLSTMLanguageModel(
    pretrained_embeddings=torch.FloatTensor(pretrained_embeddings),
    lstm_dim=lstm_dim
  ).to(device)


In [None]:
# TODO: Factor out data loading into DataLoader
# Combine context and translation
context = df_arkote["context"].to_list()

# Batch tokenize
tokens = mbert_tokenizer(
    context,
    truncation=True,
    max_length=64,
    padding='max_length',
    return_tensors='pt'
 ).to(device)

input_ids = tokens['input_ids']
attention_mask = tokens['attention_mask']
input_lens = attention_mask.sum(dim=1)

# Shift input_ids for targets
targets = input_ids.clone()
targets[:, :-1] = input_ids[:, 1:]
targets[:, -1] = mbert_tokenizer.pad_token_id

# Split into train and validation sets
train_idx, val_idx = train_test_split(
    range(input_ids.size(0)), test_size=0.2, random_state=42
 )

train_dataset = TensorDataset(
    input_ids[train_idx], input_lens[train_idx], targets[train_idx]
 )
val_dataset = TensorDataset(
    input_ids[val_idx], input_lens[val_idx], targets[val_idx]
 )
train_dl = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=8, shuffle=False)

In [None]:
if os.path.exists("cached_models/bilstm_lm"):
    model.load_state_dict(torch.load("cached_models/bilstm_lm"))
losses, best_acc = train(model, train_dl, val_dl, torch.optim.Adam(model.parameters(), lr=1e-3), n_epochs=5, device=device)
print('Training complete. Best validation accuracy:', best_acc)

In [None]:
tst_stc = "I am Sam"
sentence_log_probability(model, device, mbert_tokenizer, tst_stc)

In [None]:
perplex = perplexity(
    model, 
    device,
    mbert_tokenizer,
    context,
)
print(f"Perplexity of the sentence: {perplex}")