In [1]:
import sys
sys.path.append('../')
from utils import *
from model import *

In [2]:
data_dir = '/datadrive/data_cs/'
batch_size = 256
batch_num  = 128
epoch_num  = 100
samp_num   = 7

device = torch.device("cuda:3")
graph = dill.load(open(data_dir + 'graph.pk', 'rb'))

In [3]:
train_range = {t: True for t in graph.times if t != None and t < 2015}
valid_range = {t: True for t in graph.times if t != None and t >= 2015  and t <= 2016}
test_range  = {t: True for t in graph.times if t != None and t > 2016}

In [13]:
def pf_sample(seed, papers, pairs, t_range, batch_size, test = False):
    np.random.seed(seed)
    _time = np.random.choice(list(papers.keys()))
    sampn = min(len(papers[_time]), batch_size)
    pids = np.array(papers[_time])[np.random.choice(len(papers[_time]), sampn, replace = False)]
    fids = []
    edge = defaultdict(lambda: {})
    for x_id, p_id in enumerate(pids):
        f_ids = pairs[p_id]
        for f_id in f_ids:
            if f_id not in fids:
                fids += [f_id]
    pids = np.stack([pids, np.repeat([_time], sampn)]).T
    fids = np.stack([fids, np.repeat([_time], len(fids))]).T
 
    feature, times, edge_list, _, _ = sample_subgraph(graph, t_range, \
                inp = {'paper': pids, 'field': fids}, sampled_depth = 4, sampled_number = 100)

    el = []
    for i in edge_list['paper']['field']['rev_PF_in_L2']:
        if i[0] < len(pids):
            continue
        el += [i]
    edge_list['paper']['field']['rev_PF_in_L2'] = el

    el = []
    for i in edge_list['field']['paper']['PF_in_L2']:
        if i[1] < len(pids):
            continue
        el += [i]
    print(len(edge_list['field']['paper']['PF_in_L2']), len(el))
    edge_list['field']['paper']['PF_in_L2'] = el
    
    
    node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = \
            to_torch(feature, times, edge_list, graph)
    '''
        Trace the paper_id and field_id by its own index plus the type start index
    '''
    paper_ids = np.arange(len(pids)) + node_dict['paper'][0]
    field_ids = np.arange(len(fids)) + node_dict['field'][0]
    ylabel = torch.zeros(sampn, len(cand_list))
    for x_id, p_id in enumerate(pids[:,0]):
        for f_id in pairs[p_id]:
            ylabel[x_id][cand_list.index(f_id)] = 1
    ylabel /= ylabel.sum(axis=1).view(-1, 1)
    return node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel, field_ids
    
def prepare_data(pool, process_ids, train_papers, valid_papers):
    jobs = []
    for process_id in process_ids[:-1]:
        p = pool.apply_async(pf_sample, args=(np.random.randint(2**32 - 1), train_papers, \
                                               train_pairs, train_range, batch_size))
        jobs.append(p)
    p = pool.apply_async(pf_sample, args=(np.random.randint(2**32 - 1), valid_papers, \
                                           valid_pairs, valid_range, batch_size))
    jobs.append(p)
    return jobs

In [5]:
class GNN(nn.Module):
    def __init__(self, in_dim, n_hid, num_types, num_relations, n_heads, n_layers, dropout = 0.3):
        super(GNN, self).__init__()
        self.gcs = nn.ModuleList()
        self.num_types = num_types
        self.in_dim    = in_dim
        self.n_hid     = n_hid
        self.aggregat_ws   = nn.ModuleList()
        self.drop          = nn.Dropout(dropout)
        for t in range(num_types):
            self.aggregat_ws.append(nn.Linear(in_dim, n_hid))
        for l in range(n_layers):
            self.gcs.append(RAGCNConv(n_hid, n_hid, num_types, num_relations, n_heads, dropout))
    def set_device(self, device):
        self.device = device
        for gc in self.gcs:
            gc.device = device
    def forward(self, node_feature, node_type, edge_time, edge_index, edge_type):
        res = torch.zeros(node_feature.size(0), self.n_hid).to(node_feature.device)
        for t_id in range(self.num_types):
            aggregat_w = self.aggregat_ws[t_id]
            idx = (node_type == t_id)
            if idx.sum() == 0:
                continue
            res[idx] = torch.tanh(aggregat_w(node_feature[idx]))
        meta_xs = self.drop(res)
        del res
        for gc in self.gcs:
            meta_xs = gc(meta_xs, node_type, edge_index, edge_type, edge_time)
        return meta_xs

