In [None]:
# Install PyTorch and a convenient MLB stats data source library
!pip -q install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
!pip -q install pybaseball
!pip -q install scikit-learn


[notice] A new release of pip is available: 25.1.1 -> 26.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [3]:
# Pull a small Statcast sample via pybaseball.statcast() for initial exploration
import pandas as pd
from pybaseball import statcast

# A small window keeps the pull fast; adjust later once we validate the pipeline
start_date = "2024-04-01"
end_date = "2024-04-07"

sc = statcast(start_dt=start_date, end_dt=end_date)

print("Statcast sample shape:", sc.shape)
print("Columns (first 40):", list(sc.columns[:40]))
display(sc.head(5))

This is a large query, it may take a moment to complete


  data_copy[column] = data_copy[column].apply(pd.to_datetime, errors='ignore', format=date_format)
  data_copy[column] = data_copy[column].apply(pd.to_datetime, errors='ignore', format=date_format)
  data_copy[column] = data_copy[column].apply(pd.to_datetime, errors='ignore', format=date_format)
  data_copy[column] = data_copy[column].apply(pd.to_datetime, errors='ignore', format=date_format)
  data_copy[column] = data_copy[column].apply(pd.to_datetime, errors='ignore', format=date_format)
  data_copy[column] = data_copy[column].apply(pd.to_datetime, errors='ignore', format=date_format)
  data_copy[column] = data_copy[column].apply(pd.to_datetime, errors='ignore', format=date_format)
  data_copy[column] = data_copy[column].apply(pd.to_datetime, errors='ignore', format=date_format)
  data_copy[column] = data_copy[column].apply(pd.to_datetime, errors='ignore', format=date_format)
100%|██████████| 7/7 [00:29<00:00,  4.23s/it]


Statcast sample shape: (26072, 118)
Columns (first 40): ['pitch_type', 'game_date', 'release_speed', 'release_pos_x', 'release_pos_z', 'player_name', 'batter', 'pitcher', 'events', 'description', 'spin_dir', 'spin_rate_deprecated', 'break_angle_deprecated', 'break_length_deprecated', 'zone', 'des', 'game_type', 'stand', 'p_throws', 'home_team', 'away_team', 'type', 'hit_location', 'bb_type', 'balls', 'strikes', 'game_year', 'pfx_x', 'pfx_z', 'plate_x', 'plate_z', 'on_3b', 'on_2b', 'on_1b', 'outs_when_up', 'inning', 'inning_topbot', 'hc_x', 'hc_y', 'tfs_deprecated']


Unnamed: 0,pitch_type,game_date,release_speed,release_pos_x,release_pos_z,player_name,batter,pitcher,events,description,...,batter_days_until_next_game,api_break_z_with_gravity,api_break_x_arm,api_break_x_batter_in,arm_angle,attack_angle,attack_direction,swing_path_tilt,intercept_ball_minus_batter_pos_x_inches,intercept_ball_minus_batter_pos_y_inches
2621,CU,2024-04-07,85.2,-2.31,6.16,"Johnson, Pierce",664983,572955,field_out,hit_into_play,...,2,3.78,-0.73,0.73,36.0,21.26133,-9.620944,34.858224,44.484484,39.961796
2734,CU,2024-04-07,86.0,-2.33,6.15,"Johnson, Pierce",664983,572955,,called_strike,...,2,3.52,-0.78,0.78,41.6,,,,,
2793,CU,2024-04-07,85.8,-2.27,6.2,"Johnson, Pierce",553993,572955,field_out,hit_into_play,...,1,3.49,-0.91,-0.91,42.2,14.211988,7.795369,34.318302,45.151528,27.733033
2878,CU,2024-04-07,86.5,-2.3,6.24,"Johnson, Pierce",572233,572955,single,hit_into_play,...,1,3.39,-0.9,-0.9,38.3,4.061736,9.400294,24.872673,33.739691,28.735222
3004,CU,2024-04-07,86.1,-2.18,6.16,"Johnson, Pierce",572233,572955,,ball,...,1,3.51,-1.14,-1.14,40.2,,,,,


In [4]:
# Basic EDA: define a whiff label from Statcast pitch `description` and inspect data quality
import numpy as np
import pandas as pd

# Peek at the pitch outcome taxonomy present in this sample
print("Unique description count:", sc['description'].nunique(dropna=True))
print("Top 20 descriptions:\n", sc['description'].value_counts(dropna=False).head(20))

# Define a simple whiff label: swinging strike variants
whiff_descriptions = {
    'swinging_strike',
    'swinging_strike_blocked',
}
sc = sc.copy()
sc['is_whiff'] = sc['description'].isin(whiff_descriptions).astype('int8')

print("\nWhiff rate overall:", sc['is_whiff'].mean().round(4), "(n=", len(sc), ")")

