In [1]:
import yaml
import torch
import lightning.pytorch as pl
import polars

from seqpred.data import prep_data, BaseDataset, load_morphers
from seqpred.special_morpher import Quantiler, Integerizer, MixtureLossNormalizer
from seqpred.nn import MargeNet

with open("cfg/config.yaml") as f:
    config = yaml.load(f, Loader=yaml.CLoader)

inputs = {col: tp for [col, tp] in config["features"]}

In [2]:
morpher_dispatch = {
    "numeric": Quantiler,
    "categorical": Integerizer,
}

morphers = load_morphers(
    "model/morphers.yaml",
    inputs,
    morpher_dispatch,
)

net = MargeNet.load_from_checkpoint("model/latest.ckpt")
gen_head = net.generator_head

# Set up data
base_data, _ = prep_data(
    data_files=[config["train_data_path"]],
    key_cols=config["keys"],
    morphers=morphers,
)
ds = BaseDataset(
    base_data,
    morphers,
    key_cols=config["keys"],
    return_keys=False,
)

In [3]:
i = 800
n = 256
pitcher_id = base_data.row(i, named=True)["pitcher"]
stand_val = base_data.row(i, named=True)["stand"]

pitcher = ds[i]["pitcher"].unsqueeze(0).expand(n)
stand = ds[i]["stand"].unsqueeze(0).expand(n)

print(f"Context: {pitcher_id} | {stand_val}")

with torch.inference_mode():
    gen_pitches = net.generate(
        {"pitcher": pitcher, "stand": stand},
        temperature=1.0,
    )
    
generated_df = polars.DataFrame(
    {
        feat: v.view(-1).numpy()
        for feat, v in gen_pitches.items()
    }
)

Context: 1202 | 0


In [4]:
gen_dist = generated_df.group_by(["pitch_type"]).agg(
    count=polars.col("release_speed").count() / len(generated_df),
    **{
        f"avg_{feat}": polars.col(feat).mean() 
        for feat in morphers if feat not in config["context_features"]
    }
).sort("pitch_type")

pp = base_data.filter(
    polars.col("pitcher") == pitcher_id,
    polars.col("stand") == stand_val,
)
real_dist = pp.group_by(["pitch_type"]).agg(
    count=polars.col("release_speed").count() / len(pp),
    **{
        f"avg_{feat}": polars.col(feat).mean() 
        for feat in morphers if feat not in config["context_features"]
    }
).sort("pitch_type")

display(gen_dist)
display(real_dist)

pitch_type,count,avg_pitch_type,avg_release_speed,avg_pfx_x,avg_pfx_z,avg_plate_x,avg_plate_z
i64,f64,f64,f32,f32,f32,f32,f32
1,0.175781,1.0,0.336458,0.091319,0.336458,0.349479,0.319792
7,0.503906,7.0,0.722565,0.37718,0.794937,0.48274,0.65298
8,0.105469,8.0,0.054398,0.872975,0.061632,0.542245,0.324074
10,0.1640625,10.0,0.209635,0.84654,0.429315,0.62314,0.34933
12,0.050781,12.0,0.655048,0.105168,0.46875,0.30649,0.433293


pitch_type,count,avg_pitch_type,avg_release_speed,avg_pfx_x,avg_pfx_z,avg_plate_x,avg_plate_z
i64,f64,f64,f32,f32,f32,f32,f32
1,0.214363,1.0,0.341926,0.091569,0.300722,0.367663,0.346804
7,0.460283,7.0,0.715407,0.378712,0.777962,0.459035,0.646554
8,0.092492,8.0,0.051287,0.874632,0.061857,0.532812,0.379687
10,0.150163,10.0,0.202389,0.849864,0.402683,0.506737,0.356261
12,0.082699,12.0,0.676398,0.09817,0.462788,0.285156,0.404811


In [5]:
gen_dist = generated_df.group_by(["pitch_type"]).agg(
    count=polars.col("release_speed").count() / len(generated_df),
    avg_speed=polars.col("release_speed").mean(),
    min_speed=polars.col("release_speed").min(),
    speed_25=polars.col("release_speed").quantile(0.25),
    speed_75=polars.col("release_speed").quantile(0.75),
    max_speed=polars.col("release_speed").max(),
).sort("pitch_type")

real_dist = pp.group_by(["pitch_type"]).agg(
    count=polars.col("release_speed").count() / len(pp),
    avg_speed=polars.col("release_speed").mean(),
    min_speed=polars.col("release_speed").min(),
    speed_25=polars.col("release_speed").quantile(0.25),
    speed_75=polars.col("release_speed").quantile(0.75),
    max_speed=polars.col("release_speed").max(),
).sort("pitch_type")

display(gen_dist)
display(real_dist)

pitch_type,count,avg_speed,min_speed,speed_25,speed_75,max_speed
i64,f64,f32,f32,f32,f32,f32
1,0.175781,0.336458,0.1484375,0.3125,0.375,0.4453125
7,0.503906,0.722565,0.546875,0.6640625,0.765625,0.921875
8,0.105469,0.054398,0.015625,0.0390625,0.0703125,0.1484375
10,0.1640625,0.209635,0.0703125,0.1796875,0.25,0.3046875
12,0.050781,0.655048,0.484375,0.640625,0.6953125,0.7109375


pitch_type,count,avg_speed,min_speed,speed_25,speed_75,max_speed
i64,f64,f32,f32,f32,f32,f32
1,0.214363,0.341926,0.1796875,0.3046875,0.3828125,0.5234375
7,0.460283,0.715407,0.5546875,0.671875,0.7578125,0.8828125
8,0.092492,0.051287,0.015625,0.0390625,0.0625,0.1171875
10,0.150163,0.202389,0.046875,0.1640625,0.2421875,0.3359375
12,0.082699,0.676398,0.5078125,0.640625,0.71875,0.8046875
