In [3]:
#################################################################
# Load data
#################################################################

import os
os.environ["NFL_HOME"] = "/home/sam/repos/hobby-repos/nfl/"

from common.data_loader import DataLoader

# Get raw data
loader = DataLoader()
games_df, plays_df, players_df, location_data_df = loader.get_data(weeks=[week for week in range (1,10)])

print(location_data_df.head(10))


       gameId  playId    nflId     displayName  frameId    frameType  \
0  2022091200      64  35459.0  Kareem Jackson        1  BEFORE_SNAP   
1  2022091200      64  35459.0  Kareem Jackson        2  BEFORE_SNAP   
2  2022091200      64  35459.0  Kareem Jackson        3  BEFORE_SNAP   
3  2022091200      64  35459.0  Kareem Jackson        4  BEFORE_SNAP   
4  2022091200      64  35459.0  Kareem Jackson        5  BEFORE_SNAP   
5  2022091200      64  35459.0  Kareem Jackson        6  BEFORE_SNAP   
6  2022091200      64  35459.0  Kareem Jackson        7  BEFORE_SNAP   
7  2022091200      64  35459.0  Kareem Jackson        8  BEFORE_SNAP   
8  2022091200      64  35459.0  Kareem Jackson        9  BEFORE_SNAP   
9  2022091200      64  35459.0  Kareem Jackson       10  BEFORE_SNAP   

                    time  jerseyNumber club playDirection      x      y     s  \
0  2022-09-13 00:16:03.5          22.0  DEN         right  51.06  28.55  0.72   
1  2022-09-13 00:16:03.6          22.0  DEN  

In [4]:
#################################################################
# Filter down plays df to situations we are interested in
#################################################################
print("Filtering data...")
original_play_length = len(plays_df)
print(f'Total plays: {original_play_length}')

plays_df = plays_df[plays_df['playNullifiedByPenalty'] == 'N']
print(f'Total plays after filtering out penalties: {len(plays_df)}')

plays_df = plays_df[plays_df['pff_manZone'].isin(['Man', 'Zone'])]
print(f'Total plays after filtering to valid Man or Zone classifications: {len(plays_df)}')

plays_df = plays_df[plays_df['gameId'].isin(location_data_df['gameId'].unique())]
print(f'Total plays after matching plays_df to location_data_df: {len(plays_df)}')

print(plays_df.head())


Filtering data...
Total plays: 16124
Total plays after filtering out penalties: 16124
Total plays after filtering to valid Man or Zone classifications: 15114
Total plays after matching plays_df to location_data_df: 15114
       gameId  playId                                    playDescription  \
0  2022102302    2655  (1:54) (Shotgun) J.Burrow pass short middle to...   
1  2022091809    3698  (2:13) (Shotgun) J.Burrow pass short right to ...   
2  2022103004    3146  (2:00) (Shotgun) D.Mills pass short right to D...   
3  2022110610     348  (9:28) (Shotgun) P.Mahomes pass short left to ...   
4  2022102700    2799  (2:16) (Shotgun) L.Jackson up the middle to TB...   

   quarter  down  yardsToGo possessionTeam defensiveTeam yardlineSide  \
0        3     1         10            CIN           ATL          CIN   
1        4     1         10            CIN           DAL          CIN   
2        4     3         12            HOU           TEN          HOU   
3        1     2         10   

In [5]:
#################################################################
# Create merged df that has gameId, playId, frameID all before SNAP, with x, y, and offense/defense
#################################################################
import pandas as pd
import numpy as np

# Create a copy of the location tracking data, cut it down to columns we care about
loc_trimmed_df = location_data_df.copy()
keep_cols  = [
    'gameId',
    'playId',
    'nflId',
    'frameId',
    'frameType',
    'club',
    'x',
    'y',
]
loc_trimmed_df = location_data_df.loc[:, keep_cols]

# Cut down location tracking data copy to only before the snap and where the team isn't valid
loc_trimmed_df.query('frameType == "BEFORE_SNAP" and club != "football"', inplace=True)

