In [35]:
import polars as pl
pl.Config.set_tbl_rows(50)
pl.Config.set_tbl_cols(-1)

polars.config.Config

In [36]:
results_df = pl.concat([
    pl.read_parquet("../data/best_models/zoo/best_model_results.parquet"),
    pl.read_parquet("../data/best_models/transformer/best_model_results.parquet")
]).with_columns(
    frame_distance_from_tackle = (pl.col("tackle_frameId") - pl.col("frameId")).cut(range(-10, 51, 5), left_closed=True),
)

results_df.sample(10)

gameId,playId,mirrored,ballCarrierNflId,ballCarrierName,tackle_frameId,tackle_event,tackle_x,tackle_y,tackle_x_rel,tackle_y_rel,play_origin_x,play_origin_y,playResult,tackle_event_enum,frameId,dataset_split,tackle_x_rel_pred,tackle_y_rel_pred,tackle_x_pred,tackle_y_pred,model_type,used_play_features,batch_size,hidden_dim,num_layers,dropout,learning_rate,frame_distance_from_tackle
i64,i64,bool,i64,str,i64,str,f64,f64,f64,f64,f64,f64,i64,i64,i64,str,f32,f32,f64,f64,str,bool,i32,i32,i32,f64,f64,cat
2022092509,3667,False,52414,"""Justin Herbert""",26,"""out_of_bounds""",44.56,0.28,13.72,-14.88,30.84,15.16,7,1,17,"""test""",13.6,-17.4,44.4,-2.2,"""zoo""",False,32,128,2,0.3,0.0001,"""[5, 10)"""
2022091102,2556,False,47856,"""David Montgomery""",70,"""tackle""",41.58,32.07,2.56,6.16,39.02,25.91,15,0,15,"""train""",10.2,2.9,49.2,28.8,"""transformer""",False,32,256,4,0.3,0.0001,"""[50, inf)"""
2022091900,1292,True,47879,"""Dawson Knox""",25,"""tackle""",64.03,39.0,3.94,1.77,60.09,37.23,9,0,15,"""test""",2.4,1.9,62.5,39.1,"""transformer""",False,32,256,4,0.3,0.0001,"""[10, 15)"""
2022091103,2611,True,44860,"""Joe Mixon""",55,"""tackle""",21.83,11.74,10.18,-12.22,11.65,23.96,3,0,51,"""train""",11.3,-13.1,22.9,10.9,"""transformer""",False,32,256,4,0.3,0.0001,"""[0, 5)"""
2022091112,3454,True,44853,"""Dalvin Cook""",52,"""tackle""",49.62,37.5,13.75,14.37,35.87,23.13,6,0,25,"""val""",16.700001,25.0,52.5,48.1,"""zoo""",False,32,128,2,0.3,0.0001,"""[25, 30)"""
2022091200,2667,True,42358,"""Melvin Gordon""",40,"""tackle""",106.2,28.05,7.63,4.58,98.57,23.47,0,0,29,"""train""",8.1,3.0,106.6,26.4,"""transformer""",False,32,256,4,0.3,0.0001,"""[10, 15)"""
2022100900,1252,False,54499,"""Christian Watson""",38,"""tackle""",91.08,5.52,4.18,-27.38,86.9,32.9,1,0,17,"""train""",7.2,-27.299999,94.1,5.6,"""zoo""",False,32,128,2,0.3,0.0001,"""[20, 25)"""
2022100906,1877,False,54476,"""Chris Olave""",13,"""tackle""",68.61,7.01,-1.52,0.74,70.13,6.27,9,0,13,"""train""",-1.7,0.8,68.4,7.1,"""transformer""",False,32,256,4,0.3,0.0001,"""[0, 5)"""
2022092600,2977,True,47911,"""Tony Pollard""",51,"""out_of_bounds""",23.95,52.52,6.79,37.5,17.16,15.02,4,1,39,"""train""",9.4,36.0,26.5,51.0,"""transformer""",False,32,256,4,0.3,0.0001,"""[10, 15)"""
2022110609,2922,True,47853,"""Darrell Henderson""",36,"""tackle""",100.09,24.6,5.35,3.68,94.74,20.92,0,0,14,"""train""",8.8,1.0,103.5,21.9,"""transformer""",False,32,256,4,0.3,0.0001,"""[20, 25)"""


