In [14]:
import yaml
import torch
import lightning.pytorch as pl
import polars

from seqpred.data import prep_data, BaseDataset, load_morphers
from seqpred.quantile_morpher import Quantiler, Integerizer
from seqpred.nn import MargeNet

with open("cfg/config.yaml") as f:
    config = yaml.load(f, Loader=yaml.CLoader)

inputs = {
    col: tp
    for [col, tp] in config["features"]
}

In [2]:
morpher_dispatch = {
    "numeric": Quantiler,
    "categorical": Integerizer,
}

morphers = load_morphers(
    "model/morphers.yaml",
    inputs,
    morpher_dispatch,
)

net = MargeNet.load_from_checkpoint("model/latest.ckpt")
gen_head = net.generator_head

# Set up data
base_data, _ = prep_data(
    data_files=[config["train_data_path"]],
    key_cols=config["keys"],
    morphers=morphers,
)
ds = BaseDataset(
    base_data,
    morphers,
    key_cols=config["keys"],
    return_keys=False,
)

In [36]:
i = 10000
pitcher_id = base_data.row(i, named=True)["pitcher"]
pitcher = ds[i]["pitcher"].unsqueeze(0)

print(pitcher_id)

with torch.inference_mode():
    # Generate the pitch
    pitcher_embed = net.init_embedder(pitcher)
    x = gen_head.activation(gen_head.norm(pitcher_embed))
    pitch_dist = torch.softmax(gen_head.predictors["pitch_type"](x), dim=-1)

76


In [37]:
from collections import Counter
n = 200
pitches = []
for _ in range(n):
    pitch = torch.multinomial(pitch_dist.view(-1), 1)
    pitches.append(pitch.item())

generated_pitches = dict(Counter(pitches))
print({k: v / n for k, v in generated_pitches.items()})

{14: 0.21, 9: 0.305, 7: 0.115, 1: 0.32, 15: 0.05}


In [38]:
pp = base_data.filter(polars.col("pitcher") == pitcher_id)["pitch_type"]
ratios = pp.value_counts()
ratios.with_columns(count=polars.col("count") / pp.count())

pitch_type,count
i64,f64
14,0.176471
1,0.365472
9,0.280543
15,0.034807
11,0.000348
7,0.14236
