# NN CNN and LSTM Parallel

In [1]:
#import libraries
import pandas as pd
import numpy as np
import torch.nn as nn
import warnings
import time
import torch.optim as optim
warnings.filterwarnings("ignore")
pd.set_option('display.max_columns', None)
warnings.filterwarnings("ignore")
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

In [2]:
#Run other notebook for functions
%run Model_prep.ipynb

In [3]:
#Load data
train = pd.read_csv("../Data/train_sample_synthetic.csv")

In [4]:
#tranform gamePlayId variable to account for synthetic data
# Create a mask to identify duplicates based on 'gamePlayId', 'frameId', and 'nflId'
duplicates_mask = train.duplicated(subset=['gamePlayId', 'frameId', 'nflId'], keep='first')

# Add '.1' to 'gamePlayId' for the second occurrence of each duplicate
train.loc[duplicates_mask, 'gamePlayId'] += '.1'

In [5]:
x, y, mask = data_tensors(train, "tackle_binary_single")

In [6]:
print(x.shape)
print(y.shape)
print(mask.shape)

torch.Size([2618, 140, 11, 87])
torch.Size([2618, 140, 11])
torch.Size([2618, 140, 11])


In [7]:
batch_size = 374
train_data = TensorDataset(x, y)
train_loader = DataLoader(train_data, shuffle=False, batch_size=batch_size, drop_last=True)

In [8]:
# torch.cuda.is_available() checks and returns a Boolean True if a GPU is available, else it'll return False
is_cuda = torch.cuda.is_available()

# If we have a GPU available, we'll set our device to GPU. We'll use this device variable later in our code.
if is_cuda:
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [14]:
class CNN_LSTMModel(nn.Module):
    def __init__(self, input_channels, n_features, hidden_size, num_layers, num_classes):
        super(CNN_LSTMModel, self).__init__()
        # Define the CNN layers
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels=input_channels, out_channels=n_features, kernel_size=(1, 1), stride=1, padding=0),
            nn.ReLU(),
            nn.AdaptiveMaxPool2d((None, 2)),
            nn.Conv2d(in_channels=n_features, out_channels=n_features*2, kernel_size=(1, 1), stride=1, padding=0),
            nn.ReLU(),
            nn.AdaptiveMaxPool2d((None, 1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=n_features*2, out_channels=140, kernel_size=(1, 1), stride=1, padding=0)
            #Do i want in_channels to be n-Features or frames???
        )

        
        # LSTM layers
        self.lstm = nn.LSTM(input_size=n_features*11, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
        self.fc_LSTM = nn.Linear(348,11) # Fully connected layers

    def forward(self, x):
        batch_size, frames, players, features = x.size()
        
        out_cnn = self.cnn(x)
        out_cnn = out_cnn.view(batch_size,frames,players)
        
        # LSTM layer
        x_lstm = x.view(batch_size, frames, -1)#reshape
        out_lstm, _ = self.lstm(x_lstm) #run through the LSTM
        out_lstm = out_lstm.reshape(batch_size,frames, -1)
        out_lstm = self.fc_LSTM(out_lstm)
        out_lstm = out_lstm.reshape(batch_size,frames,11)
        
        #combine CNN and LSTM ouputs
        combined = out_cnn + out_lstm
        
        #perform sigmoid
        out = torch.sigmoid(combined)

        
        return out


In [10]:
def train_nn(train_loader, learn_rate=0.05, batch_size=374,hidden_dim=348, n_layers = 2, EPOCHS=5):
    
    # Setting dimension inputs
    input_dim = 140 #num of frames
    n_features = 87 #number of features
    
    # Define loss function, optimizer and model
    model = CNN_LSTMModel(input_channels = input_dim, n_features=n_features, 
                        hidden_size = hidden_dim, num_layers = n_layers, num_classes = 2)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=learn_rate)

    model.train()
    print("Starting Training of ParallelCNNLSTMModel")
    epoch_times = []

    for epoch in range(1, EPOCHS + 1):
        start_time = time.time()
        avg_loss = 0.0
        counter = 0

        for x, label in train_loader:
            counter += 1
            optimizer.zero_grad()
            
            # Forward pass
            out = model(x.to(device).float())
            
            # Compute loss
            loss = criterion(out, label.to(device).float())
            
            # Backpropagation
            loss.backward()
            optimizer.step()
            
            avg_loss += loss.item()
            
            if counter % 200 == 0:
                print("Epoch {}... Step: {}/{}... Average Loss for Epoch: {:.4f}".format(epoch, counter, len(train_loader), avg_loss / counter))
        
        current_time = time.time()
        print("Epoch {}/{} Done, Total Loss: {:.4f}".format(epoch, EPOCHS, avg_loss / len(train_loader)))
        print("Total Time Elapsed: {:.2f} seconds".format(current_time - start_time))
        epoch_times.append(current_time - start_time)
    
    print("Total Training Time: {:.2f} seconds".format(sum(epoch_times)))
    return model

