In [1]:
import json
import sys
from importlib import reload

import polars as pl
import torch
from datatools import plotting as dtplot
from datatools import tabular as dttab
from sklearn import metrics
from torch.utils.data import DataLoader, Dataset

sys.path.append("..")

import plotting
from src import models_torch, text_process, util

reload(plotting)
reload(models_torch)
dtplot.set_plotly_template()

if torch.cuda.is_available():
    dev = "cuda"
else:
    dev = "cpu"

print("device:", dev)

device: cuda


## load data


In [2]:
reload(util)
tag_map = None
data = util.load_examples_json(
    tag_map=tag_map,
    # ["python", "pseudo", "rust"],
    split_idx_id="0116",
)

data = {
    k: df.with_columns(
        tags_det=pl.col("tokens").map_elements(
            lambda tks: text_process.process("".join(tks))[1], pl.List(pl.String)
        )
    )
    for k, df in data.items()
}

for k, df in data.items():
    print(k)
    print(end="\t")
    dttab.value_counts(df["lang"], verbose=True, as_dict=True)


Loaded 206 examples
    train: 120
    test: 36
    val: 50
train
	18 unique (lang):  'python', 'php', 'matlab', 'pseudo', 'rust' ,...
test
	14 unique (lang):  'python', 'php', 'matlab', 'dart', 'r' ,...
val
	14 unique (lang):  'python', 'php', 'dart', 'js', 'matlab' ,...


## most common tokens


In [3]:
VOCAB_TAGS = [
    "kwfl",
    "kwty",
    "kwop",
    "kwmo",
    "kwva",
    "kwde",
    "kwfn",
    "kwim",
    "kwio",
    "id",
    "ws",
    "nl",
    "brop",
    "brcl",
    "sy",
    "pu",
    "bo",
    "li",
    "opcm",
    "opbi",
    "opun",
    "opas",
    "an",
    "uk",
]


def make_vocab(examples: pl.DataFrame, insert=["<pad>", "<unk>"]):
    """Make vocab, and inverse map"""
    vocab_cands = (
        examples.select(pl.col("tokens", "tags").explode())
        .filter(pl.col("tags").is_in(VOCAB_TAGS))
        .group_by("tokens")
        .agg(pl.len().alias("count"))
        .sort("count", "tokens", descending=True)
    )
    vocab = insert + vocab_cands["tokens"].to_list()
    token2idx = {t: i for i, t in enumerate(vocab)}

    return vocab, token2idx


vocab, token2idx = make_vocab(data["train"])
leftover = (
    data["train"]
    .select(pl.col("tokens").explode())
    .filter(pl.col("tokens").is_in(vocab).not_())
    .group_by("tokens")
    .agg(pl.len().alias("count"))
    .sort("count", "tokens", descending=True)
)
leftover = leftover["tokens"].to_list()

print(f"VOCAB ({len(vocab)}):", vocab)
print(f"LEFT  ({len(leftover)}):", leftover)


