In [11]:
import pandas as pd
import numpy as np
import torch

import model.utils as utils
import model.loader as loader

from model.loader import MiniBatchSampler
from driver import Driver
from tqdm import tqdm

In [12]:
DATA = 'lol'
BATCH_SIZE = 200
N_DIM = 32
E_DIM = 16
T_DIM = 32
UNIFORM = False
GPU = 0
N_LAYER = 1
N_HEAD = 4
DROPOUT = 0.1
N_DEGREE = 10
BETA = 0.01
LEARNING_RATE = 0.001
EPOCHS = 100


MODEL_SAVE_PATH = f'./saved_models/experiment-{DATA}.pth'

In [13]:
logger = utils.get_logger("experiment_"+DATA+"_bs"+str(BATCH_SIZE))

utils.set_random_seed(2022)

INFO:root:



In [14]:
g, g_val, train, val, test, p_classes = loader.load_and_split_data_train_test_val(DATA, N_DIM, E_DIM)

train_ngh_finder = loader.get_neighbor_finder(train, g.max_idx, UNIFORM, num_edge_type=g.num_e_type)
val_ngh_finder = loader.get_neighbor_finder(g_val, g.max_idx, UNIFORM, num_edge_type=g.num_e_type)
test_ngh_finder = loader.get_neighbor_finder(g, g.max_idx, UNIFORM,
                                                g.num_e_type)

train_batch_sampler = MiniBatchSampler(train.e_type_l, BATCH_SIZE, 'train', p_classes)
val_batch_sampler = MiniBatchSampler(val.e_type_l, BATCH_SIZE, 'val', p_classes)
test_batch_sampler = MiniBatchSampler(test.e_type_l, BATCH_SIZE, 'test',
                                        p_classes)

In [15]:
device = torch.device('cuda:{}'.format(GPU)) if GPU != -1 else 'cpu'

driver = Driver(g, g_val, train, val, test, p_classes, train_ngh_finder,
                val_ngh_finder, test_ngh_finder, train_batch_sampler,
                val_batch_sampler, test_batch_sampler, device, T_DIM,
                N_LAYER, N_HEAD, DROPOUT, N_DEGREE,
                BETA, LEARNING_RATE, MODEL_SAVE_PATH, None)

INFO:model.module:Aggregation uses attention model
INFO:model.module:Using time encoding


In [16]:
_, _, _, train_acc_l, test_acc_l, loss_l = driver.eval_epochs(EPOCHS)

100%|██████████| 100/100 [27:21<00:00, 16.41s/it]


In [17]:
time_steps = loader.get_time_steps(test, p_classes, 10)
best_epoch = np.argmax(test_acc_l)

window_acc = []
window_corr = []
for i in tqdm(range(len(time_steps) - 1)):
    train, test = loader.split_data_window(g, time_steps[i], time_steps[i + 1])

    driver.train = train
    driver.test = test
    driver.train_ngh_finder = loader.get_neighbor_finder(train, g.max_idx, UNIFORM, num_edge_type=g.num_e_type)
    driver.train_batch_sampler = MiniBatchSampler(train.e_type_l, BATCH_SIZE, 'train', p_classes)
    driver.test_batch_sampler = MiniBatchSampler(test.e_type_l, BATCH_SIZE, 'test', p_classes)

    driver.reset_model()

    train_acc_l, loss_l, memory_backup = driver.train_window(best_epoch)
    test_acc, corr = driver.test_window(memory_backup)
    window_acc.append(test_acc)
    window_corr.append(corr)

  0%|          | 0/10 [00:00<?, ?it/s]

INFO:model.module:Aggregation uses attention model
INFO:model.module:Using time encoding
100%|██████████| 27/27 [07:51<00:00, 17.47s/it]
 10%|█         | 1/10 [07:56<1:11:24, 476.05s/it]INFO:model.module:Aggregation uses attention model
INFO:model.module:Using time encoding
100%|██████████| 27/27 [07:32<00:00, 16.77s/it]
 20%|██        | 2/10 [15:34<1:02:04, 465.51s/it]INFO:model.module:Aggregation uses attention model
INFO:model.module:Using time encoding
100%|██████████| 27/27 [08:21<00:00, 18.56s/it]
 30%|███       | 3/10 [24:00<56:30, 484.35s/it]  INFO:model.module:Aggregation uses attention model
INFO:model.module:Using time encoding
100%|██████████| 27/27 [08:44<00:00, 19.44s/it]
 40%|████      | 4/10 [32:51<50:15, 502.60s/it]INFO:model.module:Aggregation uses attention model
INFO:model.module:Using time encoding
100%|██████████| 27/27 [08:37<00:00, 19.18s/it]
 50%|█████     | 5/10 [41:34<42:29, 509.93s/it]INFO:model.module:Aggregation uses attention model
INFO:model.module:Using

In [18]:
window_acc

[0.5926767,
 0.57349616,
 0.56185085,
 0.5527931,
 0.52106994,
 0.5532411,
 0.529576,
 0.52678776,
 0.52122283,
 0.5007017]

In [19]:
corr = np.array([x.cpu() for y in window_corr for x in y])

In [20]:
corr.sum() / len(corr)

0.5480173312126364