In [1]:
import sys
sys.path.append('../')
from utils import *
from model import *

In [2]:
data_dir = '/datadrive/data_cs/'
batch_size = 256
batch_num  = 128
epoch_num  = 100
samp_num   = 7

device = torch.device("cuda:0")
graph = dill.load(open(data_dir + 'graph.pk', 'rb'))

In [23]:
def sample_subgraph(graph, time_range, sampled_depth = 2, sampled_number = 8, inp = None):
    '''
        Sample Sub-Graph based on the connection of other nodes with currently sampled nodes
        We maintain budgets for each node type, indexed by <node_id, time>.
        Currently sampled nodes are stored in layer_data.
        After nodes are sampled, we construct the sampled adjacancy matrix.
    '''
    layer_data  = defaultdict( #target_type
                        lambda: {} # {target_id: [ser, time]}
                    )
    budget     = defaultdict( #source_type
                                    lambda: defaultdict(  #source_id
                                        lambda: [0., 2000] #sampled_score, source_time
                            ))
    new_layer_adj  = defaultdict( #target_type
                                    lambda: defaultdict(  #source_type
                                        lambda: defaultdict(  #relation_type
                                            lambda: [] #[target_id, source_id]
                                )))
    '''
        For each node being sampled, we find out all its neighborhood, 
        adding the degree count of these nodes in the budget.
        Note that there exist some nodes that have many neighborhoods
        (such as fields, venues), for those case, we only consider 
    '''
    def add_budget(te, target_id, target_time, layer_data, budget):
        for source_type in te:
            tes = te[source_type]
            for relation_type in tes:
                if relation_type == 'self':
                    continue
                adl = tes[relation_type][target_id]
                if len(adl) < sampled_number:
                    sampled_ids = list(adl.keys())
                else:
                    sampled_ids = np.random.choice(list(adl.keys()), sampled_number, replace = False)
                for source_id in sampled_ids:
                    source_time = adl[source_id]
                    if source_time == None:
                        source_time = target_time
                    '''
                        If the node's time is out of range or already being sampled, skip
                        Otherwise, accumulate the normalized degree.
                    '''
                    if source_time not in time_range or source_id in layer_data[source_type]:
                        continue
                    budget[source_type][source_id][0] += 1. / len(sampled_ids)
                    budget[source_type][source_id][1] = source_time
    '''
        The encode and decode function is used to index each node
        by its node_id and time together. So that a same node with
        different timestamps can exist in the sampled graph.
    '''
    
    if inp == None:
        _time = np.random.choice(list(time_range.keys()))
        res = graph.node_feature['paper'][graph.node_feature['paper']['time'] == _time]
        sampn = min(len(res), sampled_number)
        rand_paper_ids  = np.random.choice(list(res.index), sampn, replace = False)
        '''
            First adding the sampled nodes then updating budget.
        '''
        for _id in rand_paper_ids:
            layer_data['paper'][_id] = [len(layer_data['paper']), _time]
        for _id in rand_paper_ids:
            add_budget(graph.edge_list['paper'], _id, _time, layer_data, budget)
    else:
        '''
            First adding the sampled nodes then updating budget.
        '''
        for _type in inp:
            for _id, _time in inp[_type]:
                layer_data[_type][_id] = [len(layer_data[_type]), _time]
        for _type in inp:
            te = graph.edge_list[_type]
            for _id, _time in inp[_type]:
                add_budget(te, _id, _time, layer_data, budget)
    '''
        We recursively expand the sampled graph by sampled_depth.
        Each time we sample a fixed number of nodes for each budget,
        based on the accumulated degree.
    '''
    for layer in range(sampled_depth):
        sts = list(budget.keys())
        for source_type in sts:
            te = graph.edge_list[source_type]
            keys  = np.array(list(budget[source_type].keys()))
            vals  = np.array(list(budget[source_type].values()))
            if sampled_number > len(keys):
                '''
                    Directly sample all the nodes
                '''
                sampled_ids = np.arange(len(keys))
            else:
                '''
                    Sample based on accumulated degree
                '''
                score = vals[:,0] ** 2
                score = score / np.sum(score)
                sampled_ids = np.random.choice(len(score), sampled_number, p = score, replace = False) 
            sampled_keys = keys[sampled_ids]
            sampled_tims = vals[sampled_ids][:, 1]
            '''
                First adding the sampled nodes then updating budget.
            '''
            for k, t in zip(sampled_keys, sampled_tims):
                layer_data[source_type][k] = [len(layer_data[source_type]), t]
            for k, t in zip(sampled_keys, sampled_tims):
                add_budget(te, int(k), int(t), layer_data, budget)
                budget[source_type].pop(k)    
    '''
        Prepare feature, time and adjacency matrix for the sampled graph
    '''
    feature = {}
    times   = {}
    indxs   = {}
    for _type in layer_data:
        idxs = []
        tims = []
        for k in layer_data[_type]:
            idxs += [k]
            tims += [layer_data[_type][k][1]]
        if 'node_emb' in graph.node_feature[_type]:
            feature[_type] = np.array(list(graph.node_feature['field'].loc[idxs, 'node_emb']), dtype=np.float)
        else:
            feature[_type] = np.zeros([len(layer_data[_type]), 400])
        feature[_type] = np.concatenate((feature[_type], list(graph.node_feature[_type].loc[idxs, 'emb']),\
            np.log10(np.array(list(graph.node_feature[_type].loc[idxs, 'citation'])).reshape(-1, 1) + 0.01)), axis=1)
        indxs[_type] = idxs
        times[_type] = tims
        
    edge_list = defaultdict( #target_type
                        lambda: defaultdict(  #source_type
                            lambda: defaultdict(  #relation_type
                                lambda: [] # [target_id, source_id] 
                                    )))
    for _type in layer_data:
        for _key in layer_data[_type]:
            _ser = layer_data[_type][_key][0]
            edge_list[_type][_type]['self'] += [[_ser, _ser]]
    '''
        Reconstruct sampled adjacancy matrix by checking whether each
        link exist in the original graph
    '''
    for target_type in graph.edge_list:
        te = graph.edge_list[target_type]
        for source_type in te:
            tes = te[source_type]
            for relation_type in tes:
                if relation_type in ['APA_coauthor', 'rev_APV_in', 'IPI_coauthor', 'APV_in']:
                    continue
                tesr = tes[relation_type]
                for target_key in layer_data[target_type]:
                    target_ser = layer_data[target_type][target_key][0]
                    tesrt = tesr[target_key]
                    for source_key in layer_data[source_type]:
                        source_ser = layer_data[source_type][source_key][0]
                        '''
                            Check whether each link (target_id, source_id) exist in original adjacancy matrix
                        '''
                        if source_key in tesrt:
                            edge_list[target_type][source_type][relation_type] += [[target_ser, source_ser]]
    return feature, times, edge_list, indxs

