In [126]:
import sys
sys.path.append('../')
from utils import *
from model import *

In [None]:
data_dir = '/datadrive/data/'
batch_size = 512
batch_num  = 128
epoch_num  = 100
samp_num   = 7

device = torch.device("cuda:2")
graph = dill.load(open(data_dir + 'graph.pk', 'rb'))

In [3]:
train_range = {t: True for t in graph.times if t != None and t <= 2015}
valid_range = {t: True for t in graph.times if t != None and (t > 2015) & (t < 2018)}
test_range  = {t: True for t in graph.times if t != None and t >= 2018}

In [8]:
'''
Author Disambiguation
'''
author_dict = dill.load(open(data_dir + 'author_dict.pk', 'rb'))
ds_authors  = [[graph.node_forward['author'][ai] for ai in author_dict[k]] \
                   for k in author_dict if len(author_dict[k]) > 2]

len_author = [len(author_dict[k]) for k in author_dict if len(author_dict[k]) > 1]
# sb.distplot(np.log(len_author) / np.log(10))
# plt.xticks(np.arange(4), [1, 10, 100, 1000])
# plt.xlabel('Same-name Author Number', fontsize = 15)
# plt.show()
train_pairs = []
valid_pairs = []
test_pairs  = []
for ser, same_name_author_list in tqdm(enumerate(ds_authors), total = len(ds_authors)):
    same_name_author_list = np.array(same_name_author_list)
    for author_id, author in enumerate(same_name_author_list):
        pem_ids = list(range(len(same_name_author_list)))
        pem_ids.remove(author_id)
        for paper in graph.edge_list['author']['paper']['rev_AP_write'][author]:
            _time = graph.edge_list['author']['paper']['rev_AP_write'][author][paper]
            al = same_name_author_list[np.array([author_id] + pem_ids)]
            if _time in train_range:
                train_pairs += [[paper, al]]
            elif _time in valid_range:
                valid_pairs += [[paper, al]]
            else:
                test_pairs +=  [[paper, al]]

In [204]:
matcher = Matcher(400).to(device)
optimizer = torch.optim.Adam(matcher.parameters())
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 2000, eta_min=1e-5)

In [202]:
def mask_softmax(pred, size):
    loss = 0
    stx = 0
    for l in size:
        loss += torch.log_softmax(pred[stx: stx + l], dim=-1)[0] / l
        stx += l
    return -loss

