In [1]:
import json
import sys
from importlib import reload

import polars as pl
import torch
from datatools import plotting as dtplot
from datatools import tabular as dttab
from sklearn import metrics
from torch.utils.data import DataLoader, Dataset

sys.path.append("..")

import plotting
from src import models_torch, text_process, util

reload(plotting)
reload(models_torch)
dtplot.set_plotly_template()

if torch.cuda.is_available():
    dev = "cuda"
else:
    dev = "cpu"

print("device:", dev)

device: cuda


## load data


In [3]:
reload(util)
tag_map = None
split_idx = "0119"
data = util.load_examples_json(
    tag_map=tag_map,
    # ["python", "pseudo", "rust"],
    split_idx_id=split_idx,
)

data = {
    k: df.with_columns(
        tags_det=pl.col("tokens").map_elements(
            lambda tks: text_process.process("".join(tks))[1], pl.List(pl.String)
        )
    )
    for k, df in data.items()
}

for k, df in data.items():
    print(k)
    print(end="\t")
    dttab.value_counts(df["lang"], verbose=True, as_dict=True)


Loaded 212 examples
    val: 53
    test: 39
    train: 120
val
	16 unique (lang):  'python', 'php', 'matlab', 'pseudo', 'js' ,...
test
	14 unique (lang):  'python', 'php', 'dart', 'pseudo', 'matlab' ,...
train
	16 unique (lang):  'python', 'php', 'matlab', 'dart', 'rust' ,...


## most common tokens


In [4]:
# allow these in vocab
VOCAB_TAGS = [
    "kwfl",
    "kwty",
    "kwop",
    "kwmo",
    "kwva",
    "kwde",
    "kwfn",
    "kwim",
    "kwio",
    "id",
    "ws",
    "nl",
    "brop",
    "brcl",
    "sy",
    "pu",
    "bo",
    "li",
    "opcm",
    "opbi",
    "opun",
    "opas",
    "an",
    "uk",
]


def make_vocab(examples: pl.DataFrame, insert=["<pad>", "<unk>"]):
    """Make vocab, and inverse map"""
    vocab_cands = (
        examples.select(pl.col("tokens", "tags").explode())
        .filter(pl.col("tags").is_in(VOCAB_TAGS))
        .group_by("tokens")
        .agg(pl.len().alias("count"))
        .sort("count", "tokens", descending=True)
    )
    vocab = insert + vocab_cands["tokens"].to_list()
    token2idx = {t: i for i, t in enumerate(vocab)}

    return vocab, token2idx


vocab, token2idx = make_vocab(data["train"])
leftover = (
    data["train"]
    .select(pl.col("tokens").explode())
    .filter(pl.col("tokens").is_in(vocab).not_())
    .group_by("tokens")
    .agg(pl.len().alias("count"))
    .sort("count", "tokens", descending=True)
)
leftover = leftover["tokens"].to_list()

print(f"VOCAB ({len(vocab)}):", vocab)
print(f"LEFT  ({len(leftover)}):", leftover)


