<a href="https://colab.research.google.com/github/Blandalytics/baseball_snippets/blob/main/Strikezone_Heatmap.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [10]:
def zone_framework():
  # Create a DF that contains all possible horizontal/vertical zone combinations
  zone_df = pd.DataFrame()

  # Specify the zone limits (10 zones = -4 to 5)
  zone_start = -4
  zone_cutoff = 5

  # Create a 10 x 10 grid
  for x in range(zone_start,zone_cutoff+1):
    for y in range(zone_start,zone_cutoff+1):
      zone_df = zone_df.append({'h_zone':int(x),'v_zone':int(y)}, ignore_index=True)

  zone_df = zone_df.reset_index()

  # Convert each horizontal/vertical combo into Statcast Zone classifications (9 in-zone groups, 4 out-of-zone)
  # In-Zone groups
  zone_df['statcast_zone'] = None
  zone_df.loc[zone_df['index'].isin(list(range(22,24))+list(range(32,34))),'statcast_zone'] = 1
  zone_df.loc[zone_df['index'].isin(list(range(42,44))+list(range(52,54))),'statcast_zone'] = 2
  zone_df.loc[zone_df['index'].isin(list(range(62,64))+list(range(72,74))),'statcast_zone'] = 3
  zone_df.loc[zone_df['index'].isin(list(range(24,26))+list(range(34,36))),'statcast_zone'] = 4
  zone_df.loc[zone_df['index'].isin(list(range(44,46))+list(range(54,56))),'statcast_zone'] = 5
  zone_df.loc[zone_df['index'].isin(list(range(64,66))+list(range(74,76))),'statcast_zone'] = 6
  zone_df.loc[zone_df['index'].isin(list(range(26,28))+list(range(36,38))),'statcast_zone'] = 7
  zone_df.loc[zone_df['index'].isin(list(range(46,48))+list(range(56,58))),'statcast_zone'] = 8
  zone_df.loc[zone_df['index'].isin(list(range(66,68))+list(range(76,78))),'statcast_zone'] = 9

  # Out-of-zone groups
  zone_df.loc[zone_df['index'].isin(list(range(0,5))+list(range(10,15))+list(range(20,22))+list(range(30,32))+list(range(40,42))),'statcast_zone'] = 11
  zone_df.loc[zone_df['index'].isin(list(range(50,52))+list(range(60,62))+list(range(70,72))+list(range(80,85))+list(range(90,95))),'statcast_zone'] = 12
  zone_df.loc[zone_df['index'].isin(list(range(5,10))+list(range(15,20))+list(range(28,30))+list(range(38,40))+list(range(48,50))),'statcast_zone'] = 13
  zone_df.loc[zone_df['index'].isin(list(range(58,60))+list(range(68,70))+list(range(78,80))+list(range(85,90))+list(range(95,100))),'statcast_zone'] = 14

  zone_df = zone_df.drop(columns=['index'])
  return zone_df

In [11]:
zone_framework()

Unnamed: 0,h_zone,v_zone,statcast_zone
0,-4.0,-4.0,11
1,-4.0,-3.0,11
2,-4.0,-2.0,11
3,-4.0,-1.0,11
4,-4.0,0.0,11
...,...,...,...
95,5.0,1.0,14
96,5.0,2.0,14
97,5.0,3.0,14
98,5.0,4.0,14


