In [1]:
import polars as pl

from src.processing import Stringifier, TimeTokenizer, make_vocabulary

In [3]:
data_path = "./data/raw_data.parquet"


# Config info ----------------------------
sequence_keys = ["game_pk"]
order_columns = ["at_bat_number", "pitch_number"]
time_column = "inning"

feature_column_metadata = [
    ["pitch_name", "categorical", "pitch"],
    ["release_speed", "numeric", "velo"],
]

keyword_args = {
    "n_buckets": 32,
}
# -----------------------------------------

df = (
    pl.read_parquet(data_path)
    .select(
        # Keys and order
        *sequence_keys,
        *order_columns,
        # I need to transform the time column, so...
        (pl.col("inning") + pl.when(pl.col("inning_topbot") == "Top").then(0).otherwise(0.5)).alias("inning"),
        # features
        *[col for col, _, _ in feature_column_metadata],
    )
)

# A little more setup ----------------------------------

# This is nasty I'd rather not do it.
df = df.with_columns(
        (pl.col(time_column) - pl.col(time_column).shift(1))
        .over(partition_by=sequence_keys, order_by=order_columns)
        .alias("time_diffs")
    )

time_tokenizer = TimeTokenizer.from_data(df["time_diffs"])

stringifiers = {
    col_name: Stringifier.from_data(df[col_name], data_type, group, kwargs=keyword_args)
    for col_name, data_type, group in feature_column_metadata
}

complete_vocab = make_vocabulary(stringifiers.values(), time_tokenizer)
print(len(complete_vocab))

df =(
    df.select(
        *sequence_keys,
        *order_columns,
        time_tokenizer.transform(pl.col("time_diffs")).alias("time_diffs"),
        *[s.transform(pl.col(n)).alias(n) for n, s in stringifiers.items()]
    )
    .with_columns(
        pl.concat_list("time_diffs", *[pl.col(n) for n in stringifiers]).list.drop_nulls().alias("feature_list")
    )
    .select(
        *sequence_keys,
        *order_columns,
        "feature_list"
    )
    .explode("feature_list")
    .with_columns(pl.col("feature_list").replace(complete_vocab).alias("processed_list"))
)

df.group_by(*sequence_keys).agg("processed_list", "feature_list")


n_buckets=32
53


game_pk,processed_list,feature_list
i64,list[str],list[str]
716764,"[""34"", ""27"", … ""31""]","[""pitch = Slider"", ""velo = 89.5-90.2"", … ""velo = 90.8-91.4""]"
717139,"[""24"", ""15"", … ""5""]","[""pitch = Split-Finger"", ""velo = 88.8-89.5"", … ""velo = 96.8-97.9""]"
716359,"[""36"", ""41"", … ""20""]","[""pitch = Eephus"", ""velo = (-inf)-77.2"", … ""velo = 82.4-83.1""]"
717812,"[""44"", ""28"", … ""43""]","[""pitch = Changeup"", ""velo = 86.9-87.5"", … ""velo = 95.7-96.2""]"
718330,"[""42"", ""3"", … ""10""]","[""pitch = 4-Seam Fastball"", ""velo = 91.4-91.9"", … ""velo = 92.4-92.8""]"
…,…,…
718425,"[""34"", ""31"", … ""32""]","[""pitch = Slider"", ""velo = 90.8-91.4"", … ""velo = 94.8-95.2""]"
718044,"[""34"", ""11"", … ""45""]","[""pitch = Slider"", ""velo = 87.5-88.2"", … ""velo = 95.2-95.7""]"
717642,"[""33"", ""25"", … ""8""]","[""pitch = Curveball"", ""velo = 83.1-83.8"", … ""velo = 93.2-93.6""]"
718297,"[""33"", ""29"", … ""32""]","[""pitch = Curveball"", ""velo = 81.6-82.4"", … ""velo = 94.8-95.2""]"