In [7]:
train_range = {t: True for t in graph.times if t != None and t < 2015}
valid_range = {t: True for t in graph.times if t != None and t >= 2015  and t <= 2016}
test_range  = {t: True for t in graph.times if t != None and t > 2016}

In [8]:
def pf_sample(seed, papers, pairs, t_range, batch_size, test = False):
    np.random.seed(seed)
    _time = np.random.choice(list(papers.keys()))
    sampn = min(len(papers[_time]), batch_size)
    pids = np.array(papers[_time])[np.random.choice(len(papers[_time]), sampn, replace = False)]
    fids = []
    edge = defaultdict(lambda: {})
    for x_id, p_id in enumerate(pids):
        f_ids = pairs[p_id]
        for f_id in f_ids:
            if f_id not in fids:
                fids += [f_id]
            edge[x_id][fids.index(f_id)] = True
    pids = np.stack([pids, np.repeat([_time], sampn)]).T
    fids = np.stack([fids, np.repeat([_time], len(fids))]).T
 
    feature, times, edge_list, _ = sample_subgraph(graph, t_range, \
                inp = {'paper': pids, 'field': fids}, sampled_depth = 3, sampled_number = 100)

    el = []
    for i in edge_list['paper']['field']['rev_PF_in_L2']:
        if i[0] in edge and i[1] in edge[i[0]]:
            continue
        el += [i]
    edge_list['paper']['field']['rev_PF_in_L2'] = el

    el = []
    for i in edge_list['field']['paper']['PF_in_L2']:
        if i[1] in edge and i[0] in edge[i[1]]:
            continue
        el += [i]
    edge_list['field']['paper']['PF_in_L2'] = el
    
    
    node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = \
            to_torch(feature, times, edge_list, graph)
    '''
        Trace the paper_id and field_id by its own index plus the type start index
    '''
    paper_ids = np.arange(len(pids)) + node_dict['paper'][0]
    field_ids = np.arange(len(fids)) + node_dict['field'][0]
    ylabel = torch.zeros(sampn, len(cand_list))
    for x_id, p_id in enumerate(pids[:,0]):
        for f_id in pairs[p_id]:
            ylabel[x_id][cand_list.index(f_id)] = 1
    ylabel /= ylabel.sum(axis=1).view(-1, 1)
    return node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel
    
