In [126]:
import sys
sys.path.append('../')
from utils import *
from model import *

In [None]:
data_dir = '/datadrive/data/'
batch_size = 512
batch_num  = 128
epoch_num  = 100
samp_num   = 7

device = torch.device("cuda:2")
graph = dill.load(open(data_dir + 'graph.pk', 'rb'))

In [3]:
train_range = {t: True for t in graph.times if t != None and t <= 2015}
valid_range = {t: True for t in graph.times if t != None and (t > 2015) & (t < 2018)}
test_range  = {t: True for t in graph.times if t != None and t >= 2018}

In [8]:
'''
Author Disambiguation
'''
author_dict = dill.load(open(data_dir + 'author_dict.pk', 'rb'))
ds_authors  = [[graph.node_forward['author'][ai] for ai in author_dict[k]] \
                   for k in author_dict if len(author_dict[k]) > 2]

len_author = [len(author_dict[k]) for k in author_dict if len(author_dict[k]) > 1]
# sb.distplot(np.log(len_author) / np.log(10))
# plt.xticks(np.arange(4), [1, 10, 100, 1000])
# plt.xlabel('Same-name Author Number', fontsize = 15)
# plt.show()
train_pairs = {}
valid_pairs = {}
test_pairs  = {}

train_papers = {_time: {} for _time in train_range}
valid_papers = {_time: {} for _time in valid_range}
test_papers  = {_time: {} for _time in test_range}

for ser, same_name_author_list in tqdm(enumerate(ds_authors), total = len(ds_authors)):
    same_name_author_list = np.array(same_name_author_list)
    for author_id, author in enumerate(same_name_author_list):
        pem_ids = list(range(len(same_name_author_list)))
        pem_ids.remove(author_id)
        for p_id in graph.edge_list['author']['paper']['rev_AP_write'][author]:
            _time = graph.edge_list['author']['paper']['rev_AP_write'][author][paper]
            al = same_name_author_list[np.array([author_id] + pem_ids)]
            if _time in train_range:
                if p_id not in train_pairs:
                    train_pairs[p_id] = []
                train_pairs[p_id] += [al]
                train_papers[_time][p_id] = True
            elif _time in valid_range:
                if p_id not in valid_pairs:
                    valid_pairs[p_id] = []
                valid_pairs[p_id] += [al]
                valid_papers[_time][p_id] = True
            else:
                if p_id not in test_pairs:
                    test_pairs[p_id] = []
                test_pairs[p_id] += [al]
                test_papers[_time][p_id] = True

In [None]:
def ad_sample(seed, papers, pairs, t_range, batch_size, test = False):
    np.random.seed(seed)
    _time = np.random.choice(list(papers.keys()))
    pids = np.array(papers[_time])[np.random.choice(len(papers[_time]), batch_size, replace = False)]
    aids = []
    edge = defaultdict(lambda: {})
    eals = []
    for x_id, p_id in enumerate(pids):
        als = pairs[p_id]
        for al in als:
            for a_id in al:
                if a_id not in aids:
                    aids += [a_id]
            eal = [fids.index(al[0])]
            edge[x_id][eal[-1]] = 1
            for a_id in al[1:]:
                eal += [fids.index(a_id)]
                edge[x_id][eal[-1]] = 0
            eals += [[x_id, eal]]
    pids = np.stack([pids, np.repeat([_time], batch_size)]).T
    aids = np.stack([aids, np.repeat([_time], len(aids))]).T
 
    feature, times, edge_list = sample_subgraph(graph, t_range, \
                inp = {'paper': pids, 'author': fids}, sampled_depth = 4, sampled_number = 128)

    el = []
    for i in edge_list['paper']['author']['AP_write']:
        if i[0] in edge and i[1] in edge[i[0]]:
            continue
        el += [i]
    edge_list['paper']['author']['AP_write'] = el

    el = []
    for i in edge_list['author']['paper']['rev_AP_write']:
        if i[1] in edge and i[0] in edge[i[1]]:
            continue
        el += [i]
    edge_list['author']['paper']['rev_AP_write'] = el
    
    
    node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = \
            to_torch(feature, times, edge_list)
    '''
        Trace the paper_id and field_id by its own index plus the type start index
    '''
    paper_ids = np.arange(len(pids)) + node_dict['paper'][0]
    author_ids = np.arange(len(aids)) + node_dict['author'][0]
    ylabel = {}
    for x_id in eals:
        ylabel[x_id + node_dict['paper'][0]] = np.array(eals[x_id]) + node_dict['paper'][0]
    return node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel
    
def prepare_data(pool, process_ids):
    jobs = []
    for process_id in process_ids[:-1]:
        p = pool.apply_async(ad_sample, args=(np.random.randint(2**32 - 1), train_papers, \
                                               train_pairs, train_range, batch_size))
        jobs.append(p)
    p = pool.apply_async(ad_sample, args=(np.random.randint(2**32 - 1), valid_papers, \
                                           valid_pairs, valid_range, batch_size))
    jobs.append(p)
    return jobs

In [204]:
gnn = GNN(400).to(device)
matcher = Matcher(200, n_heads = 8).to(device)
model = nn.Sequential(gnn, matcher)
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 2000, eta_min=0)