# Check missingness for a minimal, interpretable feature set
feature_cols = [
    'pitch_type','release_speed','pfx_x','pfx_z','plate_x','plate_z',
    'balls','strikes','stand','p_throws','release_pos_x','release_pos_z'
]
missing = (sc[feature_cols].isna().mean().sort_values(ascending=False))
print("\nMissingness (fraction) for candidate features:\n", missing)

# Show a few raw rows including label + key features
display(sc[feature_cols + ['description','is_whiff']].sample(8, random_state=7))

# Quick numeric summaries to spot outliers / weird ranges
num_cols = ['release_speed','pfx_x','pfx_z','plate_x','plate_z','release_pos_x','release_pos_z']
display(sc[num_cols].describe(percentiles=[.01,.05,.5,.95,.99]).T)

Unique description count: 15
Top 20 descriptions:
 description
ball                       8723
foul                       4556
hit_into_play              4452
called_strike              4410
swinging_strike            2644
blocked_ball                633
foul_tip                    295
swinging_strike_blocked     135
hit_by_pitch                 83
automatic_ball               83
foul_bunt                    40
automatic_strike             10
missed_bunt                   6
bunt_foul_tip                 1
pitchout                      1
Name: count, dtype: int64

Whiff rate overall: 0.1066 (n= 26072 )

Missingness (fraction) for candidate features:
 pitch_type       0.003644
release_speed    0.003567
pfx_x            0.003567
pfx_z            0.003567
plate_x          0.003567
plate_z          0.003567
release_pos_x    0.003567
release_pos_z    0.003567
strikes          0.000000
balls            0.000000
p_throws         0.000000
stand            0.000000
dtype: float64


Unnamed: 0,pitch_type,release_speed,pfx_x,pfx_z,plate_x,plate_z,balls,strikes,stand,p_throws,release_pos_x,release_pos_z,description,is_whiff
1225,KC,80.6,0.07,-0.24,-0.35,1.26,0,2,R,L,2.06,6.3,ball,0
670,FF,91.3,0.63,1.38,-0.38,3.8,3,2,L,L,1.73,5.51,swinging_strike,1
3281,FF,91.7,0.02,1.68,-0.21,2.66,0,0,L,R,-0.22,6.92,hit_into_play,0
958,FF,92.7,0.16,0.93,0.52,2.03,2,2,R,L,1.81,5.83,foul,0
685,KC,75.2,-0.91,-0.7,0.18,0.61,0,1,R,L,2.45,6.5,swinging_strike,1
1482,FF,93.6,-0.46,1.04,0.44,2.98,3,0,R,R,-2.34,5.76,called_strike,0
3312,ST,80.5,1.32,0.37,0.07,3.11,2,2,R,R,-1.04,6.19,foul_tip,0
2716,SI,91.6,-1.47,0.3,-0.86,2.9,1,0,R,R,-4.38,4.24,foul,0


Unnamed: 0,count,mean,std,min,1%,5%,50%,95%,99%,max
release_speed,25979.0,88.714862,5.866247,39.9,74.3,78.6,89.5,96.8,98.7,102.4
pfx_x,25979.0,-0.121336,0.921258,-2.95,-1.69,-1.48,-0.16,1.39,1.65,2.4
pfx_z,25979.0,0.610208,0.709276,-1.85,-1.3,-0.67,0.63,1.59,1.74,2.14
plate_x,25979.0,0.055496,0.828552,-3.79,-1.87,-1.31,0.05,1.41,2.03,3.81
plate_z,25979.0,2.291348,0.973069,-2.82,-0.1,0.68,2.31,3.86,4.54,8.0
release_pos_x,25979.0,-0.810267,1.845483,-4.51,-3.53,-2.91,-1.48,2.611,3.33,4.56
release_pos_z,25979.0,5.794603,0.519465,0.95,4.31,4.97,5.83,6.54,6.85,7.3


In [5]:
# Build a clean modeling table for a whiff model (features + label) and make a time-based split
import numpy as np
import pandas as pd

# Ensure game_date is datetime
sc_model = sc.copy()
sc_model['game_date'] = pd.to_datetime(sc_model['game_date'])

# Define features + target
feature_cols = [
    'pitch_type','release_speed','pfx_x','pfx_z','plate_x','plate_z',
    'balls','strikes','stand','p_throws','release_pos_x','release_pos_z'
]

target_col = 'is_whiff'

# Keep only rows with non-missing target and required features
needed_cols = feature_cols + [target_col, 'game_date']
df_model = sc_model[needed_cols].copy()

# Drop missing rows
before = len(df_model)
df_model = df_model.dropna(subset=feature_cols + [target_col])
after = len(df_model)
print(f"Rows before dropna: {before:,} | after: {after:,} | dropped: {before-after:,} ({(before-after)/before:.2%})")

# Cast types for categorical columns
cat_cols = ['pitch_type', 'stand', 'p_throws']
for c in cat_cols:
    df_model[c] = df_model[c].astype('category')

df_model[target_col] = df_model[target_col].astype('int8')

