In [1]:
import sys
sys.path.append('../')
from utils import *
from model import *

In [2]:
data_dir = '/datadrive/data_cs/'
batch_size = 64
batch_num  = 128
epoch_num  = 100
samp_num   = 3

device = torch.device("cuda:0")
graph = dill.load(open(data_dir + 'graph.pk', 'rb'))

In [3]:
train_range = {t: True for t in graph.times if t != None and t < 2015}
valid_range = {t: True for t in graph.times if t != None and t >= 2015  and t <= 2016}
test_range  = {t: True for t in graph.times if t != None and t > 2016}

In [4]:
name_count = defaultdict(lambda: [])
for i, j in tqdm(graph.node_feature['author'].iterrows(), total = len(graph.node_feature['author'])):
    name_count[j['name']] += [i]

HBox(children=(IntProgress(value=0, max=510189), HTML(value='')))




In [5]:
'''
Author Disambiguation
'''
train_pairs = {}
valid_pairs = {}
test_pairs  = {}

train_papers = {_time: {} for _time in train_range}
valid_papers = {_time: {} for _time in valid_range}
test_papers  = {_time: {} for _time in test_range}

for name in name_count:
    same_name_author_list = np.array(name_count[name])
    if len(same_name_author_list) > 4:
        for author_id, author in enumerate(same_name_author_list):
            pem_ids = list(range(len(same_name_author_list)))
            pem_ids.remove(author_id)
            for p_id in graph.edge_list['author']['paper']['rev_AP_write_first'][author]:
                _time = graph.edge_list['author']['paper']['rev_AP_write_first'][author][p_id]
                if type(_time) != int:
                    continue
                al = same_name_author_list[np.array([author_id] + pem_ids)]
                if _time in train_range:
                    if p_id not in train_pairs:
                        train_pairs[p_id] = []
                    train_pairs[p_id] += [al]
                    train_papers[_time][p_id] = True
                elif _time in valid_range:
                    if p_id not in valid_pairs:
                        valid_pairs[p_id] = []
                    valid_pairs[p_id] += [al]
                    valid_papers[_time][p_id] = True
                else:
                    if p_id not in test_pairs:
                        test_pairs[p_id] = []
                    test_pairs[p_id] += [al]
                    test_papers[_time][p_id] = True
train_papers = {k:list(train_papers[k].keys()) for k in train_papers if len(train_papers[k]) >= batch_size}
valid_papers = {k:list(valid_papers[k].keys()) for k in valid_papers if len(valid_papers[k]) >= batch_size}
test_papers  = {k:list(test_papers[k].keys()) for k in test_papers if len(test_papers[k]) >= batch_size}

In [6]:
def ad_sample(seed, papers, pairs, t_range, batch_size, test = False):
    np.random.seed(seed)
    _time = np.random.choice(list(papers.keys()))
    pids = np.array(papers[_time])[np.random.choice(len(papers[_time]), batch_size, replace = False)]
    aids = []
    edge = defaultdict(lambda: {})
    eals = []
    for x_id, p_id in enumerate(pids):
        als = pairs[p_id]
        for al in als:
            for a_id in al:
                if a_id not in aids:
                    aids += [a_id]
            eal = [aids.index(al[0])]
            edge[x_id][eal[-1]] = 1
            for a_id in al[1:]:
                eal += [aids.index(a_id)]
                edge[x_id][eal[-1]] = 0
            eals += [[x_id, eal]]
    pids = np.stack([pids, np.repeat([_time], batch_size)]).T
    aids = np.stack([aids, np.repeat([_time], len(aids))]).T
 
    feature, times, edge_list, _ = sample_subgraph(graph, t_range, \
                inp = {'paper': pids, 'author': aids}, sampled_depth = 3, sampled_number = 100)

    el = []
    for i in edge_list['paper']['author']['AP_write_first']:
        if i[0] in edge and i[1] in edge[i[0]]:
            continue
        el += [i]
    edge_list['paper']['author']['AP_write_first'] = el

    el = []
    for i in edge_list['author']['paper']['rev_AP_write_first']:
        if i[1] in edge and i[0] in edge[i[1]]:
            continue
        el += [i]
    edge_list['author']['paper']['rev_AP_write_first'] = el
    
    
    node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = \
            to_torch(feature, times, edge_list, graph)
    '''
        Trace the paper_id and field_id by its own index plus the type start index
    '''
    paper_ids = np.arange(len(pids)) + node_dict['paper'][0]
    author_ids = np.arange(len(aids)) + node_dict['author'][0]
    ylabel = {}
    for x_id, eal in eals:
        ylabel[x_id + node_dict['paper'][0]] = np.array(eal) + node_dict['author'][0]
    return node_feature, node_type, edge_time, edge_index, edge_type, author_ids, paper_ids, ylabel
    