In [15]:
CNN_LSTM = train_nn(train_loader)

Starting Training of ParallelCNNLSTMModel
Epoch 1/5 Done, Total Loss: 0.1608
Total Time Elapsed: 57.90 seconds
Epoch 2/5 Done, Total Loss: 0.0791
Total Time Elapsed: 28.98 seconds
Epoch 3/5 Done, Total Loss: 0.0791
Total Time Elapsed: 30.33 seconds
Epoch 4/5 Done, Total Loss: 0.0791
Total Time Elapsed: 33.18 seconds
Epoch 5/5 Done, Total Loss: 0.0791
Total Time Elapsed: 31.56 seconds
Total Training Time: 181.97 seconds


In [16]:
display(train)

Unnamed: 0,gameId,playId,nflId,frameId,x,y,unitDir,unitO,force,home,preSnapWinProbabilityDefense,bcx,bcy,bcs,bca,bco,bcdir,bcweight,bcPosition,bcForce,play_type,c1Dist,c2Dist,c3Dist,c4Dist,c5Dist,c6Dist,c7Dist,c8Dist,c9Dist,c10Dist,bcDist,c1Ang,c2Ang,c3Ang,c4Ang,c5Ang,c6Ang,c7Ang,c8Ang,c9Ang,c10Ang,bcAng,a,s,tackles_ingame,assists_ingame,ff_ingame,misses_ingame,tackle_efficiency_ingame,tackle_rating_ingame,rolling_tackles,rolling_assists,rolling_ff,rolling_mt,DL,LB,DB,QB,RB,WR,TE,OL,tackle_binary_all,tackle_binary_single,tackle_nonbinary_all,tackle_nonbinary_single,down,yardsToGo,defendersInTheBox,offenseFormation,absoluteYardlineNumber,timeSinceStart,surface,inside_outside,presnapDefScoreDiff,weight,position,gamePlayId
0,2.022091e+09,56.0,38577.0,6.0,41.89,28.740000,87.71,79.47,288.200000,1,0.413347,40.15,35.590000,4.61,4.82,245.73,157.80,191,WR,418.463636,pass,3.195387,10.116071,10.461855,10.882909,12.035414,12.635874,12.701657,13.169210,14.799963,23.582173,7.067538,75.239538,89.059902,122.727777,103.094000,104.570954,166.159966,86.642246,80.903275,77.812451,158.729540,16.542527,2.62,3.35,0,0,0,0,0.0,0.000000,0.0,0.0,0.0,0.0,3,2,6,1,1,3,1,5,0,0,0.0,0.0,1,10,6.0,SHOTGUN,85,0,turf,inside,0,242,ILB,2022090800.056.0
1,2.022091e+09,56.0,41239.0,6.0,27.85,29.960000,247.65,276.16,364.000000,1,0.413347,40.15,35.590000,4.61,4.82,245.73,157.80,191,WR,418.463636,pass,1.400321,1.783620,2.496898,3.993257,4.414386,4.674409,8.228657,17.168183,21.436532,32.008038,13.527265,113.577579,163.037991,150.980380,102.982031,53.447313,68.130076,59.944064,110.580937,72.972206,65.058378,136.944687,2.86,3.62,0,0,0,0,0.0,0.000000,0.0,0.0,0.0,0.0,3,2,6,1,1,3,1,5,0,0,0.0,0.0,1,10,6.0,SHOTGUN,85,0,turf,inside,0,280,DT,2022090800.056.0
2,2.022091e+09,56.0,42816.0,6.0,49.38,7.660000,8.33,61.57,346.254545,1,0.413347,40.15,35.590000,4.61,4.82,245.73,157.80,191,WR,418.463636,pass,1.233207,10.014569,22.204274,22.838312,26.325539,26.712411,27.894992,30.064028,31.255438,33.017583,29.415605,89.937075,111.358111,93.020417,127.965101,125.346367,127.777247,120.762863,123.703959,122.414454,123.035494,99.957121,4.14,2.60,0,0,0,0,0.0,0.000000,0.0,0.0,0.0,0.0,3,2,6,1,1,3,1,5,0,0,0.0,0.0,1,10,6.0,SHOTGUN,85,0,turf,inside,0,184,CB,2022090800.056.0
3,2.022091e+09,56.0,43294.0,6.0,41.85,37.850000,268.50,230.96,116.290909,1,0.413347,40.15,35.590000,4.61,4.82,245.73,157.80,191,WR,418.463636,pass,8.993442,13.196030,14.422794,14.850576,15.279797,15.418982,16.539265,16.979061,21.643128,32.342421,2.828003,22.070958,48.171854,61.668459,56.543692,67.764015,42.186107,43.622479,30.430267,8.319643,15.291336,35.450938,1.23,5.88,0,0,0,0,0.0,0.000000,0.0,0.0,0.0,0.0,3,2,6,1,1,3,1,5,1,0,1.0,0.0,1,10,6.0,SHOTGUN,85,0,turf,inside,0,208,CB,2022090800.056.0
4,2.022091e+09,56.0,43298.0,6.0,27.89,33.140000,293.53,249.12,241.090909,1,0.413347,40.15,35.590000,4.61,4.82,245.73,157.80,191,WR,418.463636,pass,0.773886,2.104305,3.431049,5.466160,7.240836,7.311580,10.903687,17.517377,23.554390,34.387191,12.502404,48.770529,7.667234,0.177919,21.988835,3.672038,5.689796,3.646058,54.242746,21.039989,15.533146,77.770945,2.21,1.34,0,0,0,0,0.0,0.000000,0.0,0.0,0.0,0.0,3,2,6,1,1,3,1,5,0,0,0.0,0.0,1,10,6.0,SHOTGUN,85,0,turf,inside,0,240,DE,2022090800.056.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
765595,2.022111e+09,3299.0,46077.0,15.0,12.79,18.193333,316.93,310.26,108.727273,0,0.995175,18.98,12.923333,1.80,3.90,84.28,62.29,204,WR,361.636364,pass,3.984533,4.382237,5.707136,6.483949,7.567331,8.460715,10.242080,13.548432,28.643975,32.620414,8.129514,168.904899,131.239251,119.593791,123.840114,142.731481,132.325061,109.533515,71.450954,119.378206,93.568031,2.659818,1.04,3.33,1,3,0,0,1.0,0.625000,19.0,8.0,0.0,1.0,1,5,5,1,1,3,1,5,0,0,0.0,0.0,2,19,5.0,SHOTGUN,18,3217,turf,inside,21,230,ILB,2022110700.03299.0.1
765596,2.022111e+09,3299.0,52436.0,15.0,8.71,22.943333,258.42,310.69,217.713636,0,0.995175,18.98,12.923333,1.80,3.90,84.28,62.29,204,WR,361.636364,pass,2.731227,3.903870,4.236189,5.379303,5.468830,5.596445,9.395664,16.089006,25.507371,32.148177,14.348286,26.430338,145.542147,96.569250,119.442347,109.991615,143.103012,131.173606,107.609510,166.381267,141.013615,57.285925,2.11,0.73,5,1,0,0,1.0,0.916667,13.0,6.0,0.0,1.0,1,5,5,1,1,3,1,5,0,0,0.0,0.0,2,19,5.0,SHOTGUN,18,3217,turf,inside,21,227,ILB,2022110700.03299.0.1
765597,2.022111e+09,3299.0,52627.0,15.0,21.55,15.803333,234.64,223.00,324.545455,0,0.995175,18.98,12.923333,1.80,3.90,84.28,62.29,204,WR,361.636364,pass,9.378406,10.874213,10.960716,11.698825,12.671910,12.947185,13.876059,14.057859,30.055178,30.284795,3.859961,164.330810,101.540464,92.785483,103.348108,123.014910,74.978507,106.076921,99.121240,168.151488,140.891366,6.384482,3.40,4.72,0,0,0,0,0.0,0.000000,18.0,0.0,1.0,1.0,1,5,5,1,1,3,1,5,0,0,0.0,0.0,2,19,5.0,SHOTGUN,18,3217,turf,inside,21,210,SS,2022110700.03299.0.1
765598,2.022111e+09,3299.0,53460.0,15.0,13.66,21.483333,328.08,0.26,82.472727,0,0.995175,18.98,12.923333,1.80,3.90,84.28,62.29,204,WR,361.636364,pass,1.311869,2.306339,3.114643,4.410771,4.687057,5.225562,6.897710,11.490213,25.241626,29.562625,10.078492,155.731200,110.415185,118.791190,132.562842,149.086416,130.282696,94.091822,47.831149,108.379240,79.661956,26.219242,0.72,0.44,1,1,0,0,1.0,0.750000,4.0,2.0,0.0,2.0,1,5,5,1,1,3,1,5,0,0,0.0,0.0,2,19,5.0,SHOTGUN,18,3217,turf,inside,21,252,OLB,2022110700.03299.0.1


