In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from model import tackleNetwork
from dataset import nfl_tackle_data
from train import tackle_model_trainer
import warnings

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
device

In [None]:
FEATURES = "../features"
MASTER_FNAME = "master.csv"

In [None]:
train_nfl_tackle_data = nfl_tackle_data(FEATURES, MASTER_FNAME)
train_nfl_tackle_data_loader = DataLoader(train_nfl_tackle_data, batch_size=128, shuffle=True, drop_last=True, pin_memory=True, num_workers=20)

In [None]:
# hyperparameters
class_weight = 0.8
x_loss_weight = 0.9
epochs = 100
sensitivity_weight = 0.4
lr = 0.002
betas=(0.9, 0.999)

activation = nn.LeakyReLU()
model = tackleNetwork(activation).to(device)
optimizer = 'Adam'

seed = 90
torch.manual_seed(seed)

model_name = f'clss-wt-{class_weight}_x-lss-wt_{x_loss_weight}_x-sens-wt_{sensitivity_weight}'

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    model_trainer = tackle_model_trainer(
        model,
        device,
        optimizer,
        epochs,
        class_weight,
        x_loss_weight,
        train_nfl_tackle_data_loader,
        sensitivity_weight=sensitivity_weight,
        lr=lr,
        betas=betas
    )

    model_trainer.save_hyperparameters()
    model_trainer.fit()