In [1]:
import sys
sys.path.append('../')
from utils import *
from model import *

In [2]:
data_dir = '/datadrive/data_cs/'
batch_size = 256
batch_num  = 128
epoch_num  = 100
samp_num   = 7

device = torch.device("cuda:3")
graph = dill.load(open(data_dir + 'graph.pk', 'rb'))

In [3]:
train_range = {t: True for t in graph.times if t != None and t < 2015}
valid_range = {t: True for t in graph.times if t != None and t >= 2015  and t <= 2016}
test_range  = {t: True for t in graph.times if t != None and t > 2016}

In [4]:
def pf_sample(seed, papers, pairs, t_range, batch_size, test = False):
    np.random.seed(seed)
    _time = np.random.choice(list(papers.keys()))
    sampn = min(len(papers[_time]), batch_size)
    pids = np.array(papers[_time])[np.random.choice(len(papers[_time]), sampn, replace = False)]
    fids = []
    edge = defaultdict(lambda: {})
    for x_id, p_id in enumerate(pids):
        f_ids = pairs[p_id]
        for f_id in f_ids:
            if f_id not in fids:
                fids += [f_id]
    pids = np.stack([pids, np.repeat([_time], sampn)]).T
    fids = np.stack([fids, np.repeat([_time], len(fids))]).T
 
    feature, times, edge_list, _, _ = sample_subgraph(graph, t_range, \
                inp = {'paper': pids, 'field': fids}, sampled_depth = 4, sampled_number = 100)

    el = []
    for i in edge_list['paper']['field']['rev_PF_in_L2']:
        if i[0] < len(pids):
            continue
        el += [i]
    edge_list['paper']['field']['rev_PF_in_L2'] = el

    el = []
    for i in edge_list['field']['paper']['PF_in_L2']:
        if i[1] < len(pids):
            continue
        el += [i]
    edge_list['field']['paper']['PF_in_L2'] = el
    
    
    node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = \
            to_torch(feature, times, edge_list, graph)
    '''
        Trace the paper_id and field_id by its own index plus the type start index
    '''
    paper_ids = np.arange(len(pids)) + node_dict['paper'][0]
    field_ids = np.arange(len(fids)) + node_dict['field'][0]
    ylabel = torch.zeros(sampn, len(cand_list))
    for x_id, p_id in enumerate(pids[:,0]):
        for f_id in pairs[p_id]:
            ylabel[x_id][cand_list.index(f_id)] = 1
    ylabel /= ylabel.sum(axis=1).view(-1, 1)
    return node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel
    
def prepare_data(pool, process_ids, train_papers, valid_papers):
    jobs = []
    for process_id in process_ids[:-1]:
        p = pool.apply_async(pf_sample, args=(np.random.randint(2**32 - 1), train_papers, \
                                               train_pairs, train_range, batch_size))
        jobs.append(p)
    p = pool.apply_async(pf_sample, args=(np.random.randint(2**32 - 1), valid_papers, \
                                           valid_pairs, valid_range, batch_size))
    jobs.append(p)
    return jobs

In [5]:
class GNN(nn.Module):
    def __init__(self, in_dim, n_hid, num_types, num_relations, n_heads, n_layers, dropout = 0.3):
        super(GNN, self).__init__()
        self.gcs = nn.ModuleList()
        self.num_types = num_types
        self.in_dim    = in_dim
        self.n_hid     = n_hid
        self.aggregat_ws   = nn.ModuleList()
        self.drop          = nn.Dropout(dropout)
        for t in range(num_types):
            self.aggregat_ws.append(nn.Linear(in_dim, n_hid))
        for l in range(n_layers):
            self.gcs.append(RAGCNConv(n_hid, n_hid, num_types, num_relations, n_heads, dropout))
    def set_device(self, device):
        self.device = device
        for gc in self.gcs:
            gc.device = device
    def forward(self, node_feature, node_type, edge_time, edge_index, edge_type):
        res = torch.zeros(node_feature.size(0), self.n_hid).to(node_feature.device)
        for t_id in range(self.num_types):
            aggregat_w = self.aggregat_ws[t_id]
            idx = (node_type == t_id)
            if idx.sum() == 0:
                continue
            res[idx] = torch.tanh(aggregat_w(node_feature[idx]))
        meta_xs = self.drop(res)
        del res
        for gc in self.gcs:
            meta_xs = gc(meta_xs, node_type, edge_index, edge_type, edge_time)
        return meta_xs

