# Environment Buidling

## Imports and Config

We import our data modules, set paths, and define global constants for the season and tensor sizes.

In [13]:
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict, Counter

from pybaseball import statcast

pd.set_option("display.max_columns", None)
pd.set_option("display.max_rows", None)

import sys
import os

# Add the project root to the system path
# Assumes the notebook is run from a sub-directory, e.g. 'notebooks'
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

# Project imports
from src.data.statcast_loader import pull_season_by_month, load_raw_statcast
from src.data.pa_builder import build_pas, attach_pitcher_fatigue_features
from src.data.lineup_builder import infer_batting_order, add_lineup_window, build_positional_encodings
from src.data.batter_features import build_batter_feature_table, build_next_hitter_features
from src.data.bullpen_form import add_pitcher_form_features
from src.data.availability import build_availability
from src.data.reward import add_pa_delta_run_exp, annotate_terminals, fold_smdp_reward

# Paths (notebook in notebooks/, data at project_root/data)
DATA_DIR = Path("../data")
RAW_DIR = DATA_DIR / "raw"
PROC_DIR = DATA_DIR / "processed"
PROC_DIR.mkdir(parents=True, exist_ok=True)

# Configuration
YEARS = [2022, 2023]   # seasons to build
H = 5         # lineup window length (next H hitters)
R = 10        # max bullpen size per team we model
GAMMA = 0.99  # discount for SMDP reward folding
PENALTY_PULL = 0.005  # small penalty when using a pitching change

# Filename tag for multi-year datasets
if len(YEARS) == 1:
    YEAR_TAG = str(YEARS[0])
else:
    YEAR_TAG = f"{min(YEARS)}_{max(YEARS)}"

import warnings
warnings.filterwarnings("ignore")

## 2. Pull & load Statcast pitches

We fetch + cache the full season's Statcast data at pitch level. `statcast_loader` is a thin wrapper around `pybaseball.statcast` that saves Parquets by month.

In [14]:
# Helper to trim Statcast columns before saving
def trim_statcast_columns(df: pd.DataFrame) -> pd.DataFrame:
    """
    Keep only the Statcast columns required by the bullpen RL pipeline.
    This reduces disk space and memory usage.
    """
    needed_cols = [
        # IDs / time
        "game_pk", "game_date", "game_year", "pitch_number", "at_bat_number",
        # players / teams
        "batter", "pitcher", "home_team", "away_team",
        "stand", "p_throws",
        # game state
        "inning", "inning_topbot",
        "outs_when_up", "balls", "strikes",
        "on_1b", "on_2b", "on_3b",
        "bat_score", "fld_score", "bat_score_diff", "home_score_diff",
        # pitch result / classification
        "events", "type", "description", "zone", "bb_type",
        # batted-ball metrics
        "launch_speed", "launch_angle", "hit_distance_sc",
        "woba_value", "estimated_woba_using_speedangle", "launch_speed_angle",
        # run expectancy change
        "delta_run_exp",
        # pitcher form
        "release_speed", "release_spin_rate",
        "pitcher_days_since_prev_game",
        # batter/pitcher context
        "n_priorpa_thisgame_player_at_bat",
    ]

    missing = [c for c in needed_cols if c not in df.columns]
    if missing:
        print(f"WARNING: missing expected Statcast columns: {missing}")

    keep_cols = [c for c in needed_cols if c in df.columns]
    return df[keep_cols].copy()

In [15]:
# Pull season by month (cached). Set force=True to re-download.
pull_season_by_month(YEARS, force=False)

# Load all monthly Parquets into a single DataFrame
pitches = load_raw_statcast(YEARS)
pitches = trim_statcast_columns(pitches)
print("Pitches shape:", pitches.shape)
pitches.head()

[statcast_loader] Using cached /Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/data/raw/statcast_2022_03.parquet
[statcast_loader] Using cached /Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/data/raw/statcast_2022_04.parquet
[statcast_loader] Using cached /Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/data/raw/statcast_2022_05.parquet
[statcast_loader] Using cached /Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/data/raw/statcast_2022_06.parquet
[statcast_loader] Using cached /Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/data/raw/statcast_2022_07.parquet
[statcast_loader] Using cached /Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/data/raw/statcast_2022_08.parquet
[statcast_loader] Using cached /Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/data/raw/statcast_2022_09.parquet
[statcast_loader] Using cached /Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/data/raw/statcast_2022_10.parquet
[statcast_loader] Using cached /Users/ethanbobrik/Projects/MLB-Bullpen-Strategy/data/raw

