In [1]:
import polars as pl

from src.processing import Stringifier, TimeTokenizer, make_vocabulary

In [2]:
data_path = "./data/raw_data.parquet"


# Config info ----------------------------
sequence_keys = ["game_pk"]
order_columns = ["at_bat_number", "pitch_number"]
time_column = "inning"

feature_column_metadata = [
    ["pitch_name", "categorical", "pitch"],
    ["release_speed", "numeric", "velo"],
]

keyword_args = {
    "n_buckets": 32,
}
# -----------------------------------------

df = pl.read_parquet(data_path).select(
    # Keys and order
    *sequence_keys,
    *order_columns,
    # I need to transform the time column, so...
    (
        pl.col("inning")
        + pl.when(pl.col("inning_topbot") == "Top").then(0).otherwise(0.5)
    ).alias("inning"),
    # features
    *[col for col, _, _ in feature_column_metadata],
)

# A little more setup ----------------------------------

# This is nasty I'd rather not do it.
df = df.with_columns(
    (pl.col(time_column) - pl.col(time_column).shift(1))
    .over(partition_by=sequence_keys, order_by=order_columns)
    .alias("time_diffs")
)

time_tokenizer = TimeTokenizer.from_data(df["time_diffs"])

stringifiers = {
    col_name: Stringifier.from_data(df[col_name], data_type, group, kwargs=keyword_args)
    for col_name, data_type, group in feature_column_metadata
}

complete_vocab = make_vocabulary(stringifiers.values(), time_tokenizer)
print(len(complete_vocab))

df = (
    df.select(
        *sequence_keys,
        *order_columns,
        time_tokenizer.transform(pl.col("time_diffs")).alias("time_diffs"),
        *[s.transform(pl.col(n)).alias(n) for n, s in stringifiers.items()],
    )
    .with_columns(
        pl.concat_list("time_diffs", *[pl.col(n) for n in stringifiers])
        .list.drop_nulls()
        .alias("feature_list")
    )
    .select(*sequence_keys, *order_columns, "feature_list")
    .explode("feature_list")
    .with_columns(
        pl.col("feature_list")
        .replace(complete_vocab)
        .cast(pl.Int64)
        .alias("processed_list"),
    )
)

df = (
    df.group_by(*sequence_keys)
    .agg("processed_list", "feature_list")
    .with_columns(pl.col("processed_list").list.len().alias("sequence_length"))
)
print(df["sequence_length"].max())
print(df)

n_buckets=32
53
903
shape: (2_666, 4)
┌─────────┬──────────────────────┬─────────────────────────────────┬─────────────────┐
│ game_pk ┆ processed_list       ┆ feature_list                    ┆ sequence_length │
│ ---     ┆ ---                  ┆ ---                             ┆ ---             │
│ i64     ┆ list[i64]            ┆ list[str]                       ┆ u32             │
╞═════════╪══════════════════════╪═════════════════════════════════╪═════════════════╡
│ 718729  ┆ [52, 7, … 17]        ┆ ["pitch = 4-Seam Fastball", "v… ┆ 607             │
│ 716487  ┆ [30, 4, … 44]        ┆ ["pitch = Curveball", "velo = … ┆ 780             │
│ 717946  ┆ [16, 34, … 42]       ┆ ["pitch = Cutter", "velo = 97.… ┆ 728             │
│ 717797  ┆ [47, 41, … 31]       ┆ ["pitch = Slider", "velo = 88.… ┆ 627             │
│ 717404  ┆ [12, 41, … 2]        ┆ ["pitch = Changeup", "velo = 8… ┆ 573             │
│ …       ┆ …                    ┆ …                               ┆ …               │
│ 718