In [1]:
from wn.data import prepare_matches, DataInterface
import pandas as pd
import numpy as np

import torch
from torch.utils.data import Dataset
from torch import nn

import pickle
import os

So the idea is to:

  1) Sample a random match
  2) Identify both players
  3) Retrieve their fixed tabular features
  4) Add the match time to the tabular features
  4) Retrieve the players' last *n* matches
  5) Predict for both players (separately)

In [2]:
if os.path.exists("processed_data.pkl"):

    with open("processed_data.pkl", "rb") as f:
        players, matches = pickle.load(f)

else:

    match_list = [f"../tennis_atp/atp_matches_{year}.csv" for year in range(1968, 2018)]
    matches = prepare_matches(match_list)

    players = pd.read_csv("../tennis_atp/atp_players.csv")
    
    # Add days elapsed from 1900
    matches.tourney_date = pd.to_datetime(matches.tourney_date.astype("str"))
    matches["days_elapsed_date"] = (matches.tourney_date - pd.to_datetime("19000101")).dt.days

    # Removing missing birthday players for now
    players.dob = pd.to_datetime(players.dob.astype("str"), errors="coerce")
    players = players[~players.dob.isna()].reset_index(drop=True)
    players["days_elapsed_dob"] = (players.dob - pd.to_datetime("19000101")).dt.days
    players["last_match_date"] = [
        matches[matches.winner_id.eq(r.player_id) | matches.loser_id.eq(r.player_id)].days_elapsed_date.max()
        for r in players.itertuples()
    ]

    # Remove matches with players with unknown birthdays
    matches = matches.loc[
        matches.winner_id.isin(players.player_id)
        & matches.loser_id.isin(players.player_id)
    ].reset_index(drop=True)

    with open("processed_data.pkl", "wb") as f:
        pickle.dump((players, matches), f)

In [3]:
desired_cols = [
    "winner_rank",
    "winner_hand",
    "loser_rank",
    "loser_hand",
    "surface",
    "tourney_level",
    "days_elapsed_date",
]

augmented_matches = (
    matches
    .merge(players, "inner", left_on="winner_id", right_on="player_id")
    .loc[:, desired_cols + ["days_elapsed_dob", "loser_id"]]
    .rename({"days_elapsed_dob": "winner_dob"}, axis=1)
    .merge(players, "inner", left_on="loser_id", right_on="player_id")
    .loc[:, desired_cols + ["winner_dob", "days_elapsed_dob"]]
    .rename({"days_elapsed_dob": "loser_dob"}, axis=1)
)

In [4]:
winner_matches = augmented_matches[[
    "winner_rank",
    "winner_hand",
    "winner_dob",
    "loser_rank",
    "loser_hand",
    "loser_dob",
    "surface",
    "tourney_level",
    "days_elapsed_date",
]].fillna(-1).assign(won=1).rename({
    "winner_rank": "p1_rank",
    "winner_hand": "p1_hand",
    "winner_dob": "p1_dob",
    "loser_rank": "p2_rank",
    "loser_hand": "p2_hand",
    "loser_dob": "p2_dob",
}, axis=1)

loser_matches = augmented_matches[[
    "loser_rank",
    "loser_hand",
    "loser_dob",
    "winner_rank",
    "winner_hand",
    "winner_dob",
    "surface",
    "tourney_level",
    "days_elapsed_date",
]].fillna(-1).assign(won=0).rename({
    "loser_rank": "p1_rank",
    "loser_hand": "p1_hand",
    "loser_dob": "p1_dob",
    "winner_rank": "p2_rank",
    "winner_hand": "p2_hand",
    "winner_dob": "p2_dob",
}, axis=1)

condensed_matches = pd.concat([winner_matches, loser_matches])

In [5]:
match_interface = DataInterface({
    "p1_rank": "numeric",
    "p1_hand": "categorical",
    "p1_dob": "time",
    "p2_rank": "numeric",
    "p2_hand": "categorical",
    "p2_dob": "time",
    "surface": "categorical",
    "tourney_level": "categorical",
    "days_elapsed_date": "time",
})

