This notebook converts the `Metrica_PitchControl.py` script, written by Laurie Shaw ([@EightyFivePoint](https://twitter.com/EightyFivePoint)), from pure Python/Numpy, into [Cython](https://cython.org/) code, in order to compare computation speed. Unfortunately since I've fallen way behind on keeping up with all the [Friends of Tracking](https://www.youtube.com/channel/UCUBFJYcag8j2rm_9HkrrA7w/featured) videos recently 😔, the code here converts the pitch control code from Laurie's initial commit [`e047ede`](https://github.com/Friends-of-Tracking-Data-FoTD/LaurieOnTracking/blob/e047ede88e11030a9755ee783614f9c960664c01/Metrica_PitchControl.py).

In [None]:
# Clone the repos
!git clone https://github.com/metrica-sports/sample-data.git
!git clone https://github.com/Friends-of-Tracking-Data-FoTD/LaurieOnTracking.git
# I took too long to finish this and now the pitch control script has been
# updated :) Revert back to the original commit hash for now
!cd LaurieOnTracking && git reset --hard e047ede88e11030a9755ee783614f9c960664c01

import sys
sys.path.insert(1, '/content/LaurieOnTracking')
from LaurieOnTracking import Metrica_IO as mio
from LaurieOnTracking import Metrica_PitchControl as mpc
from LaurieOnTracking import Metrica_Velocities as mvelo
from LaurieOnTracking import Metrica_Viz as mviz
import numpy as np
import random

The notebook needs to load the Cython extension first. Then the `%%cython` magic can compile a notebook cell into C code.

If you use the `--annotate` flag with the `%%cython` declaration, like this:

```
%%cython --annotate
```

the cell will output the Cython code, highlighted by how much of the code translates directly to C.

In [None]:
%load_ext cython

In [None]:
%%cython
cimport cython
from libc.math cimport exp, log, M_PI, pow, sqrt
import numpy as np
cimport numpy as np

################################################################################
# Constant params. Most of these are stored in a Python dictionary.
# https://github.com/Friends-of-Tracking-Data-FoTD/LaurieOnTracking/blob/e047ede88e11030a9755ee783614f9c960664c01/Metrica_PitchControl.py#L143-L161
################################################################################
cdef:
  # Maximum player acceleration in m/s^2
  double MAX_PLAYER_ACCEL = 7.0
  # Max player speed in m/s
  double MAX_PLAYER_SPEED = 5.0
  # Player reaction time in seconds
  double REACTION_TIME = 0.7
  # Average ball travel speed in m/s
  double AVG_BALL_SPEED = 15.0
  # Upper limit on itegral time
  double MAX_INTGL_TIME = 10.0
  # Integration timestep (dt)
  double INTGL_DT = 0.04
  # Assume convergence when PPCF>0.99 at a given location
  double MODEL_CONVERGE_TOL = 0.01
  # Standard diviation of sigmoid, determines uncertainty in player arrival time
  double TTI_SIGMA = 0.45
  # Ball ctrl for constant attacking team
  double LAMBDA_ATT = 4.3
  # Ball ctrl for constant defending team
  double LAMBDA_DEF = 4.3
  # Coefficient to determine when to skip the pitch control calculation
  double TIME_TO_CTRL_VETO = 3.0
  double TIME_TO_CTRL_ATT = time_until_ctrl_override_for_team_cy(LAMBDA_ATT)
  double TIME_TO_CTRL_DEF = time_until_ctrl_override_for_team_cy(LAMBDA_DEF)

################################################################################
# This is calculated once in default_model_params()
# https://github.com/Friends-of-Tracking-Data-FoTD/LaurieOnTracking/blob/e047ede88e11030a9755ee783614f9c960664c01/Metrica_PitchControl.py#L159-L160
################################################################################
@cython.cdivision(True)
cdef double time_until_ctrl_override_for_team_cy(double team_lambda):
  return (TIME_TO_CTRL_VETO * log(10.0) *
          (sqrt(3.0) * TTI_SIGMA / M_PI + 1 / team_lambda))

################################################################################
# In the Metrica_PitchControl.py script, this is a class function on the
# player() class. Since Python classes don't translate well to Cython, this
# function calculates time-to-intercept on demand given a player's
# coordinates.
# https://github.com/Friends-of-Tracking-Data-FoTD/LaurieOnTracking/blob/e047ede88e11030a9755ee783614f9c960664c01/Metrica_PitchControl.py#L110-L116
################################################################################
@cython.cdivision(True)
cdef double simple_time_to_intercept_cy(double position_x,
                                        double position_y,
                                        double velocity_x,
                                        double velocity_y,
                                        double final_pos_x,
                                        double final_pos_y):
  cdef:
    double tti
    double norm_x
    double norm_y
    double norm
    double sum_of_sqrs
  cdef double norm_arr[2]
  r_reaction_x = position_x + velocity_x * REACTION_TIME
  r_reaction_y = position_y + velocity_y * REACTION_TIME
  norm_x = final_pos_x - r_reaction_x
  norm_y = final_pos_y - r_reaction_y
  # The next 2 lines re-create np.linalg.norm()
  sum_of_sqrs = pow(norm_x, 2) + pow(norm_y, 2)
  norm = sqrt(sum_of_sqrs)
  tti = REACTION_TIME + norm / MAX_PLAYER_SPEED
  return tti

################################################################################
# This is also a class function on the player() class, but here it is
# calculated on demand using pure C code.
# https://github.com/Friends-of-Tracking-Data-FoTD/LaurieOnTracking/blob/e047ede88e11030a9755ee783614f9c960664c01/Metrica_PitchControl.py#L118-L121
################################################################################
@cython.cdivision(True)
cdef double prob_intercept_cy(double time_t, double time_to_intercept):
  cdef double prob
  prob = 1.0 / (1.0 + exp(-1.0 * M_PI / sqrt(3.0) / TTI_SIGMA *
                          (time_t - time_to_intercept)))
  return prob

################################################################################
# This is done with a list comprehension in Python, a stores the time-to-
# intercept as a class attribute on each player. Here we store each player's
# time-to-intercept in an array and return the calculated min time-to-
# intercept.
# https://github.com/Friends-of-Tracking-Data-FoTD/LaurieOnTracking/blob/e047ede88e11030a9755ee783614f9c960664c01/Metrica_PitchControl.py#L244-L245
################################################################################
@cython.boundscheck(False)
@cython.wraparound(False)
cdef double calculate_min_time_to_intercept_cy(list players,
                                               double target_pos_x,
                                               double target_pos_y,
                                               double[:] tti_arr):
  cdef double tau_min_att = MAX_INTGL_TIME  # All times should be less than this
  cdef double simp_tti_att
  cdef size_t idx = 0
  for player in players:
    simp_tti_att = simple_time_to_intercept_cy(player.position[0],
                                               player.position[1],
                                               player.velocity[0],
                                               player.velocity[1],
                                               target_pos_x,
                                               target_pos_y)
    tti_arr[idx] = simp_tti_att
    if simp_tti_att < tau_min_att:
      tau_min_att = simp_tti_att
    idx += 1
  return tau_min_att

################################################################################
# The main pitch control calculation. Most of it follows the Python code, but
# is more verbose because of C. There is still some reliance on Numpy which
# could be converted to C with some effort.
# https://github.com/Friends-of-Tracking-Data-FoTD/LaurieOnTracking/blob/e047ede88e11030a9755ee783614f9c960664c01/Metrica_PitchControl.py#L217
################################################################################
@cython.cdivision(True)
@cython.boundscheck(False)
@cython.wraparound(False)
def calculate_pitch_control_at_target_cy(double target_pos_x,
                                         double target_pos_y,
                                         list attacking_players,
                                         list defending_players,
                                         int num_att_players,
                                         int num_def_players,
                                         double ball_start_pos_x,
                                         double ball_start_pos_y):
  cdef:
    double sum_of_sqrs,
    double norm,
    double ball_travel_time
  cdef double ball_travel_time_norm_arr[2]
  # Calculate ball travel time
  ball_travel_time_norm_arr[0] = target_pos_x - ball_start_pos_x
  ball_travel_time_norm_arr[1] = target_pos_y - ball_start_pos_y
  sum_of_sqrs = 0.0
  for i in range(2):
    sum_of_sqrs += pow(ball_travel_time_norm_arr[i], 2)
  norm = sqrt(sum_of_sqrs)
  ball_travel_time = norm / AVG_BALL_SPEED
  # Calculate arrival times for attacking team
  cdef double player_ttis_att[11]
  cdef double min_tti_att
  min_tti_att = calculate_min_time_to_intercept_cy(attacking_players,
                                                   target_pos_x,
                                                   target_pos_y,
                                                   player_ttis_att)
  # Calculate arrival times for defending team
  cdef double player_ttis_def[11]
  cdef double min_tti_def
  min_tti_def = calculate_min_time_to_intercept_cy(defending_players,
                                                   target_pos_x,
                                                   target_pos_y,
                                                   player_ttis_def)
  cdef double closer_obj_att
  cdef double closer_obj_def
  cdef double[:] dt_array
  cdef double[:] ppcf_att
  cdef double[:] ppcf_def
  cdef double ptot = 0.0
  cdef int idx = 1
  cdef double T
  cdef double player_tti
  cdef double d_ppcf_dt = 0.0
  cdef double d_ppcf_dt_intgl
  cdef Py_ssize_t player_tti_idx = 1
  cdef double player_ppcf_att[11]
  cdef double player_ppcf_def[11]
  cdef Py_ssize_t init_idx = 0
  for init_idx in range(11):
    player_ppcf_att[init_idx] = 0.0
    player_ppcf_def[init_idx] = 0.0
  if ball_travel_time > min_tti_def:
    closer_obj_att = ball_travel_time
  else:
    closer_obj_att = min_tti_def
  if ball_travel_time > min_tti_att:
    closer_obj_def = ball_travel_time
  else:
    closer_obj_def = min_tti_att
  if (min_tti_att - closer_obj_att) >= TIME_TO_CTRL_DEF:
    return 0.0, 1.0
  elif (min_tti_def - closer_obj_def) >= TIME_TO_CTRL_ATT:
    return 1.0, 0.0
  else:
    dt_array = np.arange(ball_travel_time - INTGL_DT,
                         ball_travel_time + MAX_INTGL_TIME,
                         INTGL_DT)
    ppcf_att = np.zeros_like(dt_array)
    ppcf_def = np.zeros_like(dt_array)
    while 1.0 - ptot > MODEL_CONVERGE_TOL and idx < dt_array.shape[0]:
      T = dt_array[idx]
      for player_tti_idx in range(num_att_players):
      # for player in attacking_players:
        # player_tti_idx += 1
        player_tti = player_ttis_att[player_tti_idx]
        if player_tti - min_tti_att > TIME_TO_CTRL_ATT:
          continue
        d_ppcf_dt = ((1.0 - ppcf_att[idx - 1] - ppcf_def[idx - 1]) *
                     prob_intercept_cy(T, player_tti) * LAMBDA_ATT)
        # assert d_ppcf_dt > 0
        player_ppcf_att[player_tti_idx + 1] += d_ppcf_dt * INTGL_DT
        ppcf_att[idx] += player_ppcf_att[player_tti_idx + 1]
      for player_tti_idx in range(num_def_players):
      # for player in defending_players:
        # player_tti_idx += 1
        player_tti = player_ttis_def[player_tti_idx]
        if player_tti - min_tti_def > TIME_TO_CTRL_DEF:
          continue
        d_ppcf_dt = ((1.0 - ppcf_att[idx - 1] - ppcf_def[idx - 1]) *
                     prob_intercept_cy(T, player_tti) * LAMBDA_DEF)
        # assert d_ppcf_dt > 0
        player_ppcf_def[player_tti_idx + 1] += d_ppcf_dt * INTGL_DT
        ppcf_def[idx] += player_ppcf_def[player_tti_idx + 1]
      ptot = ppcf_def[idx] + ppcf_att[idx]
      idx += 1
    if idx > dt_array.size:
      print('Integration failed to converge')
    return ppcf_att[idx - 1], ppcf_def[idx - 1]

################################################################################
# This function is the entry point from Python code. It creates the X and Y
# grids for the pitch regions and returns the Numpy arrays of pitch contrl for
# each team, similar to the generate_pitch_control_for_event() function in
# Python.
# https://github.com/Friends-of-Tracking-Data-FoTD/LaurieOnTracking/blob/e047ede88e11030a9755ee783614f9c960664c01/Metrica_PitchControl.py#L190-L215
################################################################################
@cython.cdivision(True)
def calculate_pitch_control_cy(double pitch_dimen_x,
                               double pitch_dimen_y,
                               int n_grid_cells_x,
                               int n_grid_cells_y,
                               list home_players,
                               list away_players,
                               double ball_start_pos_x,
                               double ball_start_pos_y):
  cdef Py_ssize_t x, y
  cdef double pitch_dimen_x_half, pitch_dimen_y_half
  cdef np.ndarray[np.float64_t, ndim=2] pitch_ctrl_a_cy
  cdef np.ndarray[np.float64_t, ndim=2] pitch_ctrl_d_cy
  cdef np.ndarray[np.float64_t, ndim=1] xgrid, ygrid, target_pos
  xgrid = np.linspace(-pitch_dimen_x/2.0, pitch_dimen_x/2.0, n_grid_cells_x)
  ygrid = np.linspace(-pitch_dimen_y/2.0, pitch_dimen_y/2.0, n_grid_cells_y)
  pitch_ctrl_a_cy = np.zeros(shape=(n_grid_cells_y, n_grid_cells_x))
  pitch_ctrl_d_cy = np.zeros(shape=(n_grid_cells_y, n_grid_cells_x))
  for y in range(n_grid_cells_y):
    for x in range(n_grid_cells_x):
      target_pos = np.array([xgrid[x], ygrid[y]])
      pitch_ctrl_a_cy[y, x], pitch_ctrl_d_cy[y, x] = \
          calculate_pitch_control_at_target_cy(target_pos[0],
                                               target_pos[1],
                                               home_players,
                                               away_players,
                                               len(home_players),
                                               len(away_players),
                                               ball_start_pos_x,
                                               ball_start_pos_y)
  return pitch_ctrl_a_cy, pitch_ctrl_d_cy

Prepare some data to use for comparing the run time of both the Python and Cython versions.

In [None]:
def populate_tracking_dataframes():
  tracking_home = mio.tracking_data(DATA_DIR, MATCH_ID, 'Home')
  tracking_away = mio.tracking_data(DATA_DIR, MATCH_ID, 'Away')
  events = mio.read_event_data(DATA_DIR, MATCH_ID)

  tracking_home = mio.to_metric_coordinates(tracking_home)
  tracking_away = mio.to_metric_coordinates(tracking_away)
  events = mio.to_metric_coordinates(events)

  tracking_home = mvelo.calc_player_velocities(tracking_home, smoothing=True)
  tracking_away = mvelo.calc_player_velocities(tracking_away, smoothing=True)

  return mio.to_single_playing_direction(tracking_home, tracking_away, events)

DATA_DIR = '/content/sample-data/data'
MATCH_ID = 2
tracking_home, tracking_away, events = populate_tracking_dataframes()
shots = events[events['Type']=='SHOT']
goals = shots[shots['Subtype'].str.contains('-GOAL')].copy()
params = mpc.default_model_params(3)
field_dimen = (106.0, 68.0)
n_grid_cells_x = 50
n_grid_cells_y = int(n_grid_cells_x * field_dimen[1] / field_dimen[0])

# Event ID of the last goal
event_id = goals.index[-1]
# Use the start frame from 4 events prior to the goal
start_frame = events.loc[event_id - 4]['Start Frame']
end_frame = events.loc[event_id]['End Frame'] + 10  # buffer some frames at end

home_team = tracking_home.loc[start_frame:end_frame]
away_team = tracking_away.loc[start_frame:end_frame]
attacking_team = events.loc[event_id].Team

ball_start_pos = np.array([home_team.loc[start_frame]['ball_x'],
                           home_team.loc[start_frame]['ball_y']])
home_players = mpc.initialise_players(home_team.loc[start_frame],
                                      'Home',
                                      params)
away_players = mpc.initialise_players(away_team.loc[start_frame],
                                      'Away',
                                      params)

Reading team: home
Reading team: away


This code cell defines wrapper functions which can be used to call both the Python and Cython code, with matching inputs, using `%timeit`.

In [None]:
def time_cython():
  return calculate_pitch_control_cy(field_dimen[0], field_dimen[1],
                                    n_grid_cells_x, n_grid_cells_y,
                                    home_players, away_players,
                                    ball_start_pos[0], ball_start_pos[1])

def time_python():
  xgrid = np.linspace(-field_dimen[0]/2.0, field_dimen[0]/2.0, n_grid_cells_x)
  ygrid = np.linspace(-field_dimen[1]/2.0, field_dimen[1]/2.0, n_grid_cells_y)
  pitch_ctrl_a_py = np.zeros(shape=(len(ygrid), len(xgrid)))
  pitch_ctrl_d_py = np.zeros(shape=(len(ygrid), len(xgrid)))
  for y in range(len(ygrid)):
    for x in range(len(xgrid)):
      target_pos = np.array([xgrid[x], ygrid[y]])
      pitch_ctrl_a_py[y, x], pitch_ctrl_d_py[y, x] = \
        mpc.calculate_pitch_control_at_target(target_pos,
                                              home_players,
                                              away_players,
                                              ball_start_pos,
                                              params)
  return pitch_ctrl_a_py, pitch_ctrl_d_py



First check that the outputs from both Python and Cython are equal. All returned data is equal within a 0.0095 threshold, except some values near the extremes, which Python clamps to 1.0.

In [None]:
att_pitch_ctrl_cy, def_pitch_ctrl_cy = time_cython()
att_pitch_ctrl_py, def_pitch_ctrl_py = time_python()

# Iterate over all results to compare the amount of matching results.
num_matching = 0
total = len(att_pitch_ctrl_cy) * len(att_pitch_ctrl_cy[0])
for i in range(len(att_pitch_ctrl_cy)):
  att_pitch_ctrl_cy_i = att_pitch_ctrl_cy[i]
  att_pitch_ctrl_py_i = att_pitch_ctrl_py[i]
  for j in range(len(att_pitch_ctrl_cy_i)):
    if abs(att_pitch_ctrl_cy_i[j] - att_pitch_ctrl_py_i[j]) > 0.0095:
      print('cy:', att_pitch_ctrl_cy_i[j])
      print('py:', att_pitch_ctrl_py_i[j])
    else:
      num_matching += 1
print(f'{num_matching} matching numbers out of {total}')

cy: 0.9903509669588355
py: 1.0
cy: 0.990028735671969
py: 1.0
cy: 0.9903987442101314
py: 1.0
cy: 0.9904091336198543
py: 1.0
cy: 0.9904533379717876
py: 1.0
cy: 0.9903116434526638
py: 1.0
cy: 0.9903117025418253
py: 1.0
cy: 0.9903910194078184
py: 1.0
cy: 0.9901012757503339
py: 1.0
cy: 0.9900327822642294
py: 1.0
1590 matching numbers out of 1600


Finally, timing both approaches, the Cython code runs 50x faster than the Python code. 🔥🔥🔥

In [None]:
%timeit time_python()
%timeit time_cython()

1 loop, best of 3: 2.79 s per loop
10 loops, best of 3: 55.7 ms per loop
