In [16]:
import json
import sys
from importlib import reload

import polars as pl
import torch
import torch.nn as nn
from datatools import plotting as dtplot
from datatools import tabular as dttab
from sklearn import metrics
from torch.utils.data import DataLoader, Dataset

sys.path.append("..")

import plotting
import util
from src import models_torch, text_process

reload(util)
reload(plotting)
reload(models_torch)
dtplot.set_plotly_template()

## load data


In [17]:
tag_map = util.MAP_TAGS

examples = util.load_examples(
    tag_map,
    # ["python", "pseudo", "rust"],
).sort("length")

examples = examples.with_columns(
    tags_det=pl.col("tokens").map_elements(
        lambda tks: text_process.process("".join(tks))[1], pl.List(pl.String)
    )
)
# display(examples)
lang_counts = dttab.value_counts(examples["lang"], verbose=True, as_dict=True)


21 unique (lang):  'python', 'matlab', 'pseudo', 'php', 'rust' ,...


In [18]:
# train-val split
train_df, val_df = util.data_split(examples, 0.3)
print(f"split: {len(train_df)} training, {len(val_df)} val")


splitted 89 & 39 (shuffled)
split: 89 training, 39 val


## most common tokens


In [19]:
token_counts = dttab.value_counts(
    examples["tokens"].explode(), verbose=True, as_dict=True
)
tag_counts = dttab.value_counts(examples["tags"].explode(), verbose=True, as_dict=True)


682 unique (tokens):  ' ', '\n', ',', ')', '(' ,...
14 unique (tags):  'ws', 'va', 'sy', 'nl', 'brop' ,...


## make a vocab!

- add padding to both tokens and tags
- also, convert tokens and tags to integers


In [20]:
# vocab for tokens
vocab = ["<pad>", "<unk>"] + list(token_counts.keys())[:10]
token2idx = {t: i for i, t in enumerate(vocab)}

# tags
tag_vocab = ["<pad>", "uk"] + list(tag_counts.keys())
tag2idx = {t: i for i, t in enumerate(tag_vocab)}

print("vocab (tokens):", vocab)
print("vocab (tags)  :", tag_vocab)

# Convert tokens and labels to indices
# these are lists of lists!
train_token_idx = [[token2idx.get(t, 1) for t in seq] for seq in train_df["tokens"]]
train_tag_true_idx = [[tag2idx[t] for t in seq] for seq in train_df["tags"]]
train_tag_det_idx = [[tag2idx.get(t, 1) for t in seq] for seq in train_df["tags_det"]]

# print("\nlists of lists:")
# print(train_token_idx)
# print(train_tag_idx)
print(f"\ntraining examples of length: {[len(e) for e in train_token_idx]}")

# validation data
val_token_idx = [[token2idx.get(t, 1) for t in seq] for seq in val_df["tokens"]]
val_tag_true_idx = [[tag2idx[t] for t in seq] for seq in val_df["tags"]]
val_tag_det_idx = [[tag2idx.get(t, 1) for t in seq] for seq in val_df["tags_det"]]
print(f"validation examples of length: {[len(e) for e in val_token_idx]}")


vocab (tokens): ['<pad>', '<unk>', ' ', '\n', ',', ')', '(', '=', '    ', '.', ';', '0']
vocab (tags)  : ['<pad>', 'uk', 'ws', 'va', 'sy', 'nl', 'brop', 'brcl', 'op', 'nu', 'kw', 'fn', 'st', 'cl', 'li', 'co']

