In [1]:
import yaml
import torch
import lightning.pytorch as pl
import polars

from seqpred.data import prep_data, BaseDataset, load_morphers
from seqpred.quantile_morpher import Quantiler, Integerizer
from seqpred.nn import MargeNet

with open("cfg/config.yaml") as f:
    config = yaml.load(f, Loader=yaml.CLoader)

inputs = {col: tp for [col, tp] in config["features"]}

In [2]:
morpher_dispatch = {
    "numeric": Quantiler,
    "categorical": Integerizer,
}

morphers = load_morphers(
    "model/morphers.yaml",
    inputs,
    morpher_dispatch,
)

net = MargeNet.load_from_checkpoint("model/latest.ckpt")
gen_head = net.generator_head

# Set up data
base_data, _ = prep_data(
    data_files=[config["train_data_path"]],
    key_cols=config["keys"],
    morphers=morphers,
)
ds = BaseDataset(
    base_data,
    morphers,
    key_cols=config["keys"],
    return_keys=False,
)

In [8]:
i = 10
n = 256
pitcher_id = base_data.row(i, named=True)["pitcher"]
stand_val = base_data.row(i, named=True)["stand"]

pitcher = ds[i]["pitcher"].unsqueeze(0).expand(n)
stand = ds[i]["stand"].unsqueeze(0).expand(n)

print(f"Context: {pitcher_id} | {stand_val}")

with torch.inference_mode():
    gen_pitches = net.generate(
        {"pitcher": pitcher, "stand": stand},
        morphers=morphers,
    )
    

generated_df = polars.DataFrame(
    {
        feat: v.view(-1).numpy()
        for feat, v in gen_pitches.items()
    }
)

Context: 330 | 1


In [9]:
gen_dist = generated_df.group_by(["pitch_type"]).agg(
    count=polars.col("release_speed").count() / len(generated_df),
    **{
        f"avg_{feat}": polars.col(feat).mean() 
        for feat in morphers if feat not in config["context_features"]
    }
).sort("pitch_type")

pp = base_data.filter(
    polars.col("pitcher") == pitcher_id,
    polars.col("stand") == stand_val,
)
real_dist = pp.group_by(["pitch_type"]).agg(
    count=polars.col("release_speed").count() / len(pp),
    **{
        f"avg_{feat}": polars.col(feat).mean() 
        for feat in morphers if feat not in config["context_features"]
    }
).sort("pitch_type")

display(gen_dist)
display(real_dist)

pitch_type,count,avg_pitch_type,avg_release_speed,avg_pfx_x,avg_pfx_z,avg_plate_x,avg_plate_z
i64,f64,f64,f32,f32,f32,f32,f32
7,0.089844,7.0,0.4375,0.558424,0.246943,0.628397,0.382473
12,0.324219,12.0,0.912556,0.383095,0.924322,0.460655,0.689006
16,0.5859375,16.0,0.290885,0.83724,0.035052,0.52125,0.425677


pitch_type,count,avg_pitch_type,avg_release_speed,avg_pfx_x,avg_pfx_z,avg_plate_x,avg_plate_z
i64,f64,f64,f32,f32,f32,f32,f32
7,0.146402,7.0,0.424391,0.560117,0.250795,0.633475,0.448226
12,0.42928,12.0,0.920114,0.382316,0.926572,0.465544,0.656702
16,0.424318,16.0,0.272432,0.847496,0.035499,0.497761,0.432246
