In [1]:
import pandas as pd
import numpy as np
import torch
from torch.nn.functional import softplus

## set GPU or CPU (uncomment second line to use CPU)
# device = torch.device('cuda:0'); dtype = torch.float16
device = 'cpu'; dtype = torch.float32

In [2]:
away_data = pd.read_csv('https://raw.githubusercontent.com/metrica-sports/sample-data/master/data/Sample_Game_1/Sample_Game_1_RawTrackingData_Away_Team.csv', skiprows=2)
home_data = pd.read_csv('https://raw.githubusercontent.com/metrica-sports/sample-data/master/data/Sample_Game_1/Sample_Game_1_RawTrackingData_Home_Team.csv', skiprows=2)


In [4]:

jitter = 1e-12 ## prevents division by zero when player is stationary

## set nans to large negative value -- this makes pitch control for players who aren't involved negligibly small everywhere
# on the pitch
home_pos = np.array([np.asarray(home_data.iloc[:,range(3 + j*2,3 + j*2 +2)]) for j in range(14)]) * np.array([105,68])
np.nan_to_num(home_pos,copy=False,nan=-1000)
away_pos = np.array([np.asarray(away_data.iloc[:,range(3 + j*2,3 + j*2 +2)]) for j in range(14)]) * np.array([105,68])
np.nan_to_num(away_pos,copy=False,nan=-1000)
ball_pos = (np.asarray(home_data.iloc[:,range(31,33)]) * np.array([105,68]))[:,None,None,:]


# defining the delta in seconds between frames
tt = np.asarray(home_data['Time [s]'])
dt = tt[1:] - tt[:-1] 
# velocity in both x and y axis + jitter 
# jitter prevents division by zero when player is stationary
home_v = (home_pos[:,1:,:] - home_pos[:,:-1,:])/dt[:,None] + jitter
np.nan_to_num(home_v,copy=False,nan=-1000)
away_v = (away_pos[:,1:,:] - away_pos[:,:-1,:])/dt[:,None] + jitter
np.nan_to_num(away_v,copy=False,nan=-1000)
# velocity -1000 not really necessary?




# severe changes in the shape of the arrays
# I am not 100% how and why this is necessary
# This is probably necessary for later 
home_pos = home_pos[:,1:,None,None,:]
away_pos = away_pos[:,1:,None,None,:]
home_v = home_v[:,:,None,None,:]
away_v = away_v[:,:,None,None,:]
ball_pos = ball_pos[None,1:]

## set up evaluation grid and set some pitch control parameters (these are taken from the FoT code)
reaction_time = 0.7
max_player_speed = 5.
average_ball_speed = 15.
sigma = np.pi / np.sqrt(3.) / 0.45
lamb = 4.3
n_grid_points_x = 50
n_grid_points_y = 30
# create grid based on tensors
XX,YY = torch.meshgrid(torch.linspace(0,105,n_grid_points_x, device = device, dtype=dtype),
                       torch.linspace(0,68,n_grid_points_y,device=device,dtype=dtype))
target_position = torch.stack([XX,YY],2)[None,None,:,:,:] # all possible positions


# the weights and the x-points of the Gauss–Legendre quadrature are set up and stored as torch tensors
# I assume that, as the Gauss–Legendre quadrature spans the range of [-1, 1] (length of 2), the 50 points selected
# represent the number of frames in a 2 second time period based on a framerate of 25Hz (i.e. 25 frames / second)
ti,wi = np.polynomial.legendre.leggauss(50) ## used for numerical integration later on
ti = torch.tensor(ti,device = device,dtype=dtype)
wi = torch.tensor(wi,device=device,dtype=dtype)




n_frames = home_pos.shape[1]
first_frame = 0
batch_size = 250

# time to intercept empty torch
tti = torch.empty([28,batch_size,n_grid_points_x,n_grid_points_y],device = device,dtype=dtype)
# 28 players*500 batches*grid
tmp2 = torch.empty([28,batch_size,n_grid_points_x,n_grid_points_y,1],device = device,dtype=dtype)
# 28 players*500 batches*grid * 1
pc = torch.empty([n_frames,n_grid_points_x,n_grid_points_y],device = device,dtype=dtype)
# frames * grid

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [45]:
#for f in range(int(n_frames/batch_size)):
for f in range(1):
    bp = torch.tensor(ball_pos[:,(first_frame + f*batch_size):(np.minimum(first_frame + (f+1)*batch_size,int(first_frame + n_frames)))],device = device,dtype=dtype)
    hp = torch.tensor(home_pos[:,(first_frame + f*batch_size):(np.minimum(first_frame + (f+1)*batch_size,int(first_frame + n_frames)))],device = device,dtype=dtype)
    hv = torch.tensor(home_v[:,(first_frame + f*batch_size):(np.minimum(first_frame + (f+1)*batch_size,int(first_frame + n_frames)))],device = device,dtype=dtype)
    ap = torch.tensor(away_pos[:,(first_frame + f*batch_size):(np.minimum(first_frame + (f+1)*batch_size,int(first_frame + n_frames)))],device = device,dtype=dtype)
    av = torch.tensor(away_v[:,(first_frame + f*batch_size):(np.minimum(first_frame + (f+1)*batch_size,int(first_frame + n_frames)))],device = device,dtype=dtype)
    ball_travel_time = torch.norm(target_position - bp, dim=4).div_(average_ball_speed)
    r_reaction_home = hp + hv.mul_(reaction_time)
    r_reaction_away = ap + av.mul_(reaction_time)
    r_reaction_home = r_reaction_home - target_position
    r_reaction_away = r_reaction_away - target_position
    tti[:14,:ball_travel_time.shape[1]] = torch.norm(r_reaction_home,dim=4).add_(reaction_time).div_(max_player_speed)
    tti[14:,:ball_travel_time.shape[1]] = torch.norm(r_reaction_away,dim=4).add_(reaction_time).div_(max_player_speed)

    y = torch.zeros([28,bp.shape[1],n_grid_points_x,n_grid_points_y],device = device,dtype=dtype)

    for tt in range(500):
        sumy = torch.sum(y, dim=0) # control over all players
        if torch.min(sumy) > 0.99:
            break
        y += 0.04 * lamb * (1. - sumy) * 1. / (1. + torch.exp(-sigma * (0.04*tt + ball_travel_time - tti)))
        # dt * lambda * probability that noone else control it * probability of intercept as in eq (3)
        # for every dt after ball arrived until convergence  
    pc[(first_frame + f*batch_size):(np.minimum(first_frame + (f+1)*batch_size,int(first_frame + n_frames)))] = y[:14].sum(0)