Unnamed: 0,game_pk,game_date,game_year,pitch_number,at_bat_number,batter,pitcher,home_team,away_team,stand,p_throws,inning,inning_topbot,outs_when_up,balls,strikes,on_1b,on_2b,on_3b,bat_score,fld_score,bat_score_diff,home_score_diff,events,type,description,zone,bb_type,launch_speed,launch_angle,hit_distance_sc,woba_value,estimated_woba_using_speedangle,launch_speed_angle,delta_run_exp,release_speed,release_spin_rate,pitcher_days_since_prev_game,n_priorpa_thisgame_player_at_bat
0,661032,2022-04-26,2022,1,4,621493,663474,LAA,CLE,R,R,1,Bot,0,0,0,,,,0,0,0,0,,S,called_strike,8,,,,,,,,-0.038,90.6,2192,6,0
1,661032,2022-04-26,2022,2,4,621493,663474,LAA,CLE,R,R,1,Bot,0,0,1,,,,0,0,0,0,,S,called_strike,6,,,,,,,,-0.051,92.6,2209,6,0
2,661032,2022-04-26,2022,3,4,621493,663474,LAA,CLE,R,R,1,Bot,0,0,2,,,,0,0,0,0,,S,foul,1,,64.5,44.0,181.0,,,,0.0,93.0,2197,6,0
3,661032,2022-04-26,2022,4,4,621493,663474,LAA,CLE,R,R,1,Bot,0,0,2,,,,0,0,0,0,,S,foul,11,,72.2,56.0,187.0,,,,0.0,92.2,2188,6,0
4,661032,2022-04-26,2022,5,4,621493,663474,LAA,CLE,R,R,1,Bot,0,0,2,,,,0,0,0,0,field_out,X,hit_into_play,2,fly_ball,94.6,47.0,284.0,0.0,0.003,3.0,-0.157,78.8,2110,6,0


In [16]:
pitches.shape

(1547975, 39)

In [17]:
pitches.iloc[:,:].isna().sum()

game_pk                                   0
game_date                                 0
game_year                                 0
pitch_number                              0
at_bat_number                             0
batter                                    0
pitcher                                   0
home_team                                 0
away_team                                 0
stand                                     0
p_throws                                  0
inning                                    0
inning_topbot                             0
outs_when_up                              0
balls                                     0
strikes                                   0
on_1b                               1072512
on_2b                               1253385
on_3b                               1400999
bat_score                                 0
fld_score                                 0
bat_score_diff                            0
home_score_diff                 

## 3. Build plate-appearance (PA) table

We compress pitch-level data into one row per PA start, then attach pitcher fatigue features.

In [18]:
# 3.1 Build PA table (one row per PA start with game state)
pa = build_pas(pitches)
print("PA shape (before fatigue):", pa.shape)
pa.head()

PA shape (before fatigue): (407660, 50)


Unnamed: 0,game_pk,game_date,game_year,pitch_number,at_bat_number,batter,pitcher_on_mound,home_team,away_team,stand,p_throws,inning,inning_topbot,outs,balls,strikes,on_1b,on_2b,on_3b,bat_score,fld_score,bat_score_diff,home_score_diff,events,type,description,zone,bb_type,launch_speed,launch_angle,hit_distance_sc,woba_value,estimated_woba_using_speedangle,launch_speed_angle,delta_run_exp,release_speed,release_spin_rate,pitcher_days_since_prev_game,n_priorpa_thisgame_player_at_bat,half,base_state,batting_team,fielding_team,score_diff,batter_is_left,batter_is_right,batter_is_switch,pitcher_is_left,pitcher_is_right,is_platoon_advantage
0,661032,2022-04-26,2022,1,4,621493,663474,LAA,CLE,R,R,1,Bot,0,0,0,,,,0,0,0,0,field_out,S,called_strike,8,fly_ball,64.5,44.0,181.0,0.0,0.003,3.0,-0.038,90.6,2192,6,0,1,0,LAA,CLE,0,0,1,0,0,1,0
1,661032,2022-04-26,2022,1,5,660271,663474,LAA,CLE,L,R,1,Bot,1,0,0,,,,0,0,0,0,field_out,X,hit_into_play,8,ground_ball,108.1,-23.0,4.0,0.0,0.161,2.0,-0.246,92.5,2036,6,0,1,0,LAA,CLE,0,1,0,0,0,1,1
2,661032,2022-04-26,2022,1,6,545361,663474,LAA,CLE,R,R,1,Bot,2,0,0,,,,0,0,0,0,strikeout,S,called_strike,9,,72.7,17.0,168.0,0.0,0.0,,-0.043,91.6,2314,6,0,1,0,LAA,CLE,0,0,1,0,0,1,0
3,661032,2022-04-26,2022,1,1,664702,663776,LAA,CLE,R,L,1,Top,0,0,0,,,,0,0,0,0,strikeout,B,ball,14,,,,,0.0,0.0,,0.037,94.7,2105,7,0,0,0,CLE,LAA,0,0,1,0,1,0,1
4,661032,2022-04-26,2022,1,2,642708,663776,LAA,CLE,R,L,1,Top,1,0,0,,,,0,0,0,0,field_out,S,called_strike,3,ground_ball,69.5,-30.0,3.0,0.0,0.062,2.0,-0.038,95.3,2207,7,0,0,0,CLE,LAA,0,0,1,0,1,0,1