def prepare_data(pool, process_ids):
    jobs = []
    for process_id in process_ids[:-1]:
        p = pool.apply_async(pf_sample, args=(np.random.randint(2**32 - 1), train_papers, \
                                               train_pairs, train_range, batch_size))
        jobs.append(p)
    p = pool.apply_async(pf_sample, args=(np.random.randint(2**32 - 1), valid_papers, \
                                           valid_pairs, valid_range, batch_size))
    jobs.append(p)
    return jobs

In [9]:
class GNN(nn.Module):
    def __init__(self, in_dim, n_hid, num_types, num_relations, n_heads, n_layers, dropout = 0.3):
        super(GNN, self).__init__()
        self.gcs = nn.ModuleList()
        self.num_types = num_types
        self.in_dim    = in_dim
        self.n_hid     = n_hid
        self.aggregat_ws   = nn.ModuleList()
        self.drop          = nn.Dropout(dropout)
        for t in range(num_types):
            self.aggregat_ws.append(nn.Linear(in_dim, n_hid))
        for l in range(n_layers):
            self.gcs.append(RAGCNConv(n_hid, n_hid, num_types, num_relations, n_heads, dropout))
    def set_device(self, device):
        self.device = device
        for gc in self.gcs:
            gc.device = device
    def forward(self, node_feature, node_type, edge_time, edge_index, edge_type):
        res = torch.zeros(node_feature.size(0), self.n_hid).to(node_feature.device)
        for t_id in range(self.num_types):
            aggregat_w = self.aggregat_ws[t_id]
            idx = (node_type == t_id)
            if idx.sum() == 0:
                continue
            res[idx] = torch.tanh(aggregat_w(node_feature[idx]))
        meta_xs = self.drop(res)
        del res
        for gc in self.gcs:
            meta_xs = gc(meta_xs, node_type, edge_index, edge_type, edge_time)
        return meta_xs

In [10]:
'''
Paper-Field
'''
paper_ser = {}

train_pairs = {}
valid_pairs = {}
test_pairs  = {}

train_papers = {_time: {} for _time in train_range}
valid_papers = {_time: {} for _time in valid_range}
test_papers  = {_time: {} for _time in test_range}

for f_id in graph.edge_list['field']['paper']['PF_in_L2']:
    for p_id in graph.edge_list['field']['paper']['PF_in_L2'][f_id]:
        _time = graph.edge_list['field']['paper']['PF_in_L2'][f_id][p_id]
        if _time in train_range:
            if p_id not in train_pairs:
                train_pairs[p_id] = []
            train_pairs[p_id] += [f_id]
            train_papers[_time][p_id] = True
        elif _time in valid_range:
            if p_id not in valid_pairs:
                valid_pairs[p_id] = []
            valid_pairs[p_id] += [f_id]
            valid_papers[_time][p_id] = True
        else:
            if p_id not in test_pairs:
                test_pairs[p_id] = []
            test_pairs[p_id] += [f_id]
            test_papers[_time][p_id] = True
for _time in list(train_papers.keys()):
    if len(train_papers[_time]) < batch_size // 2:
        train_papers.pop(_time)
    else:
        train_papers[_time] = np.array(list(train_papers[_time].keys()))
