In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW

from preprocessing import get_cv_val_ids, get_data_splits
from models import BlitzFrameData, BlitzFrameTransformer
from utils import combine_batch2
from training import train_epoch2, validate_epoch2, tidy_val_preds

In [None]:
data = pd.read_csv("data/feats.csv")
data["y"] = data["y"].fillna(0)
data = data.dropna(subset=["rel_x_lag"])

play_data = pd.read_csv("data/play_feats.csv")
play_data = play_data.set_index(["game_id", "play_id"])
play_data["is_man"] = play_data["is_man"].fillna(0)

In [None]:
play_data_dict = {}
for key, df in play_data.groupby(play_data.index):
    play_data_dict[key] = df.iloc[0].values

In [None]:
val_ids = get_cv_val_ids(data)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(2024)
torch.cuda.manual_seed(2024)

In [None]:
feat_names = ["rel_x", "rel_y", "rel_x_lag", "rel_y_lag", "speed_x", "speed_y", 
              "acc", "ox", "oy", "position_id"]
cv_preds = []
for i, ids in enumerate(val_ids):
  print(f"Validating model with week {i + 1} as validation set.")

  train, val = get_data_splits(data, play_data_dict, ids, feat_names)

  train_data = BlitzFrameData(train)
  val_data = BlitzFrameData(val)
  train_loader = DataLoader(train_data, batch_size=128, shuffle=True, collate_fn=combine_batch2)
  val_loader = DataLoader(val_data, batch_size=128, collate_fn=combine_batch2)

  input_dim = train[0][1].shape[1] - 1
  z_dim = len(train[0][2])
  model = BlitzFrameTransformer(input_dim, z_dim, 64, 2, 1).to(device)
  optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.1)
  loss_fn = nn.BCELoss()

  best_loss = 1.
  best_epoch = 0
  epoch = 0
  while epoch - best_epoch < 5:

      train_loss, train_preds = train_epoch2(train_loader, model, optimizer, loss_fn, device)
      val_loss, val_preds = validate_epoch2(val_loader, model, loss_fn, device)

      print(f"Epoch {epoch}:\n")
      print(f"Train loss: {np.round(train_loss, 4)}, Val loss: {np.round(val_loss, 4)}")

      if val_loss < best_loss:
          best_loss = val_loss
          best_epoch = epoch
          best_preds = val_preds.copy()
          best_train_preds = train_preds.copy()

      epoch += 1

  val_preds_df = tidy_val_preds(best_preds, data)
  val_preds_df.to_csv(f"data/val_preds_week_{i + 1}.csv", index=False)
  cv_preds.append(val_preds_df)

In [None]:
cv_preds = []
for i in range(1, 10):
  cv_preds.append(pd.read_csv(f"data/val_preds_week_{i}.csv"))
cv_preds_df = pd.concat(cv_preds)

In [None]:
cv_preds_df.to_csv("data/cv_preds.csv", index=False)