match_interface.complete(condensed_matches)

In [18]:
def tr(dt, col_name, interface):

    print(f"Encoding {col_name}")

    if interface[col_name][0] == "numeric":

        q = interface[col_name][2]

        dt = torch.tensor(dt.to_numpy(), dtype=torch.float)
        # How many positions should be =1?
        filled = torch.searchsorted(q, dt)

        encoded = torch.zeros([filled.shape[0], q.shape[0]])

        # There must be a better way to do this.
        for i, n in enumerate(filled):
            encoded[i, 0:n] = 1
            if n > 0:
                encoded[i, n] = (dt[i] - q[n-1]) / (q[n] - q[n-1])

        return encoded

    elif interface[col_name][0] == "time":
        return torch.tensor(dt.to_numpy(), dtype=torch.float).unsqueeze(-1)

    else:
        return torch.tensor(
            [interface[col_name][1][x] for x in dt.to_numpy()],
            dtype=torch.int
        ).unsqueeze(-1)

In [19]:
# Now rather expensive to encode, needs to be saved
input_data = {
    k: tr(condensed_matches[k], k, match_interface) 
    for k in match_interface.type_map
}

Encoding p1_rank
Encoding p1_hand
Encoding p1_dob
Encoding p2_rank
Encoding p2_hand
Encoding p2_dob
Encoding surface
Encoding tourney_level
Encoding days_elapsed_date


In [22]:
with open("tensor_list.pkl", "wb") as f:
        pickle.dump(input_data, f)

### Below here is test code

In [None]:
testo = {k: v[0:16, :] for k, v in input_data.items()}
{k: (v.shape, v.dtype) for k, v in testo.items()}

{'p1_rank': (torch.Size([16, 16]), torch.float32),
 'p1_hand': (torch.Size([16, 1]), torch.int32),
 'p1_dob': (torch.Size([16, 1]), torch.float32),
 'p2_rank': (torch.Size([16, 16]), torch.float32),
 'p2_hand': (torch.Size([16, 1]), torch.int32),
 'p2_dob': (torch.Size([16, 1]), torch.float32),
 'surface': (torch.Size([16, 1]), torch.int32),
 'tourney_level': (torch.Size([16, 1]), torch.int32),
 'days_elapsed_date': (torch.Size([16, 1]), torch.float32)}

In [29]:
from wn import net
from importlib import reload

In [63]:
reload(net);

In [64]:
l = net.TabularInputLayer(match_interface, 8, 24)

In [66]:
l(testo)

tensor([[[-7.0640e-02, -2.2788e-01,  2.1345e-01,  ..., -3.6157e-01,
          -3.2931e-02,  1.6494e-01],
         [-7.4056e-01, -2.1442e+00,  1.3673e+00,  ..., -1.4030e+00,
          -2.7392e-01,  5.5361e-01],
         [-2.1954e+03, -9.2082e-01,  3.3879e-01,  ..., -2.1911e+00,
           4.4208e-01, -1.1874e+00],
         ...,
         [ 1.2902e-01,  8.4794e-01, -1.9731e+00,  ...,  1.1609e+00,
           6.3384e-01, -4.0906e-01],
         [ 1.0622e+00,  5.0879e-01, -1.6926e-01,  ...,  3.6009e-02,
           9.1177e-01,  2.5332e+00],
         [ 1.0707e+04, -6.5655e-01, -2.0995e-01,  ..., -8.0056e-01,
           1.1965e+00,  1.2512e+00]],

        [[-7.0640e-02, -2.2788e-01,  2.1345e-01,  ..., -3.6157e-01,
          -3.2931e-02,  1.6494e-01],
         [-1.6983e+00,  6.3188e-01, -1.9206e+00,  ..., -1.4030e+00,
          -2.7392e-01,  5.5361e-01],
         [-2.6823e+03,  8.9038e-01,  9.7681e-01,  ..., -2.1911e+00,
           4.4208e-01, -1.1874e+00],
         ...,
         [ 1.2902e-01,  8