In [1]:
import pandas as pd
import torch
import numpy as np
import model.loader as loader

from model.module import THAN
from model.loader import MiniBatchSampler

In [2]:
MODEL_PREFIX = 'THAN-mem'
DATA = 'nba'
MODEL_SAVE_PATH = f'./saved_models/{MODEL_PREFIX}-{DATA}.pth'
NUM_NEIGHBORS = 10
N_DIM = 32
E_DIM = 16
T_DIM = 32
BATCH_SIZE = 200
NUM_LAYER = 1
NUM_HEADS = 4
DROPOUT = 0.1
GPU = -1
CLASSES = np.array([1, 2])

device = torch.device('cuda:{}'.format(GPU)) if GPU != -1 else 'cpu'

In [3]:
def eval_one_epoch(model: THAN, batch_sampler, data):
    with torch.no_grad():
        model = model.eval()
        batch_sampler.reset()
        r_prob, r_lbls, r_src, r_dst, r_ts = [], [], [], [], []
        while True:
            batches, counts, classes = batch_sampler.get_batch_index()
            if counts is None or counts.sum()==0:
                break
            tiles = len(batches)
            l = int(counts.sum() * tiles)

            src_l_cut = np.empty(l, dtype=int)
            dst_l_cut = np.empty(l, dtype=int)
            ts_l_cut = np.empty(l, dtype=int)
            src_utype_l_cut = np.empty(l, dtype=int)
            dst_utype_l_cut = np.empty(l, dtype=int)
            etype_l = np.empty(l, dtype=int)
            lbls = np.empty(l)
            s_idx = 0
            for i, batch in enumerate(batches):
                e_idx = s_idx + int(counts[i] * tiles)
                src_l_cut[s_idx: e_idx] = np.tile(data.src_l[batch], tiles)
                dst_l_cut[s_idx: e_idx] = np.tile(data.dst_l[batch], tiles)
                ts_l_cut[s_idx: e_idx] = np.tile(data.ts_l[batch], tiles)
                src_utype_l_cut[s_idx: e_idx] = np.tile(data.u_type_l[batch],
                                                        tiles)
                dst_utype_l_cut[s_idx: e_idx] = np.tile(data.v_type_l[batch],
                                                        tiles)
                etype_slice = np.repeat(classes, len(batch))
                etype_l[s_idx: e_idx] = etype_slice
                lbls[s_idx: e_idx] = (etype_slice == classes[i]).astype(np.float64)
                s_idx = e_idx

            prob = model.link_contrast(src_l_cut, dst_l_cut, ts_l_cut,
                                       src_utype_l_cut, dst_utype_l_cut,
                                       etype_l, lbls, NUM_NEIGHBORS)
            r_prob = np.concatenate([r_prob, prob])
            r_lbls = np.concatenate([r_lbls, etype_l])
            r_src = np.concatenate([r_src, src_l_cut])
            r_dst = np.concatenate([r_dst, dst_l_cut])
            r_ts = np.concatenate([r_ts, ts_l_cut])

    return r_prob, r_lbls, r_src, r_dst, r_ts

In [4]:
g, _, _, _, val = loader.load_and_split_data_train_test_val(DATA, N_DIM, E_DIM)
val_ngh_finder = loader.get_neighbor_finder(g, g.max_idx, num_edge_type=g.num_e_type)
val_batch_sampler = MiniBatchSampler(val.e_type_l, BATCH_SIZE, 'test', CLASSES)

In [5]:
model = THAN(val_ngh_finder, g.n_feat, g.e_feat, g.e_type_feat, g.num_n_type, g.num_e_type, T_DIM, num_layers=NUM_LAYER, n_head=NUM_HEADS, dropout=DROPOUT, device=device)
model.load_state_dict(torch.load(MODEL_SAVE_PATH))

<All keys matched successfully>

In [6]:
prob, lbls, src_l, dst_l, ts_l = eval_one_epoch(model, val_batch_sampler, val)

test batch 6/5	

In [7]:
preddict = {'prob': prob, 'lbl': lbls, 'src_id': src_l, 'dst_id': dst_l, 'ts': ts_l}

In [8]:
preds = pd.DataFrame(preddict)