VOCAB (130): ['<pad>', '<unk>', ' ', '\n', ')', '(', '=', '.', ',', ';', '    ', ':', ']', '[', '{', '  ', '\n\n', '}', 'return', 'if', '        ', '*', 'import', 'for', 'in', '-', 'int', '==', '::', '+', '>', '->', '<', 'while', 'let', 'function', 'from', 'as', '...', 'var', 'echo', 'const', '\\', '+=', '++', '**', '            ', '|', 'true', 'end', 'else', 'def', '<-', '/', '%', 'use', 'public', 'pub', 'class', 'True', 'void', 'str', 'null', 'new', 'mut', 'mod', 'float', 'extends', '^', '>>>', '>=', '<=', '-=', '&', '!=', 'xy', 'tight', 'synchronized', 'false', 'bool', 'axis', 'False', '.^', '      ', '   ', '~/', 'with', 'where', 'typeof', 'tuple', 'throws', 'then', 'r', 'private', 'package', 'on', 'not', 'next', 'local', 'is', 'instanceof', 'hold', 'func', 'foreach', 'fn', 'double', 'done', 'do', 'crate', 'continue', 'colorbar', 'char', 'break', 'async', '_', 'SELECT', 'FROM', '?', '=>', '===', '/=', '..', '--', "'", '!', '                ', '             ', '       ', '     ', '\

In [5]:
token_counts = dttab.value_counts(
    data["train"]["tokens"].explode(), verbose=True, as_dict=True
)

# tag counts for all data, closed tag set
tag_counts = dttab.value_counts(
    pl.concat([data["train"], data["val"], data["test"]])["tags"].explode(),
    verbose=True,
    as_dict=True,
)


643 unique (tokens):  ' ', '\n', ')', '(', '=' ,...
35 unique (tags):  'ws', 'va', 'brop', 'brcl', 'nl' ,...


## make a vocab!

- add padding to both tokens and tags
- also, convert tokens and tags to integers


In [6]:
# tags
tag_vocab = ["<pad>", "uk"] + list(tag_counts.keys())
tag2idx = {t: i for i, t in enumerate(tag_vocab)}

print(f"vocab ({len(vocab)} tokens):", vocab)
print(f"vocab ({len(tag_vocab)} tags)  :", tag_vocab)

# Convert tokens and labels to indices
# these are lists of lists!
train_token_idx = [
    [token2idx.get(t, 1) for t in seq] for seq in data["train"]["tokens"]
]
train_tag_true_idx = [[tag2idx[t] for t in seq] for seq in data["train"]["tags"]]
train_tag_det_idx = [
    [tag2idx.get(t, 1) for t in seq] for seq in data["train"]["tags_det"]
]


# validation data
val_token_idx = [[token2idx.get(t, 1) for t in seq] for seq in data["val"]["tokens"]]
val_tag_true_idx = [[tag2idx[t] for t in seq] for seq in data["val"]["tags"]]
val_tag_det_idx = [[tag2idx.get(t, 1) for t in seq] for seq in data["val"]["tags_det"]]


vocab (130 tokens): ['<pad>', '<unk>', ' ', '\n', ')', '(', '=', '.', ',', ';', '    ', ':', ']', '[', '{', '  ', '\n\n', '}', 'return', 'if', '        ', '*', 'import', 'for', 'in', '-', 'int', '==', '::', '+', '>', '->', '<', 'while', 'let', 'function', 'from', 'as', '...', 'var', 'echo', 'const', '\\', '+=', '++', '**', '            ', '|', 'true', 'end', 'else', 'def', '<-', '/', '%', 'use', 'public', 'pub', 'class', 'True', 'void', 'str', 'null', 'new', 'mut', 'mod', 'float', 'extends', '^', '>>>', '>=', '<=', '-=', '&', '!=', 'xy', 'tight', 'synchronized', 'false', 'bool', 'axis', 'False', '.^', '      ', '   ', '~/', 'with', 'where', 'typeof', 'tuple', 'throws', 'then', 'r', 'private', 'package', 'on', 'not', 'next', 'local', 'is', 'instanceof', 'hold', 'func', 'foreach', 'fn', 'double', 'done', 'do', 'crate', 'continue', 'colorbar', 'char', 'break', 'async', '_', 'SELECT', 'FROM', '?', '=>', '===', '/=', '..', '--', "'", '!', '                ', '             ', '       ', '   

### class weights?


In [8]:
reload(dttab)


all_tags = data["train"]["tags"].explode()

tag_counts = dttab.value_counts(all_tags, sort_by="value", as_dict=True)


def class_weights(tag_counts: dict, tag_vocab: list[str], smoothing=0.1):
    tag_weights = [1 / tag_counts.get(k, torch.inf) + smoothing for k in tag_vocab]
    tag_weights = torch.tensor(tag_weights)
    tag_weights /= sum(tag_weights)
    return tag_weights


### Prepare data for model


In [9]:
reload(models_torch)
print("train:")
train_token_tensors = models_torch.seqs2padded_tensor(train_token_idx, device=dev)
train_tag_true_tensors = models_torch.seqs2padded_tensor(train_tag_true_idx, device=dev)
train_tag_det_tensors = models_torch.seqs2padded_tensor(train_tag_det_idx, device=dev)
print("val:")
val_token_tensors = models_torch.seqs2padded_tensor(val_token_idx, device=dev)
val_tag_true_tensors = models_torch.seqs2padded_tensor(val_tag_true_idx, device=dev)
val_tag_det_tensors = models_torch.seqs2padded_tensor(val_tag_det_idx, device=dev)

assert train_token_tensors.shape == train_tag_true_tensors.shape
assert train_tag_det_tensors.shape == train_tag_true_tensors.shape

assert val_token_tensors.shape == val_tag_true_tensors.shape
assert val_tag_det_tensors.shape == val_tag_true_tensors.shape


train:
padded tensor: (120, 155) cuda:0
padded tensor: (120, 155) cuda:0
padded tensor: (120, 155) cuda:0
val:
padded tensor: (53, 85) cuda:0
padded tensor: (53, 85) cuda:0
padded tensor: (53, 85) cuda:0


In [10]:
class SequenceDataset(Dataset):
    def __init__(self, tokens, labels_det, labels_true):
        if not len(tokens) == len(labels_det) == len(labels_true):
            raise ValueError("inconsistent lengths")
        print("dataset", len(tokens))
        self.tokens = tokens
        self.labels_det = labels_det
        self.labels_true = labels_true

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        return self.tokens[idx], self.labels_det[idx], self.labels_true[idx]


# Create dataset and dataloader

if dev == "cpu":
    bs_train = 8
elif dev == "cuda":
    bs_train = 8

train_loader = DataLoader(
    SequenceDataset(train_token_tensors, train_tag_det_tensors, train_tag_true_tensors),
    batch_size=bs_train,
    shuffle=True,
)
val_loader = DataLoader(
    SequenceDataset(val_token_tensors, val_tag_det_tensors, val_tag_true_tensors),
    batch_size=2 * bs_train,
    shuffle=False,
)

dataset 120
dataset 53


## model


In [None]:
{
    "_acc": 0.9502427184466019,
    "_f1_macro": 0.8128383024255225,
    "_train_loss": 0.6775082588195801,
    "_val_loss": 0.7468067556619644,
    "bidi": True,
    "class_weight_smoothing": 10.0,
    "dropout_lstm": 0.7,
    "embedding_dim": 16,
    "hidden_dim": 64,
    "label_smoothing": 0.1,
    "lr_gamma": 0.99,
    "lr_start": 0.05,
    "n_lstm_layers": 2,
}

In [54]:
from timeit import default_timer

from tqdm import tqdm

reload(models_torch)

epochs = 2000
train_time = 5


print("vocab lengths", len(vocab), len(tag_vocab))

# LOSS PARAMS
tag_weights = class_weights(tag_counts, tag_vocab, 10.0)
label_smoothing = 0.1

constructor_params = {
    "token_vocab_size": len(vocab),
    "label_vocab_size": len(tag_vocab),
    "embedding_dim": 16,
    "hidden_dim": 64,
    "n_lstm_layers": 2,
    "dropout_lstm": 0.7,
    "bidi": True,
}

# Model with default params
model = models_torch.LSTMTagger(**constructor_params).to(dev)

loss_function = torch.nn.CrossEntropyLoss(
    weight=tag_weights,
    label_smoothing=label_smoothing,
).to(dev)

optimizer = torch.optim.Adam(model.parameters(), lr=0.05)

lr_s = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

# Training loop

losses_train = []
losses_val = []

t_start = default_timer()
for epoch in tqdm(range(epochs)):
    # TRAINING
    train_loss = models_torch.run_epoch(model, train_loader, loss_function, optimizer)
    losses_train.append(train_loss)

    # VALIDATION
    with torch.no_grad():
        val_loss = models_torch.run_epoch(model, val_loader, loss_function)
        losses_val.append(val_loss)

    lr_s.step()
    # print(lr_s.get_last_lr())
    # print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

    if train_time is not None and default_timer() > t_start + train_time:
        break
plotting.scatter(y=[losses_train, losses_val]).show()

print(
    "final loss:\n", f"  train: {losses_train[-1]:.4f}", f"  val : {losses_val[-1]:.4f}"
)


vocab lengths 130 37


  3%|▎         | 69/2000 [00:05<02:20, 13.71it/s]


final loss:
   train: 0.6930   val : 0.7525


In [55]:
# save model


# save metadata
metadata = {
    "vocab": vocab,
    "tag_vocab": tag_vocab,
    "tag_map": tag_map,
    "split_idx": split_idx,
    "constructor_params": constructor_params,
}

modelname = f"lstm_{split_idx}"
if tag_map is not None:
    modelname += "_mapped"

torch.save(model.state_dict(), f"../models/{modelname}_state.pth")
with open(f"../models/{modelname}_meta.json", "w") as f:
    json.dump(metadata, f)

## evaluate


In [56]:
model.eval()
with torch.no_grad():
    tag_scores = model(val_token_tensors, val_tag_det_tensors)
    predictions = torch.argmax(tag_scores, dim=-1)  # Shape: (batch_size, seq_len)

pred_tags = []
true_tags = []

for pred, true_t in zip(predictions, val_tag_true_idx):
    true_tags.extend([tag_vocab[t] for t in true_t])
    pred_tags.extend([tag_vocab[t] for t in pred[: len(true_t)]])

print(len(true_tags), len(pred_tags))

acc = metrics.accuracy_score(true_tags, pred_tags)
print("accuracy", acc)
f1_macro = metrics.f1_score(true_tags, pred_tags, average="macro")
print("F1_macro", f1_macro)

confmat = metrics.confusion_matrix(true_tags, pred_tags, labels=tag_vocab)

dtplot.heatmap(
    confmat,
    tag_vocab,
    log_scale=True,
    pseudo_count=10,
    size=400,
).show()

1648 1648
accuracy 0.9381067961165048
F1_macro 0.7892996687331592


### eval only on non-det


In [65]:
pred_tags = []
true_tags = []

for pred, true_t, det in zip(
    predictions, val_tag_true_idx, val_tag_det_idx, strict=True
):
    for p, t, d in zip(pred, true_t, det):
        if tag_vocab[d] == "uk":
            tp = tag_vocab[p.item()]
            tt = tag_vocab[t]
            true_tags.append(tt)
            pred_tags.append(tp)
print(len(pred_tags), len(true_tags))

labels_left = sorted(set(pred_tags + true_tags))

acc = metrics.accuracy_score(true_tags, pred_tags)
print("accuracy", acc)
f1_macro = metrics.f1_score(true_tags, pred_tags, average="macro")
print("F1_macro", f1_macro)

confmat = metrics.confusion_matrix(true_tags, pred_tags, labels=labels_left)

dtplot.heatmap(
    confmat,
    labels_left,
    log_scale=True,
    pseudo_count=10,
    size=400,
).update_layout(title="Confusions (log(count + 10))").show()

474 474
accuracy 0.7974683544303798
F1_macro 0.7414364996273485


## check output


In [58]:
# print(os.listdir())

model.eval()
with torch.no_grad():
    tag_scores = model(val_token_tensors, val_tag_det_tensors)
    predictions = torch.argmax(tag_scores, dim=-1)  # Shape: (batch_size, seq_len)

print("predictions", predictions.size())

outputs = {}
for ex, pred in zip(data["val"].iter_rows(named=True), predictions, strict=True):
    pred_tags = []
    for p in pred:
        if p == 0:
            break
        pred_tags.append(tag_vocab[p])
    assert len(ex["tokens"]) == len(pred_tags), "wrong length"
    pred_tags
    print("".join(ex["tokens"]).replace("\n", "\\\\"))
    print(ex["tags"])
    print(pred_tags)
    print()
    outputs[ex["name"] + "_" + ex["lang"]] = {"tokens": ex["tokens"], "tags": pred_tags}


predictions torch.Size([53, 85])
x=1
['va', 'opas', 'nu']
['va', 'opas', 'nu']

[2, 3, 4]
['brop', 'nu', 'pu', 'ws', 'nu', 'pu', 'ws', 'nu', 'brcl']
['brop', 'nu', 'pu', 'ws', 'nu', 'pu', 'ws', 'nu', 'brcl']

say "Hello world"
['kwio', 'ws', 'st']
['va', 'ws', 'st']

if x == 5:\\    print("five")
['kwfl', 'ws', 'va', 'ws', 'opcm', 'ws', 'nu', 'sy', 'nl', 'id', 'fnfr', 'brop', 'st', 'brcl']
['kwfl', 'ws', 'va', 'ws', 'opcm', 'ws', 'nu', 'sy', 'nl', 'id', 'fnfr', 'brop', 'st', 'brcl']

val num = reader.nextInt()
['kwva', 'ws', 'va', 'ws', 'opas', 'ws', 'va', 'sy', 'fnme', 'brop', 'brcl']
['cl', 'ws', 'va', 'ws', 'opas', 'ws', 'mo', 'sy', 'fnme', 'brop', 'brcl']

examples = util.load_examples()
['va', 'ws', 'opas', 'ws', 'mo', 'sy', 'fnas', 'brop', 'brcl']
['va', 'ws', 'opas', 'ws', 'mo', 'sy', 'fnme', 'brop', 'brcl']

val res = if (num % 2 == 0) "yes" else "no"
['kwva', 'ws', 'va', 'ws', 'opas', 'ws', 'kwfl', 'ws', 'brop', 'va', 'ws', 'opbi', 'ws', 'nu', 'ws', 'opcm', 'ws', 'nu', 'brcl',

# parameter search


In [52]:
from coolsearch import search

reload(search)


def objective(
    embedding_dim,
    hidden_dim,
    n_lstm_layers,
    dropout_lstm,
    train_time,
    class_weight_smoothing,
    label_smoothing,
    bidi,
    lr_start,
    lr_gamma,
):
    """Train the model, with some hyperparameters, and evaluate a few metrics"""
    model = models_torch.LSTMTagger(
        len(vocab),
        len(tag_vocab),
        embedding_dim,
        hidden_dim,
        n_lstm_layers,
        dropout_lstm,
        bidi,
    ).to(dev)
    tag_weights = class_weights(tag_counts, tag_vocab, class_weight_smoothing)

    loss_function = torch.nn.CrossEntropyLoss(
        weight=tag_weights,
        label_smoothing=label_smoothing,
    ).to(dev)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr_start)
    lr_s = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=lr_gamma)

    # Training loop

    t_start = default_timer()
    epochs_run = 0
    while default_timer() < t_start + train_time:
        # TRAINING
        models_torch.run_epoch(model, train_loader, loss_function, optimizer)

        lr_s.step()
        epochs_run += 1

    model.eval()
    with torch.no_grad():
        train_loss = models_torch.run_epoch(model, train_loader, loss_function)
        val_loss = models_torch.run_epoch(model, val_loader, loss_function)
        tag_scores = model(val_token_tensors, val_tag_det_tensors)
        predictions = torch.argmax(tag_scores, dim=-1)

    pred_tags = []
    true_tags = []

    for pred, true_t in zip(predictions, val_tag_true_idx):
        true_tags.extend([tag_vocab[t] for t in true_t])
        pred_tags.extend([tag_vocab[t] for t in pred[: len(true_t)]])

    return {
        "_acc": metrics.accuracy_score(true_tags, pred_tags),
        "_f1_macro": metrics.f1_score(true_tags, pred_tags, average="macro"),
        "_train_loss": train_loss,
        "_val_loss": val_loss,
    }


params = {
    "embedding_dim": [16, 18],
    "hidden_dim": [64, 48],
    "n_lstm_layers": [1],
    "dropout_lstm": [0.0],
    "train_time": 10,
    "class_weight_smoothing": [10.0],
    "label_smoothing": [0.1],
    "bidi": True,
    "lr_start": [0.05],
    "lr_gamma": [0.99],
}


cs = search.CoolSearch(
    objective,
    params,
    n_jobs=1,
    samples_file=f"../search/{modelname}.csv",
)
cs.grid_search(3)
res = cs.samples.sort(-pl.col("_acc")).head(10)
display(res)
display(res.row(0, named=True))

no new points!


_acc,_f1_macro,_train_loss,_val_loss,bidi,class_weight_smoothing,dropout_lstm,embedding_dim,hidden_dim,label_smoothing,lr_gamma,lr_start,n_lstm_layers,runtime,train_time
f64,f64,f64,f64,bool,f64,f64,i32,i32,f64,f64,f64,i32,f64,i32
0.950243,0.812838,0.677508,0.746807,True,10.0,0.7,16,64,0.1,0.99,0.05,2,10.076232,10
0.949029,0.806093,0.015774,0.129403,True,10.0,0.7,16,128,0.0,0.99,0.05,2,10.142678,10
0.949029,0.795541,0.001185,0.171534,True,5.0,0.5,23,64,0.0,0.99,0.05,2,10.100151,10
0.949029,0.826327,1.203091,1.251939,True,10.0,0.7,16,64,0.2,0.99,0.05,2,10.107578,10
0.948422,0.812129,0.675323,0.743004,True,10.0,0.5,16,64,0.1,0.99,0.05,2,10.070115,10
0.948422,0.783914,0.680538,0.748764,True,10.0,0.6,18,64,0.1,0.99,0.05,2,10.076619,10
0.948422,0.807029,0.684665,0.739446,True,15.0,0.7,16,64,0.1,0.95,0.05,2,10.088209,10
0.947816,0.789102,0.01454,0.140314,True,10.0,0.7,23,128,0.0,0.99,0.05,2,10.178978,10
0.947816,0.808173,0.677559,0.740194,True,10.0,0.65,16,64,0.1,0.99,0.05,2,10.077678,10
0.947816,0.81793,0.678915,0.743723,True,10.0,0.7,16,80,0.1,0.97,0.05,2,10.079424,10


{'_acc': 0.9502427184466019,
 '_f1_macro': 0.8128383024255225,
 '_train_loss': 0.6775082588195801,
 '_val_loss': 0.7468067556619644,
 'bidi': True,
 'class_weight_smoothing': 10.0,
 'dropout_lstm': 0.7,
 'embedding_dim': 16,
 'hidden_dim': 64,
 'label_smoothing': 0.1,
 'lr_gamma': 0.99,
 'lr_start': 0.05,
 'n_lstm_layers': 2,
 'runtime': 10.076232470999912,
 'train_time': 10}

In [53]:
marg = cs.marginals("_acc")

pars = list(marg.keys())

print(pars)

for k in pars:
    fig = plotting.scatter(x=marg[k][k], y=[marg[k]["max"], marg[k]["mean"]])
    fig.update_layout(width=400, height=200, title=k)
    fig.show()


['embedding_dim', 'hidden_dim', 'n_lstm_layers', 'dropout_lstm', 'train_time', 'class_weight_smoothing', 'label_smoothing', 'bidi', 'lr_start', 'lr_gamma']
