In [1]:
import pandas as pd
import torch
import numpy as np
import model.loader as loader

from sklearn.metrics import average_precision_score
from sklearn.metrics import roc_auc_score

from model.module import THAN
from model.loader import MiniBatchSampler

In [2]:
MODEL_PREFIX = 'THAN-mem'
DATA = 'nba'
MODEL_SAVE_PATH = f'./saved_models/{MODEL_PREFIX}-{DATA}.pth'
NUM_NEIGHBORS = 10
N_DIM = 32
E_DIM = 16
T_DIM = 32
BATCH_SIZE = 200
NUM_LAYER = 1
NUM_HEADS = 4
DROPOUT = 0.1
GPU = -1
CLASSES = np.array([1, 2])

device = torch.device('cuda:{}'.format(GPU)) if GPU != -1 else 'cpu'

In [3]:
def evaluate_score(labels, prob):
    pred_score = np.array((prob).cpu().detach().numpy())

    auc = roc_auc_score(labels, pred_score)
    ap = average_precision_score(labels, pred_score)
    return ap, auc

def eval_one_epoch(model: THAN, batch_sampler, data):
    val_ap, val_auc = [], []
    with torch.no_grad():
        model = model.eval()
        batch_sampler.reset()
        r_prob, r_lbls, r_src, r_dst, r_ts = [], [], [], [], []
        r_lab = []
        while True:
            batches, counts, classes = batch_sampler.get_batch_index()
            if counts is None or counts.sum()==0:
                break
            tiles = len(batches)
            l = int(counts.sum() * tiles)

            src_l_cut = np.empty(l, dtype=int)
            dst_l_cut = np.empty(l, dtype=int)
            ts_l_cut = np.empty(l, dtype=int)
            src_utype_l_cut = np.empty(l, dtype=int)
            dst_utype_l_cut = np.empty(l, dtype=int)
            etype_l = np.empty(l, dtype=int)
            lbls = np.empty(l)
            s_idx = 0
            for i, batch in enumerate(batches):
                e_idx = s_idx + int(counts[i] * tiles)
                src_l_cut[s_idx: e_idx] = np.repeat(data.src_l[batch], tiles)
                dst_l_cut[s_idx: e_idx] = np.repeat(data.dst_l[batch], tiles)
                ts_l_cut[s_idx: e_idx] = np.repeat(data.ts_l[batch], tiles)
                src_utype_l_cut[s_idx: e_idx] = np.repeat(data.u_type_l[batch],
                                                        tiles)
                dst_utype_l_cut[s_idx: e_idx] = np.repeat(data.v_type_l[batch],
                                                        tiles)
                etype_slice = np.tile(classes, len(batch))
                etype_l[s_idx: e_idx] = etype_slice
                lbls[s_idx: e_idx] = (etype_slice == classes[i]).astype(np.float64)
                s_idx = e_idx

            prob = model.link_contrast(src_l_cut, dst_l_cut, ts_l_cut,
                                       src_utype_l_cut, dst_utype_l_cut,
                                       etype_l, lbls, NUM_NEIGHBORS)
            prob = prob.reshape((len(prob) // tiles, tiles))
            prob = prob / prob.sum(1, keepdim=True)
            prob = prob.reshape(len(prob) * tiles)
            r_prob = np.concatenate([r_prob, prob])
            r_lbls = np.concatenate([r_lbls, etype_l])
            r_src = np.concatenate([r_src, src_l_cut])
            r_dst = np.concatenate([r_dst, dst_l_cut])
            r_ts = np.concatenate([r_ts, ts_l_cut])
            r_lab = np.concatenate([r_lab, lbls])
            ap, auc = evaluate_score(lbls, prob)
            val_ap.append(ap)
            val_auc.append(auc)
    print(np.mean(val_auc), np.mean(val_ap))
    return r_prob, r_lbls, r_src, r_dst, r_ts, r_lab

In [4]:
g, _, _, _, val = loader.load_and_split_data_train_test_val(DATA, N_DIM, E_DIM)
val_ngh_finder = loader.get_neighbor_finder(g, g.max_idx, num_edge_type=g.num_e_type)
val_batch_sampler = MiniBatchSampler(val.e_type_l, BATCH_SIZE, 'test', CLASSES)

model = THAN(val_ngh_finder, g.n_feat, g.e_feat, g.e_type_feat, g.num_n_type, g.num_e_type, T_DIM, num_layers=NUM_LAYER, n_head=NUM_HEADS, dropout=DROPOUT, device=device)
model.load_state_dict(torch.load(MODEL_SAVE_PATH))

<All keys matched successfully>

In [5]:
prob, lbls, src_l, dst_l, ts_l, lab = eval_one_epoch(model, val_batch_sampler, val)
preddict = {'prob': prob, 'lbl': lbls, 'src_id': src_l, 'dst_id': dst_l, 'ts': ts_l, 'lab': lab}
preds = pd.DataFrame(preddict)
preds['ts'] = preds['ts'] - 1

0.5733506790123457 0.575352410577892


In [6]:
matches = pd.read_csv('./data/raw/nba/games.csv').assign(ts = lambda _d: (pd.to_datetime(_d['game_date']).astype(int) / 10**9))
team_dict = pd.read_csv('./data/processed/nba/teams_dict.csv')
team_dict = team_dict.rename(columns={'Unnamed: 0': 'team_id', '0': 'id'})

matches = pd.merge(matches, team_dict, how='left', left_on='team_id_home', right_on='team_id')
matches = pd.merge(matches, team_dict, how='left', left_on='team_id_away', right_on='team_id')

# matches = pd.read_csv('./data/raw/soccer/matches.csv').assign(ts = lambda _d: (pd.to_datetime(_d['date']).astype(int) / 10**9)).drop(columns=['id'])

# team_dict = pd.read_csv('./data/processed/soccer/teams_dict.csv')
# team_dict = team_dict.rename(columns={'Unnamed: 0': 'team_id', '0': 'id'})

# matches = pd.merge(matches, team_dict, how='left', left_on='home_team_api_id', right_on='team_id')
# matches = pd.merge(matches, team_dict, how='left', left_on='away_team_api_id', right_on='team_id')

In [7]:
# match_pred = pd.merge(matches, preds, left_on=['id_x', 'id_y', 'ts'], right_on=['src_id', 'dst_id', 'ts'])
# match_pred = match_pred.sort_values('prob', ascending=False).drop_duplicates('match_api_id')

# match_pred['label'] = match_pred['home_team_goal'] - match_pred['away_team_goal']
# match_pred.loc[match_pred['label'] > 0, 'label'] = 1
# match_pred.loc[match_pred['label'] == 0, 'label'] = 2
# match_pred.loc[match_pred['label'] < 0, 'label'] = 0

match_pred = pd.merge(matches, preds, left_on=['id_x', 'id_y', 'ts'], right_on=['src_id', 'dst_id', 'ts'])
match_pred = match_pred.sort_values('prob', ascending=False).drop_duplicates('game_id')
match_pred['label'] = match_pred['wl_home'].map({'W': 1, 'L': 0})

match_pred['correct'] = (match_pred['label'] == match_pred['lbl']).astype(int)

In [8]:
match_pred.groupby('label')['correct'].count()

label
0    377
1    495
Name: correct, dtype: int64

In [9]:
match_pred.groupby('lbl')['correct'].count()

lbl
0.0     22
1.0    850
Name: correct, dtype: int64

In [10]:
match_pred.groupby('label')['correct'].sum()

label
0     10
1    483
Name: correct, dtype: int64

In [11]:
match_pred['correct'].sum()

493

In [12]:
match_pred.loc[match_pred['correct'] == 1, 'lab'].sum()

493.0

In [13]:
match_pred

Unnamed: 0.1,Unnamed: 0,season_id,team_id_home,team_abbreviation_home,team_name_home,game_id,game_date,matchup_home,wl_home,min,...,id_x,team_id_y,id_y,prob,lbl,src_id,dst_id,lab,label,correct
1263,3367,22023,1610612740,NOP,New Orleans Pelicans,22300575,2024-01-17,NOP vs. CHA,W,240,...,9,1610612766,35,0.988011,1.0,9.0,35.0,1.0,1,1
769,3120,22023,1610612738,BOS,Boston Celtics,22300319,2023-12-14,BOS vs. CLE,W,240,...,7,1610612739,8,0.985195,1.0,7.0,8.0,1.0,1,1
519,2995,22023,1610612754,IND,Indiana Pacers,22300047,2023-11-24,IND vs. DET,W,240,...,23,1610612765,34,0.985117,1.0,23.0,34.0,1.0,1,1
283,2877,22023,1610612749,MIL,Milwaukee Bucks,22300165,2023-11-08,MIL vs. DET,W,240,...,18,1610612765,34,0.985075,1.0,18.0,34.0,1.0,1,1
469,2970,22023,1610612755,PHI,Philadelphia 76ers,22300040,2023-11-21,PHI vs. CLE,L,265,...,24,1610612739,8,0.981349,1.0,24.0,8.0,0.0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
72,2772,22023,1610612759,SAS,San Antonio Spurs,22300073,2023-10-25,SAS vs. DAL,L,240,...,28,1610612742,11,0.500000,0.0,28.0,11.0,1.0,0,1
1606,3540,12023,1610612765,DET,Detroit Pistons,12300005,2023-10-08,DET vs. PHX,L,265,...,34,1610612756,25,0.500000,0.0,34.0,25.0,1.0,0,1
1340,3406,22023,1610612765,DET,Detroit Pistons,22300606,2024-01-22,DET vs. MIL,L,240,...,34,1610612749,18,0.500000,0.0,34.0,18.0,1.0,0,1
540,3006,22023,1610612739,CLE,Cleveland Cavaliers,22300251,2023-11-26,CLE vs. TOR,W,240,...,8,1610612761,30,0.500000,0.0,8.0,30.0,0.0,1,0


In [14]:
match_pred.groupby('season_type')['correct'].agg(['count', 'sum'])

Unnamed: 0_level_0,count,sum
season_type,Unnamed: 1_level_1,Unnamed: 2_level_1
Playoffs,31,18
Pre Season,73,46
Regular Season,768,429


In [26]:
match_pred.loc[match_pred['label'] == 1]

Unnamed: 0.1,Unnamed: 0,season_id,team_id_home,team_abbreviation_home,team_name_home,game_id,game_date,matchup_home,wl_home,min,...,team_id_x,id_x,team_id_y,id_y,prob,lbl,src_id,dst_id,label,correct
182,2827,22023,1610612745,HOU,Houston Rockets,22300122,2023-11-01,HOU vs. CHA,W,240,...,1610612745,14,1610612766,35,0.925131,0.0,14.0,35.0,1,0
1258,3365,22023,1610612757,POR,Portland Trail Blazers,22300578,2024-01-17,POR vs. BKN,W,240,...,1610612757,26,1610612751,20,0.902718,0.0,26.0,20.0,1,0
597,3034,22023,1610612752,NYK,New York Knicks,22300268,2023-11-30,NYK vs. DET,W,240,...,1610612752,21,1610612765,34,0.891393,1.0,21.0,34.0,1,1
1735,3604,12023,1610612742,DAL,Dallas Mavericks,12300070,2023-10-20,DAL vs. DET,W,240,...,1610612742,11,1610612765,34,0.887924,1.0,11.0,34.0,1,1
197,2834,22023,1610612740,NOP,New Orleans Pelicans,22300129,2023-11-02,NOP vs. DET,W,240,...,1610612740,9,1610612765,34,0.886000,1.0,9.0,34.0,1,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1159,3315,22023,1610612741,CHI,Chicago Bulls,22300523,2024-01-10,CHI vs. HOU,W,265,...,1610612741,10,1610612745,14,0.337758,1.0,10.0,14.0,1,1
4,2738,42022,1610612755,PHI,Philadelphia 76ers,42200214,2023-05-07,PHI vs. BOS,W,265,...,1610612755,24,1610612738,7,0.316121,0.0,24.0,7.0,1,0
1553,3512,22023,1610612739,CLE,Cleveland Cavaliers,22300714,2024-02-05,CLE vs. SAC,W,240,...,1610612739,8,1610612758,27,0.304550,1.0,8.0,27.0,1,1
1184,3328,22023,1610612743,DEN,Denver Nuggets,22300539,2024-01-12,DEN vs. NOP,W,240,...,1610612743,12,1610612740,9,0.304403,0.0,12.0,9.0,1,0


In [20]:
match_pred['bin'] = pd.qcut(match_pred.sort_values('ts')['ts'], 30)
binned = match_pred.groupby('bin')['correct'].agg(['sum', 'count'])
binned['perf'] = binned['sum'] / binned['count']

  binned = match_pred.groupby('bin')['correct'].agg(['sum', 'count'])


In [21]:
binned

Unnamed: 0_level_0,sum,count,perf
bin,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
"(1411775999.999, 1413590400.0]",30,190,0.157895
"(1413590400.0, 1414800000.0]",31,162,0.191358
"(1414800000.0, 1416614400.0]",27,147,0.183673
"(1416614400.0, 1417824000.0]",32,162,0.197531
"(1417824000.0, 1419033600.0]",25,158,0.158228
"(1419033600.0, 1421452800.0]",48,185,0.259459
"(1421452800.0, 1422748800.0]",27,156,0.173077
"(1422748800.0, 1423958400.0]",31,166,0.186747
"(1423958400.0, 1425168000.0]",19,145,0.131034
"(1425168000.0, 1426896000.0]",36,196,0.183673