In [14]:
'''
Paper-Field
'''
paper_ser = {}

train_pairs = {}
valid_pairs = {}
test_pairs  = {}

train_papers = {_time: {} for _time in train_range}
valid_papers = {_time: {} for _time in valid_range}
test_papers  = {_time: {} for _time in test_range}

for f_id in graph.edge_list['field']['paper']['PF_in_L2']:
    for p_id in graph.edge_list['field']['paper']['PF_in_L2'][f_id]:
        _time = graph.edge_list['field']['paper']['PF_in_L2'][f_id][p_id]
        if _time in train_range:
            if p_id not in train_pairs:
                train_pairs[p_id] = []
            train_pairs[p_id] += [f_id]
            train_papers[_time][p_id] = True
        elif _time in valid_range:
            if p_id not in valid_pairs:
                valid_pairs[p_id] = []
            valid_pairs[p_id] += [f_id]
            valid_papers[_time][p_id] = True
        else:
            if p_id not in test_pairs:
                test_pairs[p_id] = []
            test_pairs[p_id] += [f_id]
            test_papers[_time][p_id] = True
for _time in list(train_papers.keys()):
    if len(train_papers[_time]) < batch_size // 2:
        train_papers.pop(_time)
    else:
        train_papers[_time] = np.array(list(train_papers[_time].keys()))
for _time in list(valid_papers.keys()):
    if len(valid_papers[_time]) < batch_size // 2:
        valid_papers.pop(_time)
    else:
        valid_papers[_time] = np.array(list(valid_papers[_time].keys()))
for _time in list(test_papers.keys()):
    if len(test_papers[_time]) < batch_size // 2:
        test_papers.pop(_time)
    else:
        test_papers[_time] = np.array(list(test_papers[_time].keys()))

In [15]:
types = graph.get_types()
cand_list = list(graph.edge_list['field']['paper']['PF_in_L2'])
gnn = GNN(in_dim = len(graph.node_feature['paper']['emb'][0]) + 401, n_hid = 256, num_types = len(types), \
          num_relations = len(graph.get_meta_graph()) + 1, n_heads = 8, n_layers = 4).to(device)
# gnn = torch.load('../pre-train/save/mt_model.pt').to(device)
classifier = Classifier(256, len(cand_list)).to(device)
model = nn.Sequential(gnn, classifier)
optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000, eta_min=1e-6)

