## Nested Logit Model for Baseball Strike Prediction

We model the specific strike event (caught ball in the air) as a **three-stage nested decision process**:

1. **Pitcher Decision**  
   The pitcher chooses a pitch type $$( P_{type} \in \{ \text{fastball}, \text{curveball}, \text{slider} \} )$$. This affects the ball's characteristics such as velocity $v$ and spin $s$:  
   $$  
   v, s \sim \mathcal{N}(\mu_P, \sigma_P^2)  
   $$

2. **Hitter Decision (Swing or Not Swing)**  
   The hitter decides whether to swing based on pitch characteristics. The probability of swinging is modeled with logistic regression:  
   $$  
   \Pr(\text{swing} = 1 \mid v, s, P) = \frac{1}{1 + \exp(-\beta_0 - \beta_1 v - \beta_2 s - \beta_3 P)}  
   $$

3. **Ball Location and Strike Outcome**  
   Given the pitcher’s decision and hitter’s action, the ball location $\mathbf{x} = (x, y)$ is determined, influenced by both players' decisions and randomness:  
   $$  
   \mathbf{x} \sim \mathcal{N}(\mu_{P, \text{swing}}, \Sigma)  
   $$

   The ball can be caught by fielders at fixed positions $\mathbf{f}_i$, $i=1,\dots,4$, each with a catch radius $r$.

   The probability of a strike (caught ball) depends on proximity to fielders and ball characteristics:  
   $$  
   \Pr(\text{strike} = 1 \mid \mathbf{x}, \text{airball}) = \frac{1}{1 + \exp\left(-\alpha_0 - \alpha_1 \min_i \|\mathbf{x} - \mathbf{f}_i\| - \alpha_2 \cdot \text{airball} \right)}  
   $$




- We simulate pitch characteristics based on the pitcher’s choice.
- The hitter’s swing decision depends on pitch features.
- The ball’s location depends on pitcher and hitter decisions.
- Strikes occur if the ball is caught within a fielder’s radius.
- We estimate a nested logit model to capture this hierarchy, predicting swing and strike probabilities.
- The spatial component (location on the field) and fielder proximity crucially influence strike likelihood (key assumption)



## Generate data and clean data

In [36]:
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import OneHotEncoder
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt
from matplotlib import animation

np.random.seed(42)

#  Simulate pitcher decisions
n = 10000
pitch_types = ['fastball', 'curveball', 'slider']
pitcher_decision = np.random.choice(pitch_types, size=n)

velocity_mean = {'fastball': 95, 'curveball': 78, 'slider': 85}
spin_mean = {'fastball': 2400, 'curveball': 2800, 'slider': 2600}
velocities = np.array([np.random.normal(velocity_mean[p], 2) for p in pitcher_decision])
spins = np.array([np.random.normal(spin_mean[p], 100) for p in pitcher_decision])

#  Simulate hitter swing (depends on pitcher decision & velocity, spin)
swing_probs = 1 / (1 + np.exp(-0.15 * (velocities - 88) + 0.003 * (spins - 2500)))
swing = np.random.binomial(1, swing_probs)

#  Simulate hitter quality of contact (hit_weak: 1=weak hit, 0=strong hit)
# Weak hit more likely if swing = 1 but velocity lower
hit_weak_probs = 1 / (1 + np.exp(0.2 * (velocities - 90) - 2 * swing))
hit_weak = np.random.binomial(1, hit_weak_probs)

#  Simulate ball in air or not (depends on pitch type, swing, hit quality)
airball_prob_by_pitch = {'fastball': 0.7, 'curveball': 0.5, 'slider': 0.6}
in_air = np.array([
    np.random.binomial(1, airball_prob_by_pitch[p] * s * (1 - hw*0.7))  # weaker hits less likely to be in air
    for p, s, hw in zip(pitcher_decision, swing, hit_weak)
])

#  Simulate ball location depending on pitcher decision + hitter contact quality (and some noise)
loc_means = {
    'fastball': [0.5, 0.6],
    'curveball': [0.3, 0.5],
    'slider': [0.6, 0.4]
}
loc_stds = 0.12

