In [1]:
import sys
from importlib import reload

import numpy as np
import polars as pl
import torch
import torch.nn as nn
from sklearn import metrics
from torch.utils.data import DataLoader, Dataset

sys.path.append("..")
from datatools import tabular as dttab, plotting as dtplot


import plotting
import util

reload(util)
reload(plotting)

dtplot.set_plotly_template()

## load data


In [2]:
examples = util.load_examples().sort("length")

lang_counts = dttab.value_counts(examples["lang"], verbose=True, as_dict=True)
print(lang_counts)
# display(examples)

21 unique (lang):  'python', 'matlab', 'pseudo', 'php', 'rust' ,...
{'python': 37, 'matlab': 18, 'pseudo': 13, 'php': 12, 'rust': 10, 'csharp': 5, 'ts': 4, 'cpp': 4, 'lua': 3, 'dart': 3, 'c': 3, 'sql': 2, 'r': 2, 'kotlin': 2, 'js': 2, 'go': 2, 'ruby': 1, 'natural': 1, 'json': 1, 'java': 1, 'bash': 1}


In [3]:
# train-val split
train_df, val_df = util.data_split(examples, 0.3)
print(f"split: {len(train_df)} training, {len(val_df)} val")


splitted 88 & 39 (shuffled)
split: 88 training, 39 val


## most common tokens


In [4]:
token_counts = dttab.value_counts(
    examples["tokens"].explode(), verbose=True, as_dict=True
)
tag_counts = dttab.value_counts(examples["tags"].explode(), verbose=True, as_dict=True)


671 unique (tokens):  ' ', '\n', ',', ')', '(' ,...
33 unique (tags):  'ws', 'va', 'pu', 'nl', 'brop' ,...


## make a vocab!

- add padding to both tokens and tags
- also, convert tokens and tags to integers


In [5]:
# vocab for tokens
vocab = ["<pad>", "<unk>"] + list(token_counts.keys())[:10]
token2idx = {t: i for i, t in enumerate(vocab)}

# tags
tag_vocab = ["<pad>"] + list(tag_counts.keys())
tag2idx = {t: i for i, t in enumerate(tag_vocab)}

print("vocab (tokens):", vocab)
print("vocab (tags)  :", tag_vocab)

# Convert tokens and labels to indices
# these are lists of lists!
train_token_idx = [[token2idx.get(t, 1) for t in seq] for seq in train_df["tokens"]]
train_tag_idx = [[tag2idx[t] for t in seq] for seq in train_df["tags"]]
# print("\nlists of lists:")
# print(train_token_idx)
# print(train_tag_idx)
print(f"\ntraining examples of length: {[len(e) for e in train_token_idx]}")

# validation data
val_token_idx = [[token2idx.get(t, 1) for t in seq] for seq in val_df["tokens"]]
val_tag_idx = [[tag2idx[t] for t in seq] for seq in val_df["tags"]]
print(f"validation examples of length: {[len(e) for e in val_token_idx]}")


vocab (tokens): ['<pad>', '<unk>', ' ', '\n', ',', ')', '(', '=', '.', ';', '    ', '0']
vocab (tags)  : ['<pad>', 'ws', 'va', 'pu', 'nl', 'brop', 'brcl', 'nu', 'sy', 'opas', 'id', 'mo', 'fnfr', 'st', 'opbi', 'kwfl', 'kwim', 'fnas', 'pa', 'cl', 'fnme', 'kwty', 'at', 'kwop', 'opun', 'opcm', 'kwva', 'bo', 'kwio', 'kwmo', 'kwfn', 'kwde', 'cofl', 'li']

training examples of length: [15, 12, 12, 44, 91, 47, 40, 25, 33, 44, 73, 38, 35, 25, 13, 36, 14, 37, 29, 92, 63, 4, 15, 60, 52, 26, 13, 17, 76, 32, 33, 31, 30, 10, 8, 39, 46, 111, 44, 10, 48, 42, 26, 14, 62, 26, 15, 9, 39, 38, 28, 3, 45, 28, 47, 18, 63, 29, 5, 29, 95, 21, 13, 28, 3, 22, 25, 9, 14, 25, 85, 18, 8, 65, 8, 22, 20, 31, 32, 32, 51, 17, 22, 8, 8, 13, 32, 53]
validation examples of length: [11, 17, 3, 11, 11, 24, 20, 39, 25, 31, 3, 25, 25, 20, 21, 314, 19, 48, 58, 30, 155, 17, 25, 9, 59, 19, 15, 26, 38, 85, 15, 71, 9, 16, 98, 21, 43, 29, 25]


