# Expected Pocket Time
This notebook will compute the expected pocket time for each play within the full dataset.

## Overview
There are 2 neural networks being trained here, the first one (Model A) is designed to predict a player's movement in the next frame of a play given the current state in the frame being looked at. The implication of this is that the predicted movement represents an average movement that specific player will take in his current situation. Given that the model is only trained on the data we currently have, this is not a seasonal average. 

The second neural network is one that predicts the total time left the QB has in the pocket. This model is trained ONLY on sack plays where the total time in the pocket was reached. It is trained on a frame-by-frame basis (hence a current field state) and will compute the total seconds remaining that the QB has in the pocket given that field state.

All together when both models are being used, we can decompose the changes in pocket time remaning as the play moves forward in time. The jumps in the pocket time remaning are likely caused by an incident that happened on the field relative to either the offensive line or defensive line. We can then compute the average movement the O-line and D-line make given their field states to produce an average change pocket time. This average can be subtracted from the actual value to show whether each player is above or below their average in gaining or loosing the QB's time in the pocket.


In [4]:
import numpy as np
import pickle
from tqdm import tqdm
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt

#Set by user -> location of the plays dictionary
data_dir = "/Users/admin/Desktop/data_bowl/data_bowl_github/"

#Open the dictionary
f = open(data_dir+'full_plays_dictionary.pckl', 'rb')
dat = pickle.load(f)
f.close()

In [5]:
#Set a list of functions needed to compute this
def GetClosePlayers(dat_dict,player,plID,frID,radius=250):
    '''
    This function will return a dictionary of the all players on the field from closest
    to farthest from the player in specified in the inputs.
    
    INPUTS:
        dat_dict -> dict. the plays dictionary
        player -> str. the player ID
        plID -> str. the play ID
        frID -> str. the frame ID
        
    RETURNS:
        A sorted dictionary from closest to farthest players relative to
        the input player.
    '''
    
    #Open the dictionary
    curr=dat_dict['pass'][plID][frID]
    ofIDs=list(curr.keys()) #List of all player IDs for this game ID and frame ID
    cx0=float(curr[player][0]) #X-position
    cy0=float(curr[player][1]) #Y-position
    
    #Create empty 1D arrays to be filled (distance, x-dist, y-dist, ID, possession)
    distv = np.zeros(len(ofIDs)) * np.nan
    xrelv=np.zeros(len(ofIDs)) * np.nan
    yrelv=np.zeros(len(ofIDs)) * np.nan
    idv=np.zeros(len(ofIDs)) * np.nan
    pos = np.zeros(len(ofIDs)).astype(object) * np.nan
    
    for count, idd in enumerate(ofIDs):
        if idd=='-999.0': #Skip over the football itself
            continue
        #Take the difference in x and y position from the specified player to the others
        xrel=float(curr[idd][0])-cx0
        yrel=float(curr[idd][1])-cy0
        dist_tot=np.sqrt(xrel**2+yrel**2) #Distance formula
        poss = curr[idd][-3]
        if dist_tot>radius:
            continue
        #Fill the arrays
        distv[count] = dist_tot
        idv[count] = idd
        xrelv[count] = xrel
        yrelv[count] = yrel
        pos[count] = poss

    #Sort the arrays here
    sortperm=np.argsort(distv)
    idv = idv[sortperm]
    distv = distv[sortperm]
    yrelv = yrelv[sortperm]
    xrelv = xrelv[sortperm]
    pos = pos[sortperm]
    
    return({'idv':idv,'distv':distv,'xrelv':xrelv,'yrelv':yrelv, 'pos': pos})

def os2xyspeed(o,s):
    '''
    This function will take in the orientation of the player and return their x,y triangular coordinates.
    INPUTS:
        o,s -> 1D arrays of orientation and speed
    RETURNS:
        x,y -> arrays of input dimension corresponding to the 
        triangulated coordinates.
    '''
    
    if o<90:
        ot=o*2*np.pi/360
        x=s*np.sin(ot)
        y=s*np.cos(ot)
    elif o<180:
        ot=(180-o)*2*np.pi/360
        x=s*np.sin(ot)
        y=-s*np.cos(ot)
    elif o<180:
        ot=(o-180)*2*np.pi/360
        x=-s*np.sin(ot)
        y=-s*np.cos(ot)
    else:
        ot=(360-o)*2*np.pi/360
        x=-s*np.sin(ot)
        y=s*np.cos(ot)
    return([x,y])

def GetFieldStateFeatures(plID,frID,dat_dict):
    '''
    This function will analyze the current state of the field and produce arrays that are meant to 
    train a neural network.
    
    INPUTS:
        plID -> str. the play ID
        frID -> str. the frame ID
        dat_dict -> the plays dictionary
        
    RETURNS:
        x -> list of arrays corresponding to the field state
        x_players -> list of arrays corresponding to the player ID in x
    '''
    
    plID,frID = str(plID),str(frID)
    
    teamkick=play_d[plID][4] #Find the posession Team
    homeACR=game_d[plID.split("-")[0]][4] #Find Home Team
    awayACR=game_d[plID.split("-")[0]][5] #Find Away Team
    if teamkick==homeACR:
        homekick=True
    else:
        homekick=False
#-------------------------------------------------------------------------------------------------------------------
    #reverse x coordinates for all objects if playdirection it to the right
    revx=False
    if dat_dict['pass'][plID]['1']['-999.0'][8].strip("\"")=='right':
        revx=True
