# NFL Downfield Pass Dominance via TacticAI-Inspired GNNs

- **Competition:** NFL Big Data Bowl 2026 - Analytics. Understand player movement while the ball is in the air.
- **Data:** `train/input_2023_wXX.csv`, `train/output_2023_wXX.csv`, and `Supplementary.csv`
- **Objective:** Build a per-pass and per-route **Downfield Pass Dominance** metric for targeted receivers by comparing GNN-based catch probabilities to contextual baselines, and analyze contested vs. non-contested throws.

Inspired by **TacticAI** (geometric deep learning/GNNs for soccer set pieces) and **CLRS GNN processors**, we represent each pass at ball arrival as a graph of players and learn catch probability. We implement a simplified PyTorch/PyG pipeline (rather than the full JAX-based TacticAI stack) and define dominance as:

`Route Dominance Score = model_predicted_catch_prob - baseline_expected_catch_prob_given_context`

We compute this separately for contested vs. non-contested throws and aggregate to player-route summaries.


In [1]:
# Standard Python + path helpers
import os
import sys
import subprocess
from pathlib import Path

# Numeric + data wrangling
import numpy as np
import pandas as pd

# Plotting libraries for quick visuals
import matplotlib.pyplot as plt
import seaborn as sns

# Train/validation split and evaluation metrics
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    roc_auc_score,
    roc_curve,
    accuracy_score,
    confusion_matrix,
    classification_report,
)

# PyTorch core for deep learning
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

# PyTorch Geometric gives us graph neural network layers
try:
    import torch_geometric
    from torch_geometric.data import Data as GeoData
    from torch_geometric.loader import DataLoader as GeoDataLoader
    from torch_geometric.nn import GATv2Conv, global_mean_pool
except ImportError:
    # Install PyG dependencies if this environment does not have them
    subprocess.check_call([
        sys.executable,
        "-m",
        "pip",
        "install",
        "torch-geometric",
        "torch-scatter",
        "torch-sparse",
        "torch-cluster",
    ])
    import torch_geometric
    from torch_geometric.data import Data as GeoData
    from torch_geometric.loader import DataLoader as GeoDataLoader
    from torch_geometric.nn import GATv2Conv, global_mean_pool

# Make plots readable out of the gate
sns.set(style="whitegrid", context="talk")
plt.rcParams["figure.figsize"] = (10, 6)

# GPU if available, else CPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE


ModuleNotFoundError: No module named 'numpy'

We expect the raw Kaggle exports to live locally with the following structure relative to this notebook:

```
data/
  input_2023_w01.csv
  input_2023_w02.csv
  ...
  output_2023_w01.csv
  output_2023_w02.csv
  ...
  Supplementary.csv
```

All weeks reside directly under `data/` alongside `Supplementary.csv`. Adjust the paths below if your layout differs.


In [None]:
# Point to the folder where the CSVs live
BASE_DIR = Path("data")
TRAIN_DIR = BASE_DIR  # week-level CSVs now live directly under data/
SUPP_PATH = BASE_DIR / "Supplementary.csv"
BASE_DIR, TRAIN_DIR, SUPP_PATH


In [None]:
# Load play-level context (down, distance, coverage, etc.)
supp = pd.read_csv(SUPP_PATH)

# Quick peek at the columns and data types
display(supp.head())
print()
supp.info()


`Supplementary.csv` carries the play-level context (game/play identifiers, down & distance, pass depth/location, coverage tags, EPA, etc.) that we will merge into every graph snapshot for both baseline expectations and model conditioning.


In [None]:
def load_tracking_inputs(weeks=None, train_dir=TRAIN_DIR):
    """Stack the pre-throw tracking CSVs for the weeks we care about."""
    if weeks is None:
        weeks = list(range(1, 19))
    frames = []
    for w in weeks:
        fname = train_dir / f"input_2023_w{w:02d}.csv"
        if not fname.exists():
            print(f"Skipping missing {fname}")
            continue
        df = pd.read_csv(fname)
        df["week"] = w
        df["file_type"] = "input"
        frames.append(df)
    if not frames:
        raise ValueError("No input files were loaded; please confirm the data path.")
    return pd.concat(frames, ignore_index=True)