In [6]:
'''
Paper-Field
'''
paper_ser = {}

train_pairs = {}
valid_pairs = {}
test_pairs  = {}

train_papers = {_time: {} for _time in train_range}
valid_papers = {_time: {} for _time in valid_range}
test_papers  = {_time: {} for _time in test_range}

for f_id in graph.edge_list['field']['paper']['PF_in_L2']:
    for p_id in graph.edge_list['field']['paper']['PF_in_L2'][f_id]:
        _time = graph.edge_list['field']['paper']['PF_in_L2'][f_id][p_id]
        if _time in train_range:
            if p_id not in train_pairs:
                train_pairs[p_id] = []
            train_pairs[p_id] += [f_id]
            train_papers[_time][p_id] = True
        elif _time in valid_range:
            if p_id not in valid_pairs:
                valid_pairs[p_id] = []
            valid_pairs[p_id] += [f_id]
            valid_papers[_time][p_id] = True
        else:
            if p_id not in test_pairs:
                test_pairs[p_id] = []
            test_pairs[p_id] += [f_id]
            test_papers[_time][p_id] = True
for _time in list(train_papers.keys()):
    if len(train_papers[_time]) < batch_size // 2:
        train_papers.pop(_time)
    else:
        train_papers[_time] = np.array(list(train_papers[_time].keys()))
for _time in list(valid_papers.keys()):
    if len(valid_papers[_time]) < batch_size // 2:
        valid_papers.pop(_time)
    else:
        valid_papers[_time] = np.array(list(valid_papers[_time].keys()))
for _time in list(test_papers.keys()):
    if len(test_papers[_time]) < batch_size // 2:
        test_papers.pop(_time)
    else:
        test_papers[_time] = np.array(list(test_papers[_time].keys()))

In [31]:
types = graph.get_types()
cand_list = list(graph.edge_list['field']['paper']['PF_in_L2'])
gnn = GNN(in_dim = len(graph.node_feature['paper']['emb'][0]) + 401, n_hid = 256, num_types = len(types), \
          num_relations = len(graph.get_meta_graph()) + 1, n_heads = 8, n_layers = 4).to(device)
gnn = torch.load('../pre-train/save/mt_model.pt').to(device)
classifier = Classifier(256, len(cand_list)).to(device)
model = nn.Sequential(gnn, classifier)
optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000, eta_min=1e-6)

