In [1]:
import polars as pl

from src.processing import Stringifier, TimeTokenizer, make_vocabulary

In [2]:
data_path = "./data/raw_data.parquet"


# Config info ----------------------------
sequence_keys = ["game_pk", "player_name", "at_bat_number"]
order_columns = ["pitch_number"]
time_column = "inning"

feature_column_metadata = {
    "pitch_name": {
        "data_type": "categorical",
        "group": "pitch",
        "explicit_missing": True,
    },
    "release_speed": {
        "data_type": "numeric",
        "group": "velo",
        "explicit_missing": True,
    },
    "plate_x": {
        "data_type": "numeric",
        "group": "plate_x",
        "explicit_missing": True,
    },
    "plate_z": {
        "data_type": "numeric",
        "group": "plate_z",
        "explicit_missing": True,
    },
    "description": {
        "data_type": "categorical",
        "group": "description",
        "explicit_missing": True,
    },
    "events": {
        "data_type": "categorical",
        "group": "events",
        "explicit_missing": False,
    },
}

keyword_args = {
    "n_buckets": 32,
}
# -----------------------------------------

df = (
    pl.read_parquet(data_path)
    .filter(pl.col("game_type").is_in(["R", "F", "D", "L", "W"]))
    .select(
        # Keys and order
        *sequence_keys,
        *order_columns,
        # I need to transform the time column, so...
        (
            pl.col("inning")
            + pl.when(pl.col("inning_topbot") == "Top").then(0).otherwise(0.5)
        ).alias("inning"),
        # features
        *[column for column in feature_column_metadata],
    )
)

# A little more setup ----------------------------------

# This is nasty I'd rather not do it.
df = df.with_columns(
    (pl.col(time_column) - pl.col(time_column).shift(1))
    .over(partition_by=sequence_keys, order_by=order_columns)
    .alias("time_diffs")
)

time_tokenizer = TimeTokenizer.from_data(df["time_diffs"])

stringifiers = {
    column: Stringifier.from_data(df[column], **col_args, kwargs=keyword_args)
    for column, col_args in feature_column_metadata.items()
}

complete_vocab = make_vocabulary(
    stringifiers.values(), time_tokenizer, special_tokens=["<SOS>", "<EOS>"]
)
print(f"Vocab size: {len(complete_vocab)}")

df = (
    df.select(
        *sequence_keys,
        *order_columns,
        time_tokenizer.transform(pl.col("time_diffs")).alias("time_diffs"),
        *[s.transform(pl.col(n)).alias(n) for n, s in stringifiers.items()],
    )
    .with_columns(
        pl.concat_list("time_diffs", *[pl.col(n) for n in stringifiers])
        .list.drop_nulls()
        .alias("feature_list")
    )
    .select(*sequence_keys, *order_columns, "feature_list")
    .explode("feature_list")
    .with_columns(
        pl.col("feature_list")
        .replace(complete_vocab)
        .cast(pl.Int64)
        .alias("processed_list"),
    )
)

df = (
    df.sort(*order_columns)
    .group_by(*sequence_keys)  # Within-group order is always kept
    .agg("processed_list", "feature_list")
)

# Add SOS and EOS I guess?
df = df.with_columns(
    pl.concat_list(
        pl.lit(complete_vocab["<SOS>"]),
        pl.col("processed_list"),
        pl.lit(complete_vocab["<EOS>"]),
    ).alias("processed_list"),
    pl.col("processed_list").list.len().alias("sequence_length"),
)

print(f"Max length: {df["sequence_length"].max()}")
print(df)