locations = np.zeros((n, 2))
for i in range(n):
    base_x, base_y = loc_means[pitcher_decision[i]]
    # Weak hits tend to be lower/y closer to infield (smaller y)
    y_adj = -0.15 if hit_weak[i] == 1 else 0
    loc_x = np.random.normal(base_x, loc_stds)
    loc_y = np.random.normal(base_y + y_adj, loc_stds)
    locations[i, 0] = np.clip(loc_x, 0, 1)
    locations[i, 1] = np.clip(loc_y, 0, 1)

#  Define realistic baseball diamond fielders (4 fielders: LF, CF, RF, SS)
fielder_positions = np.array([
    [0.2, 0.8],  # Left fielder
    [0.5, 0.85],  # Center fielder
    [0.8, 0.8],  # Right fielder
    [0.45, 0.45]  # Shortstop
])
catch_radius = 0.18
distances = cdist(locations, fielder_positions)
min_distances = distances.min(axis=1)

#  Strike = caught ball in air within catch radius
strike = (in_air == 1) & (min_distances <= catch_radius)
strike = strike.astype(int)

#  Build DataFrame
df = pd.DataFrame({
    'pitcher_decision': pitcher_decision,
    'velocity': velocities,
    'spin': spins,
    'swing': swing,
    'hit_weak': hit_weak,
    'in_air': in_air,
    'x': locations[:, 0],
    'y': locations[:, 1],
    'min_dist_to_fielder': min_distances,
    'strike': strike
})

#  One-hot encode pitcher decision
ohe = OneHotEncoder(sparse_output=False)
pitch_encoded = ohe.fit_transform(df[['pitcher_decision']])

## Fit nested logit models

In [37]:
# Stage 1: Swing model
swing_features = np.hstack([
    pitch_encoded,
    df[['velocity', 'spin']].values
])
swing_model = LogisticRegression(max_iter=1000)
swing_model.fit(swing_features, df['swing'])
swing_pred_probs = swing_model.predict_proba(swing_features)[:, 1]

# Stage 2: Hit quality (hit_weak) model
hit_weak_features = np.hstack([
    pitch_encoded,
    df[['velocity', 'spin', 'swing']].values
])
hit_weak_model = LogisticRegression(max_iter=1000)
hit_weak_model.fit(hit_weak_features, df['hit_weak'])
hit_weak_pred_probs = hit_weak_model.predict_proba(hit_weak_features)[:, 1]

# Stage 3: Strike model (conditioned on pitcher, ball location, and ball traits)
strike_features = np.hstack([
    pitch_encoded,
    df[['velocity', 'spin', 'swing', 'hit_weak', 'in_air', 'x', 'y', 'min_dist_to_fielder']].values
])
strike_model = LogisticRegression(max_iter=1000)
strike_model.fit(strike_features, df['strike'])
strike_pred_probs = strike_model.predict_proba(strike_features)[:, 1]

#  Evaluation metrics
from sklearn.metrics import log_loss, brier_score_loss, f1_score

metrics = {
    "Swing Model": {
        "Log-Loss": log_loss(df['swing'], swing_pred_probs),
        "Brier Score": brier_score_loss(df['swing'], swing_pred_probs),
        "F1 Score": f1_score(df['swing'], swing_model.predict(swing_features))
    },
    "Hit Weak Model": {
        "Log-Loss": log_loss(df['hit_weak'], hit_weak_pred_probs),
        "Brier Score": brier_score_loss(df['hit_weak'], hit_weak_pred_probs),
        "F1 Score": f1_score(df['hit_weak'], hit_weak_model.predict(hit_weak_features))
    },
    "Strike Model": {
        "Log-Loss": log_loss(df['strike'], strike_pred_probs),
        "Brier Score": brier_score_loss(df['strike'], strike_pred_probs),
        "F1 Score": f1_score(df['strike'], strike_model.predict(strike_features))
    }
}

print("Evaluation Metrics:")
for model_name, m in metrics.items():
    print(f"\n{model_name}:")
    for metric_name, value in m.items():
        print(f"  {metric_name}: {value:.4f}")




Evaluation Metrics:

Swing Model:
  Log-Loss: 0.4650
  Brier Score: 0.1488
  F1 Score: 0.7323

Hit Weak Model:
  Log-Loss: 0.4456
  Brier Score: 0.1433
  F1 Score: 0.8854

Strike Model:
  Log-Loss: 0.0453
  Brier Score: 0.0131
  F1 Score: 0.9037


## Animation with heatmap of predicted strike probability per pitch type

In [38]:
fig, ax = plt.subplots(figsize=(7, 7))

# Field shape - draw baseball diamond (home plate at bottom center)
def draw_field(ax):
    # Bases
    base_coords = np.array([
        [0.5, 0.1],   # Home plate
        [0.6, 0.2],   # 1st base
        [0.5, 0.3],   # 2nd base
        [0.4, 0.2],   # 3rd base
        [0.5, 0.1]    # Close loop home
    ])
    ax.plot(base_coords[:,0], base_coords[:,1], color='brown', lw=2)
    ax.scatter(base_coords[:-1,0], base_coords[:-1,1], s=100, color='brown')

    # Fielder circles (line only)
    for pos in fielder_positions:
        circle = plt.Circle(pos, catch_radius, fill=False, edgecolor='green', lw=2)
        ax.add_patch(circle)

    ax.set_xlim(0,1)
    ax.set_ylim(0,1)
    ax.set_aspect('equal')
    ax.set_title("Predicted Strike Probability Heatmap with Strike Locations")
    ax.set_xlabel("Field X")
    ax.set_ylabel("Field Y")

# Prepare grid to predict strike probabilities on
grid_res = 50
x_grid = np.linspace(0,1,grid_res)
y_grid = np.linspace(0,1,grid_res)
xx, yy = np.meshgrid(x_grid, y_grid)
grid_points = np.column_stack([xx.ravel(), yy.ravel()])

# We will create heatmaps for each pitch type
pitch_type_indices = {pt: i for i, pt in enumerate(ohe.categories_[0])}

# Prepare animation frames
frames = []

def update(frame_idx):
    ax.clear()
    draw_field(ax)
    pt = pitch_types[frame_idx]
    idx = pitch_type_indices[pt]

    # Construct features for prediction on grid (assume avg velocity, spin)
    avg_vel = velocity_mean[pt]
    avg_spin = spin_mean[pt]
    pitch_onehot = np.zeros((grid_points.shape[0], len(pitch_types)))
    pitch_onehot[:, idx] = 1

    # For grid prediction, assume swing=1, hit_weak=0, in_air=1 (worst case for strike)
    swing_grid = np.ones((grid_points.shape[0],1))
    hit_weak_grid = np.zeros((grid_points.shape[0],1))
    in_air_grid = np.ones((grid_points.shape[0],1))

    # Distance to closest fielder per grid point
    dist_grid = cdist(grid_points, fielder_positions).min(axis=1).reshape(-1,1)

    # Build strike features for prediction
    strike_feats_grid = np.hstack([
        pitch_onehot,
        np.full((grid_points.shape[0],1), avg_vel),
        np.full((grid_points.shape[0],1), avg_spin),
        swing_grid,
        hit_weak_grid,
        in_air_grid,
        grid_points,
        dist_grid
    ])

    pred_strike_probs = strike_model.predict_proba(strike_feats_grid)[:,1].reshape(grid_res, grid_res)

    # Plot heatmap
    im = ax.imshow(pred_strike_probs, origin='lower', extent=(0,1,0,1),
                   cmap='coolwarm', alpha=0.8)

    # Plot actual strikes for current pitch type
    mask = df['pitcher_decision'] == pt
    strikes_x = df.loc[mask & (df['strike']==1), 'x']
    strikes_y = df.loc[mask & (df['strike']==1), 'y']
    ax.scatter(strikes_x, strikes_y, color='black', s=10, label='Actual Strikes')

    ax.legend(loc='upper right')
    ax.set_title(f"Strike Probability Heatmap - Pitch: {pt}")

    return [im]

anim = animation.FuncAnimation(fig, update, frames=len(pitch_types), interval=2000, blit=False, repeat=True)

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
HTML(anim.to_jshtml())