In [19]:
# 3.2 Attach pitcher fatigue features (pitch_count, TTO)
pa = attach_pitcher_fatigue_features(pa, pitches)
print("PA shape (after fatigue):", pa.shape)
pa[["game_pk","inning","half","at_bat_number","pitcher_on_mound","pitch_count","tto"]].head()

PA shape (after fatigue): (407660, 53)


Unnamed: 0,game_pk,inning,half,at_bat_number,pitcher_on_mound,pitch_count,tto
0,661032,1,0,1,663776,1,0
1,661032,1,0,2,663776,5,0
2,661032,1,0,3,663776,7,0
3,661032,1,1,4,663474,1,0
4,661032,1,1,5,663474,6,0


## 4. Infer lineup & add upcoming hitters

We infer lineup index per batter and, for each PA, compute the next H hitters in the actual sequence.

In [20]:
# 4.1 Infer a lineup index per batter within each (game_pk, batting_team)
pa = infer_batting_order(pa)

# 4.2 Add next_hitters_ids: list of next H batter IDs in the actual PA order
pa = add_lineup_window(pa, H=H)

pa[["game_pk","inning","half","batter","lineup_idx","next_hitters_ids"]].head()

Unnamed: 0,game_pk,inning,half,batter,lineup_idx,next_hitters_ids
0,661032,1,0,664702,0,"[642708, 608070, 614177, 680911, 640458]"
1,661032,1,0,642708,1,"[608070, 614177, 680911, 640458, 676391]"
2,661032,1,0,608070,2,"[614177, 680911, 640458, 676391, 595978]"
6,661032,2,0,614177,3,"[680911, 640458, 676391, 595978, 665926]"
7,661032,2,0,680911,4,"[640458, 676391, 595978, 665926, 664702]"


## 5. Build batter features and lineup tensor

From pitch-level Statcast, we aggregate per-batter plate discipline & quality-of-contact metrics. 

Then we map each PA's `next_hitters` into a [B,H, d_hit] lineup feature tensor.

In [21]:
# 5.1 Batter feature table from pitch-level data
batter_feat = build_batter_feature_table(pitches, min_pa=20)
print("Batter feature table shape:", batter_feat.shape)
batter_feat.head()

Batter feature table shape: (838, 18)


Unnamed: 0_level_0,pa_count,pitch_count,swing_rate,whiff_rate,contact_rate,chase_rate,z_swing_rate,z_contact_rate,ball_rate,called_strike_rate,hard_hit_rate,barrel_rate,xwoba_mean,bb_rate_pa,k_rate_pa,handed_L,handed_R,handed_S
batter,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
405395,366.0,1432.0,0.483939,0.189033,0.810967,0.185754,0.64018,0.899297,0.353352,0.162709,0.30427,0.010676,0.093077,0.101093,0.153005,0.0,1.0,0.0
408234,847.0,3196.0,0.490926,0.24028,0.75972,0.150188,0.703034,0.836547,0.356383,0.152065,0.198826,0.010906,0.076774,0.07438,0.214876,0.0,1.0,0.0
425877,290.0,1040.0,0.5875,0.198036,0.801964,0.204808,0.752363,0.829146,0.283654,0.128846,0.157143,0.018367,0.072735,0.024138,0.151724,0.0,1.0,0.0
429664,131.0,468.0,0.568376,0.199248,0.800752,0.222222,0.743119,0.895062,0.311966,0.119658,0.28169,0.018779,0.078412,0.045802,0.206107,1.0,0.0,0.0
435559,169.0,623.0,0.476726,0.185185,0.814815,0.134831,0.682692,0.896714,0.367576,0.155698,0.177686,0.008264,0.069436,0.106509,0.177515,0.0,1.0,0.0