def load_tracking_outputs(weeks=None, train_dir=TRAIN_DIR):
    """Stack the ball-in-flight tracking CSVs for the same weeks."""
    if weeks is None:
        weeks = list(range(1, 19))
    frames = []
    for w in weeks:
        fname = train_dir / f"output_2023_w{w:02d}.csv"
        if not fname.exists():
            print(f"Skipping missing {fname}")
            continue
        df = pd.read_csv(fname)
        df["week"] = w
        df["file_type"] = "output"
        frames.append(df)
    if not frames:
        raise ValueError("No output files were loaded; please confirm the data path.")
    return pd.concat(frames, ignore_index=True)


In [None]:
# Start with a small subset of weeks so experimentation is fast
DEV_WEEKS = [1, 2, 3]
inputs_raw = load_tracking_inputs(DEV_WEEKS)
outputs_raw = load_tracking_outputs(DEV_WEEKS)

# Basic sanity checks on the dataframes
display(inputs_raw.head())
display(outputs_raw.head())
print("Inputs shape:", inputs_raw.shape)
print("Outputs shape:", outputs_raw.shape)
print()
inputs_raw.info()
print()
outputs_raw.info()


`input_2023_wXX` contains the final pre-throw frame (frame 1) for every player plus metadata such as `num_frames_output` and ball landing coordinates, while `output_2023_wXX` tracks player positions during the ball's flight. We use `num_frames_output` to align each output sequence with the receiver-specific arrival frame.


In [None]:
FIELD_LENGTH = 120.0
FIELD_WIDTH = 53.3


def standardize_xy(df, x_col="x", y_col="y"):
    """Reflect coordinates so offense always drives to +x with y=0 at the bottom."""
    df = df.copy()
    right_mask = df["play_direction"].str.lower() == "right"
    df["x_std"] = df[x_col]
    df["y_std"] = df[y_col]
    left_mask = ~right_mask
    df.loc[left_mask, "x_std"] = FIELD_LENGTH - df.loc[left_mask, x_col]
    df.loc[left_mask, "y_std"] = FIELD_WIDTH - df.loc[left_mask, y_col]
    return df


def standardize_ball_land(df):
    df = df.copy()
    right_mask = df["play_direction"].str.lower() == "right"
    df["ball_land_x_std"] = df["ball_land_x"]
    df["ball_land_y_std"] = df["ball_land_y"]
    left_mask = ~right_mask
    df.loc[left_mask, "ball_land_x_std"] = FIELD_LENGTH - df.loc[left_mask, "ball_land_x"]
    df.loc[left_mask, "ball_land_y_std"] = FIELD_WIDTH - df.loc[left_mask, "ball_land_y"]
    return df


def add_velocity_components(df):
    df = df.copy()
    theta = np.deg2rad(df["dir"].fillna(0.0).values)
    speed = df["s"].fillna(0.0).values
    df["vx"] = speed * np.cos(theta)
    df["vy"] = speed * np.sin(theta)
    return df


# Apply the helpers so every play is in the same coordinate frame
inputs = standardize_xy(inputs_raw)
inputs = standardize_ball_land(inputs)
inputs = add_velocity_components(inputs)

outputs = standardize_xy(outputs_raw)


Standardizing coordinates (mirroring left-to-right drives) removes directional variance similar to TacticAI's field reflections, making the GNN equivariant to play direction and easing learning of spatial relationships.


In [None]:
# Keep only plays that appear in the supplementary context table
supp_keys = supp[["game_id", "play_id"]].drop_duplicates()
inputs = inputs.merge(supp_keys, on=["game_id", "play_id"], how="inner")
outputs = outputs.merge(supp_keys, on=["game_id", "play_id"], how="inner")


def get_targeted_receivers(input_df):
    """Grab the intended receiver (frame 1) for each pass."""
    tr = input_df[input_df["player_role"] == "Targeted Receiver"].copy()
    return tr[tr["frame_id"] == 1]


# This table is one row per pass attempt + targeted receiver
targeted = get_targeted_receivers(inputs)
targeted.head()


Each row in `targeted` corresponds to one pass attempt (game_id, play_id) and the intended receiver (frame 1 pre-throw snapshot) that we will track through ball arrival.


