In [22]:
import polars as pl
pl.Config.set_tbl_rows(50)
pl.Config.set_tbl_cols(-1)

# results_path = "../../data/models/transformer_no_play/H32_L1/model_preds.parquet"
results_path = "../../data/models/zoo_no_play/H32_L1/model_preds.parquet"

results_df = (
    pl.read_parquet(results_path)
    .with_columns(
        frame_distance_from_tackle = (pl.col("tackle_frameId") - pl.col("frameId")).cut(range(-10, 51, 5)),
    )
)

results_df.sample(3)

gameId,playId,mirrored,frameId,tackle_x_pred,tackle_y_pred,dataset_split,ballCarrierNflId,ballCarrierName,tackle_frameId,tackle_event,tackle_x,tackle_y,tackle_event_enum,frame_distance_from_tackle
i64,i64,bool,i64,f32,f32,str,i64,str,i64,str,f64,f64,i64,cat
2022091800,1925,True,34,80.550003,41.849998,"""val""",46155,"""Mark Andrews""",36,"""tackle""",97.6,33.83,0,"""(0, 5]"""
2022110608,3372,True,27,107.949997,32.509998,"""train""",54506,"""Kenneth Walker""",30,"""touchdown""",109.96,25.48,2,"""(0, 5]"""
2022102308,478,False,32,53.689999,34.080002,"""train""",54572,"""Dameon Pierce""",48,"""tackle""",84.85,28.77,0,"""(15, 20]"""


In [23]:
# from sklearn.metrics import mean_squared_error
from torch.nn.functional import mse_loss
from torch import tensor

mse_df = (
    results_df
    .group_by(["dataset_split"])
    .agg(
        mse_loss = pl.map_groups(
                exprs=["tackle_x_pred", "tackle_y_pred", "tackle_x", "tackle_y"],
                function=lambda list_of_series: mse_loss(tensor([list_of_series[0], list_of_series[1]]), tensor([list_of_series[2], list_of_series[3]])).item(),
                returns_scalar=True,
            )
    )
)

mse_df

dataset_split,mse_loss
str,f64
"""val""",244.68782
"""test""",238.521866
"""train""",237.844406


In [24]:
mse_df = (
    results_df
    .group_by(["dataset_split", "frame_distance_from_tackle"])
    .agg(
        mse_loss = pl.map_groups(
                exprs=["tackle_x_pred", "tackle_y_pred", "tackle_x", "tackle_y"],
                function=lambda list_of_series: mse_loss(tensor([list_of_series[0], list_of_series[1]]), tensor([list_of_series[2], list_of_series[3]])).item(),
                returns_scalar=True,
            ).round(1)
    ).pivot(values="mse_loss", columns="dataset_split", index="frame_distance_from_tackle").sort("frame_distance_from_tackle")
)
mse_df

frame_distance_from_tackle,val,train,test
cat,f64,f64,f64
"""(-5, 0]""",235.4,219.0,224.4
"""(0, 5]""",229.0,211.7,218.2
"""(5, 10]""",223.1,208.8,210.8
"""(10, 15]""",223.8,211.2,213.7
"""(15, 20]""",227.6,215.0,216.2
"""(20, 25]""",230.5,220.6,218.8
"""(25, 30]""",233.8,228.2,229.9
"""(30, 35]""",245.5,240.1,246.9
"""(35, 40]""",260.7,257.2,259.8
"""(40, 45]""",276.4,279.6,280.4


In [42]:
y_df = pl.read_parquet("../../data/split_prepped_data/test_targets.parquet")

y_df.sample(10)

gameId,playId,mirrored,ballCarrierNflId,ballCarrierName,tackle_frameId,tackle_event,tackle_x,tackle_y,tackle_x_rel,tackle_y_rel,tackle_event_enum
i64,i64,bool,i64,str,i64,str,f64,f64,f64,f64,i64
2022110604,245,True,53454,"""Travis Etienne""",45,"""tackle""",80.85,19.88,9.73,-9.39,0
2022110604,2670,False,53454,"""Travis Etienne""",55,"""tackle""",106.05,26.22,7.07,4.82,0
2022101610,1238,False,47857,"""Devin Singletary""",47,"""tackle""",41.88,26.71,12.5,-5.6,0
2022091100,679,True,48374,"""Olamide Zaccheaus""",13,"""fumble""",56.1,30.24,2.03,-7.71,4
2022090800,593,False,42448,"""Jamison Crowder""",29,"""tackle""",46.29,47.47,12.83,8.44,0
2022102310,335,True,47839,"""Mecole Hardman""",27,"""tackle""",38.66,7.77,5.51,-7.54,0
2022103002,1456,False,43424,"""Dak Prescott""",44,"""tackle""",98.03,33.37,7.33,10.25,0
2022102310,3586,True,47839,"""Mecole Hardman""",37,"""touchdown""",108.91,12.2,4.77,-21.28,2
2022103011,2104,True,45186,"""Matt Breida""",59,"""tackle""",36.13,11.56,8.19,-18.31,0
2022100209,2315,True,53579,"""Kenneth Gainwell""",42,"""tackle""",39.61,36.19,9.41,6.95,0


In [50]:
X_df['x'].plot.hist()

In [49]:
X_df['x_rel'].plot.hist()

In [46]:
play_df = pl.read_csv("../../data/bdb_2024/plays.csv", null_values=["", "NA"])

play_df.filter((pl.col("gameId") == 2022110604) & (pl.col("playId") == 2670))

