In [1]:
import json
import sys
from importlib import reload

import polars as pl
import torch
from datatools import plotting as dtplot
from datatools import tabular as dttab
from sklearn import metrics
from torch.utils.data import DataLoader, Dataset

sys.path.append("..")

import plotting
from src import torch_util, text_process, util

reload(plotting)
reload(torch_util)
dtplot.set_plotly_template()

dev = torch_util.get_dev()
print("device:", dev)

device: cuda


## load data


In [2]:
reload(util)
split_idx = "0119"
tag_map = None


data = util.load_examples_json(
    tag_map=tag_map,
    # ["python", "pseudo", "rust"],
    split_idx_id=split_idx,
)


for k, df in data.items():
    print(k)
    print(end="\t")
    dttab.value_counts(df["lang"], verbose=True, as_dict=True)


Loaded 212 examples
    val: 53
    test: 39
    train: 120
val
	16 unique (lang):  'python', 'php', 'matlab', 'pseudo', 'js' ,...
test
	14 unique (lang):  'python', 'php', 'dart', 'pseudo', 'matlab' ,...
train
	16 unique (lang):  'python', 'php', 'matlab', 'dart', 'rust' ,...


## most common tokens


In [None]:
vocab, token2idx = util.make_vocab(data["train"])
leftover = (
    data["train"]
    .select(pl.col("tokens").explode())
    .filter(pl.col("tokens").is_in(vocab).not_())
    .group_by("tokens")
    .agg(pl.len().alias("count"))
    .sort("count", "tokens", descending=True)
)
leftover = leftover["tokens"].to_list()

print(f"VOCAB ({len(vocab)}):", vocab)
print(f"LEFT  ({len(leftover)}):", leftover)