#-------------------------------------------------------------------------------------------------------------------
    gcpout=GetClosePlayers(dat,'-999.0',plID,frID,radius=1000) #Outputs all players data in relation to the ball
    
    idx = np.where(gcpout['pos'] == 'QB')[0][0]
    bholder = str(gcpout['idv'][idx]) #Get Player ID of the QB
    
    bx=float(dat_dict['pass'][plID][frID]['-999.0'][0]) #X-Coordinate of the ball
    by=float(dat_dict['pass'][plID][frID]['-999.0'][1]) #Y-Coordinate of the ball
    cx=float(dat_dict['pass'][plID][frID][bholder][0]) #X-Coordinate of the QB
    cy=float(dat_dict['pass'][plID][frID][bholder][1]) #Y-Coordinate of the QB
    cs=float(dat_dict['pass'][plID][frID][bholder][2]) #Speed of the QB at the given frame
    ca=float(dat_dict['pass'][plID][frID][bholder][3]) #Acceleration of the QB at the given frame
    cdis=float(dat_dict['pass'][plID][frID][bholder][4]) #Distance Traveled by the QB between last frame and current frame
    co=float(dat_dict['pass'][plID][frID][bholder][5]) #Orientation of the QB 
    temp=os2xyspeed(co,cs)
    cxspd=temp[0] #X-Velocity of the QB
    cyspd=temp[1] #Y-Velocity of the QB
    if revx:
        cxspd=-cxspd #If Play was to the right, flip X-Coordinates
        bx=120-bx
        cx=120-cx
    x=[]
    x_players = []
    #------------------------------------------------------------------------------------------------------------------- 
    
    #Iterate through all players now
    for i in range(len(gcpout['idv'])): 
        cID=str(gcpout['idv'][i]) 
        if cID !='nan':
            
            cx=float(dat_dict['pass'][plID][frID][cID][0])
            cy=float(dat_dict['pass'][plID][frID][cID][1])
            cs=float(dat_dict['pass'][plID][frID][cID][2])
            ca=float(dat_dict['pass'][plID][frID][cID][3])
            cdis=float(dat_dict['pass'][plID][frID][cID][4])
            co=float(dat_dict['pass'][plID][frID][cID][5])# relx, rely, dist, xspeed, yspeed, acceleration, kicking_team_boolean, SORT BY DISTANCE TO BALL
            temp=os2xyspeed(co,cs)
            cxspd=temp[0]
            cyspd=temp[1]
            if revx:
                cxspd=-cxspd
                cx=120-cx
            cretind=0
            if (homekick and dat_dict['pass'][plID][frID][cID][11].strip("\"")=="home") or ((not homekick) and dat_dict['pass'][plID][frID][cID][11].strip("\"")=="away"):
                cretind=1
            x+=[(cx-bx)/20,(cy-by)/20,gcpout['distv'][i]/13,cxspd,cyspd,ca,cretind]
            x_players+= [cID]*7

    return(x,x_players)

def GetClosePlayers_5(dat_dict,player,plID,frID,radius=250):
    '''
    Similar to the function GetClosePlayers, this will return the first 5 players.
    
    INPUTS:
        dat_dict -> dict. containing the full dataset
        player -> str. the player ID
        plID -> str. the play ID
        frID -> str. the frame ID
   
    RETURNS:
        dictionary containing the first 5 players closest to the player specified.
    '''
    #Read in the dictionary
    curr=dat_dict['pass'][plID][frID]
    ofIDs=list(curr.keys())
    cx0=float(curr[player][0])
    cy0=float(curr[player][1])
    
    #Create empty arrays to be filled (distance, x-dist, y-dist, ID, possesion)
    distv = np.zeros(len(ofIDs)) * np.nan
    xrelv=np.zeros(len(ofIDs)) * np.nan
    yrelv=np.zeros(len(ofIDs)) * np.nan
    idv=np.zeros(len(ofIDs)) * np.nan
    pos = np.zeros(len(ofIDs)).astype(object) * np.nan
    
    #Loop through each player ID
    for count, idd in enumerate(ofIDs):
        if idd!='-999.0': #remove the football location, no need for that
            xrel=float(curr[idd][0])-cx0
            yrel=float(curr[idd][1])-cy0
            dist_tot=np.sqrt(xrel**2+yrel**2)
            poss = curr[idd][-3]
            if dist_tot<radius:
                distv[count] = dist_tot
                idv[count] = idd
                xrelv[count] = xrel
                yrelv[count] = yrel
                pos[count] = poss

    #Sort the data here
    sortperm=np.argsort(distv)
    idv = idv[sortperm]
    distv = distv[sortperm]
    yrelv = yrelv[sortperm]
    xrelv = xrelv[sortperm]
    pos = pos[sortperm]  
    
    return({'idv':idv,'distv':distv,'xrelv':xrelv,'yrelv':yrelv, 'pos': pos})

def testerrv(net,testloader):
    '''
    This function tests the trained neural network.
    
    INPUTS:
        net -> the model object
        testloader -> the testing set
    RETURNS:
        the correlation between the testing and training set of the neural networks performance
        
    '''
    outerr=[]
    testout=np.array([])
    trainout=np.array([])
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            
            #Calculate outputs by running images through the network
            images=images.float()
            labels=labels.float()
            outputs = net(images).squeeze()
            
            #The class with the highest energy is what we choose as prediction
            runtot=0
            denom=float(labels.shape[0])
            for i in range(int(denom)):
                runtot+=abs(labels[i]-outputs[i])/denom
            outerr.append(float(runtot))
            testout=np.append(testout,torch.Tensor.numpy(outputs))
            trainout=np.append(trainout,torch.Tensor.numpy(labels))
    
    return(np.corrcoef(testout,trainout)[0,1])

def model_a_assimilation(dat_dict,plID,frID,frIDs,j,playerID):
    '''
    This function assimates and trains the first round of neural networks.
    
    INPUTS:
        dat_dict -> dict. the full plays dictionary
        plID -> str. the play ID
        frID -> str. the frame ID
        frIDs -> list of all frame IDs for this play
        j -> int. an iterator
        playerID -> str. the player ID
        
    RETURNS:
        curx -> list containing the predictor variables needed to predict the movement 
        of the player specified given the current state of the field.
    
    '''
    
    teampos = play_d[plID][4] #Finds posession Team
    homeACR = game_d[plID.split("-")[0]][4] #Finds Home Team
    awayACR = game_d[plID.split("-")[0]][5] #Finds Away Team
    if teampos == homeACR:
        homepos = True
    else:
        homepos = False
    
    #Get the outcome vector: [deltax,deltay,a,xspd,yspd]
    cx = float(dat_dict['pass'][plID][frID][playerID][0])
    cy = float(dat_dict['pass'][plID][frID][playerID][1])
    cs = float(dat_dict['pass'][plID][frID][playerID][2])
    ca = float(dat_dict['pass'][plID][frID][playerID][3])
    cdis = float(dat_dict['pass'][plID][frID][playerID][4])
    co = float(dat_dict['pass'][plID][frID][playerID][5])
    nx = float(dat_dict['pass'][plID][frIDs[j+1]][playerID][0])
    ny = float(dat_dict['pass'][plID][frIDs[j+1]][playerID][1])
    ns = float(dat_dict['pass'][plID][frIDs[j+1]][playerID][2])
    na = float(dat_dict['pass'][plID][frIDs[j+1]][playerID][3])
    ndis = float(dat_dict['pass'][plID][frIDs[j+1]][playerID][4])
    no = float(dat_dict['pass'][plID][frIDs[j+1]][playerID][5])
    temp = os2xyspeed(no,ns)
    nxspd = temp[0]
    nyspd = temp[1]
    temp = os2xyspeed(co,cs)
    cxspd = temp[0]
    cyspd = temp[1]
#-------------------------------------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------------------------------------
    ### get 5 closest players to the player specified
    gcpout=GetClosePlayers(dat,playerID,plID,frID,radius=250)
