In [1]:
import pandas as pd
import torch
import numpy as np
import model.loader as loader

from model.module import THAN
from model.loader import MiniBatchSampler

In [2]:
MODEL_PREFIX = 'THAN-mem'
DATA = 'nba'
MODEL_SAVE_PATH = f'./saved_models/{MODEL_PREFIX}-{DATA}.pth'
NUM_NEIGHBORS = 10
N_DIM = 32
E_DIM = 16
T_DIM = 32
BATCH_SIZE = 200
NUM_LAYER = 1
NUM_HEADS = 4
DROPOUT = 0.1
GPU = -1
CLASSES = np.array([1, 2])

device = torch.device('cuda:{}'.format(GPU)) if GPU != -1 else 'cpu'

In [3]:
def eval_one_epoch(model: THAN, batch_sampler, data):
    with torch.no_grad():
        model = model.eval()
        batch_sampler.reset()
        r_prob, r_src, r_dst = [], [], []
        while True:
            batches, counts, classes = batch_sampler.get_batch_index()
            if counts is None or counts.sum()==0:
                break
            tiles = len(batches)
            l = int(counts.sum() * tiles)

            src_l_cut = np.empty(l, dtype=int)
            dst_l_cut = np.empty(l, dtype=int)
            ts_l_cut = np.empty(l, dtype=int)
            src_utype_l_cut = np.empty(l, dtype=int)
            dst_utype_l_cut = np.empty(l, dtype=int)
            etype_l = np.empty(l, dtype=int)
            lbls = np.empty(l)
            s_idx = 0
            for i, batch in enumerate(batches):
                e_idx = s_idx + int(counts[i] * tiles)
                src_l_cut[s_idx: e_idx] = np.tile(data.src_l[batch], tiles)
                dst_l_cut[s_idx: e_idx] = np.tile(data.dst_l[batch], tiles)
                ts_l_cut[s_idx: e_idx] = np.tile(data.ts_l[batch], tiles)
                src_utype_l_cut[s_idx: e_idx] = np.tile(data.u_type_l[batch],
                                                        tiles)
                dst_utype_l_cut[s_idx: e_idx] = np.tile(data.v_type_l[batch],
                                                        tiles)
                etype_slice = np.repeat(classes, len(batch))
                etype_l[s_idx: e_idx] = etype_slice
                lbls[s_idx: e_idx] = (etype_slice == classes[i]).astype(np.float64)
                s_idx = e_idx

            prob = model.link_contrast(src_l_cut, dst_l_cut, ts_l_cut,
                                       src_utype_l_cut, dst_utype_l_cut,
                                       etype_l, lbls, NUM_NEIGHBORS)
            r_prob.append(prob)
            r_src.append(src_l_cut)
            r_dst.append(dst_l_cut)

    return r_prob, r_src, r_dst

In [4]:
g, _, _, _, val = loader.load_and_split_data_train_test_val(DATA, N_DIM, E_DIM)
val_ngh_finder = loader.get_neighbor_finder(g, g.max_idx, num_edge_type=g.num_e_type)
val_batch_sampler = MiniBatchSampler(val.e_type_l, BATCH_SIZE, 'test', CLASSES)

In [5]:
model = THAN(val_ngh_finder, g.n_feat, g.e_feat, g.e_type_feat, g.num_n_type, g.num_e_type, T_DIM, num_layers=NUM_LAYER, n_head=NUM_HEADS, dropout=DROPOUT, device=device)
model.load_state_dict(torch.load(MODEL_SAVE_PATH))

<All keys matched successfully>

In [6]:
prob, src_l, dst_l = eval_one_epoch(model, val_batch_sampler, val)

test batch 6/5	

In [8]:
prob

[tensor([0.4545, 0.4512, 0.4416, 0.4393, 0.4471, 0.4546, 0.4571, 0.4613, 0.4437,
         0.4529, 0.4542, 0.4542, 0.4520, 0.4539, 0.4467, 0.4335, 0.4451, 0.4575,
         0.4441, 0.4372, 0.4508, 0.4471, 0.4528, 0.4388, 0.4426, 0.4601, 0.4480,
         0.4526, 0.4318, 0.4518, 0.4560, 0.4501, 0.4432, 0.4368, 0.4355, 0.4299,
         0.4541, 0.4413, 0.4478, 0.4367, 0.4594, 0.4431, 0.4263, 0.4544, 0.4409,
         0.4487, 0.4374, 0.4575, 0.4318, 0.4510, 0.4597, 0.4554, 0.4469, 0.4457,
         0.4279, 0.4479, 0.4469, 0.4360, 0.4261, 0.4634, 0.4451, 0.4528, 0.4353,
         0.4331, 0.4461, 0.4427, 0.4366, 0.4397, 0.4556, 0.4495, 0.4469, 0.4422,
         0.4353, 0.4307, 0.4453, 0.4344, 0.4376, 0.4421, 0.4383, 0.4505, 0.5473,
         0.5364, 0.5495, 0.5492, 0.5524, 0.5406, 0.5440, 0.5561, 0.5580, 0.5525,
         0.5526, 0.5563, 0.5495, 0.5488, 0.5482, 0.5586, 0.5536, 0.5616, 0.5570,
         0.5503, 0.5422, 0.5618, 0.5418, 0.5570, 0.5547, 0.5627, 0.5446, 0.5497,
         0.5525, 0.5422, 0.5

In [3]:
matches = pd.read_csv('./data/raw/nba/games.csv')

In [4]:
matches

Unnamed: 0.1,Unnamed: 0,season_id,team_id_home,team_abbreviation_home,team_name_home,game_id,game_date,matchup_home,wl_home,min,...,reb_away,ast_away,stl_away,blk_away,tov_away,pf_away,pts_away,plus_minus_away,video_available_away,season_type
0,0,22021,1610612749,MIL,Milwaukee Bucks,22100001,2021-10-19,MIL vs. BKN,W,240,...,44,19,3,9,13,17,104,-23,1,Regular Season
1,1,22021,1610612747,LAL,Los Angeles Lakers,22100002,2021-10-19,LAL vs. GSW,L,240,...,50,30,9,2,17,18,121,7,1,Regular Season
2,2,22021,1610612740,NOP,New Orleans Pelicans,22100009,2021-10-20,NOP vs. PHI,L,240,...,47,24,9,5,13,11,117,20,1,Regular Season
3,3,22021,1610612762,UTA,Utah Jazz,22100011,2021-10-20,UTA vs. OKC,W,240,...,50,19,4,2,15,15,86,-21,1,Regular Season
4,4,22021,1610612750,MIN,Minnesota Timberwolves,22100008,2021-10-20,MIN vs. HOU,W,240,...,41,21,13,3,24,25,106,-18,1,Regular Season
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3604,3604,12023,1610612742,DAL,Dallas Mavericks,12300070,2023-10-20,DAL vs. DET,W,240,...,58,29,5,6,15,20,104,-10,1,Pre Season
3605,3605,12023,1610612761,TOR,Toronto Raptors,12300069,2023-10-20,TOR vs. WAS,W,240,...,42,22,10,4,10,20,98,-36,1,Pre Season
3606,3606,12023,1610612745,HOU,Houston Rockets,12300071,2023-10-20,HOU vs. MIA,W,240,...,39,22,10,5,21,27,104,-6,1,Pre Season
3607,3607,12023,1610612744,GSW,Golden State Warriors,12300073,2023-10-20,GSW vs. SAS,L,240,...,43,36,12,6,12,15,122,5,1,Pre Season