In [9]:
matches = pd.read_csv('./data/raw/nba/games.csv').assign(ts = lambda _d: (pd.to_datetime(_d['game_date']).astype(int) / 10**9))

In [10]:
team_dict = pd.read_csv('./data/processed/nba/teams_dict.csv')

In [11]:
team_dict = team_dict.rename(columns={'Unnamed: 0': 'team_id', '0': 'id'})

In [12]:
matches = pd.merge(matches, team_dict, how='left', left_on='team_id_home', right_on='team_id')

In [17]:
matches = pd.merge(matches, team_dict, how='left', left_on='team_id_away', right_on='team_id')

In [14]:
preds

Unnamed: 0,prob,lbl,src_id,dst_id,ts
0,0.328476,0.0,7.0,24.0,1.683590e+09
1,0.331991,0.0,24.0,7.0,1.683763e+09
2,0.406535,0.0,25.0,12.0,1.683763e+09
3,0.314098,0.0,7.0,17.0,1.684282e+09
4,0.351626,0.0,7.0,17.0,1.684454e+09
...,...,...,...,...,...
1739,0.432262,1.0,25.0,18.0,1.707178e+09
1740,0.483009,1.0,17.0,28.0,1.707264e+09
1741,0.577512,1.0,7.0,6.0,1.707264e+09
1742,0.544715,1.0,25.0,31.0,1.707350e+09


In [18]:
matches

Unnamed: 0.1,Unnamed: 0,season_id,team_id_home,team_abbreviation_home,team_name_home,game_id,game_date,matchup_home,wl_home,min,...,pf_away,pts_away,plus_minus_away,video_available_away,season_type,ts,team_id_x,id_x,team_id_y,id_y
0,0,22021,1610612749,MIL,Milwaukee Bucks,22100001,2021-10-19,MIL vs. BKN,W,240,...,17,104,-23,1,Regular Season,1.634602e+09,1610612749,18,1610612751,20
1,1,22021,1610612747,LAL,Los Angeles Lakers,22100002,2021-10-19,LAL vs. GSW,L,240,...,18,121,7,1,Regular Season,1.634602e+09,1610612747,16,1610612744,13
2,2,22021,1610612740,NOP,New Orleans Pelicans,22100009,2021-10-20,NOP vs. PHI,L,240,...,11,117,20,1,Regular Season,1.634688e+09,1610612740,9,1610612755,24
3,3,22021,1610612762,UTA,Utah Jazz,22100011,2021-10-20,UTA vs. OKC,W,240,...,15,86,-21,1,Regular Season,1.634688e+09,1610612762,31,1610612760,29
4,4,22021,1610612750,MIN,Minnesota Timberwolves,22100008,2021-10-20,MIN vs. HOU,W,240,...,25,106,-18,1,Regular Season,1.634688e+09,1610612750,19,1610612745,14
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3604,3604,12023,1610612742,DAL,Dallas Mavericks,12300070,2023-10-20,DAL vs. DET,W,240,...,20,104,-10,1,Pre Season,1.697760e+09,1610612742,11,1610612765,34
3605,3605,12023,1610612761,TOR,Toronto Raptors,12300069,2023-10-20,TOR vs. WAS,W,240,...,20,98,-36,1,Pre Season,1.697760e+09,1610612761,30,1610612764,33
3606,3606,12023,1610612745,HOU,Houston Rockets,12300071,2023-10-20,HOU vs. MIA,W,240,...,27,104,-6,1,Pre Season,1.697760e+09,1610612745,14,1610612748,17
3607,3607,12023,1610612744,GSW,Golden State Warriors,12300073,2023-10-20,GSW vs. SAS,L,240,...,15,122,5,1,Pre Season,1.697760e+09,1610612744,13,1610612759,28


In [20]:
match_pred = pd.merge(matches, preds, left_on=['id_x', 'id_y', 'ts'], right_on=['src_id', 'dst_id', 'ts'])

In [21]:
match_pred

