In [2]:
import pandas as pd
import numpy as np
import torch
from torch.nn.functional import softplus

In [3]:
### load the data and process it
away_data = pd.read_csv('https://raw.githubusercontent.com/metrica-sports/sample-data/master/data/Sample_Game_1/Sample_Game_1_RawTrackingData_Away_Team.csv', skiprows=2)
home_data = pd.read_csv('https://raw.githubusercontent.com/metrica-sports/sample-data/master/data/Sample_Game_1/Sample_Game_1_RawTrackingData_Home_Team.csv', skiprows=2)


In [5]:
away_data.head()

Unnamed: 0,Period,Frame,Time [s],Player25,Unnamed: 4,Player15,Unnamed: 6,Player16,Unnamed: 8,Player17,...,Player24,Unnamed: 24,Player26,Unnamed: 26,Player27,Unnamed: 28,Player28,Unnamed: 30,Ball,Unnamed: 32
0,1,1,0.04,0.90509,0.47462,0.58393,0.20794,0.67658,0.4671,0.6731,...,0.37833,0.27383,,,,,,,0.45472,0.38709
1,1,2,0.08,0.90494,0.47462,0.58393,0.20794,0.67658,0.4671,0.6731,...,0.37833,0.27383,,,,,,,0.49645,0.40656
2,1,3,0.12,0.90434,0.47463,0.58393,0.20794,0.67658,0.4671,0.6731,...,0.37833,0.27383,,,,,,,0.53716,0.42556
3,1,4,0.16,0.90377,0.47463,0.58351,0.20868,0.6764,0.46762,0.67279,...,0.37756,0.27473,,,,,,,0.55346,0.42231
4,1,5,0.2,0.90324,0.47464,0.58291,0.21039,0.67599,0.46769,0.67253,...,0.37663,0.27543,,,,,,,0.55512,0.4057


In [74]:
device = 'cpu'; dtype = torch.float32


jitter = 1e-12 ## prevents division by zero when player is stationary

home_pos = np.array([np.asarray(home_data.iloc[:,range(3 + j*2,3 + j*2 +2)]) for j in range(14)]) * np.array([105,68])
## set nans to large negative value -- this makes pitch control for players who aren't involved negligibly small everywhere on the pitch
np.nan_to_num(home_pos,copy=False,nan=-1000)
away_pos = np.array([np.asarray(away_data.iloc[:,range(3 + j*2,3 + j*2 +2)]) for j in range(14)]) * np.array([105,68])
np.nan_to_num(away_pos,copy=False,nan=-1000)

ball_pos = (np.asarray(home_data.iloc[:,range(31,33)]) * np.array([105,68]))[:,None,None,:]
tt = np.asarray(home_data['Time [s]'])
dt = tt[1:] - tt[:-1] # delta in time between frames

# velocity -1000 not really necessary?
home_v = (home_pos[:,1:,:] - home_pos[:,:-1,:])/dt[:,None] + jitter
np.nan_to_num(home_v,copy=False,nan=-1000)
away_v = (away_pos[:,1:,:] - away_pos[:,:-1,:])/dt[:,None] + jitter
np.nan_to_num(away_v,copy=False,nan=-1000)

# what is the point of this block
home_pos = home_pos[:,1:,None,None,:]
away_pos = away_pos[:,1:,None,None,:]
home_v = home_v[:,:,None,None,:]
away_v = away_v[:,:,None,None,:]
ball_pos = ball_pos[None,1:]


## set up evaluation grid and set some pitch control parameters (these are taken from the FoT code)
reaction_time = 0.7
max_player_speed = 5.
average_ball_speed = 15.
sigma = np.pi / np.sqrt(3.) / 0.45
lamb = 4.3
n_grid_points_x = 50
n_grid_points_y = 30