for _time in list(valid_papers.keys()):
    if len(valid_papers[_time]) < batch_size // 2:
        valid_papers.pop(_time)
    else:
        valid_papers[_time] = np.array(list(valid_papers[_time].keys()))
for _time in list(test_papers.keys()):
    if len(test_papers[_time]) < batch_size // 2:
        test_papers.pop(_time)
    else:
        test_papers[_time] = np.array(list(test_papers[_time].keys()))

In [11]:
types = graph.get_types()
cand_list = list(graph.edge_list['field']['paper']['PF_in_L2'])
gnn = GNN(in_dim = len(graph.node_feature['paper']['emb'][0]) + 401, n_hid = 256, num_types = len(types), \
          num_relations = len(graph.get_meta_graph()) + 1, n_heads = 8, n_layers = 3).to(device)
# gnn = torch.load('../pre-train/save/cpc_model.pt').to(device)
classifier = Classifier(256, len(cand_list)).to(device)
model = nn.Sequential(gnn, classifier)
optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000, eta_min=1e-6)

In [24]:
stats = []
pool = mp.Pool(4)
process_ids = np.arange(batch_num // 4)
st = time.time()
jobs = prepare_data(pool, process_ids)
train_step = 1500
best_val   = 0
res = []
criterion = nn.KLDivLoss(reduction='batchmean')
for epoch in np.arange(epoch_num)+1:
    '''
        Prepare Training and Validation Data
    '''
    train_data = [job.get() for job in jobs[:-1]]
    valid_data = jobs[-1].get()
    pool.close()
    pool.join()
    pool = mp.Pool(4)
    jobs = prepare_data(pool, process_ids)
    et = time.time()
    print('Data Preparation: %.1fs' % (et - st))
    
    model.train()
    train_losses = []
    torch.cuda.empty_cache()
    for batch in np.arange(2):
        for node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel in train_data:
            node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                                   edge_time.to(device), edge_index.to(device), edge_type.to(device))
            res  = classifier.forward(node_rep[paper_ids])
            loss = criterion(res, ylabel.to(device))
            optimizer.zero_grad() 
            torch.cuda.empty_cache()
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.2)
            optimizer.step()
            train_losses += [loss.cpu().detach().tolist()]
            train_step += 1
            scheduler.step(train_step)
            del res, loss
    '''
        Valid
    '''
    model.eval()
    with torch.no_grad():
        node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = valid_data
        node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                                   edge_time.to(device), edge_index.to(device), edge_type.to(device))
        res  = classifier.forward(node_rep[paper_ids])
        loss = criterion(res, ylabel.to(device))
        valid_res = []

        for ai, bi in zip(ylabel, res.argsort(descending = True)):
            valid_res += [ai[bi].tolist()]
        valid_ndcg = np.average([ndcg_at_k(resi, len(resi)) for resi in valid_res])
        if valid_ndcg > best_val:
            best_val = valid_ndcg
            torch.save(model, './save/rgt_1.pt')
        st = time.time()
        print(("Epoch: %d (%.1fs)  LR: %.5f Train Loss: %.2f  Valid Loss: %.2f  Valid NDCG: %.4f") % \
              (epoch, (st-et), optimizer.param_groups[0]['lr'], np.average(train_losses), loss.cpu().detach().tolist(),\
              valid_ndcg))
        stats += [[np.average(train_losses), loss.cpu().detach().tolist()]]
        del res, loss
        if epoch % 5 == 0:
            '''
                Test
            '''
            _time = np.random.choice(list(test_papers.keys()))
            node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = pf_sample(np.random.randint(2 ** 32 - 1), test_papers, \
                                                           test_pairs, test_range, batch_size, test=True)
            paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                      edge_time.to(device), edge_index.to(device), edge_type.to(device))[paper_ids]
            res  = classifier.forward(paper_rep)
            test_res = []
            for ai, bi in zip(ylabel, res.argsort(descending = True)):
                test_res += [ai[bi].tolist()]
            test_ndcg = np.average([ndcg_at_k(resi, len(resi)) for resi in test_res])
            print(test_ndcg)
            del res
    del train_data, valid_data