In [None]:
# Attach num_frames_output so we know which output frame is the catch/arrival snapshot
num_frames = inputs[["game_id", "play_id", "nfl_id", "num_frames_output"]].drop_duplicates()

outputs_merged = outputs.merge(
    num_frames,
    on=["game_id", "play_id", "nfl_id"],
    how="inner",
)

# Keep only the arrival frame per player (frame == num_frames_output)
arrival = outputs_merged[
    outputs_merged["frame_id"] == outputs_merged["num_frames_output"]
].copy()

arrival.head()


In [None]:
# Stitch together pre-throw info with the arrival positions for each target
target_arrival = targeted.merge(
    arrival,
    on=["game_id", "play_id", "nfl_id", "week"],
    suffixes=("_pre", "_arr"),
    how="inner",
)

target_arrival.head()


In [None]:
def compute_arrival_defender_features(target_row, outputs_arrival, inputs_df):
    """Measure how many defenders are near the receiver when the ball arrives."""
    gid = target_row["game_id"]
    pid = target_row["play_id"]
    frame_end = target_row["num_frames_output"]
    rx = target_row["x_std_arr"]
    ry = target_row["y_std_arr"]

    df_def = outputs_arrival[
        (outputs_arrival["game_id"] == gid)
        & (outputs_arrival["play_id"] == pid)
        & (outputs_arrival["frame_id"] == frame_end)
    ].merge(
        inputs_df[["game_id", "play_id", "nfl_id", "player_side"]].drop_duplicates(),
        on=["game_id", "play_id", "nfl_id"],
        how="left",
    )
    df_def = df_def[df_def["player_side"] == "Defense"].copy()

    if df_def.empty:
        return {
            "sep_nearest": np.nan,
            "sep_second": np.nan,
            "num_def_within_2": np.nan,
            "num_def_within_3": np.nan,
            "num_def_within_5": np.nan,
        }

    dx = df_def["x_std"] - rx
    dy = df_def["y_std"] - ry
    dists = np.sqrt(dx**2 + dy**2)
    dists_sorted = np.sort(dists)

    return {
        "sep_nearest": dists_sorted[0],
        "sep_second": dists_sorted[1] if len(dists_sorted) > 1 else np.nan,
        "num_def_within_2": float((dists <= 2.0).sum()),
        "num_def_within_3": float((dists <= 3.0).sum()),
        "num_def_within_5": float((dists <= 5.0).sum()),
    }


In [None]:
supp_play_lookup = supp.set_index(["game_id", "play_id"])

feature_rows = []
for _, row in target_arrival.iterrows():
    key = (row["game_id"], row["play_id"])
    if key not in supp_play_lookup.index:
        continue
    srow = supp_play_lookup.loc[key]
    feats = {
        "game_id": row["game_id"],
        "play_id": row["play_id"],
        "nfl_id": row["nfl_id"],
        "player_name": row.get("player_name_pre", row.get("player_name", "")),
        "week": row["week"],
        # Contextual info
        "pass_result": srow["pass_result"],
        "pass_length": srow["pass_length"],
        "route": srow["route_of_targeted_receiver"],
        "down": srow["down"],
        "yards_to_go": srow["yards_to_go"],
        "coverage_type": srow["team_coverage_type"],
        "coverage_man_zone": srow["team_coverage_man_zone"],
        # Binary label for modeling
        "caught": 1 if srow["pass_result"] == "C" else 0,
        # Receiver + ball locations at arrival
        "ball_land_x_std": row["ball_land_x_std"],
        "ball_land_y_std": row["ball_land_y_std"],
        "rx_arr": row["x_std_arr"],
        "ry_arr": row["y_std_arr"],
    }
    dx_ball = feats["rx_arr"] - feats["ball_land_x_std"]
    dy_ball = feats["ry_arr"] - feats["ball_land_y_std"]
    feats["dist_receiver_to_ball"] = float(np.sqrt(dx_ball**2 + dy_ball**2))
    # Add the defender separation metrics
    feats.update(compute_arrival_defender_features(row, outputs_merged, inputs))
    feature_rows.append(feats)

# Final per-pass table that fuels EDA + models
features_df = pd.DataFrame(feature_rows)
features_df.head()


