In [2]:
import yaml
import torch
import lightning.pytorch as pl
import polars

from seqpred.data import prep_data, BaseDataset, load_morphers
from seqpred.quantile_morpher import Quantiler, Integerizer
from seqpred.nn import MargeNet

with open("cfg/config.yaml") as f:
    config = yaml.load(f, Loader=yaml.CLoader)

inputs = {
    col: tp
    for [col, tp] in config["features"]
}

In [16]:
def choose_options(p):
    # p is n x k
    agg_p = p.cumsum(dim=1)
    rand = torch.rand(agg_p.shape[0], 1).to(agg_p)
    p_arrays = torch.cat([agg_p, rand], dim=1)
    # n x 1
    ranks = torch.argsort(torch.argsort(p_arrays, dim=-1), dim=-1)
    choices = ranks[:, -1]
    return choices

In [4]:
morpher_dispatch = {
    "numeric": Quantiler,
    "categorical": Integerizer,
}

morphers = load_morphers(
    "model/morphers.yaml",
    inputs,
    morpher_dispatch,
)

net = MargeNet.load_from_checkpoint("model/latest.ckpt")
gen_head = net.generator_head

# Set up data
base_data, _ = prep_data(
    data_files=[config["train_data_path"]],
    key_cols=config["keys"],
    morphers=morphers,
)
ds = BaseDataset(
    base_data,
    morphers,
    key_cols=config["keys"],
    return_keys=False,
)

In [53]:
i = 100_000
n = 256
pitcher_id = base_data.row(i, named=True)["pitcher"]
pitcher = ds[i]["pitcher"].unsqueeze(0).expand(n)

print(pitcher_id)

with torch.inference_mode():
    # Generate the pitch
    pitcher_embed = net.init_embedder(pitcher)
    x = gen_head.activation(gen_head.norm(pitcher_embed))

    pitch_dist = torch.softmax(gen_head.predictors["pitch_type"](x), dim=-1)
    pitches = choose_options(pitch_dist)
    pitch_embeddings = gen_head.embedders["pitch_type"](pitches)

    combined_embeddings = pitcher_embed + pitch_embeddings
    x = gen_head.activation(gen_head.norm(combined_embeddings))
    velo_dist = torch.softmax(gen_head.predictors["release_speed"](x), dim=-1)
    velo = choose_options(velo_dist)

generated_df = polars.DataFrame(
    {
        "pitch": pitches.view(-1).numpy(),
        "velo": velo.view(-1).numpy() / 128,
    }
)

375


In [54]:
gen_dist = generated_df.group_by(["pitch"]).agg(
    count = polars.col("velo").count() / len(generated_df),
    avg_speed = polars.col("velo").mean(),
    min_speed = polars.col("velo").min(),
    max_speed = polars.col("velo").max(),
).sort("pitch")

pp = base_data.filter(polars.col("pitcher") == pitcher_id)
real_dist = pp.group_by(["pitch_type"]).agg(
    count = polars.col("release_speed").count() / len(pp),
    avg_speed = polars.col("release_speed").mean(),
    min_speed = polars.col("release_speed").min(),
    max_speed = polars.col("release_speed").max(),
).sort("pitch_type")

display(gen_dist)
display(real_dist)

pitch,count,avg_speed,min_speed,max_speed
i64,f64,f64,f64,f64
1,0.316406,0.800733,0.03125,0.9765625
4,0.125,0.752441,0.015625,0.9765625
7,0.128906,0.083333,0.0078125,0.9609375
10,0.2421875,0.467868,0.015625,0.9765625
16,0.1875,0.246582,0.0078125,0.9765625


pitch_type,count,avg_speed,min_speed,max_speed
i64,f64,f32,f32,f32
1,0.347468,0.818495,0.609375,0.96875
4,0.107993,0.791578,0.578125,0.9375
7,0.11989,0.160862,0.046875,0.3203125
10,0.211104,0.434745,0.296875,0.5859375
16,0.213545,0.245547,0.0703125,0.453125