Vocab size: 166
Max length: 96
shape: (1_004_698, 6)
┌─────────┬──────────────────┬───────────────┬────────────────┬──────────────────┬─────────────────┐
│ game_pk ┆ player_name      ┆ at_bat_number ┆ processed_list ┆ feature_list     ┆ sequence_length │
│ ---     ┆ ---              ┆ ---           ┆ ---            ┆ ---              ┆ ---             │
│ i64     ┆ str              ┆ i64           ┆ list[i64]      ┆ list[str]        ┆ u32             │
╞═════════╪══════════════════╪═══════════════╪════════════════╪══════════════════╪═════════════════╡
│ 662468  ┆ Farmer, Buck     ┆ 62            ┆ [1, 59, … 0]   ┆ ["pitch =        ┆ 21              │
│         ┆                  ┆               ┆                ┆ Slider", "velo = ┆                 │
│         ┆                  ┆               ┆                ┆ (83…             ┆                 │
│ 633843  ┆ Gilbert, Logan   ┆ 38            ┆ [1, 52, … 0]   ┆ ["pitch =        ┆ 31              │
│         ┆                  ┆        

In [10]:
a = pl.read_parquet(data_path)

label_mapping = {
    "catcher_interf": 0,
    "double": 0,
    "sac_bunt": 1,
    "triple_play": 1,
    "fielders_choice_out": 1,
    "fielders_choice": 1,
    "hit_by_pitch": 0,
    "truncated_pa": 0,
    "sac_fly": 1,
    "field_out": 1,
    "strikeout": 1,
    "single": 0,
    "double_play": 1,
    "triple": 0,
    "force_out": 1,
    "field_error": 0,
    "grounded_into_double_play": 1,
    "home_run": 0,
    "walk": 0,
    "sac_fly_double_play": 1,
    "None": 0,
    "strikeout_double_play": 1,
    "intent_walk": 0,
    "sac_bunt_double_play": 1,
}

a = a.with_columns(
    pl.col("events").replace_strict(label_mapping, return_dtype=pl.Int64)
)

a = a.with_columns(
    pl.col("events").max().over(partition_by=["game_pk", "at_bat_number"]).alias("is_out")
)

a.select(
    "game_pk",
    "at_bat_number",
    "pitch_number",
    "is_out",
)

game_pk,at_bat_number,pitch_number,is_out
i64,i64,i64,i64
775296,89,4,1
775296,89,3,1
775296,89,2,1
775296,89,1,1
775296,88,7,1
…,…,…,…
564927,3,3,1
564927,3,2,1
564927,3,1,1
564927,2,1,1


In [8]:
pl.read_parquet(data_path).filter(pl.col("events") == "sac_bunt_double_play")

pitch_type,game_date,release_speed,release_pos_x,release_pos_z,player_name,batter,pitcher,events,description,spin_dir,spin_rate_deprecated,break_angle_deprecated,break_length_deprecated,zone,des,game_type,stand,p_throws,home_team,away_team,type,hit_location,bb_type,balls,strikes,game_year,pfx_x,pfx_z,plate_x,plate_z,on_3b,on_2b,on_1b,outs_when_up,inning,inning_topbot,…,post_away_score,post_home_score,post_bat_score,post_fld_score,if_fielding_alignment,of_fielding_alignment,spin_axis,delta_home_win_exp,delta_run_exp,bat_speed,swing_length,estimated_slg_using_speedangle,delta_pitcher_run_exp,hyper_speed,home_score_diff,bat_score_diff,home_win_exp,bat_win_exp,age_pit_legacy,age_bat_legacy,age_pit,age_bat,n_thruorder_pitcher,n_priorpa_thisgame_player_at_bat,pitcher_days_since_prev_game,batter_days_since_prev_game,pitcher_days_until_next_game,batter_days_until_next_game,api_break_z_with_gravity,api_break_x_arm,api_break_x_batter_in,arm_angle,attack_angle,attack_direction,swing_path_tilt,intercept_ball_minus_batter_pos_x_inches,intercept_ball_minus_batter_pos_y_inches
str,datetime[ns],f64,f64,f64,str,i64,i64,str,str,i64,i64,i64,i64,i64,str,str,str,str,str,str,str,i64,str,i64,i64,i64,f64,f64,f64,f64,i64,i64,i64,i64,i64,str,…,i64,i64,i64,i64,str,str,i64,f64,f64,f64,f64,f64,f64,f64,i64,i64,f64,f64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""SI""",2022-10-23 00:00:00,95.3,-0.83,6.37,"""Holmes, Clay""",665161,605280,"""sac_bunt_double_play""","""hit_into_play""",,,,,11,"""Jeremy Pena ground bunts into …","""L""","""R""","""R""","""NYY""","""HOU""","""X""",1,"""ground_ball""",0,0,2022,-1.68,0.42,-1.54,3.12,,,514888,0,9,"""Top""",…,6,5,6,5,"""Standard""","""Standard""",231,0.053,0.015,,,0.296,-0.015,88.0,-1,1,0.140575,0.859425,29,24,29,25,1,4,5.0,1.0,,5,2.11,1.68,1.68,43.4,,,,,
"""FF""",2021-08-29 00:00:00,90.2,-1.99,5.22,"""Mahle, Tyler""",666200,641816,"""sac_bunt_double_play""","""hit_into_play""",,,,,1,"""Jesus Luzardo ground bunts int…","""R""","""L""","""R""","""MIA""","""CIN""","""X""",5,"""ground_ball""",0,0,2021,-1.18,1.31,-0.49,2.87,,,642423,0,5,"""Bot""",…,0,2,2,0,"""Strategic""","""Strategic""",226,-0.048,-0.109,,,0.194,0.109,88.0,2,2,0.834329,0.834329,26,23,27,24,2,1,5.0,5.0,6.0,5,1.51,1.18,-1.18,35.8,,,,,
"""SI""",2021-04-30 00:00:00,90.6,-2.79,6.02,"""Gant, John""",664141,607231,"""sac_bunt_double_play""","""hit_into_play""",,,,,1,"""JT Brubaker ground bunts into …","""R""","""R""","""R""","""PIT""","""STL""","""X""",1,"""ground_ball""",0,1,2021,-1.25,1.19,-0.31,3.0,,,664789,0,3,"""Bot""",…,2,0,0,2,"""Strategic""","""Strategic""",224,-0.09,-0.072,,,0.242,0.072,88.0,-2,-2,0.372098,0.372098,28,27,29,28,1,0,6.0,14.0,6.0,5,1.6,1.25,1.25,45.3,,,,,
"""SI""",2019-08-10 00:00:00,93.5,2.01,5.72,"""Liriano, Francisco""",572761,434538,"""sac_bunt_double_play""","""hit_into_play""",,,,,7,"""Matt Carpenter ground bunts in…","""R""","""L""","""L""","""STL""","""PIT""","""X""",2,"""ground_ball""",0,0,2019,1.23,0.96,-0.44,2.07,,542303.0,657557,1,6,"""Bot""",…,1,3,3,1,"""Infield shift""","""Standard""",128,-0.051,-0.2,,,,0.2,,2,2,0.866162,0.866162,35,33,36,34,1,2,4.0,1.0,3.0,1,1.67,1.23,1.23,,,,,,
"""SL""",2019-08-01 00:00:00,89.8,1.28,6.38,"""Kershaw, Clayton""",622534,477132,"""sac_bunt_double_play""","""hit_into_play""",,,,,2,"""Manuel Margot ground bunts int…","""R""","""R""","""L""","""LAD""","""SD""","""X""",5,"""ground_ball""",1,0,2019,-0.19,0.85,-0.01,3.38,,,665487,0,1,"""Top""",…,0,0,0,0,"""Standard""","""Standard""",193,0.074,-0.136,,,0.253,0.136,88.0,0,0,0.464156,0.535844,31,24,31,25,1,0,5.0,2.0,5.0,1,1.97,-0.19,0.19,,,,,,
"""CU""",2019-04-01 00:00:00,74.6,-1.67,6.23,"""Wainwright, Adam""",502042,425794,"""sac_bunt_double_play""","""hit_into_play""",,,,,12,"""Chris Archer ground bunts into…","""R""","""R""","""R""","""PIT""","""STL""","""X""",2,"""ground_ball""",1,2,2019,1.7,-0.96,0.92,2.63,,,570481,0,4,"""Bot""",…,0,4,4,0,"""Strategic""","""Strategic""",61,-0.023,-0.016,,,,0.016,,4,4,0.922565,0.922565,37,30,38,31,2,1,,,6.0,6,5.24,-1.7,-1.7,,,,,,


In [3]:
df[0]

game_pk,player_name,at_bat_number,processed_list,feature_list,sequence_length
i64,str,i64,list[i64],list[str],u32
717517,"""Willingham, Amos""",70,"[130, 117, … 17]","[""pitch = Slider"", ""velo = (84.4. 85.1]"", … ""events = field_out""]",21


In [4]:
df[0]["feature_list"].item().to_list()

['pitch = Slider',
 'velo = (84.4. 85.1]',
 'plate_x = (0.39. 0.46]',
 'plate_z = (-inf, 0.43]',
 'description = ball',
 'pitch = 4-Seam Fastball',
 'velo = (95.2. 95.7]',
 'plate_x = (0.07. 0.13]',
 'plate_z = (3.04. 3.15]',
 'description = called_strike',
 'pitch = 4-Seam Fastball',
 'velo = (95.7. 96.2]',
 'plate_x = (-0.93. -0.82]',
 'plate_z = (2.77. 2.86]',
 'description = foul',
 'pitch = 4-Seam Fastball',
 'velo = (95.7. 96.2]',
 'plate_x = (-0.39. -0.32]',
 'plate_z = (2.86. 2.95]',
 'description = hit_into_play',
 'events = field_out']

In [5]:
df[0]["processed_list"].item().to_list()

[130,
 117,
 57,
 33,
 50,
 118,
 2,
 158,
 116,
 77,
 138,
 2,
 156,
 142,
 53,
 151,
 2,
 156,
 139,
 102,
 39,
 134,
 17]

In [6]:
print(complete_vocab["<SOS>"])
print(complete_vocab["<EOS>"])

130
17