Arrival-level separation (`sep_nearest`, `sep_second`, defender counts within radii) drives our contested-vs-open labeling, while `dist_receiver_to_ball` captures tracking alignment with the ball's projected landing point.


In [None]:
# Simple contested definition: tightest defender within 2 yards and at least one defender inside 3 yards
features_df["contested"] = (
    (features_df["sep_nearest"] <= 2.0)
    & (features_df["num_def_within_3"].fillna(0) >= 1)
).astype(int)
features_df["contested"].value_counts(dropna=False)


## Exploratory Data Analysis


In [None]:
# Quick frequency tables to understand label balance and popular routes/coverages
features_df["pass_result"].value_counts()
features_df["caught"].mean()
features_df["route"].value_counts().head(20)
features_df["coverage_type"].value_counts().head()
features_df["coverage_man_zone"].value_counts().head()


In [None]:
# Separation + outcome relationship
sns.histplot(data=features_df, x="sep_nearest", hue="caught", bins=30, stat="density", kde=True)
plt.title("Nearest defender separation at arrival vs. catch outcome")
plt.show()

# How far the receiver is from the ball vs. defender distance
sns.scatterplot(data=features_df, x="sep_nearest", y="dist_receiver_to_ball", hue="caught", alpha=0.5)
plt.title("Ball distance vs. separation")
plt.show()


In [None]:
# Which routes have the highest catch rates in this sample?
route_stats = (
    features_df.groupby("route")["caught"].agg(["mean", "count"])
    .sort_values("mean", ascending=False)
)
route_stats.head(15)

min_samples = 30
route_plot = route_stats[route_stats["count"] >= min_samples].reset_index().head(15)
plt.figure(figsize=(10, 6))
sns.barplot(data=route_plot, x="mean", y="route", hue="count", dodge=False)
plt.title(f"Catch rate by route (>= {min_samples} targets)")
plt.xlabel("Catch rate")
plt.ylabel("")
plt.show()


In [None]:
# Compare catch rates for contested vs non-contested throws
contested_stats = features_df.groupby("contested")["caught"].mean().rename("catch_rate")
display(contested_stats)

plt.figure(figsize=(6, 4))
sns.barplot(x=contested_stats.index.astype(str), y=contested_stats.values)
plt.title("Catch rate by contested flag")
plt.xlabel("Contested (1=yes)")
plt.ylabel("Catch rate")
plt.show()

# Dig into contested throws only to see separation differences
sns.histplot(
    data=features_df[features_df["contested"] == 1],
    x="sep_nearest",
    hue="caught",
    bins=20,
    stat="probability",
)
plt.title("Separation distribution for contested throws")
plt.show()


In [None]:
# Bucket pass length so we can compare short vs deep balls
bins = [-10, 0, 10, 20, 30, 40, 60]
labels = ["behind_LOS", "0-10", "10-20", "20-30", "30-40", "40+"]
features_df["pass_depth_bin"] = pd.cut(features_df["pass_length"], bins=bins, labels=labels)

pass_depth_stats = (
    features_df.groupby(["pass_depth_bin", "contested"])["caught"].mean().reset_index()
)
plt.figure(figsize=(10, 5))
sns.barplot(data=pass_depth_stats, x="pass_depth_bin", y="caught", hue="contested")
plt.title("Catch rate by pass depth bin and contested flag")
plt.xlabel("Pass depth bin (yards downfield)")
plt.ylabel("Catch rate")
plt.show()
pass_depth_stats


In [None]:
# Split the features into numeric and categorical chunks for the baseline model
use_cols_num = [
    "pass_length",
    "down",
    "yards_to_go",
    "sep_nearest",
    "sep_second",
    "num_def_within_2",
    "num_def_within_3",
    "num_def_within_5",
    "dist_receiver_to_ball",
]

use_cols_cat = ["route", "coverage_type", "coverage_man_zone", "pass_depth_bin"]

# Drop rows with missing key numeric fields to keep things simple
features_model = features_df.dropna(subset=use_cols_num).copy()

X_num = features_model[use_cols_num].values
X_cat = pd.get_dummies(features_model[use_cols_cat].astype("category"), dummy_na=True)
X = np.hstack([X_num, X_cat.values])
y = features_model["caught"].values

