In [84]:
from wn.data import prepare_matches
import pandas as pd

import torch
from torch.utils.data import Dataset

In [82]:
match_list = [f"../tennis_atp/atp_matches_{year}.csv" for year in range(1968, 2018)]
matches = prepare_matches(match_list)

players = pd.read_csv("../tennis_atp/atp_players.csv")

So the idea is to:

  1) Sample a random match
  2) Identify both players
  3) Retrieve their fixed tabular features
  4) Add the match time to the tabular features
  4) Retrieve the players' last *n* matches
  5) Predict for both players (separately)

In [83]:
# Add days elapsed from 1900
matches.tourney_date = pd.to_datetime(matches.tourney_date.astype("str"))
matches["days_elapsed_date"] = (matches.tourney_date - pd.to_datetime("19000101")).dt.days

# Removing missing birthday players for now
players.dob = pd.to_datetime(players.dob.astype("str"), errors="coerce")
players = players[~players.dob.isna()].reset_index(drop=True)

# Remove matches with players with unknown birthdays
matches = matches.loc[
    matches.winner_id.isin(players.player_id)
    & matches.loser_id.isin(players.player_id)
].reset_index(drop=True)

In [109]:
class MatchDataset(Dataset):

    def __init__(self, matches, players, t_interface: dict, s_interface: dict):

        super().__init__()

        self.matches = matches
        self.players = players

        # The idea here is to specify what data you want and in what form.
        self.t_interface = t_interface
        self.s_interface = s_interface

        # Make an index
        self.index = (
            [(row, row.winner_id) for row in self.matches.itertuples()] 
            + [(row, row.loser_id) for row in self.matches.itertuples()] 
        )

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):

        match, player_id = self.index[idx]
        match_date = match.days_elapsed_date

        # TKTK better ordering on matches.
        player_matches = self.matches.loc[
            (self.matches.winner_id.eq(player_id) | self.matches.loser_id.eq(player_id)) 
            & self.matches.days_elapsed_date.ge(match_date - 365) 
            & self.matches.days_elapsed_date.lt(match_date)
        ]
        player_fixed_features = self.players.loc[players.player_id.eq(winner_id)]

        return {
            "match_features": match,
            "match_history": player_matches,
            "fixed_features": player_fixed_features,
        }

In [110]:
ds = MatchDataset(matches, players, None, None)