gameId,playId,ballCarrierId,ballCarrierDisplayName,playDescription,quarter,down,yardsToGo,possessionTeam,defensiveTeam,yardlineSide,yardlineNumber,gameClock,preSnapHomeScore,preSnapVisitorScore,passResult,passLength,penaltyYards,prePenaltyPlayResult,playResult,playNullifiedByPenalty,absoluteYardlineNumber,offenseFormation,defendersInTheBox,passProbability,preSnapHomeTeamWinProbability,preSnapVisitorTeamWinProbability,homeTeamWinProbabilityAdded,visitorTeamWinProbilityAdded,expectedPoints,expectedPointsAdded,foulName1,foulName2,foulNFLId1,foulNFLId2
i64,i64,i64,str,str,i64,i64,i64,str,str,str,i64,str,i64,i64,str,i64,i64,i64,i64,str,i64,str,i64,f64,f64,f64,f64,f64,f64,f64,str,str,i64,str
2022110604,2670,53454,"""Travis Etienne""","""(:42) (Shotgun) T.Etienne up t…",3,2,2,"""JAX""","""LV""","""LV""",6,"""0:42""",17,20,,,,1,1,"""N""",16,"""SHOTGUN""",6,0.219358,0.562637,0.437363,-0.018162,0.018162,5.641214,-0.351027,,,,


In [43]:
X_df = pl.read_parquet("../../data/split_prepped_data/test_features.parquet")

# tdf = X_df.filter((pl.col("gameId") == 2022110609) & (pl.col("playId") == 3817) & (pl.col("mirrored") == False)).sort(["frameId", "nflId"])
tdf = X_df.filter((pl.col("gameId") == 2022110604) & (pl.col("playId") == 2670) & (pl.col("mirrored") == False)).sort(["frameId", "nflId"])

tdf.filter(pl.col("is_ball_carrier") == 1)

gameId,playId,nflId,displayName,frameId,time,jerseyNumber,club,x,y,s,a,dis,o,dir,event,down,yardsToGo,absoluteYardlineNumber,weight,height_inches,is_ball_carrier,side,sx,sy,ox,oy,mirrored,play_origin_x,play_origin_y,x_rel,y_rel
i64,i64,i64,str,i64,str,i64,str,f64,f64,f64,f64,f64,f64,f64,str,i64,i64,i64,i64,i64,i64,i32,f64,f64,f64,f64,bool,f64,f64,f64,f64
2022110604,2670,53454,"""Travis Etienne""",1,"""2022-11-06 15:08:50.200000""",1,"""JAX""",98.98,21.4,0.04,0.04,0.01,170.49,242.43,,2,2,16,200,70,1,1,0.018513,0.035458,0.986257,-0.16522,false,98.98,21.4,0.0,0.0
2022110604,2670,53454,"""Travis Etienne""",2,"""2022-11-06 15:08:50.299999""",1,"""JAX""",98.99,21.39,0.03,0.03,0.01,179.56,230.42,,2,2,16,200,70,1,1,0.019115,0.023122,0.999971,-0.007679,false,98.98,21.4,0.01,-0.01
2022110604,2670,53454,"""Travis Etienne""",3,"""2022-11-06 15:08:50.400000""",1,"""JAX""",98.99,21.4,0.03,0.03,0.01,182.42,234.12,,2,2,16,200,70,1,1,0.017583,0.024307,0.999108,0.042224,false,98.98,21.4,0.01,0.0
2022110604,2670,53454,"""Travis Etienne""",4,"""2022-11-06 15:08:50.500000""",1,"""JAX""",98.99,21.42,0.04,0.04,0.02,181.25,237.99,,2,2,16,200,70,1,1,0.021203,0.033918,0.999762,0.021815,false,98.98,21.4,0.01,0.02
2022110604,2670,53454,"""Travis Etienne""",5,"""2022-11-06 15:08:50.599999""",1,"""JAX""",99.0,21.44,0.04,0.04,0.03,184.81,240.74,,2,2,16,200,70,1,1,0.019551,0.034896,0.996478,0.083852,false,98.98,21.4,0.02,0.04
2022110604,2670,53454,"""Travis Etienne""",6,"""2022-11-06 15:08:50.700000""",1,"""JAX""",99.02,21.44,0.04,0.04,0.02,194.5,232.14,"""ball_snap""",2,2,16,200,70,1,1,0.024549,0.031581,0.968148,0.25038,false,98.98,21.4,0.04,0.04
2022110604,2670,53454,"""Travis Etienne""",7,"""2022-11-06 15:08:50.799999""",1,"""JAX""",99.03,21.46,0.04,0.04,0.02,197.75,234.78,,2,2,16,200,70,1,1,0.023069,0.032678,0.952396,0.304864,false,98.98,21.4,0.05,0.06
2022110604,2670,53454,"""Travis Etienne""",8,"""2022-11-06 15:08:50.900000""",1,"""JAX""",99.04,21.48,0.05,0.23,0.02,198.64,235.77,,2,2,16,200,70,1,1,0.028126,0.041339,0.947546,0.319621,false,98.98,21.4,0.06,0.08
2022110604,2670,53454,"""Travis Etienne""",9,"""2022-11-06 15:08:51.000000""",1,"""JAX""",99.06,21.52,0.23,1.86,0.04,215.29,235.1,,2,2,16,200,70,1,1,0.131594,0.188635,0.816238,0.577715,false,98.98,21.4,0.08,0.12
2022110604,2670,53454,"""Travis Etienne""",10,"""2022-11-06 15:08:51.099999""",1,"""JAX""",99.1,21.56,0.66,4.0,0.06,223.91,234.03,,2,2,16,200,70,1,1,0.387659,0.534154,0.72043,0.693528,false,98.98,21.4,0.12,0.16