In [8]:
sel_train_papers = {}
for _time in train_papers:
    pid = train_papers[_time]
    pid = pid[np.random.choice(np.arange(len(pid)), len(pid) // 10, replace=False)]
    if len(pid) >= batch_size // 2:
        sel_train_papers[_time] = pid

In [32]:
stats = []
res = []
best_val   = 0

pool = mp.Pool(4)
process_ids = np.arange(batch_num // 4)
st = time.time()
jobs = prepare_data(pool, process_ids, sel_train_papers, valid_papers)
train_step = 1500

In [33]:
criterion = nn.KLDivLoss(reduction='batchmean')
for epoch in np.arange(200)+1:
    '''
        Prepare Training and Validation Data
    '''
    train_data = [job.get() for job in jobs[:-1]]
    valid_data = jobs[-1].get()
    pool.close()
    pool.join()
    pool = mp.Pool(4)
    jobs = prepare_data(pool, process_ids, sel_train_papers, valid_papers)
    et = time.time()
    print('Data Preparation: %.1fs' % (et - st))
    
    model.train()
    train_losses = []
    torch.cuda.empty_cache()
    for batch in np.arange(2):
        for node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel in train_data:
            node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                                   edge_time.to(device), edge_index.to(device), edge_type.to(device))
            res  = classifier.forward(node_rep[paper_ids])
            loss = criterion(res, ylabel.to(device))
            optimizer.zero_grad() 
            torch.cuda.empty_cache()
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.2)
            optimizer.step()
            train_losses += [loss.cpu().detach().tolist()]
            train_step += 1
            scheduler.step(train_step)
            del res, loss
    '''
        Valid
    '''
    model.eval()
    with torch.no_grad():
        node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = valid_data
        node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                                   edge_time.to(device), edge_index.to(device), edge_type.to(device))
        res  = classifier.forward(node_rep[paper_ids])
        loss = criterion(res, ylabel.to(device))
        valid_res = []

        for ai, bi in zip(ylabel, res.argsort(descending = True)):
            valid_res += [ai[bi].tolist()]
        valid_ndcg = np.average([ndcg_at_k(resi, len(resi)) for resi in valid_res])
        if valid_ndcg > best_val:
            best_val = valid_ndcg
            torch.save(model, './save/rgt_3.pt')
        st = time.time()
        print(("Epoch: %d (%.1fs)  LR: %.5f Train Loss: %.2f  Valid Loss: %.2f  Valid NDCG: %.4f") % \
              (epoch, (st-et), optimizer.param_groups[0]['lr'], np.average(train_losses), loss.cpu().detach().tolist(),\
              valid_ndcg))
        stats += [[np.average(train_losses), loss.cpu().detach().tolist()]]
        del res, loss
        if epoch % 5 == 0:
            '''
                Test
            '''
            _time = np.random.choice(list(test_papers.keys()))
            node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = pf_sample(np.random.randint(2 ** 32 - 1), test_papers, \
                                                           test_pairs, test_range, batch_size, test=True)
            paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                      edge_time.to(device), edge_index.to(device), edge_type.to(device))[paper_ids]
            res  = classifier.forward(paper_rep)
            test_res = []
            for ai, bi in zip(ylabel, res.argsort(descending = True)):
                test_res += [ai[bi].tolist()]
            test_ndcg = np.average([ndcg_at_k(resi, len(resi)) for resi in test_res])
            print(test_ndcg)
            del res
    del train_data, valid_data