In [None]:
def prep_zone_locations(df, horizontal_location_column, vertical_location_column, 
                        strikezone_top_column='sz_top', strikezone_bottom_column='sz_bot'):
  ### 10x10 grid of all pitches, sized for each batter's SZ

  ## 10 Horizontal Zone Locations
  # -4 is left and out of strikezone (catcher's perspective)
  # 5 is right and out of strikezone (catcher's perspective)
  df['h_zone'] = 0
  df.loc[(df[horizontal_location_column]<-5/24) &
         (df[horizontal_location_column]>=-5/12),'h_zone'] = -1
  df.loc[(df[horizontal_location_column]<-5/12) &
         (df[horizontal_location_column]>=-15/24),'h_zone'] = -2
  df.loc[(df[horizontal_location_column]<-15/24) &
         (df[horizontal_location_column]>=-10/12),'h_zone'] = -3
  df.loc[(df[horizontal_location_column]<-10/12),'h_zone'] = -4

  df.loc[(df[horizontal_location_column]<5/24) &
         (df[horizontal_location_column]>=0),'h_zone'] = 1
  df.loc[(df[horizontal_location_column]>5/24) &
         (df[horizontal_location_column]<=5/12),'h_zone'] = 2
  df.loc[(df[horizontal_location_column]>5/12) &
         (df[horizontal_location_column]<=15/24),'h_zone'] = 3
  df.loc[(df[horizontal_location_column]>15/24) &
         (df[horizontal_location_column]<=10/12),'h_zone'] = 4
  df.loc[(df[horizontal_location_column]>10/12),'h_zone'] = 5

  ## 10 Vertical Zone Locations (dynamic, based on batter)
  # -4 is below & out of strikezone 
  # 5 is above and out of strikezone
  df['v_zone'] = 0

  df.loc[(df[vertical_location_column]<=(df[strikezone_top_column]-df[strikezone_bottom_column])/3+df[strikezone_bottom_column]) &
         (df[vertical_location_column]>(df[strikezone_top_column]-df[strikezone_bottom_column])/6+df[strikezone_bottom_column]),'v_zone'] = -1
  df.loc[(df[vertical_location_column]<=(df[strikezone_top_column]-df[strikezone_bottom_column])/6+df[strikezone_bottom_column]) &
         (df[vertical_location_column]>df[strikezone_bottom_column]),'v_zone'] = -2
  df.loc[(df[vertical_location_column]<=df[strikezone_bottom_column]) &
         (df[vertical_location_column]>df[strikezone_bottom_column]-(df[strikezone_top_column]-df[strikezone_bottom_column])/6),'v_zone'] = -3
  df.loc[(df[vertical_location_column]<=df[strikezone_bottom_column]-(df[strikezone_top_column]-df[strikezone_bottom_column])/6),'v_zone'] = -4

  df.loc[(df[vertical_location_column]>df[strikezone_top_column]-(df[strikezone_top_column]-df[strikezone_bottom_column])/2) &
         (df[vertical_location_column]<=df[strikezone_top_column]-(df[strikezone_top_column]-df[strikezone_bottom_column])/3),'v_zone'] = 1
  df.loc[(df[vertical_location_column]>df[strikezone_top_column]-(df[strikezone_top_column]-df[strikezone_bottom_column])/3) &
         (df[vertical_location_column]<=df[strikezone_top_column]-(df[strikezone_top_column]-df[strikezone_bottom_column])/6),'v_zone'] = 2
  df.loc[(df[vertical_location_column]>df[strikezone_top_column]-(df[strikezone_top_column]-df[strikezone_bottom_column])/6) &
         (df[vertical_location_column]<=df[strikezone_top_column]),'v_zone'] = 3
  df.loc[(df[vertical_location_column]>df[strikezone_top_column]) &
         (df[vertical_location_column]<=df[strikezone_top_column]+(df[strikezone_top_column]-df[strikezone_bottom_column])/6),'v_zone'] = 4
  df.loc[(df[vertical_location_column]>df[strikezone_top_column]+(df[strikezone_top_column]-df[strikezone_bottom_column])/6),'v_zone'] = 5

  # Merge Statcast zones back onto original data
  return df.merge(zone_framework(),how='inner',on=['h_zone','v_zone'])

In [None]:
# Heatmap of where a batter gains/loses value for each pitch type seen
def strikezone_heatmap(df, player_name_lookup, ax, target_stat, 
                       name_lookup_column='batter_name', pitch=''):
  # Find first instance of player with player_name_lookup string in name
  player_name = (
      df
      .loc[~df[name_lookup_column].isna() &
           df[name_lookup_column]
           .str
           .contains(player_name_lookup),
           name_lookup_column]
      .unique()[0]
      )
  
  # Restrict to a specific pitch type, if provided
  heatmap_data = df.loc[df['pitch_name'].isin(pitch)] if pitch!='' else df
  
  colorscale_df = (
      df.loc[df['zone'].isin(range(1,10))]
      .groupby([name_lookup_column,'zone'])[target_stat].mean()
  )
  
  # Specify colorbar thresholds
  vmin = -colorscale_df.quantile(0.9)
  vcenter = 0
  vmax = colorscale_df.quantile(0.9)
  
  # Apply batter's values to zone framework
  heatmap_df = (
      heatmap_data.loc[heatmap_data[name_lookup_column]==player_name] # Filter batter
      .groupby('statcast_zone', as_index=False)[target_stat].mean() # Find mean of each zone
      .merge(zone_framework(),how='right',on='statcast_zone') # Apply to zone framework
      )

  # Generate the heatmap
  heatmap = sns.heatmap(heatmap_df.pivot('v_zone', 'h_zone', target_stat),
              vmin=vmin, 
              center=vcenter,
              vmax=vmax,
              cbar=False,
              cmap='vlag')
  
  blank_color = '#333333' # Filler color for locations with no data
  heatmap.set_facecolor(blank_color)

  ## Add strikezone lines
  # Within-strikezone Grid
  ax.axhline(4, xmin=0.2, xmax=0.8, color=blank_color, linewidth=2)
  ax.axhline(6, xmin=0.2, xmax=0.8, color=blank_color, linewidth=2)
  ax.axvline(4, ymin=0.2, ymax=0.8, color=blank_color, linewidth=2)
  ax.axvline(6, ymin=0.2, ymax=0.8, color=blank_color, linewidth=2)

  # Out-of-strikezone Grid
  ax.axhline(5, xmin=0.8, xmax=1, color=blank_color, linewidth=2)
  ax.axhline(5, xmin=0, xmax=0.2, color=blank_color, linewidth=2)
  ax.axvline(5, ymin=0.8, ymax=1, color=blank_color, linewidth=2)
  ax.axvline(5, ymin=0, ymax=0.2, color=blank_color, linewidth=2)

  # Strikezone
  ax.axhline(2, xmin=0.2, xmax=0.8, color='black', linewidth=2*2)
  ax.axhline(8, xmin=0.2, xmax=0.8, color='black', linewidth=2*2)
  ax.axvline(2, ymin=0.2, ymax=0.8, color='black', linewidth=2*2)
  ax.axvline(8, ymin=0.2, ymax=0.8, color='black', linewidth=2*2)

  ax.set(xlabel=None, ylabel=None)
  ax.set_xticklabels([])
  ax.set_yticklabels([])
  ax.tick_params(left=False, bottom=False)

  sns.despine(left=False, bottom=False)