In [37]:
# from sklearn.metrics import mean_squared_error
# from torch.nn.functional import mse_loss
# from torch import tensor
import numpy as np


def calculate_mse(x: pl.Series, y: pl.Series, xhat: pl.Series, yhat: pl.Series):
    """
    Calculate the mean squared error between the predicted and true values of x and y.
    """
    x, y, xhat, yhat = x.to_numpy(), y.to_numpy(), xhat.to_numpy(), yhat.to_numpy()
    return np.mean((np.array([xhat - x, yhat - y]) ** 2).mean(axis=0))

def calculate_mae(x: pl.Series, y: pl.Series, xhat: pl.Series, yhat: pl.Series):
    """
    Calculate the mean average error between the predicted and true values of x and y.
    """
    x, y, xhat, yhat = x.to_numpy(), y.to_numpy(), xhat.to_numpy(), yhat.to_numpy()
    return np.mean(np.abs(np.array([xhat - x, yhat - y])).mean(axis=0))
    

(
    results_df
    .group_by(["dataset_split", "model_type"], maintain_order=True)
    .agg(
        score = pl.map_groups(
                exprs=["tackle_x", "tackle_y", "tackle_x_pred", "tackle_y_pred"],
                # function=lambda list_of_series: calculate_mse(*list_of_series),
                function=lambda list_of_series: calculate_mae(*list_of_series),
                returns_scalar=True,
            ).round(1),
    ).pivot(values="score", columns=['model_type'], index="dataset_split")
    .with_columns(
        trfm_perc_adv = ((pl.col("zoo") - pl.col("transformer"))*100 / pl.col("zoo")).round(1)
    )
)

dataset_split,zoo,transformer,trfm_perc_adv
str,f64,f64,f64
"""train""",4.1,3.2,22.0
"""val""",4.3,3.4,20.9
"""test""",3.5,2.9,17.1


In [38]:
test_loss_by_frame_df = (
    results_df
    .filter(pl.col("dataset_split") == "test")
        .group_by(["model_type", "frame_distance_from_tackle"])
    .agg(
        n_frames = pl.len(),
        n_plays = pl.struct(["gameId", "playId"]).n_unique(),
        score = pl.map_groups(
                exprs=["tackle_x", "tackle_y", "tackle_x_pred", "tackle_y_pred"],
                # function=lambda list_of_series: calculate_mse(*list_of_series),
                function=lambda list_of_series: calculate_mae(*list_of_series),
                returns_scalar=True,
            ).round(1),
    )
    .sort("frame_distance_from_tackle")
    .pivot(values="score", columns=['model_type'], index="frame_distance_from_tackle")
    .with_columns(
        trfm_perc_adv = ((pl.col("zoo") - pl.col("transformer"))*100 / pl.col("zoo")).round(1)
    )
)

test_loss_by_frame_df

frame_distance_from_tackle,zoo,transformer,trfm_perc_adv
cat,f64,f64,f64
"""[0, 5)""",2.5,1.1,56.0
"""[5, 10)""",2.2,1.0,54.5
"""[10, 15)""",2.1,1.4,33.3
"""[15, 20)""",2.3,1.9,17.4
"""[20, 25)""",2.8,2.5,10.7
"""[25, 30)""",3.2,3.0,6.3
"""[30, 35)""",3.7,3.5,5.4
"""[35, 40)""",4.4,4.3,2.3
"""[40, 45)""",5.6,5.4,3.6
"""[45, 50)""",6.9,6.6,4.3


In [33]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