In [205]:
stats = []
train_step = 0
best_val   = 100
criterion = nn.BCELoss()
for epoch in np.arange(epoch_num)+1:
    '''
        Train
    '''
    matcher.train()
    train_losses = []
    for batch in np.arange(batch_num)+1:
        train_paper_vecs  = []
        train_author_vecs = []
        train_label       = []
        train_size        = []
        for _id in np.random.choice(len(train_pairs), batch_size):
            paper, al = train_pairs[_id]
            paper_vec = torch.FloatTensor(graph.node_feature['paper'].loc[paper, 'w2v'])
            train_paper_vecs +=  [paper_vec.repeat(len(al), 1)]
            train_author_vecs += [torch.FloatTensor(list(graph.node_feature['author'].loc[al, 'w2v']))]
            label = torch.zeros(len(al))
            label[0] = 1
            train_label += [label]
            train_size  += [len(al)]
        train_paper_vecs  = torch.cat(train_paper_vecs).to(device)
        train_author_vecs = torch.cat(train_author_vecs).to(device)
        train_label       = torch.cat(train_label).to(device)

        pred = matcher.forward(train_author_vecs, train_paper_vecs)
        loss = mask_softmax(pred, train_size)
        
        optimizer.zero_grad() 
        loss.backward()
        optimizer.step()
        train_losses += [loss.cpu().detach().tolist()]
        train_step += 1
        scheduler.step(train_step)
    '''
        Valid
    '''
    matcher.eval()
    valid_paper_vecs  = []
    valid_author_vecs = []
    valid_label       = []
    valid_size        = []
    for _id in np.random.choice(len(valid_pairs), batch_size):
        paper, al = valid_pairs[_id]
        paper_vec = torch.FloatTensor(graph.node_feature['paper'].loc[paper, 'w2v'])
        valid_paper_vecs +=  [paper_vec.repeat(len(al), 1)]
        valid_author_vecs += [torch.FloatTensor(list(graph.node_feature['author'].loc[al, 'w2v']))]
        label = torch.zeros(len(al))
        label[0] = 1
        valid_label += [label]
        valid_size  += [len(al)]
    valid_paper_vecs  = torch.cat(valid_paper_vecs).to(device)
    valid_author_vecs = torch.cat(valid_author_vecs).to(device)
    valid_label       = torch.cat(valid_label).to(device)

    pred = matcher.forward(valid_author_vecs, valid_paper_vecs)
    loss = mask_softmax(pred, valid_size)

    valid_res = []
    ser = 0
    for s in valid_size:
        p = pred[ser: ser + s]
        l = valid_label[ser: ser + s]
        r = l[p.argsort(descending = True)]
        valid_res += [r.cpu().detach().tolist()]
        ser += s
    '''
        Test
    '''
    test_paper_vecs  = []
    test_author_vecs = []
    test_label       = []
    test_sizes             = []
    for _id in np.random.choice(len(test_pairs), batch_size):
        paper, al = test_pairs[_id]
        paper_vec = torch.FloatTensor(graph.node_feature['paper'].loc[paper, 'w2v'])
        test_paper_vecs +=  [paper_vec.repeat(len(al), 1)]
        test_author_vecs += [torch.FloatTensor(list(graph.node_feature['author'].loc[al, 'w2v']))]
        label = torch.zeros(len(al))
        label[0] = 1
        test_label += [label]
        test_sizes += [len(al)]
    test_paper_vecs  = torch.cat(test_paper_vecs).to(device)
    test_author_vecs = torch.cat(test_author_vecs).to(device)
    test_label       = torch.cat(test_label).to(device)

    pred = matcher.forward(test_author_vecs, test_paper_vecs)
    
    ser = 0
    test_res = []
    for s in test_sizes:
        p = pred[ser: ser + s]
        l = test_label[ser: ser + s]
        r = l[p.argsort(descending = True)]
        test_res += [r.cpu().detach().tolist()]
        ser += s
    
    print(("Epoch: %d  LR: %.4f Train Loss: %.2f  Valid Loss: %.2f  Valid NDCG: %.4f  Test NDCG: %.4f  Test NDCG@10: %.4f  Test MRR: %.4f") % \
          (epoch, optimizer.param_groups[0]['lr'], np.average(train_losses), loss.cpu().detach().tolist(),\
          np.average([ndcg_at_k(resi, len(resi)) for resi in valid_res]), np.average([ndcg_at_k(resi, len(resi)) for resi in test_res]),\
          np.average([ndcg_at_k(resi, 10) for resi in test_res]), np.average([mean_reciprocal_rank(resi) for resi in test_res])))

Epoch: 1  LR: 0.0010 Train Loss: 94.77  Valid Loss: 33.71  Valid NDCG: 0.7983  Test NDCG: 0.8309  Test NDCG@10: 0.8227  Test MRR: 0.1529
Epoch: 2  LR: 0.0010 Train Loss: 75.52  Valid Loss: 34.82  Valid NDCG: 0.8199  Test NDCG: 0.8311  Test NDCG@10: 0.8159  Test MRR: 0.1681
Epoch: 3  LR: 0.0009 Train Loss: 75.15  Valid Loss: 31.09  Valid NDCG: 0.8261  Test NDCG: 0.7742  Test NDCG@10: 0.7483  Test MRR: 0.1435
Epoch: 4  LR: 0.0008 Train Loss: 71.12  Valid Loss: 26.84  Valid NDCG: 0.8009  Test NDCG: 0.8325  Test NDCG@10: 0.8196  Test MRR: 0.1398
Epoch: 5  LR: 0.0008 Train Loss: 59.98  Valid Loss: 21.00  Valid NDCG: 0.8136  Test NDCG: 0.8543  Test NDCG@10: 0.8427  Test MRR: 0.1460
Epoch: 6  LR: 0.0007 Train Loss: 53.53  Valid Loss: 22.28  Valid NDCG: 0.7984  Test NDCG: 0.8185  Test NDCG@10: 0.8021  Test MRR: 0.1342
Epoch: 7  LR: 0.0006 Train Loss: 48.22  Valid Loss: 22.40  Valid NDCG: 0.8222  Test NDCG: 0.8567  Test NDCG@10: 0.8405  Test MRR: 0.1544


KeyboardInterrupt: 