VOCAB (126): ['<pad>', '<unk>', ' ', '\n', ')', '(', ',', '=', '.', ';', '    ', ']', '[', ':', '{', '}', '\n\n', '        ', '  ', '::', 'if', '*', 'return', '==', '\\', 'for', 'in', '+', 'import', '<', '-', '>>>', '>', 'int', '->', 'use', '|', 'function', 'echo', 'True', '...', 'while', 'from', 'const', 'true', 'public', 'let', 'def', 'as', '%', 'void', 'var', 'new', 'float', 'else', 'class', '            ', '     ', 'usize', 'pub', 'mut', 'mod', 'false', 'double', '?', '<=', '<-', '/', '&', 'val', 'synchronized', 'not', 'instanceof', 'foreach', 'extends', 'end', 'char', '>=', '.^', '..', '+=', '++', '**', "'", '!=', '   ', '\n\n\n', '||', 'xy', 'with', 'typeof', 'tuple', 'tight', 'throws', 'puts', 'private', 'php', 'on', 'of', 'next', 'late', 'is', 'hold', 'fn', 'done', 'do', 'crate', 'close', 'clear', 'clc', 'bool', 'axis', 'async', 'all', '_', '^', 'SELECT', 'False', 'FROM', 'FALSE', '=>', '/=', '--', '*=', '                ', '\t']
LEFT  (564): ['0', '1', 'x', 'i', '2', 'print', 

In [4]:
token_counts = dttab.value_counts(
    data["train"]["tokens"].explode(), verbose=True, as_dict=True
)

# tag counts for all data, closed tag set
tag_counts = dttab.value_counts(
    pl.concat([data["train"], data["val"], data["test"]])["tags"].explode(),
    verbose=True,
    as_dict=True,
)


688 unique (tokens):  ' ', '\n', ')', '(', ',' ,...
35 unique (tags):  'ws', 'va', 'brop', 'brcl', 'nl' ,...


## make a vocab!

- add padding to both tokens and tags
- also, convert tokens and tags to integers


In [5]:
# tags
tag_vocab = ["<pad>", "uk"] + list(tag_counts.keys())
tag2idx = {t: i for i, t in enumerate(tag_vocab)}

print(f"vocab ({len(vocab)} tokens):", vocab)
print(f"vocab ({len(tag_vocab)} tags)  :", tag_vocab)

# Convert tokens and labels to indices
# these are lists of lists!
train_token_idx = [
    [token2idx.get(t, 1) for t in seq] for seq in data["train"]["tokens"]
]
train_tag_true_idx = [[tag2idx[t] for t in seq] for seq in data["train"]["tags"]]
train_tag_det_idx = [
    [tag2idx.get(t, 1) for t in seq] for seq in data["train"]["tags_det"]
]


# validation data
val_token_idx = [[token2idx.get(t, 1) for t in seq] for seq in data["val"]["tokens"]]
val_tag_true_idx = [[tag2idx[t] for t in seq] for seq in data["val"]["tags"]]
val_tag_det_idx = [[tag2idx.get(t, 1) for t in seq] for seq in data["val"]["tags_det"]]


vocab (126 tokens): ['<pad>', '<unk>', ' ', '\n', ')', '(', ',', '=', '.', ';', '    ', ']', '[', ':', '{', '}', '\n\n', '        ', '  ', '::', 'if', '*', 'return', '==', '\\', 'for', 'in', '+', 'import', '<', '-', '>>>', '>', 'int', '->', 'use', '|', 'function', 'echo', 'True', '...', 'while', 'from', 'const', 'true', 'public', 'let', 'def', 'as', '%', 'void', 'var', 'new', 'float', 'else', 'class', '            ', '     ', 'usize', 'pub', 'mut', 'mod', 'false', 'double', '?', '<=', '<-', '/', '&', 'val', 'synchronized', 'not', 'instanceof', 'foreach', 'extends', 'end', 'char', '>=', '.^', '..', '+=', '++', '**', "'", '!=', '   ', '\n\n\n', '||', 'xy', 'with', 'typeof', 'tuple', 'tight', 'throws', 'puts', 'private', 'php', 'on', 'of', 'next', 'late', 'is', 'hold', 'fn', 'done', 'do', 'crate', 'close', 'clear', 'clc', 'bool', 'axis', 'async', 'all', '_', '^', 'SELECT', 'False', 'FROM', 'FALSE', '=>', '/=', '--', '*=', '                ', '\t']
vocab (37 tags)  : ['<pad>', 'uk', 'ws', 

### class weights?


In [6]:
reload(dttab)


all_tags = data["train"]["tags"].explode()

tag_counts = dttab.value_counts(all_tags, sort_by="value", as_dict=True)


def class_weights(tag_counts: dict, tag_vocab: list[str], smoothing=0.1):
    tag_weights = [1 / tag_counts.get(k, torch.inf) + smoothing for k in tag_vocab]
    tag_weights = torch.tensor(tag_weights)
    tag_weights /= sum(tag_weights)
    return tag_weights


### Prepare data for model


In [7]:
reload(models_torch)
print("train:")
train_token_tensors = models_torch.seqs2padded_tensor(train_token_idx, device=dev)
train_tag_true_tensors = models_torch.seqs2padded_tensor(train_tag_true_idx, device=dev)
train_tag_det_tensors = models_torch.seqs2padded_tensor(train_tag_det_idx, device=dev)
print("val:")
val_token_tensors = models_torch.seqs2padded_tensor(val_token_idx, device=dev)
val_tag_true_tensors = models_torch.seqs2padded_tensor(val_tag_true_idx, device=dev)
val_tag_det_tensors = models_torch.seqs2padded_tensor(val_tag_det_idx, device=dev)

assert train_token_tensors.shape == train_tag_true_tensors.shape
assert train_tag_det_tensors.shape == train_tag_true_tensors.shape

assert val_token_tensors.shape == val_tag_true_tensors.shape
assert val_tag_det_tensors.shape == val_tag_true_tensors.shape


train:
padded tensor: (120, 314) cuda:0
padded tensor: (120, 314) cuda:0
padded tensor: (120, 314) cuda:0
val:
padded tensor: (50, 111) cuda:0
padded tensor: (50, 111) cuda:0
padded tensor: (50, 111) cuda:0


In [8]:
class SequenceDataset(Dataset):
    def __init__(self, tokens, labels_det, labels_true):
        if not len(tokens) == len(labels_det) == len(labels_true):
            raise ValueError("inconsistent lengths")
        print("dataset", len(tokens))
        self.tokens = tokens
        self.labels_det = labels_det
        self.labels_true = labels_true

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        return self.tokens[idx], self.labels_det[idx], self.labels_true[idx]


# Create dataset and dataloader

if dev == "cpu":
    bs_train = 8
elif dev == "cuda":
    bs_train = 8

train_loader = DataLoader(
    SequenceDataset(train_token_tensors, train_tag_det_tensors, train_tag_true_tensors),
    batch_size=bs_train,
    shuffle=True,
)
val_loader = DataLoader(
    SequenceDataset(val_token_tensors, val_tag_det_tensors, val_tag_true_tensors),
    batch_size=2 * bs_train,
    shuffle=False,
)

dataset 120
dataset 50


## model


In [9]:
from timeit import default_timer

from tqdm import tqdm

reload(models_torch)

epochs = 2000
train_time = 5

print("vocab lengths", len(vocab), len(tag_vocab))
tag_weights = class_weights(tag_counts, tag_vocab, 10.0)

constructor_params = {
    "token_vocab_size": len(vocab),
    "label_vocab_size": len(tag_vocab),
    "embedding_dim": 16,
    "hidden_dim": 100,
    "n_lstm_layers": 2,
    "dropout_lstm": 0.7,
    "bidi": True,
}

# Model with default params
model = models_torch.LSTMTagger(**constructor_params).to(dev)

loss_function = torch.nn.CrossEntropyLoss(weight=tag_weights).to(dev)

optimizer = torch.optim.Adam(model.parameters(), lr=0.025)

lr_s = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

# Training loop

losses_train = []
losses_val = []

t_start = default_timer()
for epoch in tqdm(range(epochs)):
    # TRAINING
    train_loss = models_torch.run_epoch(model, train_loader, loss_function, optimizer)
    losses_train.append(train_loss)

    # VALIDATION
    with torch.no_grad():
        val_loss = models_torch.run_epoch(model, val_loader, loss_function)
        losses_val.append(val_loss)

    lr_s.step()
    # print(lr_s.get_last_lr())
    # print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

    if train_time is not None and default_timer() > t_start + train_time:
        break
plotting.scatter(y=[losses_train, losses_val]).show()

print(
    "final loss:\n", f"  train: {losses_train[-1]:.4f}", f"  val : {losses_val[-1]:.4f}"
)


vocab lengths 126 37


  1%|▏         | 25/2000 [00:05<06:36,  4.98it/s]


final loss:
   train: 0.0104   val : 0.0897


In [10]:
# save model


# save metadata
metadata = {
    "vocab": vocab,
    "tag_vocab": tag_vocab,
    "tag_map": tag_map,
    "constructor_params": constructor_params,
}
torch.save(model.state_dict(), "../models/MODEL_state.pth")
with open("../models/MODEL_meta.json", "w") as f:
    json.dump(metadata, f)

## evaluate


In [11]:
model.eval()
with torch.no_grad():
    tag_scores = model(val_token_tensors, val_tag_det_tensors)
    predictions = torch.argmax(tag_scores, dim=-1)  # Shape: (batch_size, seq_len)

pred_tags = []
true_tags = []

for pred, true_t in zip(predictions, val_tag_true_idx):
    true_tags.extend([tag_vocab[t] for t in true_t])
    pred_tags.extend([tag_vocab[t] for t in pred[: len(true_t)]])

print(len(true_tags), len(pred_tags))

acc = metrics.accuracy_score(true_tags, pred_tags)
print("accuracy", acc)
f1_macro = metrics.f1_score(true_tags, pred_tags, average="macro")
print("F1_macro", f1_macro)

confmat = metrics.confusion_matrix(true_tags, pred_tags, labels=tag_vocab)

dtplot.heatmap(
    confmat,
    tag_vocab,
    log_scale=True,
    pseudo_count=10,
    size=400,
).show()

1743 1743
accuracy 0.9483648881239243
F1_macro 0.8407378377918766


### eval only on non-det


In [12]:
pred_tags = []
true_tags = []

for pred, true_t, det in zip(
    predictions, val_tag_true_idx, val_tag_det_idx, strict=True
):
    for p, t, d in zip(pred, true_t, det):
        if tag_vocab[d] == "uk":
            tp = tag_vocab[p.item()]
            tt = tag_vocab[t]
            true_tags.append(tt)
            pred_tags.append(tp)
print(len(pred_tags), len(true_tags))

labels_left = sorted(set(pred_tags + true_tags))

acc = metrics.accuracy_score(true_tags, pred_tags)
print("accuracy", acc)
f1_macro = metrics.f1_score(true_tags, pred_tags, average="macro")
print("F1_macro", f1_macro)

confmat = metrics.confusion_matrix(true_tags, pred_tags, labels=labels_left)

dtplot.heatmap(
    confmat,
    labels_left,
    log_scale=True,
    pseudo_count=10,
    size=400,
).show()

509 509
accuracy 0.8467583497053045
F1_macro 0.7779950294814507


## save output


In [13]:
# print(os.listdir())

model.eval()
with torch.no_grad():
    tag_scores = model(val_token_tensors, val_tag_det_tensors)
    predictions = torch.argmax(tag_scores, dim=-1)  # Shape: (batch_size, seq_len)

print("predictions", predictions.size())

outputs = {}
for ex, pred in zip(data["val"].iter_rows(named=True), predictions, strict=True):
    pred_tags = []
    for p in pred:
        if p == 0:
            break
        pred_tags.append(tag_vocab[p])
    assert len(ex["tokens"]) == len(pred_tags), "wrong length"
    pred_tags
    print("".join(ex["tokens"]).replace("\n", "\\\\"))
    print(ex["tags"])
    print(pred_tags)
    print()
    outputs[ex["name"] + "_" + ex["lang"]] = {"tokens": ex["tokens"], "tags": pred_tags}
with open("../output/pred_output.json", "w", encoding="utf-8") as f:
    json.dump(outputs, f)

predictions torch.Size([50, 111])
x = 337\\y = 99\\z = x + y
['va', 'ws', 'opas', 'ws', 'nu', 'nl', 'va', 'ws', 'opas', 'ws', 'nu', 'nl', 'va', 'ws', 'opas', 'ws', 'va', 'ws', 'opbi', 'ws', 'va']
['va', 'ws', 'opas', 'ws', 'nu', 'nl', 'va', 'ws', 'opas', 'ws', 'nu', 'nl', 'va', 'ws', 'opas', 'ws', 'va', 'ws', 'opbi', 'ws', 'va']

var zz, xy int = 11, 33
['kwva', 'ws', 'va', 'pu', 'ws', 'va', 'ws', 'kwty', 'ws', 'opas', 'ws', 'nu', 'pu', 'ws', 'nu']
['kwva', 'ws', 'va', 'pu', 'ws', 'kwty', 'ws', 'kwty', 'ws', 'opas', 'ws', 'nu', 'pu', 'ws', 'nu']

x = True\\y = False\\\\z = x and y
['va', 'ws', 'opas', 'ws', 'bo', 'nl', 'va', 'ws', 'opas', 'ws', 'bo', 'nl', 'va', 'ws', 'opas', 'ws', 'va', 'ws', 'kwop', 'ws', 'va']
['va', 'ws', 'opas', 'ws', 'bo', 'nl', 'va', 'ws', 'opas', 'ws', 'bo', 'nl', 'va', 'ws', 'opas', 'ws', 'va', 'ws', 'va', 'ws', 'va']

a = 30\\b = a*76\\\\k = 112 ; n = a*12
['va', 'ws', 'opas', 'ws', 'nu', 'nl', 'va', 'ws', 'opas', 'ws', 'va', 'opbi', 'nu', 'nl', 'va', 'ws', '

# parameter search


In [15]:
from coolsearch import search

reload(search)


def objective(
    embedding_dim,
    hidden_dim,
    n_lstm_layers,
    dropout_lstm,
    train_time,
    class_weight_smoothing,
    bidi,
    lr_start,
    lr_gamma,
):
    model = models_torch.LSTMTagger(
        len(vocab),
        len(tag_vocab),
        embedding_dim,
        hidden_dim,
        n_lstm_layers,
        dropout_lstm,
        bidi,
    ).to(dev)
    tag_weights = class_weights(tag_counts, tag_vocab, class_weight_smoothing)

    loss_function = torch.nn.CrossEntropyLoss(weight=tag_weights).to(dev)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr_start)
    lr_s = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=lr_gamma)

    # Training loop

    t_start = default_timer()
    epochs_run = 0
    while default_timer() < t_start + train_time:
        # TRAINING
        models_torch.run_epoch(model, train_loader, loss_function, optimizer)

        lr_s.step()
        epochs_run += 1

    model.eval()
    with torch.no_grad():
        train_loss = models_torch.run_epoch(model, train_loader, loss_function)
        val_loss = models_torch.run_epoch(model, val_loader, loss_function)
        tag_scores = model(val_token_tensors, val_tag_det_tensors)
        predictions = torch.argmax(tag_scores, dim=-1)

    pred_tags = []
    true_tags = []

    for pred, true_t in zip(predictions, val_tag_true_idx):
        true_tags.extend([tag_vocab[t] for t in true_t])
        pred_tags.extend([tag_vocab[t] for t in pred[: len(true_t)]])

    return {
        "_acc": metrics.accuracy_score(true_tags, pred_tags),
        "_f1_macro": metrics.f1_score(true_tags, pred_tags, average="macro"),
        "_train_loss": train_loss,
        "_val_loss": val_loss,
    }


params = {
    "embedding_dim": [16, 24, 32],
    "hidden_dim": [64, 128],
    "n_lstm_layers": [2],
    "dropout_lstm": [0.3, 0.5, 0.7],
    "train_time": 10,
    "class_weight_smoothing": [10.0, 20.0],
    "bidi": True,
    "lr_start": [0.05, 0.01],
    "lr_gamma": [0.99],
}


cs = search.CoolSearch(
    objective,
    params,
    n_jobs=1,
    samples_file="../search/lstm_0116.csv",
)
cs.grid_search(3)
res = cs.samples.sort(-pl.col("_f1_macro")).head(10)
display(res)
display(res.row(0, named=True))

Searching 72 new parameter points


100%|██████████| 72/72 [12:13<00:00, 10.18s/it]

Sum of runtime: 733.04 s. Elapsed time 733.10 s.
Overhead: 0.0552 s.





_acc,_f1_macro,_train_loss,_val_loss,bidi,class_weight_smoothing,dropout_lstm,embedding_dim,hidden_dim,lr_gamma,lr_start,n_lstm_layers,runtime,train_time
f64,f64,f64,f64,bool,f64,f64,i32,i32,f64,f64,i32,f64,i32
0.959839,0.859647,0.000105,0.083483,True,10.0,0.7,32,128,0.99,0.01,2,10.300203,10
0.957544,0.857774,5.8e-05,0.089083,True,20.0,0.3,32,64,0.99,0.01,2,10.090659,10
0.960413,0.857146,4e-05,0.082477,True,20.0,0.5,32,64,0.99,0.01,2,10.097066,10
0.956397,0.85695,6.7e-05,0.082223,True,20.0,0.3,16,64,0.99,0.01,2,10.119179,10
0.956397,0.856522,0.000197,0.095253,True,10.0,0.7,24,64,0.99,0.01,2,10.111013,10
0.956971,0.856358,0.000359,0.078814,True,20.0,0.7,32,128,0.99,0.01,2,10.295966,10
0.958692,0.855137,9.1e-05,0.083989,True,10.0,0.5,32,128,0.99,0.01,2,10.322374,10
0.952381,0.855093,6.2e-05,0.082385,True,20.0,0.5,16,64,0.99,0.01,2,10.11508,10
0.958118,0.854376,7.3e-05,0.08803,True,10.0,0.5,24,128,0.99,0.01,2,10.280692,10
0.958118,0.85174,6.4e-05,0.075031,True,10.0,0.5,32,64,0.99,0.01,2,10.123544,10


{'_acc': 0.9598393574297188,
 '_f1_macro': 0.8596474686539787,
 '_train_loss': 0.00010493211108647908,
 '_val_loss': 0.08348282799124718,
 'bidi': True,
 'class_weight_smoothing': 10.0,
 'dropout_lstm': 0.7,
 'embedding_dim': 32,
 'hidden_dim': 128,
 'lr_gamma': 0.99,
 'lr_start': 0.01,
 'n_lstm_layers': 2,
 'runtime': 10.300203221999254,
 'train_time': 10}

In [16]:
marg = cs.marginals("_acc")

pars = list(marg.keys())

print(pars)

for k in pars:
    fig = plotting.scatter(x=marg[k][k], y=[marg[k]["max"], marg[k]["mean"]])
    fig.update_layout(width=400, height=200, title=k)
    fig.show()


['embedding_dim', 'hidden_dim', 'n_lstm_layers', 'dropout_lstm', 'train_time', 'class_weight_smoothing', 'bidi', 'lr_start', 'lr_gamma']