### Prepare data for model


In [6]:
train_token_tensors = util.seqs2padded_tensor(train_token_idx)
train_tag_tensors = util.seqs2padded_tensor(train_tag_idx)
val_token_tensors = util.seqs2padded_tensor(val_token_idx)
val_tag_tensors = util.seqs2padded_tensor(val_tag_idx)

print(f"token tensor (train): {train_token_tensors.shape}")
print(f"tag tensor   (train): {train_tag_tensors.shape}")
print(f"token tensor (val): {val_token_tensors.shape}")
print(f"tag tensor   (val): {val_tag_tensors.shape}")

padded tensor: (88, 111)
padded tensor: (88, 111)
padded tensor: (39, 314)
padded tensor: (39, 314)
token tensor (train): torch.Size([88, 111])
tag tensor   (train): torch.Size([88, 111])
token tensor (val): torch.Size([39, 314])
tag tensor   (val): torch.Size([39, 314])


In [7]:
class SequenceDataset(Dataset):
    def __init__(self, tokens, labels):
        self.tokens = tokens
        self.labels = labels

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        return self.tokens[idx], self.labels[idx]


# Create dataset and dataloader
train_dataset = SequenceDataset(train_token_tensors, train_tag_tensors)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)


## model


In [8]:
class LSTMTagger(nn.Module):
    def __init__(self, vocab_size, tagset_size, embedding_dim, hidden_dim):
        super(LSTMTagger, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)

    def forward(self, sentences):
        embeds = self.embedding(sentences)
        # Shape: (batch_size, seq_len, embedding_dim)
        lstm_out, _ = self.lstm(embeds)
        # Shape: (batch_size, seq_len, hidden_dim)
        tag_scores = self.hidden2tag(lstm_out)
        # Shape: (batch_size, seq_len, tagset_size)
        return tag_scores


In [9]:
# Parameters
embedding_dim = 32
hidden_dim = 32
vocab_size = len(vocab)
tagset_size = len(tag_vocab)

model = LSTMTagger(vocab_size, tagset_size, embedding_dim, hidden_dim)
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

# Training loop
epochs = 40
losses = []
for epoch in range(epochs):
    for examples, labels in train_loader:
        model.zero_grad()

        # Forward pass
        tag_scores = model(examples)

        # Reshape for loss calculation
        loss = loss_function(tag_scores.view(-1, tagset_size), labels.view(-1))

        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

    # print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")
plotting.scatter(y=[losses])

## evaluate


In [10]:
model.eval()
with torch.no_grad():
    tag_scores = model(val_token_tensors)
    predictions = torch.argmax(tag_scores, dim=-1)  # Shape: (batch_size, seq_len)

pred_tags = []
true_tags = []

for pred, true_t in zip(predictions, val_tag_idx):
    # print(pred.shape)
    # print((len(true_tags)))
    true_tags.extend([tag_vocab[t] for t in true_t])
    pred_tags.extend([tag_vocab[t] for t in pred[: len(true_t)]])

print(len(true_tags), len(pred_tags))


def eval(true_tags, pred_tags):
    acc = metrics.accuracy_score(true_tags, pred_tags)
    print("accuracy", acc)

    confmat = metrics.confusion_matrix(true_tags, pred_tags, labels=tag_vocab)

    dtplot.heatmap(
        confmat,
        tag_vocab,
        log_scale=True,
        pseudo_count=10,
        size=500,
    ).show()


eval(true_tags, pred_tags)

1530 1530
accuracy 0.7215686274509804


In [11]:
from src import text_process

pred_tags_det = (
    val_df["tokens"]
    .map_elements(lambda tks: text_process.process("".join(tks))[1], pl.List(pl.String))
    .explode()
    .to_list()
)

# put all unknown into some class
# pred_tags_det = [t if t != "uk" else "va" for t in pred_tags_det]