Data Preparation: 94.6s


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch: 1 (115.4s)  LR: 0.00060 Train Loss: 7.40  Valid Loss: 6.46  Valid NDCG: 0.2066
Data Preparation: 2.3s


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch: 2 (112.8s)  LR: 0.00069 Train Loss: 6.17  Valid Loss: 6.40  Valid NDCG: 0.2093
Data Preparation: 2.3s


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch: 3 (113.0s)  LR: 0.00078 Train Loss: 6.02  Valid Loss: 6.17  Valid NDCG: 0.2375
Data Preparation: 2.3s
Epoch: 4 (111.6s)  LR: 0.00085 Train Loss: 5.96  Valid Loss: 6.13  Valid NDCG: 0.2307
Data Preparation: 2.4s
Epoch: 5 (110.7s)  LR: 0.00091 Train Loss: 5.85  Valid Loss: 6.30  Valid NDCG: 0.2181
0.20911976395511278
Data Preparation: 7.2s
Epoch: 6 (112.6s)  LR: 0.00096 Train Loss: 5.79  Valid Loss: 6.05  Valid NDCG: 0.2319
Data Preparation: 2.4s
Epoch: 7 (111.6s)  LR: 0.00099 Train Loss: 5.71  Valid Loss: 6.31  Valid NDCG: 0.2256
Data Preparation: 2.7s
Epoch: 8 (112.0s)  LR: 0.00100 Train Loss: 5.71  Valid Loss: 6.19  Valid NDCG: 0.2313
Data Preparation: 2.5s
Epoch: 9 (111.8s)  LR: 0.00099 Train Loss: 5.62  Valid Loss: 6.00  Valid NDCG: 0.2364
Data Preparation: 2.5s
Epoch: 10 (112.0s)  LR: 0.00096 Train Loss: 5.59  Valid Loss: 6.07  Valid NDCG: 0.2254
0.22829141039908102
Data Preparation: 13.2s
Epoch: 11 (112.5s)  LR: 0.00092 Train Loss: 5.55  Valid Loss: 6.11  Valid NDCG: 0.2339

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch: 12 (112.6s)  LR: 0.00086 Train Loss: 5.55  Valid Loss: 6.09  Valid NDCG: 0.2378
Data Preparation: 2.7s


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch: 13 (113.5s)  LR: 0.00079 Train Loss: 5.51  Valid Loss: 5.93  Valid NDCG: 0.2478
Data Preparation: 2.7s


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch: 14 (113.3s)  LR: 0.00070 Train Loss: 5.39  Valid Loss: 5.73  Valid NDCG: 0.2512
Data Preparation: 2.7s
Epoch: 15 (112.8s)  LR: 0.00061 Train Loss: 5.39  Valid Loss: 6.25  Valid NDCG: 0.2055
0.25163706358561366
Data Preparation: 7.6s
Epoch: 16 (112.9s)  LR: 0.00051 Train Loss: 5.48  Valid Loss: 6.07  Valid NDCG: 0.2400
Data Preparation: 2.5s


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch: 17 (113.0s)  LR: 0.00042 Train Loss: 5.32  Valid Loss: 5.90  Valid NDCG: 0.2559
Data Preparation: 2.5s
Epoch: 18 (111.2s)  LR: 0.00032 Train Loss: 5.31  Valid Loss: 5.88  Valid NDCG: 0.2548
Data Preparation: 2.4s
Epoch: 19 (112.6s)  LR: 0.00024 Train Loss: 5.33  Valid Loss: 5.95  Valid NDCG: 0.2362
Data Preparation: 2.5s
Epoch: 20 (111.4s)  LR: 0.00016 Train Loss: 5.36  Valid Loss: 5.87  Valid NDCG: 0.2514
0.22481632790788486
Data Preparation: 7.7s
Epoch: 21 (112.3s)  LR: 0.00009 Train Loss: 5.40  Valid Loss: 5.95  Valid NDCG: 0.2430
Data Preparation: 2.5s
Epoch: 22 (112.4s)  LR: 0.00005 Train Loss: 5.31  Valid Loss: 5.91  Valid NDCG: 0.2402
Data Preparation: 2.7s
Epoch: 23 (111.3s)  LR: 0.00001 Train Loss: 5.31  Valid Loss: 5.99  Valid NDCG: 0.2433
Data Preparation: 2.8s
Epoch: 24 (111.4s)  LR: 0.00000 Train Loss: 5.30  Valid Loss: 6.06  Valid NDCG: 0.2314
Data Preparation: 2.6s
Epoch: 25 (112.9s)  LR: 0.00001 Train Loss: 5.34  Valid Loss: 5.62  Valid NDCG: 0.2469
0.21591237279

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch: 29 (112.9s)  LR: 0.00020 Train Loss: 5.27  Valid Loss: 5.77  Valid NDCG: 0.2617
Data Preparation: 2.5s
Epoch: 30 (112.7s)  LR: 0.00029 Train Loss: 5.30  Valid Loss: 5.88  Valid NDCG: 0.2557
0.25172371790490744
Data Preparation: 7.7s
Epoch: 31 (112.5s)  LR: 0.00038 Train Loss: 5.19  Valid Loss: 5.91  Valid NDCG: 0.2369
Data Preparation: 2.5s
Epoch: 32 (112.7s)  LR: 0.00048 Train Loss: 5.27  Valid Loss: 5.85  Valid NDCG: 0.2450
Data Preparation: 2.4s
Epoch: 33 (112.6s)  LR: 0.00057 Train Loss: 5.31  Valid Loss: 5.86  Valid NDCG: 0.2411
Data Preparation: 2.8s
Epoch: 34 (112.3s)  LR: 0.00067 Train Loss: 5.29  Valid Loss: 5.93  Valid NDCG: 0.2465
Data Preparation: 2.7s
Epoch: 35 (110.6s)  LR: 0.00075 Train Loss: 5.22  Valid Loss: 5.95  Valid NDCG: 0.2428
0.2602171257109083
Data Preparation: 7.8s
Epoch: 36 (113.0s)  LR: 0.00083 Train Loss: 5.24  Valid Loss: 6.31  Valid NDCG: 0.2187
Data Preparation: 2.4s
Epoch: 37 (112.4s)  LR: 0.00090 Train Loss: 5.26  Valid Loss: 6.03  Valid NDCG: 0

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch: 43 (176.7s)  LR: 0.00093 Train Loss: 5.09  Valid Loss: 5.80  Valid NDCG: 0.2634
Data Preparation: 2.9s
Epoch: 44 (175.3s)  LR: 0.00088 Train Loss: 5.13  Valid Loss: 5.90  Valid NDCG: 0.2454
Data Preparation: 3.2s
Epoch: 45 (177.1s)  LR: 0.00081 Train Loss: 5.12  Valid Loss: 5.97  Valid NDCG: 0.2471
0.23533486918338087
Data Preparation: 9.4s
Epoch: 46 (174.9s)  LR: 0.00072 Train Loss: 5.16  Valid Loss: 5.99  Valid NDCG: 0.2418
Data Preparation: 2.9s
Epoch: 47 (177.3s)  LR: 0.00063 Train Loss: 5.06  Valid Loss: 5.83  Valid NDCG: 0.2620
Data Preparation: 2.9s
Epoch: 48 (177.4s)  LR: 0.00054 Train Loss: 5.07  Valid Loss: 5.83  Valid NDCG: 0.2524
Data Preparation: 2.6s
Epoch: 49 (175.1s)  LR: 0.00044 Train Loss: 5.03  Valid Loss: 5.92  Valid NDCG: 0.2462
Data Preparation: 2.6s
Epoch: 50 (178.3s)  LR: 0.00035 Train Loss: 5.04  Valid Loss: 5.90  Valid NDCG: 0.2599
0.25467404998078685
Data Preparation: 13.3s
Epoch: 51 (176.3s)  LR: 0.00026 Train Loss: 5.06  Valid Loss: 5.86  Valid NDCG:

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch: 52 (174.6s)  LR: 0.00018 Train Loss: 5.05  Valid Loss: 5.64  Valid NDCG: 0.2712
Data Preparation: 2.7s
Epoch: 53 (176.5s)  LR: 0.00011 Train Loss: 5.02  Valid Loss: 5.84  Valid NDCG: 0.2463
Data Preparation: 2.6s
Epoch: 54 (173.0s)  LR: 0.00006 Train Loss: 5.03  Valid Loss: 6.03  Valid NDCG: 0.2477
Data Preparation: 2.5s
Epoch: 55 (174.0s)  LR: 0.00002 Train Loss: 5.06  Valid Loss: 6.11  Valid NDCG: 0.2419
0.22905887913146702
Data Preparation: 7.8s
Epoch: 56 (164.0s)  LR: 0.00000 Train Loss: 5.08  Valid Loss: 6.08  Valid NDCG: 0.2382
Data Preparation: 2.7s
Epoch: 57 (177.7s)  LR: 0.00000 Train Loss: 5.02  Valid Loss: 5.76  Valid NDCG: 0.2636
Data Preparation: 2.5s
Epoch: 58 (170.7s)  LR: 0.00002 Train Loss: 5.01  Valid Loss: 6.19  Valid NDCG: 0.2346
Data Preparation: 2.5s


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch: 59 (172.4s)  LR: 0.00006 Train Loss: 5.04  Valid Loss: 5.64  Valid NDCG: 0.2732
Data Preparation: 2.9s
Epoch: 60 (175.5s)  LR: 0.00012 Train Loss: 5.13  Valid Loss: 5.87  Valid NDCG: 0.2586
0.22685614839766957
Data Preparation: 9.2s
Epoch: 61 (168.6s)  LR: 0.00018 Train Loss: 5.05  Valid Loss: 5.79  Valid NDCG: 0.2584
Data Preparation: 3.1s
Epoch: 62 (176.4s)  LR: 0.00027 Train Loss: 5.10  Valid Loss: 5.83  Valid NDCG: 0.2389
Data Preparation: 2.7s
Epoch: 63 (175.6s)  LR: 0.00036 Train Loss: 4.98  Valid Loss: 6.10  Valid NDCG: 0.2392
Data Preparation: 2.6s
Epoch: 64 (175.0s)  LR: 0.00045 Train Loss: 4.88  Valid Loss: 5.74  Valid NDCG: 0.2536
Data Preparation: 3.0s
Epoch: 65 (155.8s)  LR: 0.00055 Train Loss: 5.17  Valid Loss: 5.83  Valid NDCG: 0.2534
0.2369753825287796
Data Preparation: 12.4s
Epoch: 66 (118.9s)  LR: 0.00064 Train Loss: 4.90  Valid Loss: 6.16  Valid NDCG: 0.2297
Data Preparation: 2.6s
Epoch: 67 (171.0s)  LR: 0.00073 Train Loss: 5.01  Valid Loss: 5.84  Valid NDCG: 

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Epoch: 114 (179.3s)  LR: 0.00039 Train Loss: 4.73  Valid Loss: 5.57  Valid NDCG: 0.2817
Data Preparation: 2.7s
Epoch: 115 (179.8s)  LR: 0.00030 Train Loss: 4.74  Valid Loss: 6.35  Valid NDCG: 0.2282
0.22918068716989223
Data Preparation: 8.0s
Epoch: 116 (174.3s)  LR: 0.00022 Train Loss: 4.77  Valid Loss: 6.37  Valid NDCG: 0.2301
Data Preparation: 2.7s
Epoch: 117 (172.3s)  LR: 0.00014 Train Loss: 4.76  Valid Loss: 6.17  Valid NDCG: 0.2357
Data Preparation: 2.7s
Epoch: 118 (175.0s)  LR: 0.00008 Train Loss: 4.96  Valid Loss: 6.41  Valid NDCG: 0.2125
Data Preparation: 2.7s
Epoch: 119 (177.4s)  LR: 0.00004 Train Loss: 4.92  Valid Loss: 6.02  Valid NDCG: 0.2417
Data Preparation: 2.6s
Epoch: 120 (175.6s)  LR: 0.00001 Train Loss: 4.65  Valid Loss: 6.14  Valid NDCG: 0.2308
0.22571512619416986
Data Preparation: 13.8s
Epoch: 121 (167.4s)  LR: 0.00000 Train Loss: 4.87  Valid Loss: 6.13  Valid NDCG: 0.2330
Data Preparation: 2.6s
Epoch: 122 (177.4s)  LR: 0.00001 Train Loss: 4.70  Valid Loss: 6.31  Va