print(loc_trimmed_df.head())


       gameId  playId    nflId  frameId    frameType club      x      y
0  2022091200      64  35459.0        1  BEFORE_SNAP  DEN  51.06  28.55
1  2022091200      64  35459.0        2  BEFORE_SNAP  DEN  51.13  28.57
2  2022091200      64  35459.0        3  BEFORE_SNAP  DEN  51.20  28.59
3  2022091200      64  35459.0        4  BEFORE_SNAP  DEN  51.26  28.62
4  2022091200      64  35459.0        5  BEFORE_SNAP  DEN  51.32  28.65


In [6]:
# Create a copy of the plays data and cut it down to columns we care about
plays_trimmed_df = plays_df.copy()
keep_cols_from_plays = ['gameId','playId','possessionTeam', 'defensiveTeam']
plays_trimmed_df = plays_trimmed_df.loc[:, keep_cols_from_plays].drop_duplicates()

print(plays_trimmed_df.head())

       gameId  playId possessionTeam defensiveTeam
0  2022102302    2655            CIN           ATL
1  2022091809    3698            CIN           DAL
2  2022103004    3146            HOU           TEN
3  2022110610     348             KC           TEN
4  2022102700    2799            BAL            TB


In [7]:

# Merge the two datasets such that we can have the possession and defensive team for each row
merged_df = pd.merge(plays_trimmed_df, loc_trimmed_df, on=['gameId', 'playId'], how='inner')

# Tag the "side" of the player for each row (that being "off" or "def")
merged_df['side'] = np.where(merged_df['club'] == merged_df['possessionTeam'], 'off', 'def')

# Drop some columns we don't need anymore
merged_df.drop(['possessionTeam', 'defensiveTeam', 'club', 'frameType'], axis=1, inplace=True)

# Sort for deterministic frame ordering
merged_df = merged_df.sort_values(['gameId','playId','frameId'])

# Let's see what we have
print(merged_df.head())


              gameId  playId    nflId  frameId      x      y side
29059338  2022090800      56  35472.0        1  89.48  29.52  off
29059483  2022090800      56  38577.0        1  81.93  28.52  def
29059628  2022090800      56  41239.0        1  82.90  29.84  def
29059773  2022090800      56  42392.0        1  88.80  30.19  off
29059918  2022090800      56  42489.0        1  91.08  28.34  off


In [None]:

# Decide the target sequence length using the median number of frames per play ---
frame_counts = (merged_df
                .groupby(['gameId','playId'])['frameId']
                .nunique())
median_T = int(np.median(frame_counts.values))
print(f"Using median sequence length T = {median_T}")

median_T = 50 # Just setting this to keep play count high

Using median sequence length T = 107


In [None]:
# Built dataset where frame is only 50 frames and each side has exactly 11 players
off_series = {}
def_series = {}

skipped_inccorect_players = []   # plays where offense or defense had >11 unique players
skipped_too_short = []           # plays with fewer than median_T frames