In [22]:
# 5.2 Convert next_hitters_ids into lineup tensor [B, H, d_hit]
B = len(pa)
next_hitters_feats = build_next_hitter_features(
    pa=pa,
    batter_features=batter_feat,
    H=H,
    feature_cols=None,      # use all batter feature columns
    default_zero=True,
)
print("next_hitters_feats shape:", next_hitters_feats.shape)  # (B, H, d_hit)

next_hitters_feats shape: (407660, 5, 18)


In [23]:
# 5.3 Simple positional encodings for lineup window [B, H, d_pos]
pos_enc = build_positional_encodings(num_positions=B, H=H)
print("pos_enc shape:", pos_enc.shape)

pos_enc shape: (407660, 5, 5)


## 6. Compute pitcher form & usage

`add_pitcher_form_features` adds rolling strike/zone metrics, velocity/spin deltas, and game-usage fatigue for each pitch.

In [24]:
pitches_form = add_pitcher_form_features(pitches)
print("Pitches with form shape:", pitches_form.shape)
pitches_form[[
    "pitcher","game_pk","game_date",
    "zone_ewma_h10","strike_ewma_h10","strike30_shrunk",
    "release_speed_delta","release_spin_delta",
    "pitches_in_game","last_outing_pitches","days_since_used","b2b_flag"
]].head()

Pitches with form shape: (1547975, 54)


Unnamed: 0,pitcher,game_pk,game_date,zone_ewma_h10,strike_ewma_h10,strike30_shrunk,release_speed_delta,release_spin_delta,pitches_in_game,last_outing_pitches,days_since_used,b2b_flag
0,405395,661984,2022-05-15,0.0,0.0,,0.0,0.0,1,0.0,99,0
1,405395,661984,2022-05-15,0.0,0.0,,0.45,40.5,2,0.0,99,0
2,405395,661984,2022-05-15,0.0,0.069315,,0.7,14.333333,3,0.0,99,0
3,405395,661984,2022-05-15,0.0,0.06451,,-1.2,-64.25,4,0.0,99,0
4,405395,661984,2022-05-15,0.069315,0.060039,0.6,1.76,35.8,5,0.0,99,0


In [25]:
# 6.1 Build a global time_index for pitches and PA starts

# Ensure datetime
pitches_form = pitches_form.copy()
pitches_form["game_date"] = pd.to_datetime(pitches_form["game_date"])

# Chronological order within each game at the pitch level
pitches_form = pitches_form.sort_values(
    ["game_pk", "inning", "inning_topbot", "at_bat_number", "pitch_number"]
)

# Sequence index per game
pitches_form["seq_in_game"] = (
    pitches_form.groupby("game_pk").cumcount().astype("int64")
)

# Global time_index: date (ns) + within-game sequence
pitches_form["time_index"] = (
    pitches_form["game_date"].view("int64") + pitches_form["seq_in_game"]
)

# Attach PA start time_index from first pitch of each PA
first_pitch = (
    pitches_form.sort_values(["game_pk", "at_bat_number", "pitch_number"])
    .groupby(["game_pk", "at_bat_number"], as_index=False)
    .head(1)[["game_pk", "at_bat_number", "time_index"]]
    .rename(columns={"time_index": "time_index_pa"})
)

pa = pa.copy()
pa["game_date"] = pd.to_datetime(pa["game_date"])

pa = pa.merge(
    first_pitch,
    on=["game_pk", "at_bat_number"],
    how="left",
)

# Fallback: if no first pitch found (should be rare), use game_date only
pa["time_index_pa"] = pa["time_index_pa"].fillna(
    pa["game_date"].view("int64")
).astype("int64")

pa[["game_pk","at_bat_number","time_index_pa"]].head()

Unnamed: 0,game_pk,at_bat_number,time_index_pa
0,661032,1,1650931200000000011
1,661032,2,1650931200000000015
2,661032,3,1650931200000000017
3,661032,7,1650931200000000029
4,661032,8,1650931200000000033


## 7. Build a crude bullpen roster per team

We infer each pitcher's team at pitch time and, for each team, pick the top R pitchers by pitch count as our "bullpen" set.