#-------------------------------------------------------------------------------------------------------------------
#-------------------------------------------------------------------------------------------------------------------

    #X-VARIABLE
    curx=[ca,cxspd,cyspd,cdis] ####(PREDICTOR) Current Field State for reference player####
#-------------------------------------------------------------------------------------------------------------------     
    for l in range(len(gcpout)):
        ID2=str(gcpout['idv'][l])
        cx2=float(dat_dict['pass'][plID][frID][ID2][0])
        cy2=float(dat_dict['pass'][plID][frID][ID2][1])
        cs2=float(dat_dict['pass'][plID][frID][ID2][2])
        ca2=float(dat_dict['pass'][plID][frID][ID2][3])
        cdis2=float(dat_dict['pass'][plID][frID][ID2][4])
        co2=float(dat_dict['pass'][plID][frID][ID2][5])
        temp=os2xyspeed(co2,cs2)
        cxspd2=temp[0]
        cyspd2=temp[1]

        #sameteam = 1 if reference player and closest player in question is on the same team
        if dat_dict['pass'][plID][frID][ID2][11]==dat_dict['pass'][plID][frID][playerID][11]:
            sameteam=1
        else:
            sameteam=0

        #(PREDICTOR) Adds distance, difference in x&y coords, x&y speed, 
        #same team indicator for all players (including reference player)
        curx+=[np.sqrt((cx2-cx)**2+(cy2-cy)**2),cx2-cx,cy2-cy,ca2,cxspd2,cyspd2,sameteam]
#-------------------------------------------------------------------------------------------------------------------
    # (PREDICTOR) Adds x&y coords of the ball
    cx2=float(dat_dict['pass'][plID][frID]['-999.0'][0])
    cy2=float(dat_dict['pass'][plID][frID]['-999.0'][1])

    #Indicator to determine who is posession team
    cretind=0
    if (homepos and dat_dict['pass'][plID][frID][playerID][11].strip("\"")=="home") or ((not homepos) and dat_dict['pass'][plID][frID][playerID][11].strip("\"")=="away"):
        cretind=1
#-------------------------------------------------------------------------------------------------------------------            
    #(PREDICTOR) Adds Footabll x&y coords, posession team indicator (1 = posession, 0 = defending)
    curx+=[cx2,cy2,cretind]

    return curx

def player_pos2player_id(dat_dict,gameId,Player_Pos):
    '''
    This function will return the player ID for the specified position.
    
    INPUTS:
        dat_dict -> dict. the full plays dictionary
        gameID -> str. the game ID
        Player_Pos -. str. the player position
        
    RETURNS:
        pid -> list of the player IDs corresponding to that position
    '''
    
    #Get the list of all player IDs
    player_IDs_dict = list(dat_dict['pass'][gameId]['1'].keys())
    pid = []
    for p in player_IDs_dict:
        if dat_dict['pass'][games[i]]['1'][str(p)][-3] == Player_Pos:
            pid.append(p)
    return pid


#Initiate the neural network
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(154, 200)
        self.fc2 = nn.Linear(200,100)
        self.fc3 = nn.Linear(100,20)
        self.fc4 = nn.Linear(20, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)#.squeeze()
        return x
net = Net()


In [None]:
import pandas as pd
2021090900-2330
WEEK = '1'
#d = pd.read_csv(data_dir+"raw_data/week" +WEEK+ ".csv")
np.unique(d.loc[np.where((d.gameId==2021090900)&(d.playId==2298))[0]].event.values)

In [6]:
#Organize the data to ingest into the model
os.chdir(data_dir)

WEEKS = ['1','2','3','4','5','6','7','8']
catchframe = {}
passframe = {}
sackframe = {}
pID_s=set()
pID_p=set()
pID_sa=set()

for WEEK in WEEKS:
    filein=open(data_dir+"raw_data/week" +WEEK+ ".csv",'r')
    line=filein.readline()

    #Get all playIDs when ball was snapped and onward
    filein.seek(0)
    for line in filein:
        line_v=line.split(",")
        if line_v[15].strip('\n').strip('""') == 'ball_snap' or line_v[15].strip('\n').strip('""') =="autoevent_ballsnap":
            gpID = line_v[0]+"-"+line_v[1]
            pID_s.add(gpID)
            catchframe[gpID] = line_v[3]
            
    #Get all playIDs when ball was thrown and onward
    filein.seek(0)
    for line in filein:
        line_v=line.split(",")
        if line_v[15].strip('\n').strip('""') == 'pass_forward' or line_v[15].strip('\n').strip('""') =="autoevent_passforward" or line_v[15].strip('\n').strip('""') =="qb_sack" or line_v[15].strip('\n').strip('""') =="qb_strip_sack":
            gpID = line_v[0]+"-"+line_v[1]
            pID_p.add(gpID)
            passframe[gpID] = line_v[3]
            
    #Get all playIDs if the Qb was sacked
    filein.seek(0)
    for line in filein:
        line_v=line.split(",")
        if line_v[15].strip('\n').strip('""') == 'qb_sack':
            gpID = line_v[0]+"-"+line_v[1]
            pID_sa.add(gpID)
            sackframe[gpID] = line_v[3]


filein2=open(data_dir+"raw_data/games.csv",'r')
filein2.readline()
game_d={}
for line in filein2:
    line_v=[x.strip('\"') for x in line.strip("\n").split(",")]
    game_d[line_v[0]]=line_v[1:]

filein3=open(data_dir+"raw_data/plays.csv",'r')
filein3.readline()
play_d={}
for line in filein3:
    line_v=[x.strip('\"') for x in line.strip("\n").split(",")]
    play_d[line_v[0]+"-"+line_v[1]]=line_v[2:]
    


In [7]:
#Create a second neural network that is trained ONLY on sack plays (where the true pocket time was reached)
games = list(pID_sa)
frames_tot = np.sum(np.array([len(list(dat['pass'][games[i]].keys())) for i in range(len(games))]))
Y = np.zeros((frames_tot))*np.nan #[playId,seconds remaining]
X = np.zeros((frames_tot,22*7))*np.nan #[playId, stats for each player]

count = 0
for i in tqdm(range(len(games))):
    #Get the frame at which the QB was sacked
    try:
        frames = list(dat['pass'][games[i]].keys())

        #Get the frame at which the ball was snapped
        ball_snap = int(catchframe[games[i]])
        #ball_pass = int(passframe[games[i]])
        sack_frame = int(sackframe[games[i]])
    except KeyError:
        sack_frame = int(sackframe[games[i]])
        
    tot_frames = frames[ball_snap:sack_frame]
    for j in range(len(tot_frames)): #up until QB was sacked (including)
        X[count,:] = GetFieldStateFeatures(games[i],tot_frames[j],dat)[0]
        Y[count] = ((int(sack_frame)-int(tot_frames[j]))*0.1)
        
        count += 1
        
