In [1]:
import sys
sys.path.append('../')

In [2]:
from functools import partial
import multiprocessing as mp
from utils import *

In [3]:
data_dir = '/datadrive/data/'
graph = dill.load(open(data_dir + 'graph.pk', 'rb'))

In [4]:
train_range = {t: True for t in graph.times if t != None and t <= 2015}
valid_range = {t: True for t in graph.times if t != None and (t > 2015) & (t < 2018)}
test_range  = {t: True for t in graph.times if t != None and t >= 2018}

In [5]:
class Matcher(nn.Module):
    def __init__(self, n_hid, dropout = 0.5):
        super(Matcher, self).__init__()
        self.left_linear    = nn.Linear(n_hid,  n_hid)
        self.right_linear   = nn.Linear(n_hid,  n_hid)
        self.drop     = nn.Dropout(dropout)
        self.mem      = None
    def forward(self, x, y, test = False):
        if test:
            if self.mem != None:
                tx = self.mem
            else:
                tx = self.left_linear(x)
                self.mem = tx
        else:
            tx = self.drop(self.left_linear(x))
        ty = self.drop(self.right_linear(y))
        return torch.log_softmax(torch.mm(ty, tx.T), dim=-1)

In [6]:
def dcg_at_k(r, k):
    r = np.asfarray(r)[:k]
    if r.size:
        return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1)))
    return 0.

def ndcg_at_k(r, k):
    dcg_max = dcg_at_k(sorted(r, reverse=True), k)
    if not dcg_max:
        return 0.
    return dcg_at_k(r, k) / dcg_max

def mean_reciprocal_rank(rs):
    rs = (np.asarray(r).nonzero()[0] for r in rs)
    return np.mean([1. / (r[0] + 1) if r.size else 0. for r in rs])

In [7]:
batch_size = 512
batch_num  = 128
epoch_num  = 100
samp_num   = 7

device = torch.device("cuda:2")