In [17]:
CNN_LSTM.eval()
with torch.no_grad():
    # Initialize an empty list to store predictions (optional)
    all_predictions = []

    # Loop over the testing data
    for test_batch in train_loader:
        inputs, _ = test_batch

        # Transfer the inputs to the correct device
        inputs = inputs.to(device)

        # Forward pass to get the output from the model
        outputs = CNN_LSTM(inputs)

        # Optionally store the predictions
        all_predictions.append(outputs)

7

In [18]:
flattened_values = np.concatenate(all_predictions).ravel()

masked_flat = flattened_values[mask.ravel()==1]

# Create a new DataFrame from the flattened array
df_flattened = pd.DataFrame(masked_flat, columns=['model_probs'])

In [19]:
train = pd.concat([train, df_flattened], axis=1)
display(train)

Unnamed: 0,gameId,playId,nflId,frameId,x,y,unitDir,unitO,force,home,preSnapWinProbabilityDefense,bcx,bcy,bcs,bca,bco,bcdir,bcweight,bcPosition,bcForce,play_type,c1Dist,c2Dist,c3Dist,c4Dist,c5Dist,c6Dist,c7Dist,c8Dist,c9Dist,c10Dist,bcDist,c1Ang,c2Ang,c3Ang,c4Ang,c5Ang,c6Ang,c7Ang,c8Ang,c9Ang,c10Ang,bcAng,a,s,tackles_ingame,assists_ingame,ff_ingame,misses_ingame,tackle_efficiency_ingame,tackle_rating_ingame,rolling_tackles,rolling_assists,rolling_ff,rolling_mt,DL,LB,DB,QB,RB,WR,TE,OL,tackle_binary_all,tackle_binary_single,tackle_nonbinary_all,tackle_nonbinary_single,down,yardsToGo,defendersInTheBox,offenseFormation,absoluteYardlineNumber,timeSinceStart,surface,inside_outside,presnapDefScoreDiff,weight,position,gamePlayId,model_probs
0,2.022091e+09,56.0,38577.0,6.0,41.89,28.740000,87.71,79.47,288.200000,1,0.413347,40.15,35.590000,4.61,4.82,245.73,157.80,191,WR,418.463636,pass,3.195387,10.116071,10.461855,10.882909,12.035414,12.635874,12.701657,13.169210,14.799963,23.582173,7.067538,75.239538,89.059902,122.727777,103.094000,104.570954,166.159966,86.642246,80.903275,77.812451,158.729540,16.542527,2.62,3.35,0,0,0,0,0.0,0.000000,0.0,0.0,0.0,0.0,3,2,6,1,1,3,1,5,0,0,0.0,0.0,1,10,6.0,SHOTGUN,85,0,turf,inside,0,242,ILB,2022090800.056.0,0.0
1,2.022091e+09,56.0,41239.0,6.0,27.85,29.960000,247.65,276.16,364.000000,1,0.413347,40.15,35.590000,4.61,4.82,245.73,157.80,191,WR,418.463636,pass,1.400321,1.783620,2.496898,3.993257,4.414386,4.674409,8.228657,17.168183,21.436532,32.008038,13.527265,113.577579,163.037991,150.980380,102.982031,53.447313,68.130076,59.944064,110.580937,72.972206,65.058378,136.944687,2.86,3.62,0,0,0,0,0.0,0.000000,0.0,0.0,0.0,0.0,3,2,6,1,1,3,1,5,0,0,0.0,0.0,1,10,6.0,SHOTGUN,85,0,turf,inside,0,280,DT,2022090800.056.0,0.0
2,2.022091e+09,56.0,42816.0,6.0,49.38,7.660000,8.33,61.57,346.254545,1,0.413347,40.15,35.590000,4.61,4.82,245.73,157.80,191,WR,418.463636,pass,1.233207,10.014569,22.204274,22.838312,26.325539,26.712411,27.894992,30.064028,31.255438,33.017583,29.415605,89.937075,111.358111,93.020417,127.965101,125.346367,127.777247,120.762863,123.703959,122.414454,123.035494,99.957121,4.14,2.60,0,0,0,0,0.0,0.000000,0.0,0.0,0.0,0.0,3,2,6,1,1,3,1,5,0,0,0.0,0.0,1,10,6.0,SHOTGUN,85,0,turf,inside,0,184,CB,2022090800.056.0,0.0
3,2.022091e+09,56.0,43294.0,6.0,41.85,37.850000,268.50,230.96,116.290909,1,0.413347,40.15,35.590000,4.61,4.82,245.73,157.80,191,WR,418.463636,pass,8.993442,13.196030,14.422794,14.850576,15.279797,15.418982,16.539265,16.979061,21.643128,32.342421,2.828003,22.070958,48.171854,61.668459,56.543692,67.764015,42.186107,43.622479,30.430267,8.319643,15.291336,35.450938,1.23,5.88,0,0,0,0,0.0,0.000000,0.0,0.0,0.0,0.0,3,2,6,1,1,3,1,5,1,0,1.0,0.0,1,10,6.0,SHOTGUN,85,0,turf,inside,0,208,CB,2022090800.056.0,0.0
4,2.022091e+09,56.0,43298.0,6.0,27.89,33.140000,293.53,249.12,241.090909,1,0.413347,40.15,35.590000,4.61,4.82,245.73,157.80,191,WR,418.463636,pass,0.773886,2.104305,3.431049,5.466160,7.240836,7.311580,10.903687,17.517377,23.554390,34.387191,12.502404,48.770529,7.667234,0.177919,21.988835,3.672038,5.689796,3.646058,54.242746,21.039989,15.533146,77.770945,2.21,1.34,0,0,0,0,0.0,0.000000,0.0,0.0,0.0,0.0,3,2,6,1,1,3,1,5,0,0,0.0,0.0,1,10,6.0,SHOTGUN,85,0,turf,inside,0,240,DE,2022090800.056.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
765595,2.022111e+09,3299.0,46077.0,15.0,12.79,18.193333,316.93,310.26,108.727273,0,0.995175,18.98,12.923333,1.80,3.90,84.28,62.29,204,WR,361.636364,pass,3.984533,4.382237,5.707136,6.483949,7.567331,8.460715,10.242080,13.548432,28.643975,32.620414,8.129514,168.904899,131.239251,119.593791,123.840114,142.731481,132.325061,109.533515,71.450954,119.378206,93.568031,2.659818,1.04,3.33,1,3,0,0,1.0,0.625000,19.0,8.0,0.0,1.0,1,5,5,1,1,3,1,5,0,0,0.0,0.0,2,19,5.0,SHOTGUN,18,3217,turf,inside,21,230,ILB,2022110700.03299.0.1,0.0
765596,2.022111e+09,3299.0,52436.0,15.0,8.71,22.943333,258.42,310.69,217.713636,0,0.995175,18.98,12.923333,1.80,3.90,84.28,62.29,204,WR,361.636364,pass,2.731227,3.903870,4.236189,5.379303,5.468830,5.596445,9.395664,16.089006,25.507371,32.148177,14.348286,26.430338,145.542147,96.569250,119.442347,109.991615,143.103012,131.173606,107.609510,166.381267,141.013615,57.285925,2.11,0.73,5,1,0,0,1.0,0.916667,13.0,6.0,0.0,1.0,1,5,5,1,1,3,1,5,0,0,0.0,0.0,2,19,5.0,SHOTGUN,18,3217,turf,inside,21,227,ILB,2022110700.03299.0.1,0.0
765597,2.022111e+09,3299.0,52627.0,15.0,21.55,15.803333,234.64,223.00,324.545455,0,0.995175,18.98,12.923333,1.80,3.90,84.28,62.29,204,WR,361.636364,pass,9.378406,10.874213,10.960716,11.698825,12.671910,12.947185,13.876059,14.057859,30.055178,30.284795,3.859961,164.330810,101.540464,92.785483,103.348108,123.014910,74.978507,106.076921,99.121240,168.151488,140.891366,6.384482,3.40,4.72,0,0,0,0,0.0,0.000000,18.0,0.0,1.0,1.0,1,5,5,1,1,3,1,5,0,0,0.0,0.0,2,19,5.0,SHOTGUN,18,3217,turf,inside,21,210,SS,2022110700.03299.0.1,0.0
765598,2.022111e+09,3299.0,53460.0,15.0,13.66,21.483333,328.08,0.26,82.472727,0,0.995175,18.98,12.923333,1.80,3.90,84.28,62.29,204,WR,361.636364,pass,1.311869,2.306339,3.114643,4.410771,4.687057,5.225562,6.897710,11.490213,25.241626,29.562625,10.078492,155.731200,110.415185,118.791190,132.562842,149.086416,130.282696,94.091822,47.831149,108.379240,79.661956,26.219242,0.72,0.44,1,1,0,0,1.0,0.750000,4.0,2.0,0.0,2.0,1,5,5,1,1,3,1,5,0,0,0.0,0.0,2,19,5.0,SHOTGUN,18,3217,turf,inside,21,252,OLB,2022110700.03299.0.1,0.0


In [20]:
train[train["model_probs"]>0]

Unnamed: 0,gameId,playId,nflId,frameId,x,y,unitDir,unitO,force,home,preSnapWinProbabilityDefense,bcx,bcy,bcs,bca,bco,bcdir,bcweight,bcPosition,bcForce,play_type,c1Dist,c2Dist,c3Dist,c4Dist,c5Dist,c6Dist,c7Dist,c8Dist,c9Dist,c10Dist,bcDist,c1Ang,c2Ang,c3Ang,c4Ang,c5Ang,c6Ang,c7Ang,c8Ang,c9Ang,c10Ang,bcAng,a,s,tackles_ingame,assists_ingame,ff_ingame,misses_ingame,tackle_efficiency_ingame,tackle_rating_ingame,rolling_tackles,rolling_assists,rolling_ff,rolling_mt,DL,LB,DB,QB,RB,WR,TE,OL,tackle_binary_all,tackle_binary_single,tackle_nonbinary_all,tackle_nonbinary_single,down,yardsToGo,defendersInTheBox,offenseFormation,absoluteYardlineNumber,timeSinceStart,surface,inside_outside,presnapDefScoreDiff,weight,position,gamePlayId,model_probs