VOCAB (130): ['<pad>', '<unk>', ' ', '\n', ')', '(', '=', '.', ',', ';', '    ', ':', ']', '[', '{', '  ', '\n\n', '}', 'return', 'if', '        ', '*', 'import', 'for', 'in', '-', 'int', '==', '::', '+', '>', '->', '<', 'while', 'let', 'function', 'from', 'as', '...', 'var', 'echo', 'const', '\\', '+=', '++', '**', '            ', '|', 'true', 'end', 'else', 'def', '<-', '/', '%', 'use', 'public', 'pub', 'class', 'True', 'void', 'str', 'null', 'new', 'mut', 'mod', 'float', 'extends', '^', '>>>', '>=', '<=', '-=', '&', '!=', 'xy', 'tight', 'synchronized', 'false', 'bool', 'axis', 'False', '.^', '      ', '   ', '~/', 'with', 'where', 'typeof', 'tuple', 'throws', 'then', 'r', 'private', 'package', 'on', 'not', 'next', 'local', 'is', 'instanceof', 'hold', 'func', 'foreach', 'fn', 'double', 'done', 'do', 'crate', 'continue', 'colorbar', 'char', 'break', 'async', '_', 'SELECT', 'FROM', '?', '=>', '===', '/=', '..', '--', "'", '!', '                ', '             ', '       ', '     ', '\

In [4]:
token_counts = dttab.value_counts(
    data["train"]["tokens"].explode(), verbose=True, as_dict=True
)

# tag counts for all data, closed tag set
tag_counts = dttab.value_counts(
    pl.concat([data["train"], data["val"], data["test"]])["tags"].explode(),
    verbose=True,
    as_dict=True,
)


643 unique (tokens):  ' ', '\n', ')', '(', '=' ,...
35 unique (tags):  'ws', 'va', 'brop', 'brcl', 'nl' ,...


In [5]:
# tags
tag_vocab = ["<pad>", "uk"] + list(tag_counts.keys())
tag2idx = {t: i for i, t in enumerate(tag_vocab)}

print(f"vocab ({len(vocab)} tokens):", vocab)
print(f"vocab ({len(tag_vocab)} tags)  :", tag_vocab)


vocab (130 tokens): ['<pad>', '<unk>', ' ', '\n', ')', '(', '=', '.', ',', ';', '    ', ':', ']', '[', '{', '  ', '\n\n', '}', 'return', 'if', '        ', '*', 'import', 'for', 'in', '-', 'int', '==', '::', '+', '>', '->', '<', 'while', 'let', 'function', 'from', 'as', '...', 'var', 'echo', 'const', '\\', '+=', '++', '**', '            ', '|', 'true', 'end', 'else', 'def', '<-', '/', '%', 'use', 'public', 'pub', 'class', 'True', 'void', 'str', 'null', 'new', 'mut', 'mod', 'float', 'extends', '^', '>>>', '>=', '<=', '-=', '&', '!=', 'xy', 'tight', 'synchronized', 'false', 'bool', 'axis', 'False', '.^', '      ', '   ', '~/', 'with', 'where', 'typeof', 'tuple', 'throws', 'then', 'r', 'private', 'package', 'on', 'not', 'next', 'local', 'is', 'instanceof', 'hold', 'func', 'foreach', 'fn', 'double', 'done', 'do', 'crate', 'continue', 'colorbar', 'char', 'break', 'async', '_', 'SELECT', 'FROM', '?', '=>', '===', '/=', '..', '--', "'", '!', '                ', '             ', '       ', '   

### class weights?


In [6]:
reload(dttab)


all_tags = data["train"]["tags"].explode()

tag_counts = dttab.value_counts(all_tags, sort_by="value", as_dict=True)


### Prepare data for model


In [7]:
_, token2idx = util.make_vocab(df)
tag_vocab = ["<pad>", "uk"] + df["tags"].explode().value_counts(sort=True)[
    "tags"
].to_list()
tag2idx = {t: i for i, t in enumerate(tag_vocab)}


# Create dataloaders

train_dl = torch_util.data2torch(data["train"], 4, token2idx, tag2idx, dev)
val_dl = torch_util.data2torch(data["val"], 8, token2idx, tag2idx, dev)


padded tensor: (120, 155) cuda:0
padded tensor: (120, 155) cuda:0
padded tensor: (120, 155) cuda:0
padded tensor: (53, 85) cuda:0
padded tensor: (53, 85) cuda:0
padded tensor: (53, 85) cuda:0


## model


In [39]:
from timeit import default_timer


reload(torch_util)

epochs = 2000
train_time = 5
modelname = f"lstm_{split_idx}"

if tag_map is not None:
    modelname += "_mapped"

# LOSS PARAMS
tag_weights = torch_util.class_weights(tag_counts, tag_vocab, 10.0)
label_smoothing = 0.1

constructor_params = {
    "token_vocab_size": len(vocab),
    "label_vocab_size": len(tag_vocab),
    "embedding_dim": 16,
    "hidden_dim": 64,
    "n_lstm_layers": 2,
    "dropout_lstm": 0.7,
    "bidi": True,
}

# Model with default params
model = torch_util.LSTMTagger(**constructor_params).to(dev)
loss_function = torch.nn.CrossEntropyLoss(
    weight=tag_weights,
    label_smoothing=label_smoothing,
).to(dev)
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
lr_s = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

# Training loop
losses_train, losses_val, val_accs = torch_util.train_loop(
    model,
    train_dl,
    val_dl,
    optimizer=optimizer,
    loss_function=loss_function,
    lr_s=lr_s,
    epochs=100,
    name=modelname,
    save_dir="../tmp/",
    print_interval=5,
    time_limit=8,
)


plotting.scatter(y=[losses_train, losses_val, val_accs]).show()

print(
    "final loss:\n", f"  train: {losses_train[-1]:.4f}", f"  val : {losses_val[-1]:.4f}"
)


   0/100, loss: 0.001645,                     acc: 90.23%
   5/100, loss: 0.001229, min-loss: 0.001229, acc: 96.51%
  10/100, loss: 0.001201, min-loss: 0.001201, acc: 96.98%
  15/100, loss: 0.001186, min-loss: 0.001179, acc: 97.25%
  20/100, loss: 0.001169, min-loss: 0.001169, acc: 97.56%
  25/100, loss: 0.001168, min-loss: 0.001165, acc: 97.56%
  30/100, loss: 0.001167,                     acc: 97.36%
  35/100, loss: 0.001164, min-loss: 0.001161, acc: 97.60%
  40/100, loss: 0.001157, min-loss: 0.001157, acc: 97.71%
  45/100, loss: 0.001157, min-loss: 0.001157, acc: 97.76%
  50/100, loss: 0.001156, min-loss: 0.001156, acc: 97.62%
  55/100, loss: 0.001159, min-loss: 0.001156, acc: 97.76%
  60/100, loss: 0.001159, min-loss: 0.001152, acc: 97.85%


final loss:
   train: 0.0011   val : 0.0012


In [9]:
# save model


# save metadata
metadata = {
    "vocab": vocab,
    "tag_vocab": tag_vocab,
    "tag_map": tag_map,
    "tag_weights": tag_weights,
    "split_idx": split_idx,
    "constructor_params": constructor_params,
}


# torch.save(model.state_dict(), f"../models/{modelname}_state.pth")
# with open(f"../models/{modelname}_meta.json", "w") as f:
#     json.dump(metadata, f)

In [30]:
torch_util.val_acc(model, val_dl.dataset)

0.978024423122406

## evaluate


In [None]:
model.eval()
with torch.no_grad():
    tag_scores = model(val_token_tensors, val_tag_det_tensors)
    predictions = torch.argmax(tag_scores, dim=-1)  # Shape: (batch_size, seq_len)

pred_tags = []
true_tags = []

for pred, true_t in zip(predictions, val_tag_true_idx):
    true_tags.extend([tag_vocab[t] for t in true_t])
    pred_tags.extend([tag_vocab[t] for t in pred[: len(true_t)]])

print(len(true_tags), len(pred_tags))

acc = metrics.accuracy_score(true_tags, pred_tags)
print("accuracy", acc)
f1_macro = metrics.f1_score(true_tags, pred_tags, average="macro")
print("F1_macro", f1_macro)

confmat = metrics.confusion_matrix(true_tags, pred_tags, labels=tag_vocab)

dtplot.heatmap(
    confmat,
    tag_vocab,
    log_scale=True,
    pseudo_count=10,
    size=400,
).show()

NameError: name 'val_token_tensors' is not defined

### eval only on non-det


In [13]:
pred_tags = []
true_tags = []

for pred, true_t, det in zip(
    predictions, val_tag_true_idx, val_tag_det_idx, strict=True
):
    for p, t, d in zip(pred, true_t, det):
        if tag_vocab[d] == "uk":
            tp = tag_vocab[p.item()]
            tt = tag_vocab[t]
            true_tags.append(tt)
            pred_tags.append(tp)
print(len(pred_tags), len(true_tags))

labels_left = sorted(set(pred_tags + true_tags))

acc = metrics.accuracy_score(true_tags, pred_tags)
print("accuracy", acc)
f1_macro = metrics.f1_score(true_tags, pred_tags, average="macro")
print("F1_macro", f1_macro)

confmat = metrics.confusion_matrix(true_tags, pred_tags, labels=labels_left)

dtplot.heatmap(
    confmat,
    labels_left,
    log_scale=True,
    pseudo_count=10,
    size=400,
).update_layout(title="Confusions (log(count + 10))").show()

474 474
accuracy 0.810126582278481
F1_macro 0.7211236431417442


## check output


In [14]:
# print(os.listdir())

model.eval()
with torch.no_grad():
    tag_scores = model(val_token_tensors, val_tag_det_tensors)
    predictions = torch.argmax(tag_scores, dim=-1)  # Shape: (batch_size, seq_len)

print("predictions", predictions.size())

outputs = {}
for ex, pred in zip(data["val"].iter_rows(named=True), predictions, strict=True):
    pred_tags = []
    for p in pred:
        if p == 0:
            break
        pred_tags.append(tag_vocab[p])
    assert len(ex["tokens"]) == len(pred_tags), "wrong length"
    pred_tags
    print("".join(ex["tokens"]).replace("\n", "\\\\"))
    print(ex["tags"])
    print(pred_tags)
    print()
    outputs[ex["name"] + "_" + ex["lang"]] = {"tokens": ex["tokens"], "tags": pred_tags}


predictions torch.Size([53, 85])
x=1
['va', 'opas', 'nu']
['pa', 'opas', 'nu']

[2, 3, 4]
['brop', 'nu', 'pu', 'ws', 'nu', 'pu', 'ws', 'nu', 'brcl']
['brop', 'nu', 'pu', 'ws', 'nu', 'pu', 'ws', 'nu', 'brcl']

say "Hello world"
['kwio', 'ws', 'st']
['cl', 'ws', 'st']

if x == 5:\\    print("five")
['kwfl', 'ws', 'va', 'ws', 'opcm', 'ws', 'nu', 'sy', 'nl', 'id', 'fnfr', 'brop', 'st', 'brcl']
['kwfl', 'ws', 'va', 'ws', 'opcm', 'ws', 'nu', 'sy', 'nl', 'id', 'fnfr', 'brop', 'st', 'brcl']

val num = reader.nextInt()
['kwva', 'ws', 'va', 'ws', 'opas', 'ws', 'va', 'sy', 'fnme', 'brop', 'brcl']
['cl', 'ws', 'va', 'ws', 'opas', 'ws', 'va', 'sy', 'fnas', 'brop', 'brcl']

examples = util.load_examples()
['va', 'ws', 'opas', 'ws', 'mo', 'sy', 'fnas', 'brop', 'brcl']
['va', 'ws', 'opas', 'ws', 'va', 'sy', 'fnas', 'brop', 'brcl']

val res = if (num % 2 == 0) "yes" else "no"
['kwva', 'ws', 'va', 'ws', 'opas', 'ws', 'kwfl', 'ws', 'brop', 'va', 'ws', 'opbi', 'ws', 'nu', 'ws', 'opcm', 'ws', 'nu', 'brcl',

AssertionError: wrong length

# parameter search


In [None]:
from coolsearch import search

reload(search)


def objective(
    embedding_dim,
    hidden_dim,
    n_lstm_layers,
    dropout_lstm,
    train_time,
    class_weight_smoothing,
    label_smoothing,
    bidi,
    lr_start,
    lr_gamma,
):
    """Train the model, with some hyperparameters, and evaluate a few metrics"""
    model = torch_util.LSTMTagger(
        len(vocab),
        len(tag_vocab),
        embedding_dim,
        hidden_dim,
        n_lstm_layers,
        dropout_lstm,
        bidi,
    ).to(dev)
    tag_weights = class_weights(tag_counts, tag_vocab, class_weight_smoothing)

    loss_function = torch.nn.CrossEntropyLoss(
        weight=tag_weights,
        label_smoothing=label_smoothing,
    ).to(dev)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr_start)
    lr_s = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=lr_gamma)

    # Training loop

    t_start = default_timer()
    while default_timer() < t_start + train_time:
        # TRAINING
        torch_util.run_epoch(model, train_dl, loss_function, optimizer)

        lr_s.step()

    model.eval()
    with torch.no_grad():
        train_loss = torch_util.run_epoch(model, train_dl, loss_function)
        val_loss = torch_util.run_epoch(model, val_dl, loss_function)
        tag_scores = model(val_token_tensors, val_tag_det_tensors)
        predictions = torch.argmax(tag_scores, dim=-1)

    pred_tags = []
    true_tags = []

    for pred, true_t in zip(predictions, val_tag_true_idx):
        true_tags.extend([tag_vocab[t] for t in true_t])
        pred_tags.extend([tag_vocab[t] for t in pred[: len(true_t)]])

    return {
        "_acc": metrics.accuracy_score(true_tags, pred_tags),
        "_f1_macro": metrics.f1_score(true_tags, pred_tags, average="macro"),
        "_train_loss": train_loss,
        "_val_loss": val_loss,
    }


params = {
    "embedding_dim": [16, 18],
    "hidden_dim": [64, 48],
    "n_lstm_layers": [1],
    "dropout_lstm": [0.0],
    "train_time": 10,
    "class_weight_smoothing": [10.0],
    "label_smoothing": [0.1],
    "bidi": True,
    "lr_start": [0.05],
    "lr_gamma": [0.99],
}


cs = search.CoolSearch(
    objective,
    params,
    n_jobs=1,
    samples_file=f"../search/{modelname}.csv",
)
cs.grid_search(3)
res = cs.samples.sort(-pl.col("_acc")).head(10)
display(res)
display(res.row(0, named=True))

In [None]:
marg = cs.marginals("_acc")

pars = list(marg.keys())

print(pars)

for k in pars:
    fig = plotting.scatter(x=marg[k][k], y=[marg[k]["max"], marg[k]["mean"]])
    fig.update_layout(width=400, height=200, title=k)
    fig.show()