training examples of length: [22, 9, 36, 26, 58, 15, 33, 73, 32, 25, 3, 39, 15, 15, 18, 5, 14, 19, 38, 44, 25, 13, 28, 59, 48, 8, 25, 16, 20, 9, 15, 11, 35, 10, 38, 28, 20, 29, 8, 8, 46, 22, 30, 3, 28, 13, 11, 52, 24, 29, 21, 26, 32, 17, 62, 60, 155, 26, 71, 12, 25, 45, 8, 9, 314, 29, 44, 42, 39, 22, 18, 17, 47, 44, 3, 30, 3, 14, 85, 8, 39, 47, 59, 63, 10, 51, 31, 98, 25]
validation examples of length: [31, 15, 26, 43, 85, 76, 25, 21, 92, 9, 11, 48, 25, 40, 17, 25, 32, 31, 53, 21, 19, 13, 37, 63, 91, 12, 32, 29, 38, 111, 17, 13, 14, 4, 33, 20, 25, 65, 95]


### class weights?


In [21]:
reload(dttab)


all_tags = train_df["tags"].explode()

tag_counts = dttab.value_counts(all_tags, sort_by="value", as_dict=True)


def class_weights(tag_counts: dict, tag_vocab: list[str], smoothing=0.1):
    tag_weights = [1 / tag_counts.get(k, torch.inf) + smoothing for k in tag_vocab]
    tag_weights = torch.tensor(tag_weights)
    tag_weights /= sum(tag_weights)
    return tag_weights


tag_weights = class_weights(tag_counts, tag_vocab, 30.0)

# for k in tag_vocab:
#     print(f"weight({k}) = ", tag_weights[tag2idx[k]])

### Prepare data for model


In [22]:
print("train:")
train_token_tensors = util.seqs2padded_tensor(train_token_idx)
train_tag_true_tensors = util.seqs2padded_tensor(train_tag_true_idx)
train_tag_det_tensors = util.seqs2padded_tensor(train_tag_det_idx)
print("val:")
val_token_tensors = util.seqs2padded_tensor(val_token_idx)
val_tag_true_tensors = util.seqs2padded_tensor(val_tag_true_idx)
val_tag_det_tensors = util.seqs2padded_tensor(val_tag_det_idx)

assert train_token_tensors.shape == train_tag_true_tensors.shape
assert train_tag_det_tensors.shape == train_tag_true_tensors.shape

assert val_token_tensors.shape == val_tag_true_tensors.shape
assert val_tag_det_tensors.shape == val_tag_true_tensors.shape


train:
padded tensor: (89, 314)
padded tensor: (89, 314)
padded tensor: (89, 314)
val:
padded tensor: (39, 111)
padded tensor: (39, 111)
padded tensor: (39, 111)


In [23]:
class SequenceDataset(Dataset):
    def __init__(self, tokens, labels_det, labels_true):
        print(len(tokens), len(labels_det), len(labels_true))
        self.tokens = tokens
        self.labels_det = labels_det
        self.labels_true = labels_true

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        return self.tokens[idx], self.labels_det[idx], self.labels_true[idx]


# Create dataset and dataloader

train_loader = DataLoader(
    SequenceDataset(train_token_tensors, train_tag_det_tensors, train_tag_true_tensors),
    batch_size=8,
    shuffle=True,
)
val_loader = DataLoader(
    SequenceDataset(val_token_tensors, val_tag_det_tensors, val_tag_true_tensors),
    batch_size=16,
    shuffle=False,
)


89 89 89
39 39 39


## model


In [24]:
# Parameters

epochs = 40

print("vocab lengths", len(vocab), len(tag_vocab))
reload(models_torch)

# Model with default params
model = models_torch.LSTMTagger(len(vocab), len(tag_vocab))

loss_function = nn.CrossEntropyLoss(weight=tag_weights)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

# Training loop

losses_train = []
losses_val = []

for epoch in range(epochs):
    # TRAINING
    train_loss = models_torch.run_epoch(model, train_loader, loss_function, optimizer)
    losses_train.append(train_loss)

    # VALIDATION
    with torch.no_grad():
        val_loss = models_torch.run_epoch(model, val_loader, loss_function)
        losses_val.append(val_loss)

    # print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")
plotting.scatter(y=[losses_train, losses_val]).show()

print(
    "final loss:\n", f"  train: {losses_train[-1]:.4f}", f"  val : {losses_val[-1]:.4f}"
)