print(len(true_tags), len(pred_tags_det))
print(true_tags)
print(pred_tags_det)


eval(true_tags, pred_tags_det)


1530 1530
['kwva', 'ws', 'va', 'ws', 'opas', 'ws', 'va', 'sy', 'fnme', 'brop', 'brcl', 'kwty', 'ws', 'va', 'ws', 'opas', 'ws', 'st', 'pu', 'nl', 'kwty', 'ws', 'va', 'ws', 'opas', 'ws', 'nu', 'pu', 'va', 'opas', 'nu', 'sy', 'sy', 'kwde', 'ws', 'fnfr', 'brop', 'brcl', 'pu', 'ws', 'sy', 'sy', 'va', 'ws', 'opas', 'ws', 'va', 'brop', 'nu', 'brcl', 'brop', 'st', 'brcl', 'kwfl', 'ws', 'brop', 'va', 'ws', 'opcm', 'ws', 'va', 'brcl', 'ws', 'brop', 'nl', 'id', 'cl', 'sy', 'at', 'sy', 'fnme', 'brop', 'st', 'brcl', 'pu', 'nl', 'brcl', 'kwim', 'ws', 'kwim', 'sy', 'brop', 'mo', 'sy', 'fnas', 'pu', 'ws', 'mo', 'sy', 'fnas', 'pu', 'ws', 'mo', 'sy', 'fnas', 'brcl', 'pu', 'va', 'ws', 'opas', 'ws', 'fnfr', 'brop', 'nl', 'id', 'st', 'ws', 'opas', 'ws', 'st', 'pu', 'nl', 'id', 'st', 'ws', 'opas', 'ws', 'st', 'pu', 'nl', 'id', 'st', 'ws', 'opas', 'ws', 'st', 'pu', 'nl', 'id', 'st', 'ws', 'opas', 'ws', 'st', 'nl', 'brcl', 'va', 'ws', 'opas', 'ws', 'brop', 'st', 'pu', 'ws', 'st', 'brcl', 'ws', 'opbi', 'ws', '

## simple fill


In [12]:
pred_filled = []
for p, pd in zip(pred_tags, pred_tags_det, strict=True):
    pred_filled.append(pd if pd != "uk" else p)

print(pred_filled)
eval(true_tags, pred_filled)

['va', 'ws', 'mo', 'ws', 'opas', 'ws', 'nu', 'sy', 'fnas', 'brop', 'brcl', 'va', 'ws', 'mo', 'ws', 'opas', 'ws', 'st', 'pu', 'nl', 'va', 'ws', 'va', 'ws', 'opas', 'ws', 'nu', 'pu', 'va', 'opas', 'nu', 'va', 'sy', 'va', 'ws', 'mo', 'brop', 'brcl', 'pu', 'ws', 'va', 'brcl', 'va', 'ws', 'opas', 'ws', 'nu', 'brop', 'nu', 'brcl', 'brop', 'st', 'brcl', 'va', 'ws', 'brop', 'va', 'ws', 'va', 'ws', 'opcm', 'brcl', 'ws', 'brop', 'nl', 'id', 'kwfl', 'sy', 'fnme', 'sy', 'fnme', 'brop', 'st', 'brcl', 'pu', 'nl', 'brcl', 'va', 'ws', 'mo', 'sy', 'brop', 'sy', 'sy', 'brcl', 'pu', 'ws', 'st', 'sy', 'va', 'pu', 'ws', 'st', 'sy', 'va', 'brcl', 'pu', 'va', 'ws', 'opas', 'ws', 'kwim', 'brop', 'nl', 'id', 'st', 'ws', 'opas', 'ws', 'st', 'pu', 'nl', 'id', 'st', 'ws', 'opas', 'ws', 'st', 'pu', 'nl', 'id', 'st', 'ws', 'opas', 'ws', 'st', 'pu', 'nl', 'id', 'st', 'ws', 'opas', 'ws', 'st', 'nl', 'brcl', 'va', 'ws', 'opas', 'ws', 'brop', 'st', 'pu', 'ws', 'st', 'brcl', 'ws', 'brop', 'ws', 'va', 'brop', 'va', 'sy',