In [1]:
import yaml
import torch
import lightning.pytorch as pl
import polars

from seqpred.data import prep_data, BaseDataset, load_morphers
from seqpred.quantile_morpher import Quantiler, Integerizer
from seqpred.nn import MargeNet

with open("cfg/config.yaml") as f:
    config = yaml.load(f, Loader=yaml.CLoader)

inputs = {col: tp for [col, tp] in config["features"]}

In [2]:
morpher_dispatch = {
    "numeric": Quantiler,
    "categorical": Integerizer,
}

morphers = load_morphers(
    "model/morphers.yaml",
    inputs,
    morpher_dispatch,
)

net = MargeNet.load_from_checkpoint("model/latest.ckpt")
gen_head = net.generator_head

# Set up data
base_data, _ = prep_data(
    data_files=[config["train_data_path"]],
    key_cols=config["keys"],
    morphers=morphers,
)
ds = BaseDataset(
    base_data,
    morphers,
    key_cols=config["keys"],
    return_keys=False,
)

In [5]:
i = 500
n = 256
pitcher_id = base_data.row(i, named=True)["pitcher"]
stand_val = base_data.row(i, named=True)["stand"]

pitcher = ds[i]["pitcher"].unsqueeze(0).expand(n)
stand = ds[i]["stand"].unsqueeze(0).expand(n)

print(f"Context: {pitcher_id} | {stand_val}")

with torch.inference_mode():
    gen_pitches = net.generate(
        {"pitcher": pitcher, "stand": stand},
        morphers=morphers,
    )
    

generated_df = polars.DataFrame(
    {
        feat: v.view(-1).numpy()
        for feat, v in gen_pitches.items()
    }
)

Context: 799 | 0


In [6]:
gen_dist = generated_df.group_by(["pitch_type"]).agg(
    count = polars.col("release_speed").count() / len(generated_df),
    avg_speed = polars.col("release_speed").mean(),
    avg_x = polars.col("plate_x").mean(),
    avg_z = polars.col("plate_z").mean(),
).sort("pitch_type")

pp = base_data.filter(
    polars.col("pitcher") == pitcher_id,
    polars.col("stand") == stand_val,
)
real_dist = pp.group_by(["pitch_type"]).agg(
    count = polars.col("release_speed").count() / len(pp),
    avg_speed = polars.col("release_speed").mean(),
    avg_x = polars.col("plate_x").mean(),
    avg_z = polars.col("plate_z").mean(),
).sort("pitch_type")

display(gen_dist)
display(real_dist)

pitch_type,count,avg_speed,avg_x,avg_z
i64,f64,f32,f32,f32
0,0.003906,0.2578125,0.359375,0.0234375
2,0.46875,0.28457,0.391992,0.310221
4,0.3046875,0.595052,0.415865,0.721454
13,0.222656,0.408169,0.383635,0.469161


pitch_type,count,avg_speed,avg_x,avg_z
i64,f64,f32,f32,f32
2,0.537657,0.290522,0.338947,0.310494
4,0.278243,0.601151,0.406837,0.703125
13,0.1841,0.419389,0.42081,0.403232