# Standard 80/20 split for validation
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)
X.shape, X_train.shape


In [None]:
# Gradient boosted trees make a strong baseline for tabular features
try:
    import xgboost as xgb
except ImportError:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "xgboost"])
    import xgboost as xgb


dtrain = xgb.DMatrix(X_train, label=y_train)
dval = xgb.DMatrix(X_val, label=y_val)

params = {
    "objective": "binary:logistic",
    "eval_metric": "auc",
    "eta": 0.05,
    "max_depth": 4,
    "subsample": 0.8,
    "colsample_bytree": 0.8,
}

evals = [(dtrain, "train"), (dval, "val")]
xgb_model = xgb.train(
    params,
    dtrain,
    num_boost_round=500,
    evals=evals,
    early_stopping_rounds=20,
    verbose_eval=50,
)

val_pred = xgb_model.predict(dval)
print("Baseline GBM AUC:", roc_auc_score(y_val, val_pred))

# Store predictions for every row so we can compute dominance later
features_model = features_model.reset_index(drop=True)
features_model["gbm_pred"] = xgb_model.predict(xgb.DMatrix(X))


In [None]:
def build_graph_for_pass(gid, pid, inputs_df, outputs_df, supp_df):
    """Create a PyG graph where each node is a player at ball arrival."""
    sub_in = inputs_df[(inputs_df["game_id"] == gid) & (inputs_df["play_id"] == pid)]
    tr = sub_in[sub_in["player_role"] == "Targeted Receiver"]
    if tr.empty:
        return None
    num_frames_output = int(tr["num_frames_output"].iloc[0])

    # Grab the arrival frame for this play
    sub_out = outputs_df[
        (outputs_df["game_id"] == gid)
        & (outputs_df["play_id"] == pid)
        & (outputs_df["frame_id"] == num_frames_output)
    ].copy()
    if sub_out.empty:
        return None

    # Attach player metadata (side, role, velocities)
    player_meta = sub_in[
        [
            "game_id",
            "play_id",
            "nfl_id",
            "player_side",
            "player_role",
            "player_position",
            "x_std",
            "y_std",
            "vx",
            "vy",
        ]
    ].drop_duplicates()

    df = sub_out.merge(
        player_meta,
        on=["game_id", "play_id", "nfl_id"],
        how="left",
        suffixes=("_arr", ""),
    )

    try:
        supp_row = supp_df.set_index(["game_id", "play_id"]).loc[(gid, pid)]
    except KeyError:
        return None

    # Node features = standardized positions, motion, role, and simple play context
    x_pos = df["x_std_arr"].fillna(df["x_std"]).values
    y_pos = df["y_std_arr"].fillna(df["y_std"]).values
    vx = df["vx"].fillna(0.0).values
    vy = df["vy"].fillna(0.0).values

    side_onehot = pd.get_dummies(df["player_side"].fillna("Unknown"))
    role_onehot = pd.get_dummies(df["player_role"].fillna("Unknown"))

    context = np.column_stack([
        np.full(len(df), supp_row["down"], dtype=np.float32),
        np.full(len(df), supp_row["yards_to_go"], dtype=np.float32),
        np.full(len(df), supp_row["pass_length"], dtype=np.float32),
    ])

    node_features = np.column_stack([
        x_pos,
        y_pos,
        vx,
        vy,
        side_onehot.reindex(df.index).fillna(0.0).values,
        role_onehot.reindex(df.index).fillna(0.0).values,
        context,
    ]).astype(np.float32)

    x_tensor = torch.from_numpy(node_features)
    n_nodes = x_tensor.size(0)
    if n_nodes < 2:
        return None

    # Fully connect the graph so every player exchanges information
    idx = np.arange(n_nodes)
    row_idx, col_idx = np.meshgrid(idx, idx, indexing="ij")
    mask = row_idx != col_idx
    edge_index = np.vstack([row_idx[mask], col_idx[mask]])
    edge_index = torch.from_numpy(edge_index).long()

    label = torch.tensor([1.0 if supp_row["pass_result"] == "C" else 0.0], dtype=torch.float32)

    # Track which node is the targeted receiver so we can read its embedding later
    target_id = int(tr["nfl_id"].iloc[0])
    target_mask = torch.zeros(n_nodes, dtype=torch.float32)
    target_positions = np.where(df["nfl_id"].values == target_id)[0]
    if len(target_positions) == 0:
        return None
    target_mask[target_positions[0]] = 1.0

    data = GeoData(
        x=x_tensor,
        edge_index=edge_index,
        y=label,
        target_mask=target_mask,
        game_id=torch.tensor([int(gid)]),
        play_id=torch.tensor([int(pid)]),
    )
    return data