def prepare_data(pool, process_ids):
    jobs = []
    for process_id in process_ids[:-1]:
        p = pool.apply_async(ad_sample, args=(np.random.randint(2**32 - 1), train_papers, \
                                               train_pairs, train_range, batch_size))
        jobs.append(p)
    p = pool.apply_async(ad_sample, args=(np.random.randint(2**32 - 1), valid_papers, \
                                           valid_pairs, valid_range, batch_size))
    jobs.append(p)
    return jobs

In [7]:
class GNN(nn.Module):
    def __init__(self, in_dim, n_hid, num_types, num_relations, n_heads, n_layers, dropout = 0.3):
        super(GNN, self).__init__()
        self.gcs = nn.ModuleList()
        self.num_types = num_types
        self.in_dim    = in_dim
        self.n_hid     = n_hid
        self.aggregat_ws   = nn.ModuleList()
        self.drop          = nn.Dropout(dropout)
        for t in range(num_types):
            self.aggregat_ws.append(nn.Linear(in_dim, n_hid))
        for l in range(n_layers):
            self.gcs.append(RAGCNConv(n_hid, n_hid, num_types, num_relations, n_heads, dropout))
    def set_device(self, device):
        self.device = device
        for gc in self.gcs:
            gc.device = device
    def forward(self, node_feature, node_type, edge_time, edge_index, edge_type):
        res = torch.zeros(node_feature.size(0), self.n_hid).to(node_feature.device)
        for t_id in range(self.num_types):
            aggregat_w = self.aggregat_ws[t_id]
            idx = (node_type == t_id)
            if idx.sum() == 0:
                continue
            res[idx] = torch.tanh(aggregat_w(node_feature[idx]))
        meta_xs = self.drop(res)
        del res
        for gc in self.gcs:
            meta_xs = gc(meta_xs, node_type, edge_index, edge_type, edge_time)
        return meta_xs

In [8]:
types = graph.get_types()
gnn = GNN(in_dim = len(graph.node_feature['paper']['emb'][0]) + 401, n_hid = 256, num_types = len(types), \
          num_relations = len(graph.get_meta_graph()) + 1, n_heads = 8, n_layers = 3).to(device)
matcher = Matcher(256, n_heads = 8).to(device)
model = nn.Sequential(gnn, matcher)
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000, eta_min=0)

In [9]:
def mask_softmax(pred, size):
    loss = 0
    stx = 0
    for l in size:
        loss += torch.log_softmax(pred[stx: stx + l], dim=-1)[0] / l
        stx += l
    return -loss

In [10]:
stats = []
pool = mp.Pool(4)
process_ids = np.arange(batch_num // 4)
st = time.time()
jobs = prepare_data(pool, process_ids)
train_step = 1500
best_val   = 0
res = []
criterion = nn.KLDivLoss(reduction='batchmean')
for epoch in np.arange(epoch_num)+1:
    '''
        Prepare Training and Validation Data
    '''
    train_data = [job.get() for job in jobs[:-1]]
    valid_data = jobs[-1].get()
    pool.terminate()
    pool.join()
    pool = mp.Pool(4)
    jobs = prepare_data(pool, process_ids)
    et = time.time()
    print('Data Preparation: %.1fs' % (et - st))
    
    
    model.train()
    train_losses = []
    torch.cuda.empty_cache()
    for batch in np.arange(4):
        for node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel in train_data:
            node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                                   edge_time.to(device), edge_index.to(device), edge_type.to(device))
            train_paper_vecs = []
            train_author_vecs = []
            train_size  = []
            for p_id in ylabel:
                al = ylabel[p_id]
                train_paper_vecs +=  [node_rep[p_id].repeat(len(al), 1)]
                train_author_vecs += [node_rep[al]]
                train_size  += [len(al)]
            train_paper_vecs  = torch.cat(train_paper_vecs).to(device)
            train_author_vecs = torch.cat(train_author_vecs).to(device)
            res = matcher.forward(train_author_vecs, train_paper_vecs, pair=True)
            loss = mask_softmax(res, train_size)

            optimizer.zero_grad() 
            torch.cuda.empty_cache()
            loss.backward()
            optimizer.step()
            train_losses += [loss.cpu().detach().tolist()]
            train_step += 1
            scheduler.step(train_step)
            del loss, res, node_rep
            
    '''
        Valid
    '''
    model.eval()
    with torch.no_grad():
        node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = valid_data
        node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
                                   edge_time.to(device), edge_index.to(device), edge_type.to(device))
        valid_paper_vecs = []
        valid_author_vecs = []
        valid_size  = []
        valid_label = []
        for p_id in ylabel:
            al = ylabel[p_id]
            valid_paper_vecs +=  [node_rep[p_id].repeat(len(al), 1)]
            valid_author_vecs += [node_rep[al]]
            valid_size  += [len(al)]
            label = torch.zeros(len(al))
            label[0] = 1
            valid_label += [label]
        valid_paper_vecs  = torch.cat(valid_paper_vecs).to(device)
        valid_author_vecs = torch.cat(valid_author_vecs).to(device)
        res = matcher.forward(valid_author_vecs, valid_paper_vecs, pair=True)
        loss = mask_softmax(res, valid_size)
        valid_res = []
        ser = 0
        for s, l in zip(valid_size, valid_label):
            p = res[ser: ser + s]
            r = l[p.argsort(descending = True)]
            valid_res += [r.cpu().detach().tolist()]
            ser += s
        valid_ndcg = np.average([ndcg_at_k(resi, len(resi)) for resi in valid_res])
        if valid_ndcg > best_val:
            best_val = valid_ndcg
            torch.save(model, './save/gat.pt')
        st = time.time()
        print(("Epoch: %d (%.1fs)  LR: %.5f Train Loss: %.2f  Valid Loss: %.2f  Valid NDCG: %.4f") % \
              (epoch, (st-et), optimizer.param_groups[0]['lr'], np.average(train_losses), loss.cpu().detach().tolist(),\
              valid_ndcg))