#Shave off frames after the sack
ishave = np.where(np.isnan(Y))[0][0]
X,Y = X[:ishave],Y[:ishave]

xt = torch.tensor(X)
yt = torch.tensor(Y)
trainset=torch.utils.data.TensorDataset(xt[:int(len(xt)*.9)],yt[:int(len(xt)*.9)])
testset=torch.utils.data.TensorDataset(xt[int(len(xt)*.9):],yt[int(len(xt)*.9):])

batch_size=64
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,shuffle=True, num_workers=0)

#Train the model
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.00001)#, momentum=0.95)
testcor=[]
traincor=[]
for epoch in range(30):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs=inputs.float()
        labels=labels.float()
        optimizer.zero_grad()
        outputs = net(inputs).squeeze()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 1000 == 999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 1000))
            running_loss = 0.0
    testerr=testerrv(net,testloader)
print('Finished Training') #4.8 loss



100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 451/451 [00:05<00:00, 84.54it/s]


Finished Training


In [25]:
#Use both models on all plays to compute the expected pocket time
games = list(dat['pass'].keys())
return_dict = {}
for i in tqdm(range(len(games))):
    try:
        return_dict[str(games[i])] = {}
        return_dict[str(games[i])]['EPT_ave'] = {}
        return_dict[str(games[i])]['EPT_decomp'] = {}

        #Create arrays for the function
        frames_tot = np.sum(np.array([len(list(dat['pass'][games[i]].keys())) for i in range(len(games))]))
        X = np.zeros((frames_tot,22*7))*np.nan #[playId, stats for each player]
        X_players = np.zeros((frames_tot,22*7))*np.nan


        count = 0
        #Get the frame at which the QB was sacked
        frames = list(dat['pass'][games[i]].keys())

        #Get the frame at which the ball was snapped
        ball_snap = int(catchframe[games[i]])
        ball_pass = int(passframe[games[i]])

        tot_frames = frames[ball_snap:ball_pass+1]
        for j in range(len(tot_frames)-1): #up until QB was sacked (including)
            playerID_DE = player_pos2player_id(dat,games[i],'DE')

            if '43292.0' in playerID_DE:
                print('yes')
            count += 1
    except KeyError:
        print(games[i])
        
        

  1%|▌                                                                                                                  | 45/8555 [00:24<1:14:01,  1.92it/s]

2021090900-2298


  1%|▌                                                                                                                  | 46/8555 [00:24<1:14:17,  1.91it/s]

2021090900-2330


  1%|█▍                                                                                                                | 110/8555 [00:58<1:15:39,  1.86it/s]

2021091200-843


  2%|█▋                                                                                                                | 131/8555 [01:10<1:18:47,  1.78it/s]

2021091200-2063


  2%|█▊                                                                                                                | 133/8555 [01:11<1:20:04,  1.75it/s]

2021091200-2145


  2%|█▊                                                                                                                | 134/8555 [01:12<1:20:09,  1.75it/s]

2021091200-2214


  2%|██▏                                                                                                               | 164/8555 [01:28<1:15:43,  1.85it/s]

2021091200-4131


  2%|██▎                                                                                                               | 170/8555 [01:32<1:14:48,  1.87it/s]

2021091200-4367


  2%|██▎                                                                                                               | 174/8555 [01:34<1:16:04,  1.84it/s]

2021091201-345


  3%|███                                                                                                               | 227/8555 [02:03<1:17:22,  1.79it/s]

2021091201-2983


  3%|███▏                                                                                                              | 238/8555 [02:09<1:15:59,  1.82it/s]

2021091201-3756


  4%|████▏                                                                                                             | 312/8555 [02:52<1:15:48,  1.81it/s]

2021091202-3606


  5%|█████▊                                                                                                            | 432/8555 [04:00<1:14:17,  1.82it/s]

2021091204-1979


  5%|█████▉                                                                                                            | 442/8555 [04:06<1:14:41,  1.81it/s]

2021091204-2652


  6%|██████▎                                                                                                           | 476/8555 [04:24<1:13:41,  1.83it/s]

2021091204-4135


  6%|██████▍                                                                                                           | 486/8555 [04:30<1:14:24,  1.81it/s]

2021091204-4581


  6%|██████▌                                                                                                           | 494/8555 [04:34<1:14:27,  1.80it/s]

2021091205-404


  6%|██████▉                                                                                                           | 524/8555 [04:51<1:14:52,  1.79it/s]

2021091205-2097


  6%|███████▎                                                                                                          | 547/8555 [05:04<1:12:15,  1.85it/s]

2021091205-3548


  7%|███████▋                                                                                                          | 573/8555 [05:18<1:05:14,  2.04it/s]

2021091206-269


  7%|███████▋                                                                                                          | 574/8555 [05:19<1:06:19,  2.01it/s]

2021091206-439


  7%|███████▋                                                                                                          | 575/8555 [05:19<1:07:57,  1.96it/s]

2021091206-506


  7%|███████▊                                                                                                          | 588/8555 [05:25<1:05:07,  2.04it/s]

2021091206-1296


  7%|████████                                                                                                          | 603/8555 [05:32<1:06:10,  2.00it/s]

2021091206-2210


  7%|████████▎                                                                                                         | 622/8555 [05:42<1:07:36,  1.96it/s]

2021091206-3311


  7%|████████▍                                                                                                         | 629/8555 [05:46<1:08:12,  1.94it/s]

2021091206-3761


  8%|████████▋                                                                                                         | 656/8555 [05:59<1:07:22,  1.95it/s]

2021091207-1650


  8%|████████▊                                                                                                         | 660/8555 [06:01<1:06:37,  1.98it/s]

2021091207-1792


  8%|████████▊                                                                                                         | 664/8555 [06:03<1:06:30,  1.98it/s]

2021091207-2074


  8%|█████████▍                                                                                                        | 709/8555 [06:27<1:08:33,  1.91it/s]

2021091208-649


  9%|█████████▊                                                                                                        | 739/8555 [06:42<1:08:16,  1.91it/s]

2021091208-2219


  9%|█████████▊                                                                                                        | 740/8555 [06:43<1:07:51,  1.92it/s]

2021091208-2322


  9%|██████████▏                                                                                                       | 763/8555 [06:55<1:07:28,  1.92it/s]

2021091208-3874


  9%|██████████▌                                                                                                       | 795/8555 [07:12<1:06:28,  1.95it/s]

2021091209-1295


  9%|██████████▊                                                                                                       | 812/8555 [07:20<1:06:49,  1.93it/s]

2021091209-2293


 10%|██████████▊                                                                                                       | 814/8555 [07:21<1:06:55,  1.93it/s]