0.23900581907530474
Data Preparation: 13.5s
Epoch: 186 (172.2s)  LR: 0.00000 Train Loss: 4.59  Valid Loss: 6.20  Valid NDCG: 0.2214
Data Preparation: 2.6s
Epoch: 187 (173.9s)  LR: 0.00002 Train Loss: 4.47  Valid Loss: 6.14  Valid NDCG: 0.2197
Data Preparation: 2.5s
Epoch: 188 (175.3s)  LR: 0.00006 Train Loss: 4.36  Valid Loss: 6.15  Valid NDCG: 0.2347
Data Preparation: 2.9s
Epoch: 189 (178.1s)  LR: 0.00011 Train Loss: 4.44  Valid Loss: 6.09  Valid NDCG: 0.2422
Data Preparation: 2.5s
Epoch: 190 (177.9s)  LR: 0.00018 Train Loss: 4.73  Valid Loss: 6.34  Valid NDCG: 0.2190
0.23215874286078772
Data Preparation: 9.2s
Epoch: 191 (172.6s)  LR: 0.00026 Train Loss: 4.62  Valid Loss: 6.19  Valid NDCG: 0.2285
Data Preparation: 2.8s
Epoch: 192 (174.9s)  LR: 0.00035 Train Loss: 4.68  Valid Loss: 6.39  Valid NDCG: 0.2213
Data Preparation: 2.7s
Epoch: 193 (174.0s)  LR: 0.00045 Train Loss: 4.57  Valid Loss: 6.23  Valid NDCG: 0.2232
Data Preparation: 2.8s
Epoch: 194 (176.4s)  LR: 0.00054 Train Loss: 4.4