# Sort by date and make a simple time-based split (last day as test)
df_model = df_model.sort_values('game_date').reset_index(drop=True)
last_date = df_model['game_date'].max()
test_mask = df_model['game_date'].eq(last_date)

print("Date range:", df_model['game_date'].min().date(), "→", df_model['game_date'].max().date())
print("Train rows:", (~test_mask).sum(), "| Test rows:", test_mask.sum())
print("Train whiff rate:", df_model.loc[~test_mask, target_col].mean().round(4))
print("Test  whiff rate:", df_model.loc[test_mask, target_col].mean().round(4))

display(df_model.head())

# Expose split indices for downstream modeling cells
train_idx = np.flatnonzero(~test_mask)
test_idx = np.flatnonzero(test_mask)
print("train_idx/test_idx ready.")

Rows before dropna: 26,072 | after: 25,977 | dropped: 95 (0.36%)
Date range: 2024-04-01 → 2024-04-07
Train rows: 21938 | Test rows: 4039
Train whiff rate: 0.1083
Test  whiff rate: 0.1


Unnamed: 0,pitch_type,release_speed,pfx_x,pfx_z,plate_x,plate_z,balls,strikes,stand,p_throws,release_pos_x,release_pos_z,is_whiff,game_date
0,FF,97.3,0.71,1.55,0.16,1.75,0,0,R,L,2.14,5.78,0,2024-04-01
1,FF,94.0,-1.08,1.24,-1.17,2.7,1,0,L,R,-1.43,6.08,0,2024-04-01
2,FF,93.9,-0.82,1.31,-0.58,3.68,2,0,L,R,-1.36,6.17,0,2024-04-01
3,FF,93.7,-0.71,1.22,0.86,2.95,3,0,L,R,-1.28,6.16,0,2024-04-01
4,FF,92.9,-0.58,1.2,0.81,1.29,3,1,L,R,-1.21,6.1,0,2024-04-01


train_idx/test_idx ready.


In [15]:
# Train a simple baseline model (logistic regression) to predict whiff probability
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, log_loss, brier_score_loss, classification_report

# Split data
X = df_model[feature_cols]
y = df_model[target_col].astype(int)
X_train, y_train = X.iloc[train_idx], y.iloc[train_idx]
X_test, y_test = X.iloc[test_idx], y.iloc[test_idx]

cat_cols = ['pitch_type', 'stand', 'p_throws']
num_cols = [c for c in feature_cols if c not in cat_cols]

# Preprocessing: impute (just in case), scale numerics, one-hot encode categoricals
preprocess = ColumnTransformer(
    transformers=[
        (
            'num',
            Pipeline(steps=[
                ('imputer', SimpleImputer(strategy='median')),
                ('scaler', StandardScaler()),
            ]),
            num_cols,
        ),
        (
            'cat',
            Pipeline(steps=[
                ('imputer', SimpleImputer(strategy='most_frequent')),
                ('onehot', OneHotEncoder(handle_unknown='ignore')),
            ]),
            cat_cols,
        ),
    ],
    remainder='drop',
)

clf = LogisticRegression(
    max_iter=2000,
    class_weight='balanced',
    n_jobs=None,
    solver='lbfgs',
)

pipe = Pipeline(steps=[('preprocess', preprocess), ('model', clf)])

pipe.fit(X_train, y_train)

# Probabilities + metrics
yhat_train = pipe.predict_proba(X_train)[:, 1]
yhat_test = pipe.predict_proba(X_test)[:, 1]

metrics = {
    'train_auc': roc_auc_score(y_train, yhat_train),
    'test_auc': roc_auc_score(y_test, yhat_test),
    'train_logloss': log_loss(y_train, yhat_train, labels=[0, 1]),
    'test_logloss': log_loss(y_test, yhat_test, labels=[0, 1]),
    'train_brier': brier_score_loss(y_train, yhat_train),
    'test_brier': brier_score_loss(y_test, yhat_test),
}

print("Baseline logistic regression metrics:")
for k, v in metrics.items():
    print(f"  {k:>14}: {v:.4f}")

# Quick threshold report at 0.5 (not necessarily optimal; mostly for sanity)
print("\nClassification report on test (threshold=0.5):")
print(classification_report(y_test, (yhat_test >= 0.5).astype(int), digits=3))

# Expose for downstream use
baseline_pipe = pipe
baseline_yhat_test = yhat_test

Baseline logistic regression metrics:
       train_auc: 0.6210
        test_auc: 0.6081
   train_logloss: 0.6710
    test_logloss: 0.6648
     train_brier: 0.2391
      test_brier: 0.2361

Classification report on test (threshold=0.5):
              precision    recall  f1-score   support

           0      0.922     0.608     0.733      3635
           1      0.132     0.537     0.212       404

    accuracy                          0.601      4039
   macro avg      0.527     0.573     0.473      4039
weighted avg      0.843     0.601     0.681      4039