for (g, p), grp in merged_df.groupby(['gameId','playId'], sort=False):
    # Skip if >11 players on either side
    off_n = grp.loc[grp['side']=='off', 'nflId'].nunique()
    def_n = grp.loc[grp['side']=='def', 'nflId'].nunique()
    if off_n != 11 or def_n != 11:
        skipped_inccorect_players.append((g,p))
        continue

    # B) Left->right slot order using median x (tie-break median y) per side
    off_stats = (grp[grp['side']=='off']
                 .groupby('nflId')
                 .agg(x_med=('x','median'), y_med=('y','median'))
                 .sort_values(['x_med','y_med']))
    def_stats = (grp[grp['side']=='def']
                 .groupby('nflId')
                 .agg(x_med=('x','median'), y_med=('y','median'))
                 .sort_values(['x_med','y_med']))

    off_order = off_stats.index.tolist()[:11]
    def_order = def_stats.index.tolist()[:11]
    if len(off_order) == 0 or len(def_order) == 0:
        # nothing useful to build on this play
        skipped_too_short.append((g,p))  # treat as unusable
        continue

    off_id2slot = {pid:i for i,pid in enumerate(off_order)}
    def_id2slot = {pid:i for i,pid in enumerate(def_order)}

    tmp = grp.copy()
    tmp['slot'] = np.where(tmp['side']=='off',
                           tmp['nflId'].map(off_id2slot),
                           tmp['nflId'].map(def_id2slot))

    # Keep only slotted players (0..10)
    tmp = tmp[tmp['slot'].between(0,10)]

    # Make frame index and pivot
    frames = np.sort(tmp['frameId'].unique())
    T = len(frames)

    # If this play has fewer than the target T, skip it
    if T < median_T:
        skipped_too_short.append((g,p))
        continue

    # Optionally trim to the last median_T frames (closest to snap)
    if T > median_T:
        frames = frames[-median_T:]

    goff = tmp[tmp['side']=='off']
    gdef = tmp[tmp['side']=='def']

    off_x = (goff.pivot_table(index='frameId', columns='slot', values='x')
                  .reindex(frames).reindex(columns=range(11), fill_value=np.nan))
    off_y = (goff.pivot_table(index='frameId', columns='slot', values='y')
                  .reindex(frames).reindex(columns=range(11), fill_value=np.nan))
    def_x = (gdef.pivot_table(index='frameId', columns='slot', values='x')
                  .reindex(frames).reindex(columns=range(11), fill_value=np.nan))
    def_y = (gdef.pivot_table(index='frameId', columns='slot', values='y')
                  .reindex(frames).reindex(columns=range(11), fill_value=np.nan))

    off_arr = np.stack([off_x.to_numpy(), off_y.to_numpy()], axis=-1)  # (median_T, 11, 2)
    def_arr = np.stack([def_x.to_numpy(), def_y.to_numpy()], axis=-1)  # (median_T, 11, 2)

    off_series[(g,p)] = off_arr
    def_series[(g,p)] = def_arr

print(f"Kept plays: {len(off_series)}")
print(f"Skipped (>11 players): {len(skipped_inccorect_players)}")
print(f"Skipped (<{median_T} frames): {len(skipped_too_short)}")

Kept plays: 13651
Skipped (>11 players): 0
Skipped (<50 frames): 1458


In [11]:
import torch
from torch.utils.data import Dataset, DataLoader

class PlaysDataset(Dataset):
    def __init__(self, off_series, def_series, labels_dict):  # labels_dict[(gameId,playId)] -> 0/1
        X_list, y_list = [], []
        for key in off_series.keys():
            off_arr = off_series[key]     # (T, 11, 2)
            def_arr = def_series[key]     # (T, 11, 2)
            X = np.concatenate([off_arr, def_arr], axis=1).reshape(off_arr.shape[0], -1)  # (T, 44)
            if np.isnan(X).any():
                # You can impute; for now, skip if NaNs present
                continue
            X_list.append(torch.from_numpy(X).float())
            y_list.append(torch.tensor(labels_dict[key], dtype=torch.long))
        self.X = X_list
        self.y = y_list

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        # return (T, F), label
        return self.X[idx], self.y[idx]

def collate_batch(batch):
    # All sequences already same T → simple stack
    Xs, ys = zip(*batch)
    return torch.stack(Xs, dim=0), torch.stack(ys, dim=0)  # (B, T, F), (B,)


In [12]:
import torch.nn as nn

class LSTMClassifier(nn.Module):
    def __init__(self, input_size=44, hidden_size=64, num_layers=1, dropout=0.0, bidir=False, num_classes=2):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0,
            bidirectional=bidir,
        )
        out_dim = hidden_size * (2 if bidir else 1)
        self.head = nn.Sequential(
            nn.LayerNorm(out_dim),
            nn.Linear(out_dim, num_classes)
        )

    def forward(self, x):  # x: (B, T, F)
        out, (h_n, c_n) = self.lstm(x)        # out: (B, T, H)
        last = out[:, -1, :]                   # use last timestep representation
        logits = self.head(last)               # (B, C)
        return logits


In [None]:
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