2021091209-2364


 10%|███████████▉                                                                                                      | 897/8555 [08:05<1:05:29,  1.95it/s]

2021091211-100


 10%|███████████▉                                                                                                      | 898/8555 [08:05<1:08:15,  1.87it/s]

2021091211-142


 11%|████████████                                                                                                      | 908/8555 [08:10<1:05:35,  1.94it/s]

2021091211-640


 11%|████████████▊                                                                                                     | 957/8555 [08:36<1:05:23,  1.94it/s]

2021091212-788


 11%|████████████▉                                                                                                     | 967/8555 [08:41<1:05:20,  1.94it/s]

2021091212-1239


 12%|█████████████▏                                                                                                    | 991/8555 [08:54<1:05:10,  1.93it/s]

2021091212-2474


 12%|█████████████▎                                                                                                    | 996/8555 [08:56<1:05:55,  1.91it/s]

2021091212-2804


 12%|█████████████▎                                                                                                    | 997/8555 [08:57<1:05:30,  1.92it/s]

2021091212-2857


 12%|█████████████▎                                                                                                    | 999/8555 [08:58<1:04:59,  1.94it/s]

2021091212-2905


 12%|█████████████▊                                                                                                   | 1046/8555 [09:22<1:04:23,  1.94it/s]

2021091213-1616


 12%|██████████████                                                                                                   | 1065/8555 [09:32<1:05:14,  1.91it/s]

2021091213-2757


 13%|██████████████▎                                                                                                  | 1083/8555 [09:41<1:04:54,  1.92it/s]

2021091300-148


 13%|██████████████▎                                                                                                  | 1087/8555 [09:43<1:05:00,  1.91it/s]

2021091300-356


 13%|██████████████▍                                                                                                  | 1089/8555 [09:45<1:04:18,  1.93it/s]

2021091300-407


 13%|██████████████▉                                                                                                  | 1131/8555 [10:07<1:11:33,  1.73it/s]

2021091300-2526


 13%|███████████████▏                                                                                                 | 1146/8555 [10:15<1:09:30,  1.78it/s]

2021091300-3256


 14%|███████████████▎                                                                                                 | 1160/8555 [10:23<1:07:47,  1.82it/s]

2021091300-3980


 14%|███████████████▎                                                                                                 | 1161/8555 [10:24<1:07:13,  1.83it/s]

2021091300-4094


 14%|███████████████▌                                                                                                 | 1178/8555 [10:33<1:08:01,  1.81it/s]

2021091600-235


 14%|███████████████▌                                                                                                 | 1179/8555 [10:34<1:07:20,  1.83it/s]

2021091600-261


 14%|███████████████▉                                                                                                 | 1207/8555 [10:49<1:06:58,  1.83it/s]

2021091600-1991


 14%|████████████████▏                                                                                                | 1225/8555 [10:59<1:05:48,  1.86it/s]

2021091600-3001


 14%|████████████████▎                                                                                                | 1232/8555 [11:03<1:04:32,  1.89it/s]

2021091600-3392


 14%|████████████████▎                                                                                                | 1238/8555 [11:06<1:04:00,  1.91it/s]

2021091600-3743


 15%|████████████████▉                                                                                                | 1287/8555 [11:31<1:04:03,  1.89it/s]

2021091900-1850


 15%|█████████████████                                                                                                | 1296/8555 [11:36<1:04:30,  1.88it/s]

2021091900-2467


 15%|█████████████████▏                                                                                               | 1299/8555 [11:38<1:04:34,  1.87it/s]

2021091900-2784


 15%|█████████████████▏                                                                                               | 1301/8555 [11:39<1:03:18,  1.91it/s]

2021091900-2879


 16%|█████████████████▋                                                                                               | 1338/8555 [11:59<1:04:09,  1.87it/s]

2021091901-988


 16%|█████████████████▋                                                                                               | 1339/8555 [11:59<1:03:57,  1.88it/s]

2021091901-1059


 16%|█████████████████▊                                                                                               | 1350/8555 [12:05<1:03:23,  1.89it/s]

2021091901-1741


 16%|██████████████████                                                                                               | 1363/8555 [12:12<1:04:48,  1.85it/s]

2021091901-2488


 16%|██████████████████                                                                                               | 1368/8555 [12:15<1:04:57,  1.84it/s]

2021091901-2764


 16%|██████████████████▍                                                                                              | 1396/8555 [12:30<1:01:16,  1.95it/s]

2021091902-1029


 16%|██████████████████▌                                                                                              | 1406/8555 [12:35<1:01:58,  1.92it/s]

2021091902-1619


 17%|██████████████████▋                                                                                              | 1414/8555 [12:39<1:02:09,  1.91it/s]

2021091902-2492


 17%|██████████████████▊                                                                                              | 1423/8555 [12:44<1:09:57,  1.70it/s]

2021091902-3163


 17%|██████████████████▊                                                                                              | 1424/8555 [12:45<1:08:13,  1.74it/s]

2021091902-3293


 17%|██████████████████▉                                                                                              | 1435/8555 [12:51<1:02:12,  1.91it/s]

2021091902-3820


 17%|███████████████████▏                                                                                             | 1456/8555 [13:02<1:02:18,  1.90it/s]

2021091903-1170


 17%|███████████████████▎                                                                                             | 1463/8555 [13:06<1:03:07,  1.87it/s]

2021091903-1494


 17%|███████████████████▍                                                                                             | 1473/8555 [13:11<1:03:52,  1.85it/s]

2021091903-2050


 17%|███████████████████▍                                                                                             | 1476/8555 [13:13<1:02:55,  1.87it/s]

2021091903-2194


 17%|███████████████████▌                                                                                             | 1477/8555 [13:13<1:02:31,  1.89it/s]

2021091903-2276


 17%|███████████████████▌                                                                                             | 1485/8555 [13:17<1:02:56,  1.87it/s]

2021091903-2531


 18%|████████████████████▎                                                                                            | 1534/8555 [13:43<1:01:17,  1.91it/s]

2021091904-1614


 18%|████████████████████▎                                                                                            | 1539/8555 [13:46<1:05:30,  1.79it/s]

2021091904-1816


 18%|████████████████████▌                                                                                            | 1555/8555 [13:55<1:02:44,  1.86it/s]

2021091904-2640


 18%|████████████████████▋                                                                                            | 1563/8555 [13:59<1:02:10,  1.87it/s]

2021091904-3071


 18%|████████████████████▊                                                                                            | 1577/8555 [14:06<1:01:33,  1.89it/s]

2021091905-427


 19%|████████████████████▉                                                                                            | 1589/8555 [14:13<1:01:08,  1.90it/s]