Unnamed: 0.1,Unnamed: 0,season_id,team_id_home,team_abbreviation_home,team_name_home,game_id,game_date,matchup_home,wl_home,min,...,season_type,ts,team_id_x,id_x,team_id_y,id_y,prob,lbl,src_id,dst_id
0,2736,42022,1610612747,LAL,Los Angeles Lakers,42200233,2023-05-06,LAL vs. GSW,W,240,...,Playoffs,1.683331e+09,1610612747,16,1610612744,13,0.439145,0.0,16.0,13.0
1,2736,42022,1610612747,LAL,Los Angeles Lakers,42200233,2023-05-06,LAL vs. GSW,W,240,...,Playoffs,1.683331e+09,1610612747,16,1610612744,13,0.671465,1.0,16.0,13.0
2,2737,42022,1610612748,MIA,Miami Heat,42200203,2023-05-06,MIA vs. NYK,W,240,...,Playoffs,1.683331e+09,1610612748,17,1610612752,21,0.535768,0.0,17.0,21.0
3,2737,42022,1610612748,MIA,Miami Heat,42200203,2023-05-06,MIA vs. NYK,W,240,...,Playoffs,1.683331e+09,1610612748,17,1610612752,21,0.427525,1.0,17.0,21.0
4,2738,42022,1610612755,PHI,Philadelphia 76ers,42200214,2023-05-07,PHI vs. BOS,W,265,...,Playoffs,1.683418e+09,1610612755,24,1610612738,7,0.458000,0.0,24.0,7.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1739,3606,12023,1610612745,HOU,Houston Rockets,12300071,2023-10-20,HOU vs. MIA,W,240,...,Pre Season,1.697760e+09,1610612745,14,1610612748,17,0.384517,1.0,14.0,17.0
1740,3607,12023,1610612744,GSW,Golden State Warriors,12300073,2023-10-20,GSW vs. SAS,L,240,...,Pre Season,1.697760e+09,1610612744,13,1610612759,28,0.312178,0.0,13.0,28.0
1741,3607,12023,1610612744,GSW,Golden State Warriors,12300073,2023-10-20,GSW vs. SAS,L,240,...,Pre Season,1.697760e+09,1610612744,13,1610612759,28,0.541512,1.0,13.0,28.0
1742,3608,12023,1610612749,MIL,Milwaukee Bucks,12300072,2023-10-20,MIL vs. MEM,W,240,...,Pre Season,1.697760e+09,1610612749,18,1610612763,32,0.572879,0.0,18.0,32.0


In [28]:
match_pred = match_pred.sort_values('prob', ascending=False).drop_duplicates('game_id')

In [29]:
match_pred.loc[match_pred['lbl'] == 1]

Unnamed: 0.1,Unnamed: 0,season_id,team_id_home,team_abbreviation_home,team_name_home,game_id,game_date,matchup_home,wl_home,min,...,season_type,ts,team_id_x,id_x,team_id_y,id_y,prob,lbl,src_id,dst_id
1419,3445,22023,1610612744,GSW,Golden State Warriors,22300650,2024-01-27,GSW vs. LAL,L,290,...,Regular Season,1.706314e+09,1610612744,13,1610612747,16,0.785682,1.0,13.0,16.0
1039,3255,22023,1610612744,GSW,Golden State Warriors,22300463,2024-01-02,GSW vs. ORL,W,240,...,Regular Season,1.704154e+09,1610612744,13,1610612753,22,0.764928,1.0,13.0,22.0
985,3228,22023,1610612756,PHX,Phoenix Suns,22300436,2023-12-29,PHX vs. CHA,W,240,...,Regular Season,1.703808e+09,1610612756,25,1610612766,35,0.760085,1.0,25.0,35.0
325,2898,22023,1610612744,GSW,Golden State Warriors,22300176,2023-11-11,GSW vs. CLE,L,240,...,Regular Season,1.699661e+09,1610612744,13,1610612739,8,0.755800,1.0,13.0,8.0
455,2963,22023,1610612764,WAS,Washington Wizards,22300219,2023-11-20,WAS vs. MIL,L,240,...,Regular Season,1.700438e+09,1610612764,33,1610612749,18,0.754467,1.0,33.0,18.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
589,3030,22023,1610612753,ORL,Orlando Magic,22300259,2023-11-29,ORL vs. WAS,W,240,...,Regular Season,1.701216e+09,1610612753,22,1610612764,33,0.398102,1.0,22.0,33.0
1031,3251,22023,1610612756,PHX,Phoenix Suns,22300456,2024-01-01,PHX vs. POR,W,240,...,Regular Season,1.704067e+09,1610612756,25,1610612757,26,0.387332,1.0,25.0,26.0
33,2752,42022,1610612738,BOS,Boston Celtics,42200301,2023-05-17,BOS vs. MIA,L,240,...,Playoffs,1.684282e+09,1610612738,7,1610612748,17,0.380673,1.0,7.0,17.0
687,3079,22023,1610612763,MEM,Memphis Grizzlies,22301221,2023-12-08,MEM vs. MIN,L,240,...,Regular Season,1.701994e+09,1610612763,32,1610612750,19,0.377093,1.0,32.0,19.0