Data Preparation: 44.4s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 1 (102.2s)  LR: 0.00060 Train Loss: 8.48  Valid Loss: 6.82  Valid NDCG: 0.2055
Data Preparation: 2.7s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 2 (107.8s)  LR: 0.00069 Train Loss: 6.50  Valid Loss: 6.46  Valid NDCG: 0.2064
Data Preparation: 2.7s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 3 (105.4s)  LR: 0.00078 Train Loss: 6.34  Valid Loss: 6.46  Valid NDCG: 0.2179
Data Preparation: 2.8s
Epoch: 4 (105.6s)  LR: 0.00085 Train Loss: 6.17  Valid Loss: 6.55  Valid NDCG: 0.2124
Data Preparation: 2.6s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 5 (106.7s)  LR: 0.00091 Train Loss: 6.08  Valid Loss: 6.31  Valid NDCG: 0.2215
0.2065509787238171
Data Preparation: 10.7s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 6 (104.8s)  LR: 0.00096 Train Loss: 5.97  Valid Loss: 6.21  Valid NDCG: 0.2252
Data Preparation: 2.6s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 7 (105.1s)  LR: 0.00099 Train Loss: 5.90  Valid Loss: 6.09  Valid NDCG: 0.2390
Data Preparation: 2.6s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 8 (105.1s)  LR: 0.00100 Train Loss: 5.76  Valid Loss: 6.17  Valid NDCG: 0.2411
Data Preparation: 2.9s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 9 (108.3s)  LR: 0.00099 Train Loss: 5.71  Valid Loss: 6.00  Valid NDCG: 0.2501
Data Preparation: 2.8s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 10 (105.1s)  LR: 0.00096 Train Loss: 5.64  Valid Loss: 6.03  Valid NDCG: 0.2534
0.24316243173578475
Data Preparation: 11.0s
Epoch: 11 (105.8s)  LR: 0.00092 Train Loss: 5.66  Valid Loss: 6.04  Valid NDCG: 0.2496
Data Preparation: 2.6s
Epoch: 12 (102.5s)  LR: 0.00086 Train Loss: 5.44  Valid Loss: 6.05  Valid NDCG: 0.2385
Data Preparation: 2.5s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 13 (105.2s)  LR: 0.00079 Train Loss: 5.48  Valid Loss: 5.89  Valid NDCG: 0.2626
Data Preparation: 2.6s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 14 (104.4s)  LR: 0.00070 Train Loss: 5.53  Valid Loss: 5.80  Valid NDCG: 0.2695
Data Preparation: 2.5s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 15 (102.5s)  LR: 0.00061 Train Loss: 5.39  Valid Loss: 5.67  Valid NDCG: 0.2711
0.24418578938370966
Data Preparation: 11.3s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 16 (106.3s)  LR: 0.00051 Train Loss: 5.43  Valid Loss: 5.71  Valid NDCG: 0.2736
Data Preparation: 2.9s
Epoch: 17 (106.1s)  LR: 0.00042 Train Loss: 5.37  Valid Loss: 5.65  Valid NDCG: 0.2727
Data Preparation: 3.0s
Epoch: 18 (106.3s)  LR: 0.00032 Train Loss: 5.36  Valid Loss: 5.68  Valid NDCG: 0.2639
Data Preparation: 2.8s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 19 (107.0s)  LR: 0.00024 Train Loss: 5.35  Valid Loss: 5.52  Valid NDCG: 0.2889
Data Preparation: 3.0s
Epoch: 20 (105.5s)  LR: 0.00016 Train Loss: 5.30  Valid Loss: 5.71  Valid NDCG: 0.2654
0.26081106024485756
Data Preparation: 11.6s
Epoch: 21 (100.6s)  LR: 0.00009 Train Loss: 5.23  Valid Loss: 5.60  Valid NDCG: 0.2813
Data Preparation: 2.9s
Epoch: 22 (105.1s)  LR: 0.00005 Train Loss: 5.26  Valid Loss: 5.71  Valid NDCG: 0.2667
Data Preparation: 2.9s
Epoch: 23 (104.0s)  LR: 0.00001 Train Loss: 5.28  Valid Loss: 5.63  Valid NDCG: 0.2736
Data Preparation: 2.9s
Epoch: 24 (102.7s)  LR: 0.00000 Train Loss: 5.29  Valid Loss: 5.71  Valid NDCG: 0.2668
Data Preparation: 2.9s
Epoch: 25 (104.9s)  LR: 0.00001 Train Loss: 5.32  Valid Loss: 5.75  Valid NDCG: 0.2757
0.26879638834007635
Data Preparation: 11.8s
Epoch: 26 (106.0s)  LR: 0.00003 Train Loss: 5.35  Valid Loss: 5.57  Valid NDCG: 0.2829
Data Preparation: 3.0s
Epoch: 27 (103.6s)  LR: 0.00007 Train Loss: 5.25  Valid Loss: 5.69  Valid NDCG

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 34 (103.9s)  LR: 0.00067 Train Loss: 5.14  Valid Loss: 5.53  Valid NDCG: 0.3000
Data Preparation: 2.8s
Epoch: 35 (104.8s)  LR: 0.00075 Train Loss: 5.09  Valid Loss: 5.45  Valid NDCG: 0.2958
0.28298487061919647
Data Preparation: 11.6s
Epoch: 36 (106.4s)  LR: 0.00083 Train Loss: 5.15  Valid Loss: 5.51  Valid NDCG: 0.2867
Data Preparation: 3.0s
Epoch: 37 (107.4s)  LR: 0.00090 Train Loss: 5.18  Valid Loss: 5.58  Valid NDCG: 0.2953
Data Preparation: 3.0s
Epoch: 38 (104.3s)  LR: 0.00095 Train Loss: 5.06  Valid Loss: 5.54  Valid NDCG: 0.2922
Data Preparation: 2.9s
Epoch: 39 (105.5s)  LR: 0.00098 Train Loss: 5.02  Valid Loss: 5.64  Valid NDCG: 0.2861
Data Preparation: 2.9s
Epoch: 40 (103.7s)  LR: 0.00100 Train Loss: 4.97  Valid Loss: 5.50  Valid NDCG: 0.2739
0.29186716247919037
Data Preparation: 11.6s
Epoch: 41 (100.8s)  LR: 0.00100 Train Loss: 4.83  Valid Loss: 5.45  Valid NDCG: 0.2995
Data Preparation: 3.0s
Epoch: 42 (103.9s)  LR: 0.00097 Train Loss: 4.95  Valid Loss: 5.55  Valid NDCG

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 43 (105.0s)  LR: 0.00093 Train Loss: 4.99  Valid Loss: 5.33  Valid NDCG: 0.3004
Data Preparation: 3.0s
Epoch: 44 (105.6s)  LR: 0.00088 Train Loss: 4.89  Valid Loss: 5.53  Valid NDCG: 0.2975
Data Preparation: 3.0s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 45 (105.3s)  LR: 0.00081 Train Loss: 4.87  Valid Loss: 5.39  Valid NDCG: 0.3169
0.32118589836254946
Data Preparation: 11.4s
Epoch: 46 (104.7s)  LR: 0.00072 Train Loss: 4.83  Valid Loss: 5.40  Valid NDCG: 0.3006
Data Preparation: 2.7s
Epoch: 47 (106.4s)  LR: 0.00063 Train Loss: 4.85  Valid Loss: 5.44  Valid NDCG: 0.3072
Data Preparation: 3.0s
Epoch: 48 (105.0s)  LR: 0.00054 Train Loss: 4.80  Valid Loss: 5.32  Valid NDCG: 0.3123
Data Preparation: 3.1s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 49 (105.9s)  LR: 0.00044 Train Loss: 4.78  Valid Loss: 5.20  Valid NDCG: 0.3264
Data Preparation: 2.8s
Epoch: 50 (103.2s)  LR: 0.00035 Train Loss: 4.78  Valid Loss: 5.24  Valid NDCG: 0.3152
0.25008213419777886
Data Preparation: 11.3s
Epoch: 51 (104.9s)  LR: 0.00026 Train Loss: 4.82  Valid Loss: 5.37  Valid NDCG: 0.2968
Data Preparation: 3.1s
Epoch: 52 (101.9s)  LR: 0.00018 Train Loss: 4.64  Valid Loss: 5.18  Valid NDCG: 0.3189
Data Preparation: 3.0s
Epoch: 53 (103.9s)  LR: 0.00011 Train Loss: 4.69  Valid Loss: 5.37  Valid NDCG: 0.3018
Data Preparation: 2.8s
Epoch: 54 (107.8s)  LR: 0.00006 Train Loss: 4.75  Valid Loss: 5.36  Valid NDCG: 0.3187
Data Preparation: 3.0s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 55 (100.2s)  LR: 0.00002 Train Loss: 4.59  Valid Loss: 5.31  Valid NDCG: 0.3300
0.294239288405747
Data Preparation: 11.4s
Epoch: 56 (103.8s)  LR: 0.00000 Train Loss: 4.77  Valid Loss: 5.20  Valid NDCG: 0.3148
Data Preparation: 3.0s
Epoch: 57 (104.6s)  LR: 0.00000 Train Loss: 4.66  Valid Loss: 5.22  Valid NDCG: 0.3063
Data Preparation: 3.0s
Epoch: 58 (103.5s)  LR: 0.00002 Train Loss: 4.72  Valid Loss: 5.22  Valid NDCG: 0.3258
Data Preparation: 2.9s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 59 (103.4s)  LR: 0.00006 Train Loss: 4.71  Valid Loss: 4.98  Valid NDCG: 0.3331
Data Preparation: 3.0s
Epoch: 60 (108.7s)  LR: 0.00012 Train Loss: 4.78  Valid Loss: 5.17  Valid NDCG: 0.3282
0.33473263476907605
Data Preparation: 11.6s
Epoch: 61 (103.2s)  LR: 0.00018 Train Loss: 4.63  Valid Loss: 5.16  Valid NDCG: 0.3219
Data Preparation: 2.8s
Epoch: 62 (105.4s)  LR: 0.00027 Train Loss: 4.70  Valid Loss: 5.16  Valid NDCG: 0.3278
Data Preparation: 2.7s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 63 (104.4s)  LR: 0.00036 Train Loss: 4.61  Valid Loss: 5.10  Valid NDCG: 0.3340
Data Preparation: 2.8s
Epoch: 64 (108.3s)  LR: 0.00045 Train Loss: 4.81  Valid Loss: 5.26  Valid NDCG: 0.3225
Data Preparation: 3.0s
Epoch: 65 (104.2s)  LR: 0.00055 Train Loss: 4.68  Valid Loss: 5.14  Valid NDCG: 0.3193
0.24243862079081255
Data Preparation: 11.1s
Epoch: 66 (106.2s)  LR: 0.00064 Train Loss: 4.66  Valid Loss: 5.53  Valid NDCG: 0.3114
Data Preparation: 2.8s
Epoch: 67 (106.9s)  LR: 0.00073 Train Loss: 4.67  Valid Loss: 5.35  Valid NDCG: 0.3075
Data Preparation: 3.0s
Epoch: 68 (105.0s)  LR: 0.00081 Train Loss: 4.60  Valid Loss: 5.28  Valid NDCG: 0.3230
Data Preparation: 2.7s
Epoch: 69 (107.6s)  LR: 0.00088 Train Loss: 4.68  Valid Loss: 5.39  Valid NDCG: 0.3012
Data Preparation: 3.1s
Epoch: 70 (106.8s)  LR: 0.00094 Train Loss: 4.74  Valid Loss: 5.20  Valid NDCG: 0.3224
0.25286838839648934
Data Preparation: 11.5s
Epoch: 71 (105.0s)  LR: 0.00098 Train Loss: 4.61  Valid Loss: 5.13  Valid NDCG

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 72 (106.7s)  LR: 0.00100 Train Loss: 4.53  Valid Loss: 5.10  Valid NDCG: 0.3391
Data Preparation: 3.1s
Epoch: 73 (104.2s)  LR: 0.00100 Train Loss: 4.48  Valid Loss: 5.29  Valid NDCG: 0.3171
Data Preparation: 3.0s
Epoch: 74 (102.9s)  LR: 0.00098 Train Loss: 4.53  Valid Loss: 5.22  Valid NDCG: 0.3221
Data Preparation: 3.1s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 75 (103.9s)  LR: 0.00095 Train Loss: 4.42  Valid Loss: 5.05  Valid NDCG: 0.3449
0.2854326859358558
Data Preparation: 11.7s
Epoch: 76 (104.5s)  LR: 0.00089 Train Loss: 4.48  Valid Loss: 5.25  Valid NDCG: 0.3207
Data Preparation: 3.0s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 77 (102.7s)  LR: 0.00083 Train Loss: 4.56  Valid Loss: 4.86  Valid NDCG: 0.3645
Data Preparation: 3.1s
Epoch: 78 (105.1s)  LR: 0.00075 Train Loss: 4.55  Valid Loss: 5.12  Valid NDCG: 0.3386
Data Preparation: 3.1s
Epoch: 79 (106.3s)  LR: 0.00066 Train Loss: 4.64  Valid Loss: 4.97  Valid NDCG: 0.3353
Data Preparation: 3.0s
Epoch: 80 (103.6s)  LR: 0.00056 Train Loss: 4.44  Valid Loss: 5.16  Valid NDCG: 0.3229
0.3267764033475853
Data Preparation: 11.7s
Epoch: 81 (106.6s)  LR: 0.00047 Train Loss: 4.59  Valid Loss: 5.00  Valid NDCG: 0.3358
Data Preparation: 3.1s
Epoch: 82 (107.0s)  LR: 0.00037 Train Loss: 4.56  Valid Loss: 5.06  Valid NDCG: 0.3317
Data Preparation: 3.1s
Epoch: 83 (104.0s)  LR: 0.00028 Train Loss: 4.45  Valid Loss: 4.94  Valid NDCG: 0.3492
Data Preparation: 2.8s
Epoch: 84 (107.0s)  LR: 0.00020 Train Loss: 4.56  Valid Loss: 4.91  Valid NDCG: 0.3363
Data Preparation: 2.8s
Epoch: 85 (106.4s)  LR: 0.00013 Train Loss: 4.40  Valid Loss: 5.07  Valid NDCG: 0.3379
0.30661991991