2021091905-1102


 19%|█████████████████████▏                                                                                           | 1601/8555 [14:19<1:02:15,  1.86it/s]

2021091905-1730


 19%|█████████████████████▏                                                                                           | 1606/8555 [14:22<1:00:40,  1.91it/s]

2021091905-1889


 19%|█████████████████████▍                                                                                           | 1623/8555 [14:31<1:00:19,  1.92it/s]

2021091905-2891


 19%|█████████████████████▋                                                                                           | 1640/8555 [14:40<1:01:19,  1.88it/s]

2021091905-3779


 20%|██████████████████████                                                                                           | 1671/8555 [14:56<1:00:19,  1.90it/s]

2021091906-1989


 20%|██████████████████████                                                                                           | 1675/8555 [14:58<1:01:11,  1.87it/s]

2021091906-2272


 20%|██████████████████████▍                                                                                          | 1697/8555 [15:10<1:01:28,  1.86it/s]

2021091906-3336


 20%|███████████████████████▏                                                                                           | 1722/8555 [15:24<59:54,  1.90it/s]

2021091907-843


 20%|███████████████████████▎                                                                                           | 1736/8555 [15:31<59:10,  1.92it/s]

2021091907-1836


 20%|███████████████████████▎                                                                                           | 1737/8555 [15:31<56:45,  2.00it/s]

2021091907-1857


 20%|███████████████████████▍                                                                                           | 1743/8555 [15:35<58:16,  1.95it/s]

2021091907-2273


 20%|███████████████████████▍                                                                                           | 1745/8555 [15:36<58:46,  1.93it/s]

2021091907-2459


 20%|███████████████████████▏                                                                                         | 1752/8555 [15:39<1:01:09,  1.85it/s]

2021091907-2890


 21%|███████████████████████▏                                                                                         | 1756/8555 [15:41<1:00:21,  1.88it/s]

2021091907-3224


 21%|███████████████████████▏                                                                                         | 1759/8555 [15:43<1:01:17,  1.85it/s]

2021091907-3627


 21%|███████████████████████▏                                                                                         | 1760/8555 [15:44<1:01:11,  1.85it/s]

2021091907-3653


 21%|███████████████████████▌                                                                                         | 1782/8555 [15:55<1:01:00,  1.85it/s]

2021091908-1028


 22%|████████████████████████▊                                                                                          | 1843/8555 [16:28<59:25,  1.88it/s]

2021091909-569


 22%|████████████████████████▍                                                                                        | 1847/8555 [16:30<1:00:56,  1.83it/s]

2021091909-802


 22%|████████████████████████▍                                                                                        | 1849/8555 [16:32<1:00:10,  1.86it/s]

2021091909-1175


 22%|█████████████████████████▏                                                                                         | 1876/8555 [16:46<58:58,  1.89it/s]

2021091909-3273


 22%|█████████████████████████▎                                                                                         | 1882/8555 [16:49<59:27,  1.87it/s]

2021091909-3565


 22%|█████████████████████████▋                                                                                         | 1914/8555 [17:06<58:32,  1.89it/s]

2021091910-848


 23%|██████████████████████████▏                                                                                        | 1944/8555 [17:22<58:01,  1.90it/s]

2021091910-2359


 24%|███████████████████████████                                                                                        | 2012/8555 [17:59<57:35,  1.89it/s]

2021091911-1972


 24%|███████████████████████████▍                                                                                       | 2040/8555 [18:14<59:08,  1.84it/s]

2021091911-3554


 24%|███████████████████████████▌                                                                                       | 2051/8555 [18:20<57:16,  1.89it/s]

2021091912-110


 24%|███████████████████████████▊                                                                                       | 2070/8555 [18:30<58:31,  1.85it/s]

2021091912-1171


 24%|███████████████████████████▉                                                                                       | 2079/8555 [18:35<58:18,  1.85it/s]

2021091912-1868


 25%|████████████████████████████▎                                                                                      | 2103/8555 [18:48<57:08,  1.88it/s]

2021091912-3422


 25%|████████████████████████████▍                                                                                      | 2118/8555 [18:56<57:11,  1.88it/s]

2021091912-4213


 25%|█████████████████████████████                                                                                      | 2160/8555 [19:18<57:21,  1.86it/s]

2021091913-2561


 25%|█████████████████████████████▏                                                                                     | 2167/8555 [19:22<57:34,  1.85it/s]

2021091913-2960


 26%|█████████████████████████████▍                                                                                     | 2187/8555 [19:33<56:14,  1.89it/s]

2021092000-354


 26%|█████████████████████████████▌                                                                                     | 2196/8555 [19:38<56:41,  1.87it/s]

2021092000-1073


 26%|█████████████████████████████▊                                                                                     | 2220/8555 [19:51<56:45,  1.86it/s]

2021092000-2587


 26%|█████████████████████████████▉                                                                                     | 2231/8555 [19:57<57:17,  1.84it/s]

2021092000-3365


 26%|██████████████████████████████▏                                                                                    | 2250/8555 [20:07<56:05,  1.87it/s]

2021092300-620


 27%|██████████████████████████████▋                                                                                    | 2283/8555 [20:24<55:04,  1.90it/s]

2021092300-2422


 27%|██████████████████████████████▋                                                                                    | 2287/8555 [20:26<55:43,  1.87it/s]

2021092300-2713


 27%|███████████████████████████████▍                                                                                   | 2336/8555 [20:52<53:54,  1.92it/s]

2021092600-1440


 27%|███████████████████████████████▍                                                                                   | 2340/8555 [20:54<54:17,  1.91it/s]

2021092600-1636


 28%|███████████████████████████████▊                                                                                   | 2365/8555 [21:07<53:48,  1.92it/s]

2021092600-3245


 28%|███████████████████████████████▊                                                                                   | 2371/8555 [21:10<53:32,  1.93it/s]

2021092600-3495


 28%|███████████████████████████████▉                                                                                   | 2373/8555 [21:12<53:54,  1.91it/s]

2021092600-3544


 28%|███████████████████████████████▉                                                                                   | 2375/8555 [21:13<53:42,  1.92it/s]

2021092600-3639


 28%|███████████████████████████████▉                                                                                   | 2380/8555 [21:15<52:13,  1.97it/s]

2021092600-3942


 28%|████████████████████████████████                                                                                   | 2383/8555 [21:17<52:34,  1.96it/s]

2021092600-4119


 28%|████████████████████████████████                                                                                   | 2388/8555 [21:19<53:14,  1.93it/s]

2021092601-166


 28%|████████████████████████████████▏                                                                                  | 2391/8555 [21:21<53:09,  1.93it/s]

2021092601-358


 28%|████████████████████████████████▎                                                                                  | 2401/8555 [21:26<51:35,  1.99it/s]