In [26]:
# Infer pitcher_team at pitch level using inning_topbot + home/away
pitches_form = pitches_form.copy()
pitches_form["pitcher_team"] = np.where(
    pitches_form["inning_topbot"] == "Top",
    pitches_form["home_team"],
    pitches_form["away_team"],
)

# Make sure pitches are in proper chronological order
pitches_chrono = pitches_form.sort_values(
    ["game_pk", "inning", "inning_topbot", "at_bat_number", "pitch_number"]
)

# relief_appearances[team] is a Counter of how many games each pitcher relieved
relief_appearances = defaultdict(Counter)

# Group by game and pitcher_team (i.e., each team's pitching usage in that game)
for (game_pk, team), g in pitches_chrono.groupby(["game_pk", "pitcher_team"], sort=False):
    # g: all pitches thrown by this team in this game, in time order
    # Starter = pitcher on the first pitch for this team in this game
    starter = g.iloc[0]["pitcher"]

    # Any pitcher other than the starter in this game is a reliever (for this game)
    relievers_in_game = g.loc[g["pitcher"] != starter, "pitcher"].unique()

    # Increment relief appearance counts (per game)
    for pid in relievers_in_game:
        relief_appearances[team][pid] += 1

# Now build team_relievers_map:
# For each team, sort pitchers by how many games they relieved,
# and take the full list (we'll still cap to R later when building tensors).
team_relievers_map = {}
for team, counter in relief_appearances.items():
    # pitchers sorted by relief-game count, descending
    pitchers_sorted = [pid for pid, _ in counter.most_common()]
    team_relievers_map[team] = pitchers_sorted

# Quick sanity check: show first few teams
len(team_relievers_map), list(team_relievers_map.items())[:3]