In [16]:
sel_train_papers = {}
for _time in train_papers:
    pid = train_papers[_time]
    pid = pid[np.random.choice(np.arange(len(pid)), len(pid) // 1, replace=False)]
    if len(pid) >= batch_size // 2:
        sel_train_papers[_time] = pid

In [17]:
stats = []
res = []
best_val   = 0

pool = mp.Pool(4)
process_ids = np.arange(batch_num // 4)
st = time.time()
jobs = prepare_data(pool, process_ids, sel_train_papers, valid_papers)
train_step = 1500


criterion = nn.KLDivLoss(reduction='batchmean')
for epoch in np.arange(200)+1:
    '''
        Prepare Training and Validation Data
    '''
    train_data = [job.get() for job in jobs[:-1]]
    valid_data = jobs[-1].get()
    pool.close()
    pool.join()
    pool = mp.Pool(4)
    jobs = prepare_data(pool, process_ids, sel_train_papers, valid_papers)
    et = time.time()
    print('Data Preparation: %.1fs' % (et - st))
    
    model.train()
    train_losses = []
    torch.cuda.empty_cache()
    for batch in np.arange(2):
        for node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel in train_data:
            node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                                   edge_time.to(device), edge_index.to(device), edge_type.to(device))
            res  = classifier.forward(node_rep[paper_ids])
            loss = criterion(res, ylabel.to(device))
            optimizer.zero_grad() 
            torch.cuda.empty_cache()
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.2)
            optimizer.step()
            train_losses += [loss.cpu().detach().tolist()]
            train_step += 1
            scheduler.step(train_step)
            del res, loss
    '''
        Valid
    '''
    model.eval()
    with torch.no_grad():
        node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = valid_data
        node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                                   edge_time.to(device), edge_index.to(device), edge_type.to(device))
        res  = classifier.forward(node_rep[paper_ids])
        loss = criterion(res, ylabel.to(device))
        valid_res = []

        for ai, bi in zip(ylabel, res.argsort(descending = True)):
            valid_res += [ai[bi].tolist()]
        valid_ndcg = np.average([ndcg_at_k(resi, len(resi)) for resi in valid_res])
        if valid_ndcg > best_val:
            best_val = valid_ndcg
            torch.save(model, './save/rgt_3.pt')
        st = time.time()
        print(("Epoch: %d (%.1fs)  LR: %.5f Train Loss: %.2f  Valid Loss: %.2f  Valid NDCG: %.4f") % \
              (epoch, (st-et), optimizer.param_groups[0]['lr'], np.average(train_losses), loss.cpu().detach().tolist(),\
              valid_ndcg))
        stats += [[np.average(train_losses), loss.cpu().detach().tolist()]]
        del res, loss
        if epoch % 5 == 0:
            '''
                Test
            '''
            _time = np.random.choice(list(test_papers.keys()))
            node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = pf_sample(np.random.randint(2 ** 32 - 1), test_papers, \
                                                           test_pairs, test_range, batch_size, test=True)
            paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                      edge_time.to(device), edge_index.to(device), edge_type.to(device))[paper_ids]
            res  = classifier.forward(paper_rep)
            test_res = []
            for ai, bi in zip(ylabel, res.argsort(descending = True)):
                test_res += [ai[bi].tolist()]
            test_ndcg = np.average([ndcg_at_k(resi, len(resi)) for resi in test_res])
            print(test_ndcg)
            del res
    del train_data, valid_data

2361 1167
2153 1020
2279 1086
2234 1098
1781 910
2237 1068
2198 1011
2258 1076
2095 982
2124 1029
2065 979
2142 1007
2314 1138
2061 1004
2296 1059
2256 1110
2121 999
2234 1089
1760 964
1754 884
1769 944
1351 743
2096 979
2334 1124
2210 1043
2234 1115
2108 985
2091 1028
1797 898
2154 997
2242 1076
2182 1058
Data Preparation: 305.2s


ValueError: too many values to unpack (expected 8)

2143 1013
2033 987
2334 1141
2246 1108
2079 1026
2131 1025
2144 994
2428 1184
1883 991
2206 1087
2448 1201
2140 1017
2095 988
2239 1076
2046 980
2164 1047
2353 1108
2443 1176
1772 905
1967 942
2049 916
2049 950
2296 1130
2148 1009
2073 958
2395 1125
1960 955
2123 1034
2036 994
2446 1216
2201 1050
2050 965


In [18]:
print(1)

1


In [None]:
stats = np.array(stats)
plt.plot(stats[:,0])
plt.plot(stats[:,1])
plt.show()

In [None]:
model.eval()
gnn, classifier = model
with torch.no_grad():
    test_res = []
    for _ in range(10):
        _time = np.random.choice(list(test_papers.keys()))
        node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = pf_sample(np.random.randint(2 ** 32 - 1), test_papers, \
                                                       test_pairs, test_range, batch_size, test=True)
        paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                      edge_time.to(device), edge_index.to(device), edge_type.to(device))[paper_ids]
        res = classifier.forward(paper_rep)
        for ai, bi in zip(ylabel, res.argsort(descending = True)):
            test_res += [ai[bi].tolist()]
    test_ndcg = [ndcg_at_k(resi, len(resi)) for resi in test_res]
    print(np.average(test_ndcg), np.var(test_ndcg))
    test_mrr = mean_reciprocal_rank(test_res)
    print(np.average(test_mrr), np.var(test_mrr))