In [30]:
match_pred.loc[match_pred['lbl'] == 0]

Unnamed: 0.1,Unnamed: 0,season_id,team_id_home,team_abbreviation_home,team_name_home,game_id,game_date,matchup_home,wl_home,min,...,season_type,ts,team_id_x,id_x,team_id_y,id_y,prob,lbl,src_id,dst_id
504,2988,22023,1610612753,ORL,Orlando Magic,22300043,2023-11-24,ORL vs. BOS,W,240,...,Regular Season,1.700784e+09,1610612753,22,1610612738,7,0.613858,0.0,22.0,7.0
1246,3359,22023,1610612761,TOR,Toronto Raptors,22300563,2024-01-15,TOR vs. BOS,L,240,...,Regular Season,1.705277e+09,1610612761,30,1610612738,7,0.610173,0.0,30.0,7.0
288,2880,22023,1610612755,PHI,Philadelphia 76ers,22300159,2023-11-08,PHI vs. BOS,W,240,...,Regular Season,1.699402e+09,1610612755,24,1610612738,7,0.603082,0.0,24.0,7.0
144,2808,22023,1610612764,WAS,Washington Wizards,22300103,2023-10-30,WAS vs. BOS,L,240,...,Regular Season,1.698624e+09,1610612764,33,1610612738,7,0.599015,0.0,33.0,7.0
1618,3546,12023,1610612752,NYK,New York Knicks,12300010,2023-10-09,NYK vs. BOS,W,240,...,Pre Season,1.696810e+09,1610612752,21,1610612738,7,0.589947,0.0,21.0,7.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1032,3252,22023,1610612746,LAC,LA Clippers,22300458,2024-01-01,LAC vs. MIA,W,240,...,Regular Season,1.704067e+09,1610612746,15,1610612748,17,0.429697,0.0,15.0,17.0
1202,3337,22023,1610612737,ATL,Atlanta Hawks,22300533,2024-01-12,ATL vs. IND,L,240,...,Regular Season,1.705018e+09,1610612737,6,1610612754,23,0.427052,0.0,6.0,23.0
1480,3476,22023,1610612739,CLE,Cleveland Cavaliers,22300675,2024-01-31,CLE vs. DET,W,240,...,Regular Season,1.706659e+09,1610612739,8,1610612765,34,0.420794,0.0,8.0,34.0
1554,3513,22023,1610612751,BKN,Brooklyn Nets,22300720,2024-02-06,BKN vs. DAL,L,240,...,Regular Season,1.707178e+09,1610612751,20,1610612742,11,0.416668,0.0,20.0,11.0


In [31]:
match_pred['label'] = match_pred['wl_home'].map({'W': 1, 'L': 0})

In [34]:
match_pred['correct'] = (match_pred['label'] == match_pred['lbl']).astype(int)

In [36]:
match_pred.loc[match_pred['correct'] == 1]