# create grid based on tensors
XX,YY = torch.meshgrid(torch.linspace(0,105,n_grid_points_x, device = device, dtype=dtype),torch.linspace(0,68,n_grid_points_y,device=device,dtype=dtype))
ti,wi = np.polynomial.legendre.leggauss(50) ## used for numerical integration later on
ti = torch.tensor(ti,device = device,dtype=dtype)
wi = torch.tensor(wi,device=device,dtype=dtype)
target_position = torch.stack([XX,YY],2)[None,None,:,:,:] # all possible positions
n_frames = home_pos.shape[1]
first_frame = 0
batch_size = 500

# time to intercept empty torch
tti = torch.empty([28,batch_size,n_grid_points_x,n_grid_points_y],device = device,dtype=dtype) # 28 players*500 batches*grid
tmp2 = torch.empty([28,batch_size,n_grid_points_x,n_grid_points_y,1],device = device,dtype=dtype) # 28 players*500 batches*grid * 1
pc = torch.empty([n_frames,n_grid_points_x,n_grid_points_y],device = device,dtype=dtype) # frames * grid

for f in range(int(n_frames/batch_size)):
    # taking 500 frames or the last frames if less than 500 left
    bp = torch.tensor(ball_pos[:,(first_frame + f*batch_size):(np.minimum(first_frame + (f+1)*batch_size,int(first_frame + n_frames)))],device = device,dtype=dtype)
    hp = torch.tensor(home_pos[:,(first_frame + f*batch_size):(np.minimum(first_frame + (f+1)*batch_size,int(first_frame + n_frames)))],device = device,dtype=dtype)
    hv = torch.tensor(home_v[:,(first_frame + f*batch_size):(np.minimum(first_frame + (f+1)*batch_size,int(first_frame + n_frames)))],device = device,dtype=dtype)
    ap = torch.tensor(away_pos[:,(first_frame + f*batch_size):(np.minimum(first_frame + (f+1)*batch_size,int(first_frame + n_frames)))],device = device,dtype=dtype)
    av = torch.tensor(away_v[:,(first_frame + f*batch_size):(np.minimum(first_frame + (f+1)*batch_size,int(first_frame + n_frames)))],device = device,dtype=dtype)
    ball_travel_time = torch.norm(target_position - bp, dim=4).div_(average_ball_speed) # ball travel time to each location in each frame
    r_reaction_home = hp + hv.mul_(reaction_time) # position after reaction time =
    r_reaction_away = ap + av.mul_(reaction_time) # position + velocity multiplied by reaction time
    r_reaction_home = r_reaction_home - target_position # distance to target position
    r_reaction_away = r_reaction_away - target_position # after reaction time
    
    # time to intercept for home and away filled 
    tti[:14,:ball_travel_time.shape[1]] = torch.norm(r_reaction_home,dim=4).add_(reaction_time).div_(max_player_speed)
    tti[14:,:ball_travel_time.shape[1]] = torch.norm(r_reaction_away,dim=4).add_(reaction_time).div_(max_player_speed)



In [78]:
torch.norm(r_reaction_home,dim=4)

tensor([[[[ 111.0587,  110.3900,  109.7674,  ...,  110.4343,  111.1061,
            111.8230],
          [ 109.0129,  108.3316,  107.6971,  ...,  108.3767,  109.0612,
            109.7915],
          [ 106.9710,  106.2766,  105.6297,  ...,  106.3226,  107.0202,
            107.7643],
          ...,
          [  33.1908,   30.8796,   28.5739,  ...,   31.0374,   33.3489,
             35.6647],
          [  32.9097,   30.5773,   28.2470,  ...,   30.7367,   33.0692,
             35.4033],
          [  32.7666,   30.4233,   28.0802,  ...,   30.5835,   32.9268,
             35.2704]],

         [[ 111.0667,  110.3981,  109.7755,  ...,  110.4424,  111.1141,
            111.8310],
          [ 109.0210,  108.3397,  107.7052,  ...,  108.3848,  109.0692,
            109.7994],
          [ 106.9790,  106.2847,  105.6378,  ...,  106.3306,  107.0282,
            107.7722],
          ...,
          [  33.1921,   30.8811,   28.5755,  ...,   31.0389,   33.3503,
             35.6660],
          [  32.91