In [None]:
best_model = torch.load('./save/rgt_3.pt')

In [None]:
best_model.eval()
gnn, classifier = best_model
with torch.no_grad():
    test_res = []
    for _ in range(10):
        _time = np.random.choice(list(test_papers.keys()))
        node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = pf_sample(np.random.randint(2 ** 32 - 1), test_papers, \
                                                       test_pairs, test_range, batch_size, test=True)
        paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                      edge_time.to(device), edge_index.to(device), edge_type.to(device))[paper_ids]
        res = classifier.forward(paper_rep)
        for ai, bi in zip(ylabel, res.argsort(descending = True)):
            test_res += [ai[bi].tolist()]
    test_ndcg = [ndcg_at_k(resi, len(resi)) for resi in test_res]
    print(np.average(test_ndcg), np.var(test_ndcg))
    test_mrr = mean_reciprocal_rank(test_res)
    print(np.average(test_mrr), np.var(test_mrr))

In [None]:
# without pre-train

In [None]:
stats = np.array(stats)
plt.plot(stats[:,0])
plt.plot(stats[:,1])
plt.show()

In [None]:
model.eval()
gnn, classifier = model
with torch.no_grad():
    test_res = []
    for _ in range(10):
        _time = np.random.choice(list(test_papers.keys()))
        node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = pf_sample(np.random.randint(2 ** 32 - 1), test_papers, \
                                                       test_pairs, test_range, batch_size, test=True)
        paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                      edge_time.to(device), edge_index.to(device), edge_type.to(device))[paper_ids]
        res = classifier.forward(paper_rep)
        for ai, bi in zip(ylabel, res.argsort(descending = True)):
            test_res += [ai[bi].tolist()]
    test_ndcg = [ndcg_at_k(resi, len(resi)) for resi in test_res]
    print(np.average(test_ndcg), np.var(test_ndcg))
    test_mrr = mean_reciprocal_rank(test_res)
    print(np.average(test_mrr), np.var(test_mrr))
    
best_model = torch.load('./save/rgt_1.pt')

best_model.eval()
gnn, classifier = best_model
with torch.no_grad():
    test_res = []
    for _ in range(10):
        _time = np.random.choice(list(test_papers.keys()))
        node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = pf_sample(np.random.randint(2 ** 32 - 1), test_papers, \
                                                       test_pairs, test_range, batch_size, test=True)
        paper_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                      edge_time.to(device), edge_index.to(device), edge_type.to(device))[paper_ids]
        res = classifier.forward(paper_rep)
        for ai, bi in zip(ylabel, res.argsort(descending = True)):
            test_res += [ai[bi].tolist()]
    test_ndcg = [ndcg_at_k(resi, len(resi)) for resi in test_res]
    print(np.average(test_ndcg), np.var(test_ndcg))
    test_mrr = mean_reciprocal_rank(test_res)
    print(np.average(test_mrr), np.var(test_mrr))

In [None]:
l2_fields = graph.node_feature['field'][graph.node_feature['field']['attr'] == 'L2'].index.values

In [None]:
graph.node_feature['field'].iloc[[l2_fields[424]]]

In [None]:
graph.node_feature['field'].iloc[[l2_fields[637]]]

In [None]:
l2_fields[424], l2_fields[637]

In [None]:
papers = []
mxn = 0
for f in l2_fields:
    p = graph.edge_list['field']['paper']['PF_in_L2'][f]
    if len(p) > 100 and len(p) < 10000:
        papers += [p]
len(papers)

In [None]:
citation_matrix = np.zeros([len(papers), len(papers)])

In [None]:
for i1, j1 in enumerate(papers):
    if i1 % 10 == 0:
        print('%d / %d' % (i1, len(papers)))
    for i2, j2 in enumerate(papers):
        if i2 > i1: 
            continue
        cnt = 0
        for p1 in j1:
            for p2 in graph.edge_list['paper']['paper']['PP_cite'][p1]:
                if p2 in j2:
                    cnt += 1
        citation_matrix[i1][i2] = cnt / len(j1) / len(j2)