In [None]:
stats = np.array(stats)
plt.plot(stats[:,0])
plt.plot(stats[:,1])
plt.show()

In [None]:
model.eval()
gnn, classifier = model
with torch.no_grad():
    test_res = []
    for _ in range(10):
        _time = np.random.choice(list(test_papers.keys()))
        node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = pf_sample(np.random.randint(2 ** 32 - 1), test_papers, \
                                                       test_pairs, test_range, batch_size, test=True)
        paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                      edge_time.to(device), edge_index.to(device), edge_type.to(device))[paper_ids]
        res = classifier.forward(paper_rep)
        for ai, bi in zip(ylabel, res.argsort(descending = True)):
            test_res += [ai[bi].tolist()]
    test_ndcg = [ndcg_at_k(resi, len(resi)) for resi in test_res]
    print(np.average(test_ndcg), np.var(test_ndcg))
    test_mrr = mean_reciprocal_rank(test_res)
    print(np.average(test_mrr), np.var(test_mrr))

In [None]:
best_model = torch.load('./save/rgt_3.pt')

In [None]:
best_model.eval()
gnn, classifier = best_model
with torch.no_grad():
    test_res = []
    for _ in range(10):
        _time = np.random.choice(list(test_papers.keys()))
        node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = pf_sample(np.random.randint(2 ** 32 - 1), test_papers, \
                                                       test_pairs, test_range, batch_size, test=True)
        paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                      edge_time.to(device), edge_index.to(device), edge_type.to(device))[paper_ids]
        res = classifier.forward(paper_rep)
        for ai, bi in zip(ylabel, res.argsort(descending = True)):
            test_res += [ai[bi].tolist()]
    test_ndcg = [ndcg_at_k(resi, len(resi)) for resi in test_res]
    print(np.average(test_ndcg), np.var(test_ndcg))
    test_mrr = mean_reciprocal_rank(test_res)
    print(np.average(test_mrr), np.var(test_mrr))

