In [None]:
from importlib import reload
from pathlib import Path

import numpy as np
import polars as pl
from plotly import graph_objects as go

from src import data_util
from src import plotlyplot as pp
from src.evt_seq import torch_model as tm
from src.evt_seq import util

reload(pp)
reload(data_util)

pp.set_plotly_template()

data_conf = data_util.load_config()
print(data_conf)
df, evt_types = data_util.load_data(data_conf)

# all event types, sorted by frequency
evt_types_w_count: list[tuple[str, int]] = list(
    df.group_by("type_name").agg(pl.len()).sort("len", descending=True).iter_rows()
)

display(df.tail())


In [None]:
reload(util)

data = util.EvtSeqData(
    df,
    max_gap_mins=15,
    blksz=5,
    seed=1337,
    vocab=evt_types,
)

splits = data.as_integers()

reload(tm)
dsets = {k: tm.EventCBOWDataset(s, data.blksz) for k, s in splits.items()}
for k, d in dsets.items():
    print(f"{k}: {d}")


In [None]:
c = len(data.vocab)
print(f"Expected worst CE loss? {c=} -> {np.log(c)=:.3f}")

## Training


In [None]:
from torch.utils.data import DataLoader

model = tm.CBOWFFModel(len(data.vocab), emb_dim=3, block_size=data.blksz)

loss_train, loss_val, lr_hist = tm.train_cbow_ff(
    model,
    DataLoader(dsets["train"], batch_size=128),
    DataLoader(dsets["val"], batch_size=len(dsets["val"])),  # all at once
    label_smoothing=0.01,
    start_lr=1e-3,
    lrs_patience=10,
    stop_patience=30,
)

go.Figure(
    [
        go.Scatter(y=loss_train, name="train"),
        go.Scatter(y=loss_val, name="val"),
    ]
).show()
go.Figure(
    go.Scatter(y=lr_hist, name="LR"),
)

In [None]:
# extract learned embeddings from model
embs = model.emb.weight.cpu().detach().numpy()
print(f"{embs.shape=}")
fp = Path("tmp_data/embs_cbow.npz")
np.savez(fp, embs=embs, vocab=data.vocab)