In [None]:
possible_ls = []
th = 0.002
for i in range(len(citation_matrix)):
    for j in range(len(citation_matrix)):
        if i >= j:
            inter = citation_matrix[i][j]
            left  = citation_matrix[i][i]
            right = citation_matrix[j][j]
#             if inter>th and left >= inter and right >= 0.7 * left and right <= left and right>=inter:
            if inter>th and left >= inter and right>=inter:
                if len(list(papers[i].keys()))>=2.5*len(list(papers[j].keys())):
                    print(i,j, len(list(papers[i].keys())), len(list(papers[j].keys())))
                    possible_ls.append([i,j])

In [None]:
paper_list_0 = list(papers[637].keys())
paper_list_1 = list(papers[424].keys())
print(len(paper_list_0), len(paper_list_1))
_papers = paper_list_0 + paper_list_1
paper_id = {j:i for i, j in enumerate(_papers)}
paper_labels = {}
for p in paper_list_0:
    paper_labels[p] = 0
for p in paper_list_1:
    paper_labels[p] = 1

In [None]:
_papers = paper_list_0 + paper_list_1
paper_id = {j:i for i, j in enumerate(_papers)}
paper_labels = {}
for p in paper_list_0:
    paper_labels[p] = 0
for p in paper_list_1:
    paper_labels[p] = 1

In [None]:
edges = []
degrees = defaultdict(lambda: 0)
for pi in _papers:
    for pj in  graph.edge_list['paper']['paper']['PP_cite'][pi]:
        if pj in paper_id:
            edges += [paper_id[pi], paper_id[pj]]
            edges += [paper_id[pj], paper_id[pi]]
            degrees[pi] += 1
            degrees[pj] += 1

In [None]:
removes = {}
for p in paper_list_0:
    if degrees[p] <= 5 or degrees[p] > 60:
        removes[p] = True
for p in paper_list_1:
    if degrees[p] <= 20 or degrees[p] > 60:
        removes[p] = True
_papers = [p for p in _papers if p not in removes]
paper_id = {j:i for i, j in enumerate(_papers)}

In [None]:
edges = []
degrees = defaultdict(lambda: 0)
rec = {}
for pi in _papers:
    for pj in graph.edge_list['paper']['paper']['PP_cite'][pi]:
        if pj in paper_id:
            if pi != pj:
                e = [[paper_id[pi], paper_id[pj]]]
                if "%d_%d" % (paper_id[pi], paper_id[pj]) not in rec:
                    edges += [[paper_id[pi], paper_id[pj]]]
                    edges += [[paper_id[pj], paper_id[pi]]]
                    rec["%d_%d" % (paper_id[pi], paper_id[pj])] = True
                    rec["%d_%d" % (paper_id[pj], paper_id[pi])] = True
                    degrees[pi] += 1
                    degrees[pj] += 1

In [None]:
features = np.array(list(graph.node_feature['paper'].loc[_papers, 'emb']))

In [None]:
labels = []
paper_list_0 = []
paper_list_1 = []
for p in _papers:
    labels.append(paper_labels[p])
    if paper_labels[p] == 0:
        paper_list_0 += [p]
    if paper_labels[p] == 1:
        paper_list_1 += [p]
print(len(paper_list_0), len(paper_list_1))

In [None]:
cnt = defaultdict(lambda: 0)
for l in labels:
    cnt[l] += 1
cnt

In [None]:
dill.dump((edges, features, labels), open('./data.pk', 'wb'))

In [None]:
[len(paper_list_0), len(paper_list_1)]

In [None]:
tr = defaultdict(lambda: defaultdict(lambda: 0))
for e in np.array(edges).reshape(-1, 2):
    tr[labels[e[0]]][labels[e[1]]] += 1

In [None]:
tr[0][0] / len(paper_list_0) / len(paper_list_0)

In [None]:
tr[0][1] / len(paper_list_0) / len(paper_list_1)

In [None]:
tr[1][1] / len(paper_list_1) / len(paper_list_1)

In [None]:
tr[1][1]

In [None]:
len(paper_list_1) * len(paper_list_1)