#         if epoch % 5 == 0:
#             '''
#                 Test
#             '''
#             node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = test_data
#             node_rep = gnn.forward(node_feature.to(device), node_type.to(device), \
#                                    edge_time.to(device), edge_index.to(device), edge_type.to(device))
#             test_paper_vecs = []
#             test_author_vecs = []
#             test_size  = []
#             test_label = []
#             for p_id in ylabel:
#                 al = ylabel[p_id]
#                 test_paper_vecs +=  [node_rep[p_id].repeat(len(al), 1)]
#                 test_author_vecs += [node_rep[al]]
#                 test_size  += [len(al)]
#                 label = torch.zeros(len(al))
#                 label[0] = 1
#                 test_label += [label]
#             test_paper_vecs  = torch.cat(test_paper_vecs).to(device)
#             test_author_vecs = torch.cat(test_author_vecs).to(device)
#             res = matcher.forward(test_author_vecs, test_paper_vecs, pair=True)
#             loss = mask_softmax(res, test_size)
#             test_res = []
#             ser = 0
#             for s, l in zip(test_size, test_label):
#                 p = res[ser: ser + s]
#                 r = l[p.argsort(descending = True)]
#                 test_res += [r.cpu().detach().tolist()]
#                 ser += s
#             test_ndcg = np.average([ndcg_at_k(resi, len(resi)) for resi in test_res])
#             print(test_ndcg)

Data Preparation: 272.3s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 1 (642.9s)  LR: 0.00069 Train Loss: 25.74  Valid Loss: 8.57  Valid NDCG: 0.7529
Data Preparation: 2.7s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 2 (670.7s)  LR: 0.00085 Train Loss: 12.01  Valid Loss: 6.40  Valid NDCG: 0.8881
Data Preparation: 2.9s
Epoch: 3 (671.2s)  LR: 0.00096 Train Loss: 5.18  Valid Loss: 3.98  Valid NDCG: 0.8709
Data Preparation: 2.9s
Epoch: 4 (673.6s)  LR: 0.00100 Train Loss: 4.05  Valid Loss: 3.13  Valid NDCG: 0.8695
Data Preparation: 2.6s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 5 (671.0s)  LR: 0.00096 Train Loss: 3.74  Valid Loss: 3.37  Valid NDCG: 0.9019
Data Preparation: 2.6s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 6 (664.1s)  LR: 0.00086 Train Loss: 3.52  Valid Loss: 2.71  Valid NDCG: 0.9366
Data Preparation: 2.9s
Epoch: 7 (663.8s)  LR: 0.00070 Train Loss: 3.49  Valid Loss: 5.29  Valid NDCG: 0.8908
Data Preparation: 2.8s
Epoch: 8 (665.4s)  LR: 0.00051 Train Loss: 3.37  Valid Loss: 2.59  Valid NDCG: 0.9273
Data Preparation: 2.8s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 9 (672.6s)  LR: 0.00032 Train Loss: 3.25  Valid Loss: 2.89  Valid NDCG: 0.9731
Data Preparation: 2.6s
Epoch: 10 (672.8s)  LR: 0.00016 Train Loss: 3.03  Valid Loss: 2.61  Valid NDCG: 0.9225
Data Preparation: 3.1s


  "type " + obj.__name__ + ". It won't be checked "