In [None]:
# Build graphs for up to ~2k passes to keep training manageable
sample_keys = features_df[["game_id", "play_id"]].drop_duplicates()
if len(sample_keys) > 2000:
    sample_keys = sample_keys.sample(n=2000, random_state=42)

graphs = []
for _, row in sample_keys.iterrows():
    g = build_graph_for_pass(row["game_id"], row["play_id"], inputs, outputs_merged, supp)
    if g is not None:
        graphs.append(g)

len(graphs)


In [None]:
# Split by game_id so the same game does not appear in both train and validation
game_ids = np.array([int(g.game_id.item()) for g in graphs])
unique_games = np.unique(game_ids)
train_games, val_games = train_test_split(unique_games, test_size=0.2, random_state=42)

train_graphs = [g for g in graphs if int(g.game_id.item()) in train_games]
val_graphs = [g for g in graphs if int(g.game_id.item()) in val_games]

len(train_graphs), len(val_graphs)


In [None]:
# PyTorch Geometric loaders batch graphs during training
BATCH_SIZE = 32
train_loader = GeoDataLoader(train_graphs, batch_size=BATCH_SIZE, shuffle=True)
val_loader = GeoDataLoader(val_graphs, batch_size=BATCH_SIZE, shuffle=False)


In [None]:
class PassDominanceGAT(nn.Module):
    """A lightweight graph attention network that focuses on the targeted receiver node."""

    def __init__(self, in_dim, hidden_dim=64, num_heads=4, num_layers=3, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(GATv2Conv(in_dim, hidden_dim, heads=num_heads, dropout=dropout, concat=True))
        for _ in range(num_layers - 2):
            self.layers.append(GATv2Conv(hidden_dim * num_heads, hidden_dim, heads=num_heads, dropout=dropout, concat=True))
        self.layers.append(GATv2Conv(hidden_dim * num_heads, hidden_dim, heads=1, dropout=dropout, concat=True))
        self.dropout = nn.Dropout(dropout)
        self.out_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, data):
        x, edge_index, target_mask = data.x, data.edge_index, data.target_mask
        for conv in self.layers:
            x = conv(x, edge_index)
            x = torch.relu(x)
            x = self.dropout(x)
        batch = getattr(data, "batch", torch.zeros(x.size(0), dtype=torch.long, device=x.device))
        target_idx = (target_mask > 0.5).nonzero(as_tuple=False).squeeze(-1)
        target_embeddings = torch.zeros(batch.max().item() + 1 if batch.numel() else 1, x.size(1), device=x.device)
        for idx_node in target_idx:
            g_id = batch[idx_node].item() if batch.numel() else 0
            target_embeddings[g_id] = x[idx_node]
        logits = self.out_mlp(target_embeddings).squeeze(-1)
        return logits


In [None]:
def train_gnn(model, train_loader, val_loader, num_epochs=15, lr=1e-3):
    """Standard train loop with validation AUC tracking."""
    model = model.to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    criterion = nn.BCEWithLogitsLoss()
    best_auc = 0.0
    best_state = None
    for epoch in range(1, num_epochs + 1):
        model.train()
        train_losses = []
        for batch in train_loader:
            batch = batch.to(DEVICE)
            logits = model(batch)
            labels = batch.y.to(DEVICE).float().squeeze()
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
        model.eval()
        val_logits, val_labels = [], []
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(DEVICE)
                logits = model(batch)
                labels = batch.y.to(DEVICE).float().squeeze()
                val_logits.append(logits.cpu().numpy())
                val_labels.append(labels.cpu().numpy())
        val_logits = np.concatenate(val_logits)
        val_labels = np.concatenate(val_labels)
        val_probs = 1 / (1 + np.exp(-val_logits))
        val_auc = roc_auc_score(val_labels, val_probs)
        print(f"Epoch {epoch:02d} | Train Loss {np.mean(train_losses):.4f} | Val AUC {val_auc:.4f}")
        if val_auc > best_auc:
            best_auc = val_auc
            best_state = model.state_dict()
    if best_state is not None:
        model.load_state_dict(best_state)
    return model, best_auc


