# Analysis
Let's compare the results of the computed optimal policy to the actual OBP of different states.

In [1]:
from collections import defaultdict

import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook as tqdm

from src.data.data_loading import BaseballData
from src.model.at_bat import AtBatState, AtBatOutcome
from src.model.players import Pitcher, Batter
from src.policy.optimal_policy import calculate_swing_outcome_distribution, calculate_batter_patience_distribution, calculate_pitcher_control_distribution, precalculate_transition_distribution, calculate_optimal_policy

In [2]:
bd = BaseballData.load_with_cache()

Loading baseball data from cache... done


In [3]:
conditions = {
    'strong pitcher vs. weak batter': lambda p: p.at_bat.pitcher.obp_percentile > 0.75 and p.at_bat.batter.obp_percentile < 1/2,
    'weak pitcher vs. strong batter': lambda p: p.at_bat.pitcher.obp_percentile < 0.25 and p.at_bat.batter.obp_percentile > 1/2,
}

states = set([AtBatState(balls=balls, strikes=strikes) for balls in range(4) for strikes in range(3)])

empirical_obp: dict[str, dict[AtBatState, float]] = defaultdict(lambda: defaultdict(float))
matches: dict[str, dict[AtBatState, float]] = defaultdict(lambda: defaultdict(float))

matchups: dict[str, set[tuple[Pitcher, Batter]]] = defaultdict(set)
pitchers = set()
batters = set()

for idx, pitch in enumerate(bd.pitches):
    for condition_name, condition in conditions.items():
        if condition(pitch) and pitch.at_bat_state in states:
            matches[condition_name][pitch.at_bat_state] += 1
            empirical_obp[condition_name][pitch.at_bat_state] += pitch.at_bat.state.outcome == AtBatOutcome.BASE
            matchups[condition_name].add((pitch.at_bat.pitcher, pitch.at_bat.batter))
            pitchers.add(pitch.at_bat.pitcher)
            batters.add(pitch.at_bat.batter)
for condition_name in conditions:
    for state in states:
        if matches[condition_name][state] > 0:
            empirical_obp[condition_name][state] /= matches[condition_name][state]

In [7]:
# This is causing me serious memory issues 
swing_outcome_distribution = calculate_swing_outcome_distribution([matchup for condition in conditions for matchup in matchups[condition]])

Calculating swing outcomes:   0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

0.178355

In [None]:
pitcher_control_distribution = calculate_pitcher_control_distribution(list(pitchers))

In [None]:
batter_patience_distribution = calculate_batter_patience_distribution(list(batters))

In [None]:
predicted_obp: dict[str, dict[AtBatState, float]] = defaultdict(lambda: defaultdict(float))
for condition_name in conditions:
    for pitcher, batter in tqdm(matchups[condition_name]):
        transition_distribution = precalculate_transition_distribution(pitcher, batter, swing_outcome_distribution[(pitcher, batter)], pitcher_control_distribution[pitcher], batter_patience_distribution[batter])
        policy, values = calculate_optimal_policy(pitcher, batter, transition_distribution)
        
        # The values[AtBatState] corresponds to the OBP at that point
        for state in states:
            predicted_obp[condition_name][state] += values[state]

In [None]:
sorted_states = sorted(states, key=lambda s: (s.balls, s.strikes))

fig, ax = plt.subplots(figsize=(12, 6))

bar_width = 0.35
index = np.arange(len(states))

for i, condition in enumerate(conditions):
    empirical_obp_values = [empirical_obp[condition][state] for state in sorted_states]
    predicted_obp_values = [predicted_obp[condition][state] for state in sorted_states]
    
    bar_position = index + i * bar_width
    ax.bar(bar_position, empirical_obp_values, bar_width, label=f'Empirical {condition}')
    ax.bar(bar_position, predicted_obp_values, bar_width, label=f'Predicted {condition}', alpha=0.5, hatch='\\')

# Adding labels and titles
ax.set_xlabel('At-Bat States (Balls, Strikes)')
ax.set_ylabel('OBP')
ax.set_title('Empirical and Predicted OBP by At-Bat State and Condition')
ax.set_xticks(index + bar_width / 2)
ax.set_xticklabels([f'{state.balls}-{state.strikes}' for state in sorted_states])
ax.legend()

# Display the plot
plt.xticks(rotation=45)
plt.ylim(0, 1)
plt.tight_layout()
plt.show()