In [1]:
import polars as pl
from pre_snap_prediction.data import process_data
from pre_snap_prediction.modeling import route_clustering

In [2]:
tracking = process_data.read_tracking_csv(3)

In [3]:
player_play = pl.read_csv("../data/player_play.csv", null_values="NA")

In [4]:
inverse_tracking = process_data.inverse_left_directed_plays(tracking)

In [5]:
route_tracking = process_data.get_route_tracking(inverse_tracking, player_play)

In [6]:
route_tracking = process_data.get_route_direction(route_tracking)

In [7]:
route_tracking = process_data.inverse_right_route(route_tracking)

In [8]:
route_tracking = process_data.process_route_tracking(route_tracking)

In [9]:
route_features = process_data.compute_route_features(route_tracking)

In [None]:
route_features

In [11]:
outliers_model = route_clustering.train_outliers_model(route_features.filter(pl.col("week")==1))

In [None]:
outliers_route = route_clustering.predict_outliers(route_features, outliers_model)

In [13]:
valid_route_features = route_clustering.remove_outliers(outliers_route)

In [14]:
clustering_model = route_clustering.train_route_clustering(valid_route_features.filter(pl.col("week")==1))

In [None]:
clusters_route = route_clustering.predict_route_cluters(valid_route_features, clustering_model)

In [16]:
route_mode = route_clustering.get_modified_route_mode(player_play, clusters_route)

In [None]:
route_mode

In [18]:
clusters_route_tracking = route_clustering.join_clusters_to_data(route_tracking, clusters_route)

In [19]:
clusters_reception_zone = route_clustering.get_clusters_reception_zones(player_play, clusters_route_tracking)

In [None]:
clusters_reception_zone

In [None]:
from pre_snap_prediction.visualization import Field
field = Field()

half_length = 120/2
half_width = 53.3/2
field.fig.add_shape(
    type="circle",
    x0=half_length+5.39, x1=half_length+8.14, y0=half_width+3.0, y1=half_width+7.42,
    opacity=0.7,
    fillcolor="PaleTurquoise",
    line_color="LightSeaGreen",
)

In [22]:
valid_clusters_route_tracking = clusters_route_tracking.filter(pl.col("cluster").is_not_null())

In [23]:
import matplotlib.pyplot as plt

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(12, 8))
axs[0].scatter(
    clusters_route_tracking["relative_x"].to_numpy(),
    clusters_route_tracking["relative_y"].to_numpy(),
    c=clusters_route_tracking["route_frameId"].to_numpy(),
    cmap="viridis",
    s=5,
    alpha=0.1,
)
axs[0].set_aspect('equal', adjustable='box')
axs[1].scatter(
    valid_clusters_route_tracking["relative_x"].to_numpy(),
    valid_clusters_route_tracking["relative_y"].to_numpy(),
    c=valid_clusters_route_tracking["route_frameId"].to_numpy(),
    cmap="viridis",
    s=5,
    alpha=0.1,
)
axs[1].set_aspect('equal', adjustable='box')


In [None]:
clusters_count = clusters_route["cluster"].value_counts().sort("count", descending=True)
unique_clusters = clusters_count["cluster"][:20]
fig, axs = plt.subplots(len(unique_clusters), 1, figsize=(12,len(unique_clusters)*4))
for i, cluster in enumerate(unique_clusters):
    cluster_reception_zone = clusters_reception_zone.filter(pl.col("cluster")==cluster)
    if cluster_reception_zone.shape[0]>0:
        circle = plt.Circle(
            (
                (cluster_reception_zone["relative_x_min"][0]+cluster_reception_zone["relative_x_max"][0])/2, 
                (cluster_reception_zone["relative_y_min"][0]+cluster_reception_zone["relative_y_max"][0])/2
            ), 
            (
                cluster_reception_zone["relative_x_max"][0] - cluster_reception_zone["relative_x_min"][0] +
                cluster_reception_zone["relative_y_max"][0] - cluster_reception_zone["relative_y_min"][0]
            )/2, 
            color='PaleTurquoise', 
            fill=True, 
            alpha=0.5, 
        )
        axs[i].add_patch(circle)
    
    cluster_route_tracking = valid_clusters_route_tracking.filter(pl.col("cluster")==cluster)
    axs[i].scatter(
        cluster_route_tracking["relative_x"], 
        cluster_route_tracking["relative_y"], 
        c=cluster_route_tracking["route_frameId"], 
        cmap='viridis', 
        s=5, 
        alpha=0.2
    )

    
    axs[i].set_aspect('equal', adjustable='box')
    axs[i].set_title(f"Cluster: {cluster}, Count: {clusters_count.filter(pl.col('cluster')==cluster)['count'][0]}")