vocab lengths 12 16


final loss:
   train: 0.0136   val : 0.0731


In [26]:
# save model


torch.save(model.state_dict(), "../models/lstmTagger_state.pth")
# save metadata
metadata = {"vocab": vocab, "tag_vocab": tag_vocab, "tag_map": tag_map}
with open("../models/lstmTagger_vocabs.json", "w") as f:
    json.dump(metadata, f)

## evaluate


In [27]:
model.eval()
with torch.no_grad():
    tag_scores = model(val_token_tensors, val_tag_det_tensors)
    predictions = torch.argmax(tag_scores, dim=-1)  # Shape: (batch_size, seq_len)

pred_tags = []
true_tags = []

for pred, true_t in zip(predictions, val_tag_true_idx):
    true_tags.extend([tag_vocab[t] for t in true_t])
    pred_tags.extend([tag_vocab[t] for t in pred[: len(true_t)]])

print(len(true_tags), len(pred_tags))

acc = metrics.accuracy_score(true_tags, pred_tags)
print("accuracy", acc)
f1_macro = metrics.f1_score(true_tags, pred_tags, average="macro")
print("F1_macro", f1_macro)

confmat = metrics.confusion_matrix(true_tags, pred_tags, labels=tag_vocab)

dtplot.heatmap(
    confmat,
    tag_vocab,
    log_scale=True,
    pseudo_count=10,
    size=400,
).show()

1457 1457
accuracy 0.938229238160604
F1_macro 0.7256484603769511


### eval only on non-det


In [28]:
pred_tags = []
true_tags = []

for pred, true_t, det in zip(
    predictions, val_tag_true_idx, val_tag_det_idx, strict=True
):
    for p, t, d in zip(pred, true_t, det):
        if tag_vocab[d] == "uk":
            tp = tag_vocab[p.item()]
            tt = tag_vocab[t]
            # if tp != tt:
            #     print(f"{repr(tt)} -> {repr(tp)}")
            true_tags.append(tt)
            pred_tags.append(tp)
print(len(pred_tags), len(true_tags))

labels_left = sorted(set(pred_tags + true_tags))

acc = metrics.accuracy_score(true_tags, pred_tags)
print("accuracy", acc)

confmat = metrics.confusion_matrix(true_tags, pred_tags, labels=labels_left)

dtplot.heatmap(
    confmat,
    labels_left,
    log_scale=True,
    pseudo_count=10,
    size=400,
).show()

698 698
accuracy 0.8739255014326648


## save output


In [13]:
import json

# print(os.listdir())

model.eval()
with torch.no_grad():
    tag_scores = model(val_token_tensors, val_tag_det_tensors)
    predictions = torch.argmax(tag_scores, dim=-1)  # Shape: (batch_size, seq_len)

print("predictions", predictions.size())

outputs = {}
for ex, pred in zip(val_df.iter_rows(named=True), predictions, strict=True):
    pred_tags = []
    for p in pred:
        if p == 0:
            break
        pred_tags.append(tag_vocab[p])
    assert len(ex["tokens"]) == len(pred_tags), "wrong length"
    pred_tags
    print("".join(ex["tokens"]).replace("\n", "\\\\"))
    print(ex["tags"])
    print(pred_tags)
    print()
    outputs[ex["name"]] = {"tokens": ex["tokens"], "tags": pred_tags}
with open("../output/pred_output.json", "w", encoding="utf-8") as f:
    json.dump(outputs, f)

predictions torch.Size([39, 47])
int main() {\\    std::cout << "Hello";\\    return 0;\\}
['cl', 'ws', 'fn', 'brop', 'brcl', 'ws', 'brop', 'nl', 'ws', 'va', 'sy', 'va', 'ws', 'op', 'ws', 'st', 'sy', 'nl', 'ws', 'kw', 'ws', 'nu', 'sy', 'nl', 'brcl']
['kw', 'ws', 'fn', 'brop', 'brcl', 'ws', 'brop', 'nl', 'ws', 'va', 'sy', 'va', 'ws', 'op', 'ws', 'st', 'sy', 'nl', 'ws', 'kw', 'ws', 'nu', 'sy', 'nl', 'brcl']