In [None]:
# Kick off GNN training (falls back gracefully if no graphs were built)
if graphs:
    in_dim = graphs[0].x.size(1)
    gnn_model = PassDominanceGAT(in_dim)
    gnn_model, best_val_auc = train_gnn(gnn_model, train_loader, val_loader, num_epochs=15)
    print("Best validation AUC (GNN):", best_val_auc)
else:
    gnn_model = None
    print("No graphs available for training.")


In [None]:
def predict_gnn(model, graphs):
    """Run the trained GNN on every graph to obtain catch probabilities."""
    if model is None or not graphs:
        return pd.DataFrame(columns=["game_id", "play_id", "gnn_pred"])
    model.eval()
    preds = []
    with torch.no_grad():
        for g in graphs:
            logits = model(g.to(DEVICE))
            prob = torch.sigmoid(logits).cpu().numpy().flatten()[0]
            preds.append({
                "game_id": int(g.game_id.item()),
                "play_id": int(g.play_id.item()),
                "gnn_pred": prob,
            })
    return pd.DataFrame(preds)


gnn_pred_df = predict_gnn(gnn_model, graphs)
gnn_pred_df.head()


In [None]:
# Compute a "what normally happens" baseline for each context bucket
features_df["down_bucket"] = features_df["down"].clip(upper=4)
features_df["coverage_type_simplified"] = features_df["coverage_type"].fillna("Unknown")
context_cols = ["route", "pass_depth_bin", "down_bucket", "coverage_type_simplified"]

baseline_table = (
    features_df.groupby(context_cols)["caught"].agg(["mean", "count"]).reset_index()
    .rename(columns={"mean": "baseline_catch_prob", "count": "baseline_count"})
)

features_df = features_df.merge(baseline_table, on=context_cols, how="left")
features_df.head()


In [None]:
# Merge both model predictions back onto the master per-pass table
gbm_preds = features_model[["game_id", "play_id", "nfl_id", "gbm_pred"]]
features_df = features_df.merge(gbm_preds, on=["game_id", "play_id", "nfl_id"], how="left")

features_df = features_df.merge(gnn_pred_df, on=["game_id", "play_id"], how="left")
features_df[["gbm_pred", "gnn_pred"]].describe()


In [None]:
# Dominance = model probability minus baseline expectation
features_df["dominance_gbm"] = features_df["gbm_pred"] - features_df["baseline_catch_prob"]
features_df["dominance_gnn"] = features_df["gnn_pred"] - features_df["baseline_catch_prob"]
features_df["dominance_gbm_contested"] = np.where(features_df["contested"] == 1, features_df["dominance_gbm"], np.nan)
features_df["dominance_gnn_contested"] = np.where(features_df["contested"] == 1, features_df["dominance_gnn"], np.nan)
features_df[["dominance_gbm", "dominance_gnn"]].describe()


In [None]:
def summarize_player_route_dominance(df, min_targets=15, dominance_col="dominance_gnn"):
    """Average dominance by player+route so we can rank who beats expectations."""
    group_cols = ["player_name", "nfl_id", "route"]
    agg = (
        df.groupby(group_cols)[dominance_col]
        .agg(["mean", "count"])
        .rename(columns={"mean": "mean_dominance", "count": "num_targets"})
        .reset_index()
    )
    agg = agg[agg["num_targets"] >= min_targets]
    return agg.sort_values("mean_dominance", ascending=False)


player_route_dom = summarize_player_route_dominance(features_df, min_targets=15, dominance_col="dominance_gnn")
player_route_dom.head(20)


In [None]:
# Same idea but only looking at contested targets
player_route_dom_contested = summarize_player_route_dominance(
    features_df, min_targets=10, dominance_col="dominance_gnn_contested"
)
player_route_dom_contested.head(20)