tracking_df = pl.read_parquet("../data/split_prepped_data/*_features.parquet")
play_df = pl.read_csv("../data/bdb_2024/plays.csv", null_values=["", "NA", "na", "nan", "NaN", "NAN"]).with_columns(
        distanceToGoal=(
            pl.when(pl.col("possessionTeam") == pl.col("yardlineSide"))
            .then(100 - pl.col("yardlineNumber"))
            .otherwise(pl.col("yardlineNumber"))
        )
    )
def animate_play(tracking_df: pl.DataFrame, play_df: pl.DataFrame, results_df: pl.DataFrame, gameId: int, playId: int, mirrored: bool = False):
    """
    Animate a play from the tracking data and the results of the models.
    """
    mvmt_df = tracking_df.filter((pl.col("gameId") == gameId) & (pl.col("playId") == playId) & (pl.col("mirrored") == mirrored)).to_pandas().round(2)
    play_df = play_df.filter((pl.col("gameId") == gameId) & (pl.col("playId") == playId)).to_pandas()
    model_results_df = results_df.filter((pl.col("gameId") == gameId) & (pl.col("playId") == playId) & (pl.col("mirrored") == mirrored)).to_pandas()
    assert len(mvmt_df) > 0
    assert len(play_df) > 0
    assert len(model_results_df) > 0

    mvmt_df["side"] = mvmt_df["side"].replace({1: "OFF", -1: "DEF"})
    # display(mvmt_df.sample(3), play_df, model_results_df)

    # get some info
    distToGoal = play_df["distanceToGoal"].values[0]
    down = play_df["down"].values[0]
    yards_to_go = play_df["yardsToGo"].values[0]
    play_description = play_df["playDescription"].values[0]
    ballCarrierName = model_results_df["ballCarrierName"].values[0]
    off_club = mvmt_df.loc[mvmt_df["side"] == "OFF", "club"].values[0]
    def_club = mvmt_df.loc[mvmt_df["side"] == "DEF", "club"].values[0]

    mvmt_y_min = mvmt_df["x"].min()
    mvmt_y_max = mvmt_df["x"].max()

    tkl_x, tkl_y = model_results_df[["tackle_x", "tackle_y"]].values[0]
    # Get prediction data for animation
    tkl_mvmt_df = model_results_df.sort_values("frameId")[["model_type", "frameId", "tackle_x_pred", "tackle_y_pred"]]
    tkl_mvmt_df = tkl_mvmt_df.rename(columns={"tackle_x_pred": "x", "tackle_y_pred": "y"})
    tkl_mvmt_df['displayName'] = tkl_mvmt_df['model_type']
    tkl_mvmt_df['nflId'] = tkl_mvmt_df['model_type'].map({"zoo": -10, "transformer": -20})
    tkl_mvmt_df["club"] = tkl_mvmt_df['model_type']
    tkl_mvmt_df["side"] = tkl_mvmt_df['model_type']
    tkl_mvmt_df['symbol'] = tkl_mvmt_df['model_type']
    tkl_mvmt_df["size"] = 1

    
    # set some things
    # colors_df = pl.read_csv("../data/team_colors.csv")
    # club_color_map = dict(colors_df.select(["club", "secondaryCol"]).rows())
    mvmt_df["size"] = 2
    mvmt_df["text_color"] = "black"

    # Different symbols for different positions
    mvmt_df["symbol"] = "player"
    mvmt_df.loc[mvmt_df["is_ball_carrier"] == 1, "symbol"] = "ball_carrier"
    # mvnt.loc[mvnt["side"] == "BALL", "symbol"] = "diamond_0"
    symbol_map = {"player": "circle", "ball_carrier": "hexagon", "zoo": "x", "transformer": "x"}

    # Data to display on hover
    hover_data = {
        "displayName": True,
        "club": False,
        "side": False,
        "jerseyNumber": False,
        "is_ball_carrier": True,
        "symbol": False,
        "frameId": False,
        "x": True,
        "y": True,
        "vx": True,
        "vy": True,
        "size": False,
    }

    X_LEFT = 0
    X_MIDDLE = 160 / 6.0
    X_RIGHT = 160 / 3.0
    Y_MIN = 0
    Y_MAX = 120
    Y_MIDDLE = 60
    fig = px.scatter(
        pd.concat([mvmt_df, tkl_mvmt_df]),
        x="y",
        y="x",
        animation_frame="frameId",
        animation_group="nflId",
        hover_name="displayName",
        hover_data=hover_data,
        text="jerseyNumber",
        width=1000,
        height=900,
        # range_x=[-160 / 6.0, 160 / 6.0],
        range_x=[X_LEFT-2, X_RIGHT+2],
        size="size",
        size_max=15,
        color="side",
        color_discrete_map={"OFF": "#39FF14", "DEF": "#FF69B4", "zoo": "blue", "transformer": "gold"},
        opacity=0.9,
        symbol="symbol",
        symbol_map=symbol_map,
    )

    # Add marker for tackle location
    fig.add_trace(
        go.Scatter(
            x=[tkl_y],
            y=[tkl_x],
            mode="markers",
            marker=dict(color="green", size=15, symbol="x"),
            hoverinfo="none",
            showlegend=False,
            opacity=0.8,
        )
    )

    # Add line of scrimmage
    los = Y_MAX-10-distToGoal
    fig.add_shape(
        type="line", x0=X_LEFT, y0=los, x1=X_RIGHT, y1=los, line=dict(color="rgba(137, 207, 240, 0.2)", width=3, dash="dash")
    )
    # Add yards to go line
    fig.add_shape(
        type="line",
        x0=X_LEFT,
        y0=los + yards_to_go,
        x1=X_RIGHT,
        y1=los + yards_to_go,
        line=dict(color="rgba(255, 255, 0, 0.2)", width=3, dash="dash"),
    )
    # Add border to the field
    fig.add_shape(
        type="rect", x0=X_LEFT, y0=Y_MIN, x1=X_RIGHT, y1=Y_MAX, line=dict(color="rgba(255, 255, 255, 0.5)", width=10)
    )
    # endzone
    fig.add_shape(
        type="rect", x0=X_LEFT, y0=Y_MAX-10, x1=X_RIGHT, y1=Y_MAX, line=dict(color="#39FF14", width=6), opacity=0.4,
    )
    # Add the path traces to the figure first to place them in the background
    # for trace in path_traces:
    #     fig.add_trace(trace)

    # set play speed
    frame_duration = 100
    for button in fig.layout.updatemenus[0].buttons:
        button["args"][1]["frame"]["duration"] = frame_duration
    # set aspect ratio
    fig.update_yaxes(scaleanchor="x", scaleratio=1)
    # background color
    fig.update_layout(paper_bgcolor="#333333", plot_bgcolor="#363636", font_color="white", font_size=14)
    # turn off axis
    fig.update_xaxes(showgrid=False, zeroline=False, showticklabels=False)
    # grid line thickness
    fig.update_yaxes(
        showgrid=True,
        gridwidth=3,
        gridcolor="rgba(237, 234, 222, 0.1)",
        linewidth=0,
        linecolor="rgba(0, 0, 0, 0.01)",
        mirror=True,
        showticklabels=False,
    )
    # set y axis range
    fig.update_yaxes(range=[mvmt_y_min, mvmt_y_max])
    # set yaxes ticks to 10 yards
    # fig.update_yaxes(tick0=0, dtick=10)
    # text size
    fig.update_layout(uniformtext_minsize=2, uniformtext_mode="hide")
    # hide legend
    fig.update_layout(showlegend=False)
    # text color of jersey numbers
    fig.update_traces(textfont=dict(family="Tahoma", size=12, color=mvmt_df["text_color"]))
    fig.update_traces(marker_line_width=0)

    # hide x and y labels
    fig.update_xaxes(title_text="")
    fig.update_yaxes(title_text="")
    
    # add hash marks, yard markers, etc
    for y_loc in range(Y_MIN+10, Y_MAX-10+1, 1):
        if y_loc % 10 == 0:
            ydln = y_loc-10 if y_loc <= 60 else 110 - y_loc
            ydln_txt = str(ydln) if ydln != 0 else "E Z"
            fig.add_shape(
                type="line", x0=X_LEFT, y0=y_loc, x1=X_RIGHT, y1=y_loc, line=dict(color="white", width=2), opacity=0.05
            )
            fig.add_annotation(
                x=X_LEFT+4, y=y_loc, text=ydln_txt, showarrow=False, font=dict(color="white", size=60), textangle=90, opacity=0.05
            )
            fig.add_annotation(
                x=X_RIGHT-4, y=y_loc, text=ydln_txt, showarrow=False, font=dict(color="white", size=60), textangle=270, opacity=0.05
            )
        elif y_loc % 5 == 0:
            fig.add_shape(
                type="line", x0=X_LEFT, y0=y_loc, x1=X_RIGHT, y1=y_loc, line=dict(color="white", width=1), opacity=0.05
            )
        else:
            fig.add_shape(
                type="rect", x0=X_MIDDLE-9.5, y0=y_loc, x1=X_MIDDLE-8.5, y1=y_loc, line=dict(color="white", width=3), opacity=0.05
            )
            fig.add_shape(
                type="rect", x0=X_MIDDLE+9.5, y0=y_loc, x1=X_MIDDLE+8.5, y1=y_loc, line=dict(color="white", width=3), opacity=0.05
            )

    fig.add_annotation(x=X_MIDDLE, y=Y_MAX-5, text=off_club, showarrow=False, font=dict(color="white", size=90), opacity=0.2)
    fig.add_annotation(x=X_MIDDLE, y=Y_MIDDLE+2, text="SŪMER", showarrow=False, font=dict(color="white", size=50), opacity=0.05)
    fig.add_annotation(x=X_MIDDLE, y=Y_MIDDLE-2, text="SPORTS", showarrow=False, font=dict(color="white", size=50), opacity=0.05)

    # Add play description
    if play_description is not None:
        # make list of 100 character slices
        play_desc_list = [play_description[i : i + 100] for i in range(0, len(play_description), 100)]
        for i, play_desc_txt in enumerate(play_desc_list):
            text_y_loc = mvmt_y_min - 5 - 5 * i
            fig.add_annotation(
                x=X_MIDDLE, y=text_y_loc, text=play_desc_txt, showarrow=False, font=dict(color="white", size=14), opacity=0.8
            )

    # set title
    # offense = play_info["offense"].values[0].upper()
    # defense = play_info["defense"].values[0].upper()
    # down = int(play_info["down"].values[0])
    # yards_to_go = int(play_info["yards_to_go"].values[0])
    # quarter = int(play_info["quarter"].values[0])
    # game_clock = play_info["game_clock"].values[0]
    fig.update_layout(
        title=f"{ballCarrierName} | {gameId} {playId} | {off_club} vs {def_club} | Down: {down} | YTG: {yards_to_go} | DTG: {distToGoal}",
        font_size=12,
        title_x=0.5,
        title_y=0.98,
    )

    return fig

In [34]:
result_row = results_df.filter(pl.col("dataset_split").is_in(["val", "test"])).sample(1)
gid, pid = results_df.filter(pl.col("dataset_split").is_in(["val", "test"])).select(['gameId', 'playId']).sample(1).to_numpy()[0]

# gid, pid = 2022100206, 988 # austin ekeler reverse TD
# gid, pid = 2022092505, 2441 # lamar scramble with a broken tackle # zoo model goes crazy near the end
# gid, pid = 2022102305, 380 # blocker causes prediction to jump
# gid, pid = 2022092501, 2398 # huge hole opens up and prediction jumps # this one is good frames 50-60. Zoo doesnt seem to generalize well in open field situations, even transformer struggles much later
# gid, pid = 2022100909, 1582 # derrik henry dead run -> huge run
# gid, pid = 2022103004, 2631 # Dameon Pierce weaving run
# gid, pid = 2022103002, 3044 # Jet Sweep shows poor generalization from Zoo model

print(f"gid, pid = {gid}, {pid}")
animate_play(tracking_df, play_df, results_df, gid, pid)

gid, pid = 2022103002, 801