In [None]:
stats = np.array(stats)
plt.plot(stats[:,0])
plt.plot(stats[:,1])
plt.show()

In [None]:
model.eval()
gnn, classifier = model
with torch.no_grad():
    test_res = []
    for _ in range(10):
        _time = np.random.choice(list(test_papers.keys()))
        node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = pf_sample(np.random.randint(2 ** 32 - 1), test_papers, \
                                                       test_pairs, test_range, batch_size, test=True)
        paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                      edge_time.to(device), edge_index.to(device), edge_type.to(device))[paper_ids]
        res = classifier.forward(paper_rep)
        for ai, bi in zip(ylabel, res.argsort(descending = True)):
            test_res += [ai[bi].tolist()]
    test_ndcg = [ndcg_at_k(resi, len(resi)) for resi in test_res]
    print(np.average(test_ndcg), np.var(test_ndcg))
    test_mrr = mean_reciprocal_rank(test_res)
    print(np.average(test_mrr), np.var(test_mrr))

In [None]:
best_model = torch.load('./save/rgt_1.pt')

In [None]:
best_model.eval()
gnn, classifier = best_model
with torch.no_grad():
    test_res = []
    for _ in range(10):
        _time = np.random.choice(list(test_papers.keys()))
        node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = pf_sample(np.random.randint(2 ** 32 - 1), test_papers, \
                                                       test_pairs, test_range, batch_size, test=True)
        paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                      edge_time.to(device), edge_index.to(device), edge_type.to(device))[paper_ids]
        res = classifier.forward(paper_rep)
        for ai, bi in zip(ylabel, res.argsort(descending = True)):
            test_res += [ai[bi].tolist()]
    test_ndcg = [ndcg_at_k(resi, len(resi)) for resi in test_res]
    print(np.average(test_ndcg), np.var(test_ndcg))
    test_mrr = mean_reciprocal_rank(test_res)
    print(np.average(test_mrr), np.var(test_mrr))