In [202]:
def mask_softmax(pred, size):
    loss = 0
    stx = 0
    for l in size:
        loss += torch.log_softmax(pred[stx: stx + l], dim=-1)[0] / l
        stx += l
    return -loss

In [None]:
stats = []
pool = mp.Pool(16)
process_ids = np.arange(batch_num // 8)
st = time.time()
jobs = prepare_data(pool, process_ids)
train_step = 3000
best_val   = 0
res = []
criterion = nn.KLDivLoss(reduction='batchmean')
for epoch in np.arange(epoch_num)+1:
    '''
        Prepare Training and Validation Data
    '''
    train_data = [job.get() for job in jobs[:-1]]
    valid_data = jobs[-1].get()
    pool.terminate()
    pool.join()
    pool = mp.Pool(16)
    jobs = prepare_data(pool, process_ids)
    et = time.time()
    print('Data Preparation: %.1fs' % (et - st))
    
    
    model.train()
    train_losses = []
    for batch in np.arange(8):
        for node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel in train_data:
            node_rep = gnn.forward(node_feature.to(device), edge_index.to(device))
            train_paper_vecs = []
            train_author_vecs = []
            train_size  = []
            for x_id in ylabel:
                for al in ylabel[x_id]:
                    al = train_pairs[_id]
                    train_paper_vecs +=  [node_rep[x_id].repeat(len(al), 1)]
                    train_author_vecs += [node_rep[al]]
                    train_size  += [len(al)]
            train_paper_vecs  = torch.cat(train_paper_vecs).to(device)
            train_author_vecs = torch.cat(train_author_vecs).to(device)
            res = matcher.forward(train_author_vecs, train_paper_vecs, pair=True)
            loss = mask_softmax(res, train_size)
            
            optimizer.zero_grad() 
            loss.backward()
            optimizer.step()
            train_losses += [loss.cpu().detach().tolist()]
            train_step += 1
            scheduler.step(train_step)
    '''
        Valid
    '''
    model.eval()
    with torch.no_grad():
        node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = valid_data
        node_rep = gnn.forward(node_feature.to(device), edge_index.to(device))
        valid_paper_vecs = []
        valid_author_vecs = []
        valid_size  = []
        valid_label = []
        for x_id in ylabel:
            for al in ylabel[x_id]:
                al = valid_pairs[_id]
                valid_paper_vecs +=  [node_rep[x_id].repeat(len(al), 1)]
                valid_author_vecs += [node_rep[al]]
                valid_size  += [len(al)]
                label = torch.zeros(len(al))
                label[0] = 1
                valid_label += [label]
        valid_paper_vecs  = torch.cat(valid_paper_vecs).to(device)
        valid_author_vecs = torch.cat(valid_author_vecs).to(device)
        res = matcher.forward(valid_author_vecs, valid_paper_vecs, pair=True)
        loss = mask_softmax(res, valid_size)
        valid_res = []
        ser = 0
        for s in valid_size:
            p = res[ser: ser + s]
            l = valid_label[ser: ser + s]
            r = l[p.argsort(descending = True)]
            valid_res += [r.cpu().detach().tolist()]
            ser += s
        valid_ndcg = np.average([ndcg_at_k(resi, len(resi)) for resi in valid_res])
        if valid_ndcg > best_val:
            best_val = valid_ndcg
            torch.save(model, './save/gat.pt')
        st = time.time()
        print(("Epoch: %d (%.1fs)  LR: %.5f Train Loss: %.2f  Valid Loss: %.2f  Valid NDCG: %.4f") % \
              (epoch, (st-et), optimizer.param_groups[0]['lr'], np.average(train_losses), loss.cpu().detach().tolist(),\
              valid_ndcg))

        if epoch % 5 == 0:
            '''
                Test
            '''
            node_feature, node_type, edge_time, edge_index, edge_type, field_ids, paper_ids, ylabel = test_data
            node_rep = gnn.forward(node_feature.to(device), edge_index.to(device))
            test_paper_vecs = []
            test_author_vecs = []
            test_size  = []
            test_label = []
            for x_id in ylabel:
                for al in ylabel[x_id]:
                    al = test_pairs[_id]
                    test_paper_vecs +=  [node_rep[x_id].repeat(len(al), 1)]
                    test_author_vecs += [node_rep[al]]
                    test_size  += [len(al)]
                    label = torch.zeros(len(al))
                    label[0] = 1
                    test_label += [label]
            test_paper_vecs  = torch.cat(test_paper_vecs).to(device)
            test_author_vecs = torch.cat(test_author_vecs).to(device)
            res = matcher.forward(test_author_vecs, test_paper_vecs, pair=True)
            loss = mask_softmax(res, test_size)
            test_res = []
            ser = 0
            for s in test_size:
                p = res[ser: ser + s]
                l = test_label[ser: ser + s]
                r = l[p.argsort(descending = True)]
                test_res += [r.cpu().detach().tolist()]
                ser += s
            test_ndcg = np.average([ndcg_at_k(resi, len(resi)) for resi in test_res])
            print(test_ndcg)