2021092601-908


 28%|████████████████████████████████▌                                                                                  | 2421/8555 [21:36<52:53,  1.93it/s]

2021092601-2172


 28%|████████████████████████████████▌                                                                                  | 2424/8555 [21:38<52:25,  1.95it/s]

2021092601-2294


 29%|█████████████████████████████████                                                                                  | 2458/8555 [21:55<51:05,  1.99it/s]

2021092602-799


 29%|█████████████████████████████████                                                                                  | 2459/8555 [21:56<50:58,  1.99it/s]

2021092602-857


 29%|████████████████████████████████▌                                                                                | 2461/8555 [21:57<1:00:16,  1.69it/s]

2021092602-941


 29%|█████████████████████████████████                                                                                  | 2463/8555 [21:58<59:53,  1.70it/s]

2021092602-1058


 29%|████████████████████████████████▌                                                                                | 2467/8555 [22:01<1:04:40,  1.57it/s]

2021092602-1246


 29%|█████████████████████████████████▌                                                                                 | 2496/8555 [22:17<54:53,  1.84it/s]

2021092602-2850


 30%|██████████████████████████████████                                                                                 | 2531/8555 [22:36<53:58,  1.86it/s]

2021092603-1055


 30%|██████████████████████████████████                                                                                 | 2532/8555 [22:36<53:51,  1.86it/s]

2021092603-1111


 30%|██████████████████████████████████                                                                                 | 2536/8555 [22:38<54:19,  1.85it/s]

2021092603-1302


 30%|██████████████████████████████████▏                                                                                | 2543/8555 [22:42<52:43,  1.90it/s]

2021092603-1875


 30%|██████████████████████████████████▎                                                                                | 2548/8555 [22:45<53:38,  1.87it/s]

2021092603-2262


 30%|██████████████████████████████████▋                                                                                | 2579/8555 [23:01<52:33,  1.89it/s]

2021092604-232


 30%|██████████████████████████████████▋                                                                                | 2580/8555 [23:02<52:19,  1.90it/s]

2021092604-304


 30%|██████████████████████████████████▊                                                                                | 2586/8555 [23:05<54:33,  1.82it/s]

2021092604-564


 31%|███████████████████████████████████▎                                                                               | 2624/8555 [23:25<52:01,  1.90it/s]

2021092604-2625


 31%|███████████████████████████████████▍                                                                               | 2637/8555 [23:32<52:56,  1.86it/s]

2021092604-3368


 31%|███████████████████████████████████▍                                                                               | 2639/8555 [23:33<52:37,  1.87it/s]

2021092604-3610


 31%|███████████████████████████████████▋                                                                               | 2654/8555 [23:41<54:14,  1.81it/s]

2021092604-4253


 31%|███████████████████████████████████▉                                                                               | 2674/8555 [23:52<52:25,  1.87it/s]

2021092605-1069


 31%|████████████████████████████████████                                                                               | 2687/8555 [23:59<53:55,  1.81it/s]

2021092605-1784


 31%|████████████████████████████████████▏                                                                              | 2690/8555 [24:01<53:39,  1.82it/s]

2021092605-1911


 31%|████████████████████████████████████▏                                                                              | 2691/8555 [24:02<53:23,  1.83it/s]

2021092605-1940


 32%|████████████████████████████████████▍                                                                              | 2707/8555 [24:10<52:21,  1.86it/s]

2021092605-2915


 32%|████████████████████████████████████▍                                                                              | 2713/8555 [24:13<51:50,  1.88it/s]

2021092605-3112


 32%|████████████████████████████████████▍                                                                              | 2714/8555 [24:14<51:40,  1.88it/s]

2021092605-3197


 32%|████████████████████████████████████▌                                                                              | 2717/8555 [24:16<50:55,  1.91it/s]

2021092605-3296


 32%|████████████████████████████████████▌                                                                              | 2722/8555 [24:18<51:40,  1.88it/s]

2021092605-3684


 32%|████████████████████████████████████▋                                                                              | 2733/8555 [24:25<52:00,  1.87it/s]


KeyboardInterrupt: 

