# LM for QA Tidy_XOR dataset

In [None]:
import polars as pl
from transformers import AutoModel, AutoTokenizer
from data.const import ARB_CACHE, KOR_CACHE, TELU_CACHE
import numpy as np
import torch
import torch.nn as nn
import math

In [None]:
df_arkote = pl.concat([
    pl.read_parquet(ARB_CACHE),
    pl.read_parquet(KOR_CACHE),
    pl.read_parquet(TELU_CACHE)
])

In [None]:
nllb_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
nllb_model = AutoModel.from_pretrained("facebook/nllb-200-distilled-600M")
pretrained_embeddings = nllb_model.get_input_embeddings().weight.data
print(pretrained_embeddings.shape)

In [None]:
# BiLSTM Language Model for sentence probability
import torch.nn as nn

class BiLSTMLanguageModel(nn.Module):
    def __init__(self, 
                 pretrained_embeddings: torch.tensor, 
                 lstm_dim: int, 
                 dropout_prob: float = 0.1
    ):
        """
        Initializer for basic BiLSTM network
        :param pretrained_embeddings: A tensor containing the pretrained BPE embeddings
        :param lstm_dim: The dimensionality of the BiLSTM network
        :param dropout_prob: Dropout probability
        """
        # First thing is to call the superclass initializer
        super().__init__()

        # Get vocab size and embedding dimension from the pretrained embeddings
        vocab_size = pretrained_embeddings.shape[0] # Size of the vocabulary
        embed_dim = pretrained_embeddings.shape[1] # Dimensionality of the embeddings

        # We'll define the network in a ModuleDict, which makes organizing the model a bit nicer
        # The components are an embedding layer, a 2 layer BiLSTM, and a feed-forward output layer
        self.model = nn.ModuleDict({
            'embeddings': nn.Embedding.from_pretrained(pretrained_embeddings, padding_idx=vocab_size - 1),
            'bilstm': nn.LSTM(embed_dim, lstm_dim, 1, batch_first=True, dropout=dropout_prob, bidirectional=True),
            'lm_head': nn.Linear(2 * lstm_dim, vocab_size)
        })
        self.n_classes = vocab_size
        self.dropout = nn.Dropout(p=dropout_prob)
        
        # Initialize the weights of the model
        self._init_weights()

    def _init_weights(self):
        all_params = list(self.model['bilstm'].named_parameters()) + \
                     list(self.model['lm_head'].named_parameters())
        for n, p in all_params:
            if 'weight' in n:
                nn.init.xavier_normal_(p)
            elif 'bias' in n:
                nn.init.zeros_(p)

    def forward(self, inputs, input_lens):
        """
        Defines how tensors flow through the model
        :param inputs: (b x sl) The IDs into the vocabulary of the input samples
        :param input_lens: (b) The length of each input sequence
        :return: logits
        """
        
        # Get embeddings (b x sl x edim)
        embeds = self.model['embeddings'](inputs)

        # Pack padded: This is necessary for padded batches input to an RNN
        lstm_in = nn.utils.rnn.pack_padded_sequence(
            embeds, 
            input_lens.cpu(), 
            batch_first=True, 
            enforce_sorted=False
        )

        # Pass the packed sequence through the BiLSTM
        lstm_out, _ = self.model['bilstm'](lstm_in)

        # Unpack the packed sequence
        lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)

        ff_in = self.dropout(lstm_out)

        logits = self.model['lm_head'](ff_in)

        return logits  # (batch, seq_len, vocab_size)

In [None]:
print(torch.cuda.is_available())
device = torch.device("cpu")
if torch.cuda.is_available():
  device = torch.device("cuda")
print(f'Using device: {device}')

In [None]:
lstm_dim = 100

model = BiLSTMLanguageModel(
    pretrained_embeddings=torch.FloatTensor(pretrained_embeddings),
    lstm_dim=lstm_dim
  ).to(device)


In [None]:
def accuracy(logits, targets):
    ...

In [None]:
def evaluate(model: nn.Module, valid_dl):
    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in valid_dl:
            # Move batch to device
            batch = tuple(t.to(device) for t in batch)
            inputs, input_lens, targets = batch
            # Get logits from the model
            logits = model(inputs, input_lens)
            # Get predictions
            preds = logits.argmax(dim=-1)
            # Create mask to ignore padding in accuracy calculation
            mask = (targets != logits.size(-1)-1)
            # Calculate number of correct predictions
            correct += ((preds == targets) & mask).sum().item()
            # Calculate total number of tokens
            total += mask.sum().item()

    return correct / total if total > 0 else 0.0

