In [18]:
import pandas as pd
import numpy as np

import importlib
import models.play_success_bayes as psb
importlib.reload(psb)
import tools
import tools.tracking_utils as tracking_utils
import tools.play_viz as play_viz
import models

importlib.reload(tools)
importlib.reload(models)
importlib.reload(tracking_utils)
importlib.reload(play_viz)
import models.pipeline as psdata
importlib.reload(psdata)
from models import BayesianPlaySuccessModel
from tools import get_din_dout, animate_week_play, visualize_predictions
from tools.tracking_utils import frames_from_input, frames_from_output_merged
from models.pipeline import build_play_frame_features, prob_by_step_for_play, build_prefix_training_data
from matplotlib import rc
rc("animation", html="jshtml")

def get_game_play_index(input_path):
    df_ids = pd.read_csv(input_path, usecols=["game_id", "play_id"])
    df_ids = df_ids.drop_duplicates().sort_values(["game_id", "play_id"])
    return df_ids

In [19]:
THRESHOLD = 2.0
n_weeks = 1

feature_cols = [
    "abs_yardline_at_throw",
    "dist_target_to_land",
    "num_defenders_close",
    "breakaway",
    "tackle_range",
    "red_zone",
    "ball_x",
    "ball_y",
    "dist_to_nearest_defender",
    "dist_to_ball_land_position",
    "dist_to_bounds",
]

ps_model = BayesianPlaySuccessModel(feature_cols=feature_cols)

train_prefix_df = build_prefix_training_data(
    weeks=n_weeks,
    model=ps_model,
    threshold=THRESHOLD,
)

print(train_prefix_df.columns)
print(train_prefix_df.head())


Weeks:   0%|          | 0/1 [00:00<?, ?it/s]

Plays (week 1):   0%|          | 0/819 [00:00<?, ?it/s]

Index(['game_id', 'play_id', 'frame_id', 'week', 'abs_yardline_at_throw',
       'dist_target_to_land', 'num_defenders_close', 'breakaway',
       'tackle_range', 'red_zone', 'ball_x', 'ball_y',
       'dist_to_nearest_defender', 'dist_to_ball_land_position',
       'dist_to_bounds', 'play_success'],
      dtype='object')
      game_id  play_id  frame_id  week  abs_yardline_at_throw  \
0  2023090700      101         1     1                   42.0   
1  2023090700      101         2     1                   42.0   
2  2023090700      101         3     1                   42.0   
3  2023090700      101         4     1                   42.0   
4  2023090700      101         5     1                   42.0   

   dist_target_to_land  num_defenders_close  breakaway  tackle_range  \
0            17.402401                  0.0        1.0           0.0   
1            16.807438                  0.0        1.0           0.0   
2            16.205282                  0.0        1.0           0.0 

In [20]:
ps_model.fit(train_prefix_df)

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [beta0, beta]


Output()

Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 169 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics


In [44]:
week = 2
index_df = get_game_play_index(f"train/input_2023_w{week:02d}.csv")
idx = 4
gi = index_df.iloc[idx]["game_id"]
pi = index_df.iloc[idx]["play_id"]

# build d_play with frame_id
d_in, d_out = get_din_dout(week, gi, pi)
frames_pre = frames_from_input(d_in)
frames_post = frames_from_output_merged(d_in, d_out)

max_pre_frame = max(frames_pre.keys())
frames_post_shifted = {
    f + max_pre_frame: df.assign(frame_id=f + max_pre_frame)
    for f, df in frames_post.items()
}
all_frames = {**frames_pre, **frames_post_shifted}

d_play = (
    pd.concat(
        [df.assign(frame_id=f) for f, df in all_frames.items()],
        ignore_index=True,
    )
    .sort_values("frame_id")
)

p_by_frame = ps_model.frame_prob_dict(d_play, debug=True)

37 frames


In [45]:
train_prefix_df["play_success"].mean()
train_prefix_df[ps_model.feature_cols].corrwith(train_prefix_df["play_success"])
train_prefix_df[ps_model.feature_cols].describe()

  c /= stddev[:, None]
  c /= stddev[None, :]


Unnamed: 0,abs_yardline_at_throw,dist_target_to_land,num_defenders_close,breakaway,tackle_range,red_zone,ball_x,ball_y,dist_to_nearest_defender,dist_to_ball_land_position,dist_to_bounds
count,23239.0,23239.0,23239.0,23239.0,23239.0,23239.0,23239.0,23239.0,23239.0,23239.0,23239.0
mean,60.122811,7.711943,0.40815,0.505099,0.254443,0.0,59.909287,26.535193,5.404646,7.711943,15.749664
std,23.512641,6.101623,0.614948,0.499985,0.435557,0.0,26.34746,10.613371,7.325021,6.101623,6.092711
min,11.0,0.019998,0.0,0.0,0.0,0.0,1.33,-1.69,0.02,0.019998,0.39
25%,40.0,2.935967,0.0,0.0,0.0,0.0,38.459999,23.36,1.968324,2.935967,11.35
50%,59.0,6.340545,0.0,1.0,0.0,0.0,59.31,26.29,3.757047,6.340545,16.12
75%,79.0,10.723458,1.0,1.0,1.0,0.0,81.81,30.1,6.285833,10.723458,20.5
max,109.0,43.854711,4.0,1.0,1.0,0.0,119.779999,57.330002,50.0,43.854711,26.65


In [46]:

ani = visualize_predictions(
    model=None,
    week=week,
    game_id=gi,
    play_id=pi,
    horizon=0,
    show_paths=False,
    show_cones=False,
    p_by_frame=p_by_frame,
)
ani