# Build labels dict mapping (gameId, playId) -> 0/1
label_map = {'Man': 1, 'Zone': 0}
labels_df = plays_df[['gameId','playId','pff_manZone']].drop_duplicates()
labels_dict = {(r.gameId, r.playId): label_map[r.pff_manZone] for r in labels_df.itertuples()}

dataset = PlaysDataset(off_series, def_series, labels_dict)
idx_train, idx_val = train_test_split(np.arange(len(dataset)), test_size=0.2, random_state=0, stratify=[dataset.y[i].item() for i in range(len(dataset))])

subset = torch.utils.data.Subset
train_loader = DataLoader(subset(dataset, idx_train), batch_size=64, shuffle=True, collate_fn=collate_batch)
val_loader   = DataLoader(subset(dataset, idx_val),   batch_size=128, shuffle=False, collate_fn=collate_batch)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LSTMClassifier(input_size=44, hidden_size=64, num_layers=1, dropout=0.0, bidir=False).to(device)


# Zone dominates class weighting, calc distro then assign man a higher waiting on the CE loss
y_all = np.array([dataset.y[i].item() for i in range(len(dataset))], dtype=int) # Build an array of all dataset labels
y_train = y_all[idx_train] # Slice to the training fold
classes = np.array([0, 1], dtype=int)  # 0=Zone, 1=Man
w = compute_class_weight(class_weight='balanced', classes=classes, y=y_train)
print("Class weights (Zone, Man):", w)
class_weights = torch.tensor(w, dtype=torch.float32, device=device)
criterion = nn.CrossEntropyLoss(weight=class_weights)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(100):

    # Train
    model.train()
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        logits = model(X)
        loss = criterion(logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Val
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for X, y in val_loader:
            X, y = X.to(device), y.to(device)
            pred = model(X).argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.numel()
    print(f"Epoch {epoch+1}: val acc = {correct/total:.3f}")


Class weights (Zone, Man): [0.6932453  1.79369251]
Epoch 1: val acc = 0.518
Epoch 2: val acc = 0.709
Epoch 3: val acc = 0.602
Epoch 4: val acc = 0.616
Epoch 5: val acc = 0.535
Epoch 6: val acc = 0.676
Epoch 7: val acc = 0.699
Epoch 8: val acc = 0.722
Epoch 9: val acc = 0.710
Epoch 10: val acc = 0.711
Epoch 11: val acc = 0.678
Epoch 12: val acc = 0.726
Epoch 13: val acc = 0.683
Epoch 14: val acc = 0.279
Epoch 15: val acc = 0.461
Epoch 16: val acc = 0.728
Epoch 17: val acc = 0.680
Epoch 18: val acc = 0.723
Epoch 19: val acc = 0.400
Epoch 20: val acc = 0.708
Epoch 21: val acc = 0.730
Epoch 22: val acc = 0.731
Epoch 23: val acc = 0.380
Epoch 24: val acc = 0.709
Epoch 25: val acc = 0.732
Epoch 26: val acc = 0.373
Epoch 27: val acc = 0.408
Epoch 28: val acc = 0.726
Epoch 29: val acc = 0.469
Epoch 30: val acc = 0.729
Epoch 31: val acc = 0.713
Epoch 32: val acc = 0.742
Epoch 33: val acc = 0.712
Epoch 34: val acc = 0.736
Epoch 35: val acc = 0.728
Epoch 36: val acc = 0.664
Epoch 37: val acc = 0.

In [22]:
from sklearn.metrics import classification_report

model.eval()
all_preds, all_true = [], []

with torch.no_grad():
    for X, y in val_loader:
        X, y = X.to(device), y.to(device)
        pred = model(X).argmax(dim=1)
        all_preds.extend(pred.cpu().numpy())
        all_true.extend(y.cpu().numpy())

print(classification_report(all_true, all_preds, target_names=["Zone", "Man"]))

              precision    recall  f1-score   support

        Zone       0.73      0.98      0.84      1970
         Man       0.58      0.09      0.15       761

    accuracy                           0.73      2731
   macro avg       0.66      0.53      0.50      2731
weighted avg       0.69      0.73      0.65      2731

