In [None]:
import numpy as np
import pandas as pd
import torch 
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW

from preprocessing import get_cv_val_ids, get_data_splits
from models import BlitzData, BlitzFrameData, BlitzLSTM, PressureFrameTransformer
from utils import combine_batch, combine_batch2
from training import train_epoch, validate_epoch, train_epoch2, validate_epoch2, get_lstm_output

In [None]:
torch.backends.cudnn.deterministic = True
torch.manual_seed(2024)

In [None]:
data = pd.read_csv("data/feats.csv")

data["y"] = data["y"].fillna(0)
data["y_pressure"] = data["y_pressure"].fillna(0)
data["acc"] = data["acc"].fillna(0)
data["blitz_prob"] = np.where(
    (data["pass_rusher"] == 1) & (data["y"] == -1),
    1., data["blitz_prob"]
)
data["blitz_prob"] = data["blitz_prob"].fillna(0)
data["index"] = data.groupby(["game_id", "play_id", "player_id"]).cumcount()

# at least one blitzer or rush at least 5 
filt = data.groupby(["game_id", "play_id"])["y"].transform(lambda x: np.any(x == 1))
filt = filt | data.groupby(["game_id", "play_id"])["pass_rusher"].transform(lambda x: x.mean() > 4/22)
data = data[filt].reset_index(drop=True)
data = data.dropna(subset=["rel_x_lag"])

play_data = pd.read_csv("data/play_feats.csv")
play_data = play_data.set_index(["game_id", "play_id"])
play_data["is_man"] = play_data["is_man"].fillna(0)

In [None]:
play_data_dict = {}
for key, df in play_data.groupby(play_data.index):
    play_data_dict[key] = df.iloc[0].values

In [None]:
val_ids = get_cv_val_ids(data)

In [None]:
lstm_preds = []
all_preds = []
#for i in range(len(val_ids)):
for i in range(len(val_ids)):

    if i != 7:
        continue

    print(f"Training with week {i + 1} as validation set.\n")

    data_lstm = data.loc[data["y"] == 1].copy()
    data_lstm = data_lstm.loc[data["tts"] < 100]
    data_lstm["position_id"] = data_lstm["position_id"].apply(
        lambda x: np.select(
            [x == 6, x == 7, x == 9, x == 11, x == 14],
            [0, 1, 2, 3, 4]
        )
    )
    train_df = data_lstm.loc[~data_lstm.index.isin(val_ids[i])]
    val_df = data_lstm.loc[data_lstm.index.isin(val_ids[i])]

    # feature sequences by player for lstm (n_frames x n_features) 
    feats = ["rel_x", "rel_y", "speed_x", "speed_y", "acc", 
             "blitz_prob_norm", "tts", "position_id"]
    train = []
    for key, df in train_df.groupby(["game_id", "play_id", "player_id"]):
        y = df["play_pressure"].values 
        df = df.drop(columns="y_pressure")
        X = df[feats].values
        train.append((key, X, y))
    val = []
    for key, df in val_df.groupby(["game_id", "play_id", "player_id"]):
        y = df["play_pressure"].values 
        df = df.drop(columns="y_pressure")
        X = df[feats].values
        val.append((key, X, y))

    train_data = BlitzData(train)
    train_loader = DataLoader(train_data, batch_size=64, collate_fn=combine_batch, shuffle=True)
    val_data = BlitzData(val)
    val_loader = DataLoader(val_data, batch_size=64, collate_fn=combine_batch)

    input_dim = train[0][1].shape[1]
    model = BlitzLSTM(input_dim, hidden_dim=16, output_dim=1, num_lstm_layers=1)
    optimizer = AdamW(model.parameters(), lr=2e-4, weight_decay=0.01)
    loss_fn = nn.BCELoss()

    best_loss = 1.
    best_epoch = 0
    epoch = 0 
    while epoch - best_epoch < 5:

        train_loss = train_epoch(train_loader, model, optimizer, loss_fn, final_output=True)
        val_loss, val_preds = validate_epoch(val_loader, model, loss_fn, final_output=True)

        print(f"Epoch {epoch}:\n")
        print(f"Train loss: {np.round(train_loss, 4)}, Val loss: {np.round(val_loss, 4)}")
        
        if val_loss < best_loss:
            best_loss = val_loss
            best_epoch = epoch
            best_preds = val_preds.copy()
        
        epoch += 1

    # save lstm preds 
    val_preds_df = pd.DataFrame()
    for key, arr in val_preds.items():
        df = pd.DataFrame(arr, columns=["pred"])
        df["game_id"] = key[0]
        df["play_id"] = key[1]
        df["player_id"] = key[2]
        val_preds_df = pd.concat([val_preds_df, df])
    val_preds_df = val_preds_df[["game_id", "play_id", "player_id", "pred"]]
    lstm_preds.append(val_preds_df)

    all_data = BlitzData(train + val)
    all_data_loader = DataLoader(all_data, batch_size=128, collate_fn=combine_batch)
    all_df = get_lstm_output(model, all_data_loader)

    # add lstm output as feature for transformer model 
    if "lstm0" in data.columns:
        data = data.drop(columns="lstm0")
    data = data.merge(all_df, on=["game_id", "play_id", "player_id", "index"], how="left")
    data["lstm0"] = np.where(
        data["y"] == -1, 0.45, data["lstm0"]
    )
    data["lstm0"] = data["lstm0"].fillna(0.)

    feats = ["rel_x", "rel_y", "rel_x_lag", "rel_y_lag", "speed_x", "speed_y", 
             "acc", "pass_rusher", "lstm0", "tts", "position_id"]
    train, val = get_data_splits(data, 
                                 play_data_dict, 
                                 val_ids[i], 
                                 feats=feats,
                                 label="y_pressure",
                                 mirror_train=True)
    train = [x for x in train if x[1][0, -2] < 100]
    val = [x for x in val if x[1][0, -2] < 100]

    train_data = BlitzFrameData(train)
    val_data = BlitzFrameData(val)
    train_loader = DataLoader(train_data, batch_size=64, shuffle=True, collate_fn=combine_batch2)
    val_loader = DataLoader(val_data, batch_size=64, collate_fn=combine_batch2)

    input_dim = train[0][1].shape[1] - 2
    z_input_dim = len(train[0][2])
    device = torch.device("cpu")
    model = PressureFrameTransformer(input_dim, z_input_dim, 32, 1, 1)
    optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.03)
    loss_fn = nn.BCELoss()

    best_loss = 1.
    best_epoch = 0
    epoch = 0 
    while epoch - best_epoch < 7:

        train_loss, train_preds = train_epoch2(train_loader, model, optimizer, loss_fn, device, pool=True)
        val_loss, val_preds = validate_epoch2(val_loader, model, loss_fn, device, pool=True)

        print(f"Epoch {epoch}:\n")
        print(f"Train loss: {np.round(train_loss, 4)}, Val loss: {np.round(val_loss, 4)}")
        
        if val_loss < best_loss:
            best_loss = val_loss
            best_epoch = epoch
            best_preds = val_preds.copy()
            best_train_preds = train_preds.copy()
        
        epoch += 1

    keys = []
    vals = []
    for k, v in val_preds.items():
        keys.append(k)
        vals.append(v)
    df = pd.DataFrame(vals, 
                      columns=["pred"],
                      index=pd.MultiIndex.from_tuples(keys, names=["game_id", "play_id", "frame_id"]))
    df = df.reset_index()
    all_preds.append(df)

In [None]:
all_preds_df = pd.concat(all_preds)

In [None]:
all_preds_df.to_csv("data/cv_pressure_preds.csv", index=False)