In [8]:
def sample_subgraph(graph, time_range, sampled_depth = 2, sampled_number = 8, inp = False):
    '''
    Sample Sub-Graph based on the connection of other nodes with currently sampled nodes
    '''
    def add_budget(te, target_id, target_time, layer_data, budget, maxnum = sampled_depth * sampled_number):
        for source_type in te:
            tes = te[source_type]
            for relation_type in tes:
                if relation_type == 'self':
                    continue
                adl = tes[relation_type][target_id]
                if len(adl) < maxnum:
                    sampled_ids = list(adl.keys())
                else:
                    sampled_ids = np.random.choice(list(adl.keys()), maxnum, replace = False)
                for source_id in sampled_ids:
                    source_time = adl[source_id]
                    if source_time == None:
                        source_time = target_time
                    k = encode(source_id, source_time)
                    if source_time not in time_range or k in layer_data[source_type]:
                        continue
                    budget[source_type][k] += 1. / len(sampled_ids)
    def decode(s):
        idx = s.find('-')
        return np.array([s[:idx], s[idx+1:]], dtype=float)
    def encode(i, t):
        return '%s-%s' % (i, t)

    layer_data  = defaultdict( #target_type
                        lambda: {} # {target_id + time}
                    )
    budget     = defaultdict( #source_type
                                lambda: defaultdict(  #source_id + source_time
                                    lambda: 0. #sampled_score
                            ))
    new_layer_adj  = defaultdict( #target_type
                                    lambda: defaultdict(  #source_type
                                        lambda: defaultdict(  #relation_type
                                            lambda: [] #[target_id, source_id]
                                )))
    
    if inp == None:
        rand_paper_ids  = np.random.choice(range(len(graph.node_feature['paper'])), sampled_number * 2, replace = False)
        rand_paper_time = np.array(list(graph.node_feature['paper'].loc[rand_paper_ids, 'time']))
        for _id, _time in zip(rand_paper_ids, rand_paper_time):
            if _time not in time_range:
                continue
            layer_data['paper'][encode(_id, _time)] = len(layer_data['paper'])
        for _id, _time in zip(rand_paper_ids, rand_paper_time):
            if _time not in time_range:
                continue
            add_budget(graph.edge_list['paper'], _id, _time, layer_data, budget)
    else:
        '''
        budget: {_type: [[_id, _time]]}
        '''
        for _type in inp:
            for _id, _time in inp[_type]:
                layer_data[_type][encode(_id, _time)] = len(layer_data[_type])
        for _type in inp:
            te = graph.edge_list[_type]
            for _id, _time in inp[_type]:
                add_budget(te, _id, _time, layer_data, budget)
    for layer in range(sampled_depth):
        for source_type in graph.get_types():
            te = graph.edge_list[source_type]
            keys  = np.array(list(budget[source_type].keys()))
            for k in keys:
                if k in layer_data[source_type]:
                    print('bug', layer)
            
            if sampled_number > len(keys):
                sampled_ids = np.arange(len(keys))
            else:
                score = np.array(list(budget[source_type].values()))
                score = score / np.sum(score)
                sampled_ids = np.random.choice(len(score), sampled_number, p = score, replace = False) 
            sampled_ids = keys[sampled_ids]
            for k in sampled_ids:
                layer_data[source_type][k] = len(layer_data[source_type])
            for k in sampled_ids:
                source_id, source_time = decode(k)
                add_budget(te, int(source_id), int(source_time), layer_data, budget)
                budget[source_type].pop(k)    
    
    feature = {}
    times   = {}
    for target_type in layer_data:
        idxs  = np.array([decode(key) for key in layer_data[target_type]])
        feature[target_type] = list(graph.node_feature[target_type].loc[idxs[:,0], 'w2v'])
        times[target_type]   = idxs[:,1]
    edge_list = defaultdict( #target_type
                        lambda: defaultdict(  #source_type
                            lambda: defaultdict(  #relation_type
                                lambda: [] # [target_id, source_id] 
                                    )))
    for target_type in layer_data:
        for target_id in layer_data[target_type]:
            target_ser = layer_data[target_type][target_id]
            edge_list[target_type][target_type]['self'] += [[target_ser, target_ser]]
    for target_type in graph.edge_list:
        te = graph.edge_list[target_type]
        for source_type in te:
            tes = te[source_type]
            for relation_type in tes:
                tesr = tes[relation_type]
                for target_id in layer_data[target_type]:
                    target_ser = layer_data[target_type][target_id]
                    tesrt = tesr[decode(target_id)[0]]
                    for source_id in layer_data[source_type]:
                        source_ser = layer_data[source_type][source_id]
                        if decode(source_id)[0] in tesrt:
                            edge_list[target_type][source_type][relation_type] += [[target_ser, source_ser]]
    return feature, times, edge_list

def to_torch(feature, time, edge_list):
    ser = 0
    node_dict = {}
    for t in feature:
        node_dict[t] = [0, ser]
        ser += 1
    node_num = 0
    node_feature = []
    node_type    = []
    node_time    = []
    for t in feature:
        node_dict[t][0] = node_num
        node_feature += list(feature[t])
        node_time    += list(time[t])
        node_type    += [node_dict[t][1] for _ in range(len(feature[t]))]
        node_num     += len(feature[t])
    edge_index = []
    edge_type  = []
    edge_time  = []
    edge_dict  = {}
    for target_type in edge_list:
        for source_type in edge_list[target_type]:
            for relation_type in edge_list[target_type][source_type]:
                if relation_type not in edge_dict:
                    edge_dict[relation_type] = len(edge_dict)
                for ti, si in edge_list[target_type][source_type][relation_type]:
                    sid, tid = si + node_dict[source_type][0], ti + node_dict[target_type][0]
                    edge_index += [[sid, tid]]
                    edge_type  += [edge_dict[relation_type]]   
                    edge_time  += [node_time[tid] - node_time[sid] + 120]
    node_feature = torch.FloatTensor(node_feature)
    node_type    = torch.LongTensor(node_type)
    edge_time    = torch.LongTensor(edge_time)
    edge_index   = torch.LongTensor(edge_index).t()
    edge_type    = torch.LongTensor(edge_type)
    return node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict

In [9]:
def pf_sample(seed, papers, pairs, t_range, batch_size):
    np.random.seed(seed)
    _time = np.random.choice(list(papers.keys()))
    pids = np.array(papers[_time])[np.random.choice(len(papers[_time]), batch_size, replace = False)]
    fids = []
    edge = defaultdict(lambda: {})
    for x_id, p_id in enumerate(pids):
        f_ids = pairs[p_id]
        for f_id in f_ids:
            if f_id not in fids:
                fids += [f_id]
            edge[x_id][fids.index(f_id)] = True
    pids = np.stack([pids, np.repeat([_time], batch_size)]).T
    fids = np.stack([fids, np.repeat([_time], len(fids))]).T
 
    feature, times, edge_list = sample_subgraph(graph, t_range, \
                inp = {'paper': pids, 'field': fids}, sampled_depth = 4, sampled_number = 128)

    el = []
    for i in edge_list['paper']['field']['rev_PF_in']:
        if i[0] in edge and i[1] in edge[i[0]]:
            continue
        el += [i]
    edge_list['paper']['field']['rev_PF_in'] = el

    el = []
    for i in edge_list['field']['paper']['PF_in']:
        if i[1] in edge and i[0] in edge[i[1]]:
            continue
        el += [i]
    edge_list['field']['paper']['PF_in'] = el
    
    
    node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = \
            to_torch(feature, times, edge_list)
    paper_ids = np.arange(len(pids)) + node_dict['paper'][0]
    field_ids = np.arange(len(fids)) + node_dict['field'][0]
    ylabel = torch.zeros(batch_size, len(field_ids))
    for x_id in edge:
        for y_id in edge[x_id]:
            ylabel[x_id][y_id] = 1
    ylabel /= ylabel.sum(axis=1).view(-1, 1)
    return node_feature, edge_index, field_ids, paper_ids, ylabel
    
def prepare_data(pool):
    jobs = []
    for process_id in process_ids[:-1]:
        p = pool.apply_async(pf_sample, args=(np.random.randint(2**32 - 1), train_papers, \
                                               train_pairs, train_range, batch_size))
        jobs.append(p)
    p = pool.apply_async(pf_sample, args=(np.random.randint(2**32 - 1), valid_papers, \
                                           valid_pairs, valid_range, batch_size))
    jobs.append(p)
    return jobs