In [16]:
#Use both models on all plays to compute the expected pocket time
games = list(dat['pass'].keys())
return_dict = {}
for i in tqdm(range(len(games))):
    i = 44
    #try:
    return_dict[str(games[i])] = {}
    return_dict[str(games[i])]['EPT_ave'] = {}
    return_dict[str(games[i])]['EPT_decomp'] = {}

    #Create arrays for the function
    frames_tot = np.sum(np.array([len(list(dat['pass'][games[i]].keys())) for i in range(len(games))]))
    X = np.zeros((frames_tot,22*7))*np.nan #[playId, stats for each player]
    X_players = np.zeros((frames_tot,22*7))*np.nan


    count = 0
    #Get the frame at which the QB was sacked
    frames = list(dat['pass'][games[i]].keys())

    #Get the frame at which the ball was snapped
    ball_snap = int(catchframe[games[i]])
    ball_pass = int(passframe[games[i]])

    tot_frames = frames[ball_snap:ball_pass+1]
    for j in range(len(tot_frames)-1): #up until QB was sacked (including)
        X[count,:],X_players[count,:] = GetFieldStateFeatures(games[i],tot_frames[j],dat)

        #Get offsenive linemam
        playerID_C = player_pos2player_id(dat,games[i],'C')
        playerID_G = player_pos2player_id(dat,games[i],'G')
        playerID_T = player_pos2player_id(dat,games[i],'T')

        #Get defensive lineman
        playerID_NT = player_pos2player_id(dat,games[i],'NT')
        playerID_DT = player_pos2player_id(dat,games[i],'DT')
        playerID_DE = player_pos2player_id(dat,games[i],'DE')

        if '43292.0' in playerID_DE:
            print('yes')

        all_players = np.concatenate((playerID_C,playerID_G,playerID_T,playerID_NT,playerID_DT,playerID_DE))

        #Initialize empty arrays
        if j == 0:
            playerID = np.zeros((frames_tot,all_players.size))
            X_a = np.zeros((all_players.size,frames_tot,6*7))*np.nan #[O-line player,playId, 5 closest players + reference player w/ 7 data points]

        playerID[count,:] = all_players
        playerID_tags = np.concatenate((['C']*len(playerID_C),['G']*len(playerID_G),['T']*len(playerID_T),['NT']*len(playerID_NT),['DT']*len(playerID_DT),['DE']*len(playerID_DE)))[:6]

        X_a[:,count,:] = [model_a_assimilation(dat,games[i],tot_frames[j],tot_frames,j,str(pid)) for pid in playerID[count,:]]

        count += 1
    #Remove nan data points (frames after the ball was thrown)
    ishave = np.where(np.isnan(X))[0][0]
    X = X[:ishave]
    X_players = X_players[:ishave]
    X_a = X_a[:,:ishave,:]
    xt = torch.tensor(X)

    #Call the model and make the prediction
    pred=np.zeros((X.shape[0]))*np.nan

    with torch.no_grad():
        for count,data in enumerate(xt):
            outputs = net(data.float()).squeeze()
            pred[count] = float(outputs)

    return_dict[str(games[i])]['EPT'] = pred

    #Put the field states into model a for the O-line players to find contributions to the total expected time
    net_a = torch.load(data_dir+'model_a_full.pt') #model a
    xt_a = torch.tensor(X_a)

    #Call the model and make the prediction for each O-line player
    pred_a=np.zeros((X_a.shape[0],X_a.shape[1],5))*np.nan #[O-line Player, frames, (dx,dy,a,u,v)]
    for c in range(X_a.shape[0]):
        with torch.no_grad():
            for count,data in enumerate(xt_a[c]):
                outputs = net_a(data.float()).squeeze()
                pred_a[c,count,:] = np.array(outputs).astype(float)


    #Compute the field states needed for model B from model A's output a: (dx,dy,a,u,v) b: (xball,yball,dist,u,v,a)
    playerID = playerID[0].tolist()
    pred_newb_return = []
    for i_ in range(pred_a.shape[0]):
        X_newb = np.copy(X)
        for j in range(X_newb.shape[0]):
            #dx relative to the football
            X_newb[j][np.where(X_players[j]==float(playerID[i_]))[0][0]] = X[j][np.where(X_players[j]==float(playerID[i_]))[0][0]] + pred_a[i_][j][0]

            #dy relative to the football
            X_newb[j][np.where(X_players[j]==float(playerID[i_]))[0][1]] = X[j][np.where(X_players[j]==float(playerID[i_]))[0][1]] + pred_a[i_][j][1]

            #distance relative to the football
            X_newb[j][np.where(X_players[j]==float(playerID[i_]))[0][2]] = X[j][np.where(X_players[j]==float(playerID[i_]))[0][2]] + (np.sqrt((pred_a[i_][j][0]**2)+(pred_a[i_][j][1]**2)))

            #u component of the speed
            X_newb[j][np.where(X_players[j]==float(playerID[i_]))[0][3]] = X[j][np.where(X_players[j]==float(playerID[i_]))[0][3]] + pred_a[i_][j][3]

            #v component of the speed
            X_newb[j][np.where(X_players[j]==float(playerID[i_]))[0][4]] = X[j][np.where(X_players[j]==float(playerID[i_]))[0][4]] + pred_a[i_][j][4]

            #Acceleration of the player
            X_newb[j][np.where(X_players[j]==float(playerID[i_]))[0][5]] = X[j][np.where(X_players[j]==float(playerID[i_]))[0][5]] + pred_a[i_][j][2]

        #Call the model and make the prediction of the contribution
        xt_newb = torch.tensor(X_newb)
        pred_newb=np.zeros((len(X_newb)))*np.nan

        with torch.no_grad():
            for count,data in enumerate(xt_newb):
                outputs = net(data.float()).squeeze()
                pred_newb[count] = float(outputs)
        return_dict[str(games[i])]['EPT_ave'][str(playerID[i_])] = pred_newb

    #Compute the field states needed for model B -> second controlled experiment
    for i_ in range(len(playerID)):
        X_newb2 = np.copy(X)
        for j in range(X_newb2.shape[0]-1):
            #dx relative to the football
            X_newb2[j][np.where(X_players[j]==float(playerID[i_]))[0][0]] = X[j][np.where(X_players[j]==float(playerID[i_]))[0][0]] + X[j+1][np.where(X_players[j+1]==float(playerID[i_]))[0][0]]

            #dy relative to the football
            X_newb2[j][np.where(X_players[j]==float(playerID[i_]))[0][1]] = X[j][np.where(X_players[j]==float(playerID[i_]))[0][1]] + X[j+1][np.where(X_players[j+1]==float(playerID[i_]))[0][1]]

            #distance relative to the football
            X_newb2[j][np.where(X_players[j]==float(playerID[i_]))[0][2]] = X[j][np.where(X_players[j]==float(playerID[i_]))[0][2]] + (np.sqrt((X[j+1][np.where(X_players[j+1]==float(playerID[i_]))[0][0]]**2)+(X[j+1][np.where(X_players[j+1]==float(playerID[i_]))[0][1]]**2)))

            #u component of the speed
            X_newb2[j][np.where(X_players[j]==float(playerID[i_]))[0][3]] = X[j][np.where(X_players[j]==float(playerID[i_]))[0][3]] + X[j+1][np.where(X_players[j+1]==float(playerID[i_]))[0][3]]

            #v component of the speed
            X_newb2[j][np.where(X_players[j]==float(playerID[i_]))[0][4]] = X[j][np.where(X_players[j]==float(playerID[i_]))[0][4]] + X[j+1][np.where(X_players[j+1]==float(playerID[i_]))[0][4]]

            #Acceleration of the player
            X_newb2[j][np.where(X_players[j]==float(playerID[i_]))[0][5]] = X[j][np.where(X_players[j]==float(playerID[i_]))[0][5]] + X[j+1][np.where(X_players[j+1]==float(playerID[i_]))[0][2]]


        #Call the model and make the prediction of the contribution
        xt_newb2 = torch.tensor(X_newb2)
        pred_newb2 = np.zeros((len(X_newb2)))*np.nan

        with torch.no_grad():
            for count,data in enumerate(xt_newb2):
                outputs = net(data.float()).squeeze()
                pred_newb2[count] = float(outputs)
        return_dict[str(games[i])]['EPT_decomp'][str(playerID[i_])] = pred_newb2
    #except KeyError:
        #pass


  0%|                                                                                                                              | 0/8555 [00:00<?, ?it/s]


KeyError: '2021090900-2298'

In [60]:
#Dump this data into a pickle file for later analysis
with open('EPT_variables.pkl', 'wb') as f:
    pickle.dump([return_dict], f)

In [52]:
return_dict[str(games[i])].keys()

dict_keys(['EPT_ave', 'EPT_decomp', 'EPT'])

In [66]:
dat['pass'][str(games[0])]['1']['25511.0']

['37.77',
 '24.22',
 '0.29',
 '0.3',
 '0.03',
 '165.16',
 '84.99',
 '12.0',
 'right',
 'QB',
 '-1.73852913314477',
 'home']

In [26]:
games[i]

'2021092605-4162'

In [27]:
ball_snap = int(catchframe['2021091913-2775'])
ball_pass = int(passframe['2021091913-2775'])
ball_snap,ball_pass

(6, 55)