(30,
 [('CLE',
   [np.int64(661403),
    np.int64(663986),
    np.int64(660853),
    np.int64(656529),
    np.int64(669212),
    np.int64(680704),
    np.int64(675916),
    np.int64(543766),
    np.int64(675540),
    np.int64(682120),
    np.int64(543238),
    np.int64(547184),
    np.int64(668948),
    np.int64(621593),
    np.int64(625643),
    np.int64(663752),
    np.int64(622065),
    np.int64(596057),
    np.int64(663455),
    np.int64(621057),
    np.int64(679777),
    np.int64(663531),
    np.int64(519043),
    np.int64(683769),
    np.int64(668684),
    np.int64(663474),
    np.int64(656349),
    np.int64(680940),
    np.int64(664139),
    np.int64(676520),
    np.int64(663613),
    np.int64(676391),
    np.int64(668676),
    np.int64(687330),
    np.int64(657228),
    np.int64(663671),
    np.int64(665178),
    np.int64(667280),
    np.int64(679033),
    np.int64(663892),
    np.int64(681259),
    np.int64(681856),
    np.int64(668964),
    np.int64(669437),
    np.int64(6719

## 8. Build reliever features and availability masks per PA

For each PA, we:
1. Identify the fielding team
2. Get that team's bullpen list from `team_relievers_map`
3. For each reliever, grab their latest form snapshot up to this `game_date`
4. Construct a reliever feature tensor [B,R,d_rel]
5. Build an availability mask [B,R] using rest & last-outing constraints

In [27]:
rel_form_cols = [
    "zone_ewma_h10",
    "strike_ewma_h10",
    "strike30_shrunk",
    "release_speed_delta",
    "release_spin_delta",
    "last_outing_pitches",
    "days_since_used",
    "b2b_flag",
]
d_rel = len(rel_form_cols)

B = len(pa)
reliever_feats = np.zeros((B, R, d_rel), dtype=np.float32)
avail_mask = np.zeros((B, R), dtype=bool)
reliever_id_slots = np.zeros((B, R), dtype=np.int64)

# Make sure types we need are consistent
pitches_form = pitches_form.copy()
pitches_form["pitcher_team"] = pitches_form["pitcher_team"].astype(str)
pitches_form["time_index"] = pitches_form["time_index"].astype("int64")

pa = pa.copy()
pa["fielding_team"] = pa["fielding_team"].astype(str)
pa["time_index_pa"] = pa["time_index_pa"].astype("int64")

# 8.1 Build a long table of (PA, reliever-slot) pairs with time_index_pa
rows = []  # (pa_index, reliever_slot, reliever_id, fielding_team, time_index_pa)

for idx, row in enumerate(pa.itertuples(index=False)):
    team = row.fielding_team
    t_rel = team_relievers_map.get(team, [])[:R]

    n_rel = len(t_rel)
    if n_rel > 0:
        reliever_id_slots[idx, :n_rel] = t_rel

    for slot, rid in enumerate(t_rel):
        rows.append((idx, slot, int(rid), team, int(row.time_index_pa)))

if rows:
    pa_rel_df = pd.DataFrame(
        rows,
        columns=["pa_index", "reliever_slot", "reliever_id", "fielding_team", "time_index_pa"],
    )

    pa_rel_df["reliever_id"] = pa_rel_df["reliever_id"].astype("int64")
    pa_rel_df["fielding_team"] = pa_rel_df["fielding_team"].astype(str)
    pa_rel_df["time_index_pa"] = pa_rel_df["time_index_pa"].astype("int64")

    # 8.2 Build pitch-level snapshots from pitches_form
    # We only need the reliever form/usage columns + time_index
    snapshot_cols = ["pitcher", "pitcher_team", "time_index"] + rel_form_cols

    form_snapshots = pitches_form[snapshot_cols].rename(
        columns={
            "pitcher": "reliever_id",
            "pitcher_team": "fielding_team",
        }
    )

    form_snapshots["reliever_id"] = form_snapshots["reliever_id"].astype("int64")
    form_snapshots["fielding_team"] = form_snapshots["fielding_team"].astype(str)
    form_snapshots["time_index"] = form_snapshots["time_index"].astype("int64")

    # Sort once for each side
    form_snapshots = form_snapshots.sort_values(
        ["reliever_id", "fielding_team", "time_index"]
    )
    pa_rel_df = pa_rel_df.sort_values(
        ["reliever_id", "fielding_team", "time_index_pa"]
    )

    # Group by (reliever_id, fielding_team) on both sides
    snap_groups = {
        key: g for key, g in form_snapshots.groupby(["reliever_id", "fielding_team"])
    }
    pa_groups = {
        key: g for key, g in pa_rel_df.groupby(["reliever_id", "fielding_team"])
    }

    # Allocate fatigue arrays for availability
    days_since_full = np.full((B, R), np.nan, dtype=float)
    last_out_full = np.full((B, R), np.nan, dtype=float)

    # 8.3 For each (reliever, team), match snapshots via searchsorted
    for key, g_pa in pa_groups.items():
        if key not in snap_groups:
            continue

        snap_g = snap_groups[key]

        # Arrays of time indexes
        snap_t = snap_g["time_index"].to_numpy(dtype="int64")
        pa_t = g_pa["time_index_pa"].to_numpy(dtype="int64")

        # For each PA time, find last snapshot time <= PA time
        idx_snap = np.searchsorted(snap_t, pa_t, side="right") - 1
        valid = idx_snap >= 0
        if not valid.any():
            continue

        g_pa_valid = g_pa.iloc[valid]
        snap_rows = snap_g.iloc[idx_snap[valid]]

        pa_idx_arr = g_pa_valid["pa_index"].to_numpy(dtype=np.int64)
        slot_arr = g_pa_valid["reliever_slot"].to_numpy(dtype=np.int64)

        # Fill reliever feature tensor [B, R, d_rel]
        feat_mat = snap_rows[rel_form_cols].fillna(0.0).to_numpy(dtype=np.float32)
        reliever_feats[pa_idx_arr, slot_arr, :] = feat_mat

        # Store fatigue info for availability
        days_since_full[pa_idx_arr, slot_arr] = snap_rows["days_since_used"].to_numpy()
        last_out_full[pa_idx_arr, slot_arr] = snap_rows["last_outing_pitches"].to_numpy()

else:
    # No relievers at all (very unlikely), keep shapes compatible
    days_since_full = np.full((B, R), np.nan, dtype=float)
    last_out_full = np.full((B, R), np.nan, dtype=float)

# 8.4 Availability
already_used_by_game = defaultdict(set)
rest_min_days = 1
last_outing_max = 35

for idx, row in enumerate(pa.itertuples(index=False)):
    game = row.game_pk
    already_used_ids = already_used_by_game[game]

    relievers_row = reliever_id_slots[idx]       # [R]
    ds_row = days_since_full[idx]                # [R]
    lo_row = last_out_full[idx]                  # [R]

    # List of actual relievers (exclude zero padding)
    team_relievers = [int(rid) for rid in relievers_row if rid != 0]

    # Build per-reliever snapshot mapping for availability
    form_snapshot = {}
    for slot, rid in enumerate(relievers_row):
        if rid == 0:
            continue

        days_since = ds_row[slot]
        last_p = lo_row[slot]

        days_val = 99.0 if np.isnan(days_since) else float(days_since)
        last_val = 0.0 if np.isnan(last_p) else float(last_p)

        form_snapshot[int(rid)] = pd.Series(
            {
                "days_since_used": days_val,
                "last_outing_pitches": last_val,
            }
        )

    mask_small = build_availability(
        team_relievers=team_relievers,
        form_snapshot=form_snapshot,
        already_used_ids=already_used_ids,
        rest_min_days=rest_min_days,
        last_outing_max=last_outing_max,
    )

    n_rel = len(team_relievers)
    if n_rel > 0:
        avail_mask[idx, :n_rel] = mask_small

    # Mark current pitcher as used in this game
    already_used_ids.add(row.pitcher_on_mound)

print("reliever_feats shape:", reliever_feats.shape)
print("avail_mask shape:", avail_mask.shape)
print("reliever_id_slots shape:", reliever_id_slots.shape)

reliever_feats shape: (407660, 10, 8)
avail_mask shape: (407660, 10)
reliever_id_slots shape: (407660, 10)


## 9. Add PA-level Run Expectancy deltaas and mark terminals

We aggregate Statcast's pitch-level `delta_run_exp` into a per-PA column `delta_re_pa`, then mark half-inning and game terminal PAs.

In [28]:
# 9.1 Aggregate pitch-level delta_run_exp to PA-level delta_re_pa
pa = add_pa_delta_run_exp(
    pa=pa,
    pitches=pitches,
    delta_col="delta_run_exp",
    out_col="delta_re_pa",
)
pa[["game_pk","at_bat_number","delta_re_pa"]].head()

Unnamed: 0,game_pk,at_bat_number,delta_re_pa
0,661032,1,-0.247
1,661032,2,-0.246
2,661032,3,-0.255
3,661032,7,-0.246
4,661032,8,-0.246


In [29]:
# 9.2 Annotate half-inning and game terminals
pa = annotate_terminals(pa)
pa[["game_pk","inning","half","at_bat_number","half_inning_over","game_over"]].head()

Unnamed: 0,game_pk,inning,half,at_bat_number,half_inning_over,game_over
0,661032,1,0,1,False,False
1,661032,1,0,2,False,False
2,661032,1,0,3,True,False
3,661032,2,0,7,False,False
4,661032,2,0,8,False,False


## 10. Fold SMDP rewards using delta_re_pa

For each sequence of PAs with the same (`game_pk`,`fielding_team`,`half`), we:
- Treat each PA as a decision point
- Detect whether the pitcher changed before the next PA
- Fold -`delta_re_pa` over up to 3 batters (or until half/game ends)
- Record (`reward`, `next_state_idx`,`done`) and simple action labels

In [30]:
B = len(pa)

reward_folded = np.zeros(B, dtype=np.float32)
next_idx_arr = np.zeros(B, dtype=np.int64)
done_arr = np.zeros(B, dtype=bool)

# Actions:
#   0    = stay with current pitcher
#   1..R = pull and bring in reliever in slot (a-1) of reliever_id_slots
num_actions = 1 + R
action_idx = np.zeros(B, dtype=np.int64)          # 0..R
next_pitcher_id = np.full(B, -1, dtype=np.int64)  # for logging / analysis

# Diagnostics
total_changes = 0
mapped_changes = 0
unmapped_counter = Counter()

# Attach stable original index
pa_with_idx = pa.reset_index().rename(columns={"index": "orig_idx"})
pa_sorted = pa_with_idx.sort_values(
    ["game_pk", "fielding_team", "inning", "half", "at_bat_number"]
)

for (g, team, half), gdf in pa_sorted.groupby(
    ["game_pk", "fielding_team", "half"], sort=False
):
    # gdf: all PAs for one (game_pk, fielding_team, half) in time order
    gdf = gdf.reset_index(drop=True)
    orig_idx_group = gdf["orig_idx"].to_numpy()
    local_n = len(gdf)

    for local_i in range(local_n):
        global_i = int(orig_idx_group[local_i])

        # If this is the last PA in this (game, team, half), no future to roll over
        if local_i >= local_n - 1:
            reward_folded[global_i] = 0.0
            next_idx_arr[global_i] = global_i
            done_arr[global_i] = bool(gdf.loc[local_i, "game_over"])
            action_idx[global_i] = 0
            next_pitcher_id[global_i] = -1
            continue

        cur_pitcher = gdf.loc[local_i, "pitcher_on_mound"]
        next_pitcher = gdf.loc[local_i + 1, "pitcher_on_mound"]

        pulled_real = cur_pitcher != next_pitcher

        # Default: stay with current pitcher
        act = 0
        mapped_pull = False

        if pulled_real:
            total_changes += 1

            # Map actual next pitcher into one of the bullpen slots at this PA
            relievers_here = reliever_id_slots[global_i]  # shape [R]
            slots = np.where(relievers_here == next_pitcher)[0]

            if len(slots) > 0:
                # Slot index 0..R-1 -> action 1..R
                act = int(slots[0]) + 1
                mapped_pull = True
                mapped_changes += 1
            else:
                # Next pitcher not in our modeled R relievers
                unmapped_counter[int(next_pitcher)] += 1

            next_pitcher_id[global_i] = int(next_pitcher)
        else:
            next_pitcher_id[global_i] = -1

        action_idx[global_i] = act

        # Only penalize a pull in the reward if we actually modeled it as a pull
        pulled_for_reward = mapped_pull

        # Fold SMDP reward over up to 3 batters using run expectancy deltas
        R_val, local_next_idx, done_flag = fold_smdp_reward(
            pa_seq=gdf,
            start_idx=local_i,
            gamma=GAMMA,
            penalty_pull=PENALTY_PULL,
            pulled=pulled_for_reward,
            max_horizon=3,
            delta_re_col="delta_re_pa",
        )

        reward_folded[global_i] = R_val
        global_next_idx = int(orig_idx_group[local_next_idx])
        next_idx_arr[global_i] = global_next_idx
        done_arr[global_i] = done_flag

print("reward_folded shape:", reward_folded.shape)
print("unique action_idx:", np.unique(action_idx))

print("Number of pitcher changes detected:", total_changes)
print("Mapped to reliever slots:", mapped_changes)

reward_folded shape: (407660,)
unique action_idx: [ 0  1  2  3  4  5  6  7  8  9 10]
Number of pitcher changes detected: 37010
Mapped to reliever slots: 24032


## 11. Build core state vectors for RL

We construct compact PA-level state vectors and corresponding next-state vectors, using PA features that summarize game state and pitcher fatigue.

In [32]:
state_cols = [
    "inning",
    "half",
    "outs",
    "base_state",
    "score_diff",
    "pitch_count",
    "tto",
    # Handedness / matchup features
    "is_platoon_advantage",
    "batter_is_left",
    "batter_is_right",
    "batter_is_switch",
    "pitcher_is_left",
    "pitcher_is_right",
]

state_vec = pa[state_cols].to_numpy(dtype=np.float32)
next_state_vec = state_vec[next_idx_arr]

print("state_vec shape:", state_vec.shape)
print("next_state_vec shape:", next_state_vec.shape)

state_vec shape: (407660, 13)
next_state_vec shape: (407660, 13)


## 12. Save all tensors to disk

We store everything needed for RL training in a single `.npz` file under `data/processed/`.
This file is what the `src/rl` will load.

In [33]:
outfile = PROC_DIR / f"rl_tensors_{YEAR_TAG}.npz"
np.savez_compressed(
    outfile,
    # State
    state_vec=state_vec,
    next_state_vec=next_state_vec,
    next_hitters_feats=next_hitters_feats,
    pos_enc=pos_enc,
    reliever_feats=reliever_feats,
    avail_mask=avail_mask,
    reliever_id_slots=reliever_id_slots,

    # Actions & rewards
    action_idx=action_idx,            # 0 = stay, 1..R = specific reliever slot
    next_pitcher_id=next_pitcher_id,  # for debugging / interpretation
    reward_folded=reward_folded,
    done=done_arr,

    # Bookkeeping
    pa_index=np.arange(B),
)
outfile

PosixPath('../data/processed/rl_tensors_2022_2023.npz')

## 13. Save PA table with rewards & actions

This Parquet is handy for sanity-checking the dataset and doing EDA.

In [34]:
pa_export = pa.copy()
pa_export["reward_folded"] = reward_folded
pa_export["next_state_idx"] = next_idx_arr
pa_export["action_idx"] = action_idx
pa_export["next_pitcher_id"] = next_pitcher_id

pa_export_file = PROC_DIR / f"pa_decisions_{YEAR_TAG}.parquet"
pa_export.to_parquet(pa_export_file, index=False)
pa_export_file

PosixPath('../data/processed/pa_decisions_2022_2023.parquet')