In [None]:
# without pre-train

In [None]:
stats = np.array(stats)
plt.plot(stats[:,0])
plt.plot(stats[:,1])
plt.show()

In [None]:
model.eval()
gnn, classifier = model
with torch.no_grad():
    test_res = []
    for _ in range(10):
        _time = np.random.choice(list(test_papers.keys()))
        node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = pf_sample(np.random.randint(2 ** 32 - 1), test_papers, \
                                                       test_pairs, test_range, batch_size, test=True)
        paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                      edge_time.to(device), edge_index.to(device), edge_type.to(device))[paper_ids]
        res = classifier.forward(paper_rep)
        for ai, bi in zip(ylabel, res.argsort(descending = True)):
            test_res += [ai[bi].tolist()]
    test_ndcg = [ndcg_at_k(resi, len(resi)) for resi in test_res]
    print(np.average(test_ndcg), np.var(test_ndcg))
    test_mrr = mean_reciprocal_rank(test_res)
    print(np.average(test_mrr), np.var(test_mrr))
    
best_model = torch.load('./save/rgt_1.pt')

best_model.eval()
gnn, classifier = best_model
with torch.no_grad():
    test_res = []
    for _ in range(10):
        _time = np.random.choice(list(test_papers.keys()))
        node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = pf_sample(np.random.randint(2 ** 32 - 1), test_papers, \
                                                       test_pairs, test_range, batch_size, test=True)
        paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                      edge_time.to(device), edge_index.to(device), edge_type.to(device))[paper_ids]
        res = classifier.forward(paper_rep)
        for ai, bi in zip(ylabel, res.argsort(descending = True)):
            test_res += [ai[bi].tolist()]
    test_ndcg = [ndcg_at_k(resi, len(resi)) for resi in test_res]
    print(np.average(test_ndcg), np.var(test_ndcg))
    test_mrr = mean_reciprocal_rank(test_res)
    print(np.average(test_mrr), np.var(test_mrr))

In [None]:
l2_fields = graph.node_feature['field'][graph.node_feature['field']['attr'] == 'L2'].index.values

In [None]:
graph.node_feature['field'].iloc[[l2_fields[424]]]

In [None]:
graph.node_feature['field'].iloc[[l2_fields[637]]]

In [None]:
l2_fields[424], l2_fields[637]

In [None]:
papers = []
mxn = 0
for f in l2_fields:
    p = graph.edge_list['field']['paper']['PF_in_L2'][f]
    if len(p) > 100 and len(p) < 10000:
        papers += [p]
len(papers)

In [None]:
citation_matrix = np.zeros([len(papers), len(papers)])

In [None]:
for i1, j1 in enumerate(papers):
    if i1 % 10 == 0:
        print('%d / %d' % (i1, len(papers)))
    for i2, j2 in enumerate(papers):
        if i2 > i1: 
            continue
        cnt = 0
        for p1 in j1:
            for p2 in graph.edge_list['paper']['paper']['PP_cite'][p1]:
                if p2 in j2:
                    cnt += 1
        citation_matrix[i1][i2] = cnt / len(j1) / len(j2)

