## Extra Features

LSTM tagger that includes extra features, generated from the input text


In [None]:
import sys
from importlib import reload

from plotly import graph_objects as go

sys.path.append("..")
from datatools import plotting

from src import torch_util, util

plotting.set_plotly_template()

data = util.load_examples_json(split_idx_id="0301")
display(data["train"].head(5))
vocab, token2idx, tag_vocab, tag2idx = util.make_vocab(data["train"])
print(f"vocab: {len(vocab)} tokens | tag_vocab: {len(tag_vocab)} tags")

## make features


In [None]:
reload(torch_util)
print(torch_util.make_extra_feats(["Sys", "print", "\n", "9", "\n", "=="]))
print(torch_util.make_extra_feats(["print", "(", ")"]))

## a model


In [None]:
reload(torch_util)
NF = 3
BS = 4

constr_params = {
    "embedding_dim": 16,
    "hidden_dim": 64,
    "n_lstm_layers": 2,
    "dropout_lstm": 0.3,
    "bidi": True,
    "token_vocab_size": len(vocab),
    "label_vocab_size": len(tag_vocab),
}


print("\nModels:")
model_base = torch_util.LSTMTagger(**constr_params)
model_feats = torch_util.LSTMTagger(**constr_params, n_extra=NF)

dls_base = {
    k: torch_util.data2torch(d, BS, token2idx, tag2idx) for k, d in data.items()
}

print("data shape (base):")
ex = next(iter(dls_base["train"]))
for k, t in ex[0].items():
    print(f"  - {k}", t.shape)
print("  -> out", model_base(**ex[0]).shape)

dls_feats = {
    k: torch_util.data2torch(d, BS, token2idx, tag2idx, extra_feats=NF)
    for k, d in data.items()
}
print("\ndata shape (w. feats):")
ex = next(iter(dls_feats["train"]))
for k, t in ex[0].items():
    print(f"  - {k}", t.shape)
print("  -> out", model_feats(**ex[0]).shape)


## train and see


In [None]:
N_epoch = 5  # ca 40+ to converge
# lrs = torch.optim.lr_scheduler.CosineAnnealingLR()
print("BASE MODEL")
losses_base = torch_util.train_loop(
    model_base,
    dls_base["train"],
    dls_base["val"],
    epochs=N_epoch,
    reduce_lr_on_plat={"factor": 0.75, "patience": 5},
)
print("FEATS MODEL")
losses_feats = torch_util.train_loop(
    model_feats,
    dls_feats["train"],
    dls_feats["val"],
    epochs=N_epoch,
    reduce_lr_on_plat={"factor": 0.75, "patience": 5},
)

In [None]:
for k in losses_base.keys():
    go.Figure(
        [
            go.Scatter(y=losses_base[k], name="base"),
            go.Scatter(y=losses_feats[k], name="feats"),
        ],
    ).show()