Epoch: 11 (661.5s)  LR: 0.00004 Train Loss: 2.82  Valid Loss: 2.06  Valid NDCG: 0.9786
Data Preparation: 2.8s
Epoch: 12 (621.1s)  LR: 0.00000 Train Loss: 2.60  Valid Loss: 2.63  Valid NDCG: 0.9219
Data Preparation: 2.7s
Epoch: 13 (618.5s)  LR: 0.00003 Train Loss: 2.66  Valid Loss: 3.71  Valid NDCG: 0.9174
Data Preparation: 2.7s
Epoch: 14 (611.5s)  LR: 0.00013 Train Loss: 2.71  Valid Loss: 2.01  Valid NDCG: 0.9462
Data Preparation: 2.7s
Epoch: 15 (610.3s)  LR: 0.00029 Train Loss: 2.65  Valid Loss: 2.22  Valid NDCG: 0.9295
Data Preparation: 2.7s
Epoch: 16 (615.7s)  LR: 0.00047 Train Loss: 2.60  Valid Loss: 2.60  Valid NDCG: 0.9652
Data Preparation: 2.8s
Epoch: 17 (611.9s)  LR: 0.00067 Train Loss: 2.87  Valid Loss: 2.42  Valid NDCG: 0.9679
Data Preparation: 2.9s
Epoch: 18 (615.2s)  LR: 0.00083 Train Loss: 3.21  Valid Loss: 1.91  Valid NDCG: 0.9610
Data Preparation: 2.9s
Epoch: 19 (613.8s)  LR: 0.00095 Train Loss: 2.64  Valid Loss: 2.44  Valid NDCG: 0.9457
Data Preparation: 2.9s
Epoch: 20 

  "type " + obj.__name__ + ". It won't be checked "


Epoch: 26 (628.8s)  LR: 0.00018 Train Loss: 2.34  Valid Loss: 1.96  Valid NDCG: 0.9922
Data Preparation: 2.8s
Epoch: 27 (630.7s)  LR: 0.00006 Train Loss: 2.57  Valid Loss: 2.56  Valid NDCG: 0.9204
Data Preparation: 3.3s
Epoch: 28 (624.7s)  LR: 0.00000 Train Loss: 2.53  Valid Loss: 2.23  Valid NDCG: 0.9571
Data Preparation: 3.0s
Epoch: 29 (617.7s)  LR: 0.00002 Train Loss: 2.52  Valid Loss: 2.28  Valid NDCG: 0.9634
Data Preparation: 3.3s
Epoch: 30 (621.6s)  LR: 0.00011 Train Loss: 2.10  Valid Loss: 2.80  Valid NDCG: 0.9864
Data Preparation: 2.9s
Epoch: 31 (628.5s)  LR: 0.00026 Train Loss: 2.40  Valid Loss: 2.67  Valid NDCG: 0.9501
Data Preparation: 3.2s
Epoch: 32 (625.7s)  LR: 0.00045 Train Loss: 2.45  Valid Loss: 1.78  Valid NDCG: 0.9452
Data Preparation: 3.1s
Epoch: 33 (628.5s)  LR: 0.00064 Train Loss: 2.24  Valid Loss: 1.81  Valid NDCG: 0.9749
Data Preparation: 3.0s
Epoch: 34 (631.0s)  LR: 0.00081 Train Loss: 2.33  Valid Loss: 1.43  Valid NDCG: 0.9731
Data Preparation: 3.0s
Epoch: 35 

Process ForkPoolWorker-209:
Process ForkPoolWorker-210:
Process ForkPoolWorker-211:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/ziniu/anaconda3/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/home/ziniu/anaconda3/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/home/ziniu/anaconda3/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/home/ziniu/anaconda3/lib/python3.7/multiprocessing/pool.py", line 121, in worker
    result = (True, func(*args, **kwds))
Process ForkPoolWorker-212:
  File "/home/ziniu/anaconda3/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-6-2fb815af380e>", line 24, in ad_sample
    inp = {'paper': pids, 'author': aids}, sampled_depth = 3, sampled_number = 100)
  File "/ho

KeyboardInterrupt: 

In [None]:
len(valid_label), len(valid_size)