In [1]:
import polars as pl
from pre_snap_prediction.data import process_data
from pre_snap_prediction.modeling import route_clustering

In [2]:
tracking = pl.read_csv("../data/tracking_week_1.csv", null_values="NA")
player_play = pl.read_csv("../data/player_play.csv", null_values="NA")

In [3]:
inverse_tracking = process_data.inverse_left_directed_plays(tracking)

In [4]:
route_tracking = process_data.get_route_tracking(inverse_tracking, player_play)

In [5]:
route_tracking_dir = process_data.get_route_direction(route_tracking)

In [None]:
inverse_route_tracking = process_data.inverse_right_route(route_tracking_dir)

In [7]:
processed_route_tracking = process_data.process_route_tracking(inverse_route_tracking)

In [8]:
route_features = process_data.compute_route_features(processed_route_tracking)

In [None]:
route_features

In [10]:
outliers_model = route_clustering.train_outliers_model(route_features)

In [None]:
outliers_route = route_clustering.predict_outliers(route_features, outliers_model)

In [12]:
valid_route_features = route_clustering.remove_outliers(outliers_route)

In [13]:
clustering_model = route_clustering.train_route_clustering(valid_route_features)

In [None]:
clusters_route = route_clustering.predict_route_cluters(valid_route_features, clustering_model)

In [15]:
clusters_route_tracking = route_clustering.join_clusters_to_data(processed_route_tracking, clusters_route)

In [17]:
valid_clusters_route_tracking = clusters_route_tracking.filter(pl.col("cluster").is_not_null())

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(2, 1, figsize=(12, 8))
axs[0].scatter(
    clusters_route_tracking["relative_x"].to_numpy(),
    clusters_route_tracking["relative_y"].to_numpy(),
    c=clusters_route_tracking["route_frameId"].to_numpy(),
    cmap="viridis",
    s=5,
    alpha=0.1,
)
axs[0].set_aspect('equal', adjustable='box')
axs[1].scatter(
    valid_clusters_route_tracking["relative_x"].to_numpy(),
    valid_clusters_route_tracking["relative_y"].to_numpy(),
    c=valid_clusters_route_tracking["route_frameId"].to_numpy(),
    cmap="viridis",
    s=5,
    alpha=0.1,
)
axs[1].set_aspect('equal', adjustable='box')


In [None]:
clusters_count = valid_clusters_route_tracking["cluster"].value_counts().sort("count", descending=True)
unique_clusters = clusters_count["cluster"][:20]
fig, axs = plt.subplots(len(unique_clusters), 1, figsize=(12,len(unique_clusters)*4))
for i, cluster in enumerate(unique_clusters):
    cluster_route_tracking = valid_clusters_route_tracking.filter(pl.col("cluster")==cluster)
    axs[i].scatter(
        cluster_route_tracking["relative_x"], 
        cluster_route_tracking["relative_y"], 
        c=cluster_route_tracking["route_frameId"], 
        cmap='viridis', 
        s=5, 
        alpha=0.2
    )
    axs[i].set_aspect('equal', adjustable='box')
    axs[i].set_title(f"Cluster: {cluster}, Count: {clusters_count.filter(pl.col('cluster')==cluster)['count'][0]}, Mode: {cluster_route_tracking['routeRan'].mode()[0]}")