In [53]:
0.04 * lamb * (1. - sumy) * 1. / (1. + torch.exp(-sigma * (0.04*tt + ball_travel_time - tti)))


tensor([[[[0., 0., 0.,  ..., 0., nan, 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         ...,

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0.,

In [23]:
1. / (1. + torch.exp(-sigma * (0.02*tt + ball_travel_time - tti)))

tensor([[[[9.9987e-01, 9.9997e-01, 9.9999e-01,  ..., 9.9999e-01,
           9.9997e-01, 9.9986e-01],
          [9.9977e-01, 9.9995e-01, 9.9999e-01,  ..., 9.9999e-01,
           9.9995e-01, 9.9977e-01],
          [9.9956e-01, 9.9991e-01, 9.9998e-01,  ..., 9.9998e-01,
           9.9990e-01, 9.9958e-01],
          ...,
          [1.0773e-22, 1.4066e-22, 1.8048e-22,  ..., 6.0781e-22,
           5.0345e-22, 4.0750e-22],
          [3.4435e-23, 4.4922e-23, 5.7605e-23,  ..., 1.8555e-22,
           1.5375e-22, 1.2458e-22],
          [1.1026e-23, 1.4369e-23, 1.8410e-23,  ..., 5.6857e-23,
           4.7134e-23, 3.8230e-23]],

         [[9.9996e-01, 9.9999e-01, 1.0000e+00,  ..., 1.0000e+00,
           9.9998e-01, 9.9993e-01],
          [9.9993e-01, 9.9999e-01, 1.0000e+00,  ..., 9.9999e-01,
           9.9997e-01, 9.9988e-01],
          [9.9986e-01, 9.9997e-01, 9.9999e-01,  ..., 9.9999e-01,
           9.9995e-01, 9.9979e-01],
          ...,
          [5.1955e-23, 6.5663e-23, 8.1531e-23,  ..., 2.1168

In [35]:
def prin(T, tti):
    return 1. / (1. + np.exp(-sigma * (T - tti)))

def prin2(T, tti, tt):
    return 1. / (1. + np.exp(-sigma * (0.02*tt + T - tti)))


In [36]:
prin(3,2)

0.982547490913071

In [44]:
prin2(3,2,0)

0.982547490913071

In [34]:
tt*0.02

4.98

In [15]:
sumy.shape

torch.Size([250, 50, 30])

In [16]:
dt

array([0.04, 0.04, 0.04, ..., 0.04, 0.04, 0.04])

tensor([[[[8.9910e-01, 9.5694e-01, 9.8441e-01,  ..., 9.9535e-01,
           9.8955e-01, 9.7865e-01],
          [7.3303e-01, 8.6176e-01, 9.4115e-01,  ..., 9.8268e-01,
           9.6461e-01, 9.3468e-01],
          [4.5462e-01, 6.3836e-01, 8.0363e-01,  ..., 9.3767e-01,
           8.8851e-01, 8.2034e-01],
          ...,
          [1.7559e-25, 5.3561e-26, 1.5440e-26,  ..., 2.4233e-26,
           4.9053e-26, 9.1709e-26],
          [5.3482e-26, 1.6404e-26, 4.7755e-27,  ..., 1.1272e-26,
           2.5340e-26, 5.1834e-26],
          [1.7252e-26, 5.4167e-27, 1.6196e-27,  ..., 5.0041e-27,
           1.2235e-26, 2.7266e-26]],

         [[8.9903e-01, 9.5690e-01, 9.8439e-01,  ..., 9.9534e-01,
           9.8954e-01, 9.7863e-01],
          [7.3358e-01, 8.6212e-01, 9.4133e-01,  ..., 9.8273e-01,
           9.6472e-01, 9.3485e-01],
          [4.5615e-01, 6.3989e-01, 8.0478e-01,  ..., 9.3810e-01,
           8.8916e-01, 8.2121e-01],
          ...,
          [2.1125e-25, 6.6063e-26, 1.9410e-26,  ..., 2.7960

In [18]:
lamb

4.3

In [20]:
sumy

tensor([[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         ...,
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         ...,
         [0.9999, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [0.9998, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [0.9997, 0.9999, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

        [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
         [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.