Unnamed: 0.1,Unnamed: 0,season_id,team_id_home,team_abbreviation_home,team_name_home,game_id,game_date,matchup_home,wl_home,min,...,team_id_x,id_x,team_id_y,id_y,prob,lbl,src_id,dst_id,label,correct
1039,3255,22023,1610612744,GSW,Golden State Warriors,22300463,2024-01-02,GSW vs. ORL,W,240,...,1610612744,13,1610612753,22,0.764928,1.0,13.0,22.0,1,1
985,3228,22023,1610612756,PHX,Phoenix Suns,22300436,2023-12-29,PHX vs. CHA,W,240,...,1610612756,25,1610612766,35,0.760085,1.0,25.0,35.0,1,1
1695,3584,12023,1610612739,CLE,Cleveland Cavaliers,12300045,2023-10-16,CLE vs. MRA,W,240,...,1610612739,8,50009,5,0.751956,1.0,8.0,5.0,1,1
1329,3400,22023,1610612755,PHI,Philadelphia 76ers,22300608,2024-01-22,PHI vs. SAS,W,240,...,1610612755,24,1610612759,28,0.745686,1.0,24.0,28.0,1,1
83,2777,22023,1610612748,MIA,Miami Heat,22300068,2023-10-25,MIA vs. DET,W,240,...,1610612748,17,1610612765,34,0.743666,1.0,17.0,34.0,1,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
257,2864,22023,1610612748,MIA,Miami Heat,22300150,2023-11-06,MIA vs. LAL,W,240,...,1610612748,17,1610612747,16,0.403593,1.0,17.0,16.0,1,1
927,3199,22023,1610612748,MIA,Miami Heat,22300404,2023-12-25,MIA vs. PHI,W,240,...,1610612748,17,1610612755,24,0.400815,1.0,17.0,24.0,1,1
425,2948,22023,1610612741,CHI,Chicago Bulls,22300204,2023-11-18,CHI vs. MIA,W,240,...,1610612741,10,1610612748,17,0.400381,1.0,10.0,17.0,1,1
589,3030,22023,1610612753,ORL,Orlando Magic,22300259,2023-11-29,ORL vs. WAS,W,240,...,1610612753,22,1610612764,33,0.398102,1.0,22.0,33.0,1,1


In [38]:
match_pred.groupby('season_type')['correct'].agg(['count', 'sum'])

Unnamed: 0_level_0,count,sum
season_type,Unnamed: 1_level_1,Unnamed: 2_level_1
Playoffs,31,14
Pre Season,73,44
Regular Season,768,432


In [41]:
match_pred.loc[match_pred['label'] == 1]

Unnamed: 0.1,Unnamed: 0,season_id,team_id_home,team_abbreviation_home,team_name_home,game_id,game_date,matchup_home,wl_home,min,...,team_id_x,id_x,team_id_y,id_y,prob,lbl,src_id,dst_id,label,correct
1039,3255,22023,1610612744,GSW,Golden State Warriors,22300463,2024-01-02,GSW vs. ORL,W,240,...,1610612744,13,1610612753,22,0.764928,1.0,13.0,22.0,1,1
985,3228,22023,1610612756,PHX,Phoenix Suns,22300436,2023-12-29,PHX vs. CHA,W,240,...,1610612756,25,1610612766,35,0.760085,1.0,25.0,35.0,1,1
1695,3584,12023,1610612739,CLE,Cleveland Cavaliers,12300045,2023-10-16,CLE vs. MRA,W,240,...,1610612739,8,50009,5,0.751956,1.0,8.0,5.0,1,1
1329,3400,22023,1610612755,PHI,Philadelphia 76ers,22300608,2024-01-22,PHI vs. SAS,W,240,...,1610612755,24,1610612759,28,0.745686,1.0,24.0,28.0,1,1
83,2777,22023,1610612748,MIA,Miami Heat,22300068,2023-10-25,MIA vs. DET,W,240,...,1610612748,17,1610612765,34,0.743666,1.0,17.0,34.0,1,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
257,2864,22023,1610612748,MIA,Miami Heat,22300150,2023-11-06,MIA vs. LAL,W,240,...,1610612748,17,1610612747,16,0.403593,1.0,17.0,16.0,1,1
927,3199,22023,1610612748,MIA,Miami Heat,22300404,2023-12-25,MIA vs. PHI,W,240,...,1610612748,17,1610612755,24,0.400815,1.0,17.0,24.0,1,1
425,2948,22023,1610612741,CHI,Chicago Bulls,22300204,2023-11-18,CHI vs. MIA,W,240,...,1610612741,10,1610612748,17,0.400381,1.0,10.0,17.0,1,1
589,3030,22023,1610612753,ORL,Orlando Magic,22300259,2023-11-29,ORL vs. WAS,W,240,...,1610612753,22,1610612764,33,0.398102,1.0,22.0,33.0,1,1