In [None]:
possible_ls = []
th = 0.002
for i in range(len(citation_matrix)):
    for j in range(len(citation_matrix)):
        if i >= j:
            inter = citation_matrix[i][j]
            left  = citation_matrix[i][i]
            right = citation_matrix[j][j]
#             if inter>th and left >= inter and right >= 0.7 * left and right <= left and right>=inter:
            if inter>th and left >= inter and right>=inter:
                if len(list(papers[i].keys()))>=2.5*len(list(papers[j].keys())):
                    print(i,j, len(list(papers[i].keys())), len(list(papers[j].keys())))
                    possible_ls.append([i,j])

In [None]:
paper_list_0 = list(papers[637].keys())
paper_list_1 = list(papers[424].keys())
print(len(paper_list_0), len(paper_list_1))
_papers = paper_list_0 + paper_list_1
paper_id = {j:i for i, j in enumerate(_papers)}
paper_labels = {}
for p in paper_list_0:
    paper_labels[p] = 0
for p in paper_list_1:
    paper_labels[p] = 1

In [None]:
_papers = paper_list_0 + paper_list_1
paper_id = {j:i for i, j in enumerate(_papers)}
paper_labels = {}
for p in paper_list_0:
    paper_labels[p] = 0
for p in paper_list_1:
    paper_labels[p] = 1

In [None]:
edges = []
degrees = defaultdict(lambda: 0)
for pi in _papers:
    for pj in  graph.edge_list['paper']['paper']['PP_cite'][pi]:
        if pj in paper_id:
            edges += [paper_id[pi], paper_id[pj]]
            edges += [paper_id[pj], paper_id[pi]]
            degrees[pi] += 1
            degrees[pj] += 1

In [None]:
removes = {}
for p in paper_list_0:
    if degrees[p] <= 5 or degrees[p] > 60:
        removes[p] = True
for p in paper_list_1:
    if degrees[p] <= 20 or degrees[p] > 60:
        removes[p] = True
_papers = [p for p in _papers if p not in removes]
paper_id = {j:i for i, j in enumerate(_papers)}

In [None]:
edges = []
degrees = defaultdict(lambda: 0)
rec = {}
for pi in _papers:
    for pj in graph.edge_list['paper']['paper']['PP_cite'][pi]:
        if pj in paper_id:
            if pi != pj:
                e = [[paper_id[pi], paper_id[pj]]]
                if "%d_%d" % (paper_id[pi], paper_id[pj]) not in rec:
                    edges += [[paper_id[pi], paper_id[pj]]]
                    edges += [[paper_id[pj], paper_id[pi]]]
                    rec["%d_%d" % (paper_id[pi], paper_id[pj])] = True
                    rec["%d_%d" % (paper_id[pj], paper_id[pi])] = True
                    degrees[pi] += 1
                    degrees[pj] += 1

In [None]:
features = np.array(list(graph.node_feature['paper'].loc[_papers, 'emb']))

In [None]:
labels = []
paper_list_0 = []
paper_list_1 = []
for p in _papers:
    labels.append(paper_labels[p])
    if paper_labels[p] == 0:
        paper_list_0 += [p]
    if paper_labels[p] == 1:
        paper_list_1 += [p]
print(len(paper_list_0), len(paper_list_1))

In [None]:
cnt = defaultdict(lambda: 0)
for l in labels:
    cnt[l] += 1
cnt

In [None]:
dill.dump((edges, features, labels), open('./data.pk', 'wb'))

In [None]:
[len(paper_list_0), len(paper_list_1)]

In [None]:
tr = defaultdict(lambda: defaultdict(lambda: 0))
for e in np.array(edges).reshape(-1, 2):
    tr[labels[e[0]]][labels[e[1]]] += 1

In [None]:
tr[0][0] / len(paper_list_0) / len(paper_list_0)

In [None]:
tr[0][1] / len(paper_list_0) / len(paper_list_1)

In [None]:
tr[1][1] / len(paper_list_1) / len(paper_list_1)

In [None]:
tr[1][1]

In [None]:
len(paper_list_1) * len(paper_list_1)