In [None]:
def plot_route_leaders(df, route_name, top_n=15, dominance_col="dominance_gnn"):
    """Visual helper to show who dominates a given route concept."""
    subset = df[df["route"].str.lower() == route_name.lower()]
    if subset.empty:
        print(f"No targets for route {route_name}")
        return
    top = subset.head(top_n)
    plt.figure(figsize=(8, 6))
    sns.barplot(data=top, x="mean_dominance", y="player_name", palette="viridis")
    plt.title(f"Top {top_n} {route_name} route dominance ({dominance_col})")
    plt.xlabel("Mean dominance over baseline")
    plt.ylabel("")
    plt.show()


plot_route_leaders(player_route_dom, "go")


In [None]:
def plot_player_routes(df, player_name, dominance_col="dominance_gnn"):
    """Show which routes a specific player excels at relative to context."""
    sub = df[df["player_name"].str.contains(player_name, case=False, na=False)].copy()
    if sub.empty:
        print(f"No passes found for {player_name}")
        return
    agg = (
        sub.groupby("route")[dominance_col]
        .agg(["mean", "count"])
        .sort_values("mean", ascending=False)
        .reset_index()
    )
    display(agg)
    plt.figure(figsize=(8, 4))
    sns.barplot(data=agg, x="mean", y="route")
    plt.title(f"{player_name} route dominance ({dominance_col})")
    plt.xlabel("Mean dominance over baseline")
    plt.ylabel("")
    plt.show()


plot_player_routes(features_df, "Keon Coleman", "dominance_gnn")
plot_player_routes(features_df, "Keon Coleman", "dominance_gnn_contested")


### Exporting Dominance Tables for External Analysis
Saving tidy tables lets us plug the dominance metrics back into a Kaggle notebook or scouting workflow without recomputing the whole pipeline.


In [None]:
OUTPUT_DIR = Path("artifacts")
OUTPUT_DIR.mkdir(exist_ok=True)

features_export_cols = [
    "game_id",
    "play_id",
    "nfl_id",
    "player_name",
    "route",
    "pass_length",
    "pass_depth_bin",
    "down",
    "yards_to_go",
    "coverage_type",
    "coverage_man_zone",
    "contested",
    "gbm_pred",
    "gnn_pred",
    "baseline_catch_prob",
    "dominance_gbm",
    "dominance_gnn",
    "dominance_gbm_contested",
    "dominance_gnn_contested",
]

features_df[features_export_cols].to_parquet(OUTPUT_DIR / "pass_dominance_per_target.parquet", index=False)
player_route_dom.to_parquet(OUTPUT_DIR / "player_route_dominance_overall.parquet", index=False)
player_route_dom_contested.to_parquet(OUTPUT_DIR / "player_route_dominance_contested.parquet", index=False)

OUTPUT_DIR


## Findings and Next Steps

- Separation and ball-alignment features clearly stratify catch outcomes: completions concentrate near `sep_nearest > 2 yds` and `dist_receiver_to_ball < 1 yd`, while contested throws (<2 yds) show a steep drop in catch rate.
- The tabular GBM provides a strong contextual baseline (report the `Baseline GBM AUC` printed above) and highlights the explanatory power of static arrival snapshots.
- The TacticAI-inspired GAT attends over all players simultaneously and captured additional relational signal (see the `Best validation AUC (GNN)` readout). Positive dominance (`gnn_pred - baseline_catch_prob`) indicates receivers beating expectations given route/depth/down/coverage context, negative scores flag underperformance.
- Aggregations uncovered which player-route combos (e.g., go routes) and which contested specialists (using `dominance_gnn_contested`) dominate beyond baseline expectation.

**Extensions**
- Incorporate multiple frames from ball release to arrival with temporal GNNs/LSTMs to capture defender leverage changes.
- Engineer geometric features such as relative leverage vectors (is the defender between ball and receiver) in addition to Euclidean separation.
- Explore CLRS-style processor modules for node/edge/global updates to mimic the original TacticAI architecture more faithfully.
- Expand dominance to downstream outcomes (EPA, YAC) or to quarterback decision support, and export per-player CSVs for Kaggle submissions or scouting dashboards.
- Integrate full-season weeks and hyperparameter tuning, then persist `features_df` with dominance columns for further analytics or Kaggle notebook handoff.