In [None]:
def train(
        model: nn.Module, 
        train_dl: torch.utils.data.DataLoader, 
        valid_dl: torch.utils.data.DataLoader,
        optimizer: torch.optim.Optimizer,
        n_epochs: int, 
        device: torch.device,
        patience: int = 10
    ):
    '''
    Train a language model with early stopping on validation accuracy.
    '''
    losses = []
    best_acc = 0.0
    pcounter = 0

    for ep in range(n_epochs):

        loss_epoch = []

        for batch in train_dl:
            model.train()
            optimizer.zero_grad()
            batch = tuple(t.to(device) for t in batch)
            inputs, input_lens, targets = batch
            logits = model(inputs, input_lens)
            # Mask out padding tokens using input_lens
            batch_size, seq_len, vocab_size = logits.size()
            mask = torch.arange(seq_len)[None, :].to(device) < input_lens[:, None]
            logits_flat = logits[mask]
            targets_flat = targets[mask]
            loss = nn.functional.cross_entropy(logits_flat, targets_flat, ignore_index=logits.size(-1)-1)
            loss_epoch.append(loss.item())
            loss.backward()
            optimizer.step()

        avg_train_loss = np.mean(loss_epoch)
        losses.append(avg_train_loss)

        acc = evaluate(model, valid_dl)
        print(f'Validation accuracy: {acc}, train loss: {sum(loss_epoch) / len(loss_epoch)}')

        # Keep track of the best model based on the accuracy
        if acc > best_acc:
            torch.save(model.state_dict(), 'best_model')
            best_acc = acc
            pcounter = 0
        else:
            pcounter += 1
            if pcounter == patience:
                break
    model.load_state_dict(torch.load('best_model'))
    return losses, best_acc


In [None]:
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

# Combine context and translation
texts = df_arkote['context'].to_list() + df_arkote['translation'].to_list()
texts_small = texts[:1000]

# Batch tokenize
tokens = nllb_tokenizer(
    texts_small,
    truncation=True,
    max_length=64,
    padding='max_length',
    return_tensors='pt'
 )

input_ids = tokens['input_ids']
attention_mask = tokens['attention_mask']
input_lens = attention_mask.sum(dim=1)

# Shift input_ids for targets
targets = input_ids.clone()
targets[:, :-1] = input_ids[:, 1:]
targets[:, -1] = nllb_tokenizer.pad_token_id

# Split into train and validation sets
train_idx, val_idx = train_test_split(
    range(input_ids.size(0)), test_size=0.2, random_state=42
 )

train_dataset = TensorDataset(
    input_ids[train_idx], input_lens[train_idx], targets[train_idx]
 )
val_dataset = TensorDataset(
    input_ids[val_idx], input_lens[val_idx], targets[val_idx]
 )
train_dl = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=8, shuffle=False)

losses, best_acc = train(model, train_dl, val_dl, torch.optim.Adam(model.parameters(), lr=1e-3), n_epochs=5, device=device)
print('Training complete. Best validation accuracy:', best_acc)

In [None]:
def sentence_log_probability(model, sentence_ids, sentence_len):
    '''
    Compute log-probability of a sentence under the language model.
    sentence_ids: torch.LongTensor of shape (1, seq_len)
    sentence_len: torch.LongTensor of shape (1,)
    '''
    model.eval()
    with torch.no_grad():
        logits = model(sentence_ids, sentence_len)
        probs = torch.log_softmax(logits, dim=-1)
        # Shift sentence_ids for next-token prediction
        target = sentence_ids[:, 1:]
        input_probs = probs[:, :-1, :]
        # Gather log-probs for actual next tokens
        sentence_token_log_probs = input_probs.gather(2, target.unsqueeze(-1)).squeeze(-1)
        # Mask for actual length
        mask = torch.arange(sentence_ids.size(1)-1)[None, :] < (sentence_len-1)[:, None]
        total_log_prob = sentence_token_log_probs[mask].sum().item()
    return total_log_prob

In [None]:
model.eval()

example_sentence = "This is a test sentence."

sentence_ids = nllb_tokenizer(example_sentence, return_tensors='pt')['input_ids']
sentence_len = torch.LongTensor([len(nllb_tokenizer(example_sentence)['input_ids'])])

with torch.no_grad():
    logits = model(sentence_ids, sentence_len)
    probs = torch.log_softmax(logits, dim=-1)
    print(probs.shape)
    print(probs)

In [None]:
def perplexity(model, tokenizer,sentence_list):
    '''
    Compute perplexity of a sentence under the language model.
    sentence_ids: torch.LongTensor of shape (1, seq_len)
    sentence_len: torch.LongTensor of shape (1,)
    '''
    props = 0
    tokens = 0
    for sentence in sentence_list:
        sentence_ids = tokenizer(sentence, return_tensors='pt')['input_ids']
        sentence_len = torch.LongTensor([len(tokenizer(sentence)['input_ids'])])
        log_prob = sentence_log_probability(model, sentence_ids, sentence_len)
        props += log_prob
        tokens += (sentence_len - 1).item()
    props = math.exp(props)
    return math.pow(props, 1 / tokens)

In [None]:
# Example usage of sentence_log_probability
example_sentence = "This is a test sentence."
log_prob = sentence_log_probability(
    model, 
    nllb_tokenizer(example_sentence, return_tensors='pt')['input_ids'], 
    torch.LongTensor([len(nllb_tokenizer(example_sentence)['input_ids'])])
)
print(f"Log probability of the sentence: {log_prob}")


In [None]:
test_corpus = " ".join(small_texts)
perplex = perplexity(
    model, 
    
)
print(f"Perplexity of the sentence: {perplex}")