In [10]:
from torch_geometric.nn import GCNConv
class GNN(nn.Module):
    def __init__(self, n_hid, n_layers, dropout = 0.5):
        super(GNN, self).__init__()
        self.gcs = nn.ModuleList()
        self.adapt = nn.Linear(n_hid, n_hid // 2)
        self.drop  = nn.Dropout(dropout)
        for l in range(n_layers):
            self.gcs.append(GCNConv(n_hid // 2, n_hid // 2))
    def set_device(self, device):
        self.device = device
        for gc in self.gcs:
            gc.device = device
    def forward(self, node_feature, edge_index):
        meta_xs = self.drop(F.elu(self.adapt(node_feature)))
        for gc in self.gcs:
            meta_xs = self.drop(gc(meta_xs, edge_index))
        return meta_xs

In [11]:
'''
Paper-Field
'''
field_dict = dill.load(open(data_dir + 'field_dict.pk', 'rb'))
ids = np.array([graph.node_forward['field'][k] for k in field_dict if field_dict[k][0] == 'L1'])

paper_ser = {}

train_pairs = {}
valid_pairs = {}
test_pairs  = {}

train_papers = {_time: {} for _time in train_range}
valid_papers = {_time: {} for _time in valid_range}
test_papers  = {_time: {} for _time in test_range}

for f_id in ids:
    for p_id in graph.edge_list['field']['paper']['PF_in'][f_id]:
        _time = graph.edge_list['field']['paper']['PF_in'][f_id][p_id]
        if _time in train_range:
            if p_id not in train_pairs:
                train_pairs[p_id] = []
            train_pairs[p_id] += [f_id]
            train_papers[_time][p_id] = True
        elif _time in valid_range:
            if p_id not in valid_pairs:
                valid_pairs[p_id] = []
            valid_pairs[p_id] += [f_id]
            valid_papers[_time][p_id] = True
        else:
            if p_id not in test_pairs:
                test_pairs[p_id] = []
            test_pairs[p_id] += [f_id]
            test_papers[_time][p_id] = True
for _time in list(train_papers.keys()):
    if len(train_papers[_time]) < batch_size:
        train_papers.pop(_time)
    else:
        train_papers[_time] = np.array(list(train_papers[_time].keys()))
for _time in list(valid_papers.keys()):
    if len(valid_papers[_time]) < batch_size:
        valid_papers.pop(_time)
    else:
        valid_papers[_time] = np.array(list(valid_papers[_time].keys()))
for _time in list(test_papers.keys()):
    if len(test_papers[_time]) < batch_size:
        test_papers.pop(_time)
    else:
        test_papers[_time] = np.array(list(test_papers[_time].keys()))

In [None]:
gnn = GNN(400, n_layers = 2).to(device)
matcher = Matcher(200).to(device)
optimizer = torch.optim.Adam(list(gnn.parameters()) + list(matcher.parameters()))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 5000, eta_min=1e-5)

In [None]:
stats = []
pool = mp.Pool(16)
process_ids = np.arange(batch_num // 8)
st = time.time()
jobs = prepare_data(pool)
train_step = 0
best_val   = 0
res = []
criterion = nn.KLDivLoss(reduction='batchmean')
for epoch in np.arange(epoch_num)+1:
    '''
        Prepare Training and Validation Data
    '''
    train_data = [job.get() for job in jobs[:-1]]
    valid_data = jobs[-1].get()
    pool.terminate()
    pool.join()
    pool = mp.Pool(16)
    jobs = prepare_data(pool)
    et = time.time()
    print('Data Preparation: %.1fs' % (et - st))
    
    
    matcher.train()
    train_losses = []
    for batch in np.arange(8):
        for node_feature, edge_index, field_ids, paper_ids, ylabel in train_data:
            node_rep = gnn.forward(node_feature.to(device), edge_index.to(device))
            res  = matcher.forward(node_rep[field_ids], node_rep[paper_ids])
            loss = criterion(res, ylabel.to(device))
            optimizer.zero_grad() 
            loss.backward()
            optimizer.step()
            train_losses += [loss.cpu().detach().tolist()]
            train_step += 1
            scheduler.step(train_step)
    '''
        Valid
    '''
    matcher.eval()
    node_feature, edge_index, field_ids, paper_ids, ylabel = valid_data
    node_rep = gnn.forward(node_feature.to(device), edge_index.to(device))
    res  = matcher.forward(node_rep[field_ids], node_rep[paper_ids])
    loss = criterion(res, ylabel.to(device))
    valid_res = []
    
    for ai, bi in zip(ylabel, res.argsort(descending = True)):
        valid_res += [ai[bi].tolist()]
        
    st = time.time()
    print(("Epoch: %d (%.1fs)  LR: %.5f Train Loss: %.2f  Valid Loss: %.2f  Valid NDCG: %.4f") % \
          (epoch, (st-et), optimizer.param_groups[0]['lr'], np.average(train_losses), loss.cpu().detach().tolist(),\
          np.average([ndcg_at_k(resi, len(resi)) for resi in valid_res])))