examples = util.load_examples()
['va', 'ws', 'op', 'ws', 'va', 'sy', 'fn', 'brop', 'brcl']
['va', 'ws', 'op', 'ws', 'va', 'sy', 'fn', 'brop', 'brcl']

prod = 1\\for i = 1,2,3\\    prod = prod * i\\return prod
['va', 'ws', 'op', 'ws', 'nu', 'nl', 'kw', 'ws', 'va', 'ws', 'op', 'ws', 'nu', 'sy', 'nu', 'sy', 'nu', 'nl', 'ws', 'va', 'ws', 'op', 'ws', 'va', 'ws', 'op', 'ws', 'va', 'nl', 'kw', 'ws', 'va']
['va', 'ws', 'op', 'ws', 'nu', 'nl', 'kw', 'ws', 'va', 'ws', 'op', 'ws', 'nu', 'sy', 'nu', 'sy', 'nu', 'nl', 'ws', 'va', 'ws', 'op', 'ws', 'va', 'ws', 'kw', 'ws', 'va', 'nl', 'kw', 'ws', '

# parameter search


In [13]:
from coolsearch import search

reload(search)


def objective(
    embedding_dim,
    hidden_dim,
    n_lstm_layers,
    dropout_lstm,
    epochs,
    class_weight_smoothing,
    bidi,
):
    model = models_torch.LSTMTagger(
        len(vocab),
        len(tag_vocab),
        embedding_dim,
        hidden_dim,
        n_lstm_layers,
        dropout_lstm,
        bidi,
    )
    tag_weights = class_weights(tag_counts, tag_vocab, class_weight_smoothing)

    loss_function = nn.CrossEntropyLoss(weight=tag_weights)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

    # Training loop

    for _ in range(epochs):
        # TRAINING
        models_torch.run_epoch(model, train_loader, loss_function, optimizer)

    model.eval()
    with torch.no_grad():
        train_loss = models_torch.run_epoch(model, train_loader, loss_function)
        val_loss = models_torch.run_epoch(model, val_loader, loss_function)
        tag_scores = model(val_token_tensors, val_tag_det_tensors)
        predictions = torch.argmax(tag_scores, dim=-1)

    pred_tags = []
    true_tags = []

    for pred, true_t in zip(predictions, val_tag_true_idx):
        true_tags.extend([tag_vocab[t] for t in true_t])
        pred_tags.extend([tag_vocab[t] for t in pred[: len(true_t)]])

    return {
        "_acc": metrics.accuracy_score(true_tags, pred_tags),
        "_f1_macro": metrics.f1_score(true_tags, pred_tags, average="macro"),
        "_train_loss": train_loss,
        "_val_loss": val_loss,
    }


params = {
    "embedding_dim": [6, 8, 10],
    "hidden_dim": [64],
    "n_lstm_layers": [2],
    "dropout_lstm": [0.3, 0.5],
    "epochs": 30,
    "class_weight_smoothing": [10.0, 30.0],
    "bidi": True,
}


cs = search.CoolSearch(objective, params, n_jobs=1, samples_file="../search/_.csv")


In [14]:
cs.grid_search(3)
display(cs.samples.sort(-pl.col("_f1_macro")))


Searching 12 new parameter points


4it [00:21,  5.44s/it]


KeyboardInterrupt: 

In [None]:
plotting.scatter(cs.samples["val_loss"], [cs.samples["val_acc"]]).update_traces(
    mode="markers"
).update_layout(width=200, height=100)

In [None]:
marg = cs.marginals("val_acc")

pars = list(marg.keys())

print(pars)
k = pars[1]
print(k)
plotting.scatter(x=marg[k][k], y=[marg[k]["max"], marg[k]["mean"]])
