In [1]:
import sys
sys.path.append('../')
from utils import *
from model import *

In [2]:
data_dir = '/datadrive/data_cs/'
batch_size = 64
batch_num  = 128
epoch_num  = 1000
samp_num   = 64 - 1

device = torch.device("cuda:2")
graph = dill.load(open(data_dir + 'graph.pk', 'rb'))

In [46]:
graph.edge_list = edge_list

In [48]:
dill.dump(graph, open(data_dir + 'graph.pk', 'wb'))

In [3]:
train_range = {t: True for t in graph.times if t != None and t < 2015}
valid_range = {t: True for t in graph.times if t != None and t >= 2015  and t <= 2016}
test_range  = {t: True for t in graph.times if t != None and t > 2016}

In [4]:
class GNN(nn.Module):
    def __init__(self, in_dim, n_hid, num_types, num_relations, n_heads, n_layers, dropout = 0.3):
        super(GNN, self).__init__()
        self.gcs = nn.ModuleList()
        self.num_types = num_types
        self.in_dim    = in_dim
        self.n_hid     = n_hid
        self.aggregat_ws   = nn.ModuleList()
        self.drop          = nn.Dropout(dropout)
        for t in range(num_types):
            self.aggregat_ws.append(nn.Linear(in_dim, n_hid))
        for l in range(n_layers):
            self.gcs.append(RAGCNConv(n_hid, n_hid, num_types, num_relations, n_heads, dropout))
    def set_device(self, device):
        self.device = device
        for gc in self.gcs:
            gc.device = device
    def forward(self, node_feature, node_type, edge_time, edge_index, edge_type):
        res = torch.zeros(node_feature.size(0), self.n_hid).to(node_feature.device)
        for t_id in range(self.num_types):
            aggregat_w = self.aggregat_ws[t_id]
            idx = (node_type == t_id)
            if idx.sum() == 0:
                continue
            res[idx] = torch.tanh(aggregat_w(node_feature[idx]))
        meta_xs = self.drop(res)
        del res
        for gc in self.gcs:
            meta_xs = gc(meta_xs, node_type, edge_index, edge_type, edge_time)
        return meta_xs

In [6]:
def multi_mask(task_set, feature, time, edge_list, batch_size):
    rem_lists   = []
    neg_nums    = []
    ori_lists   = []
    for target_type, source_type, rel_type in task_set:
        edges = np.array(edge_list[target_type][source_type][rel_type])
        '''
            edge: (target_id, source_id)
        '''
        ori_lists += [edges]
        remn  = min((len(edges)-1) // 2, batch_size)
        rem_ids = np.random.choice(np.arange(len(edges)), remn, replace = False)
        lft_ids = np.array([i for i in range(len(edges)) if i not in rem_ids])
        lft_edges = edges[lft_ids]
        rem_lists += [edges[rem_ids]]
        neg_nums  += [len(time[target_type])]
        edge_list[target_type][source_type][rel_type] = list(lft_edges)
        edge_list[source_type][target_type]['rev_' + rel_type] = list(np.stack((lft_edges[:,1], lft_edges[:,0])).T)
    node_feature, node_type, node_time, edge_index, edge_type, node_dict, _ = to_torch(feature, time, edge_list, graph)
    return neg_nums, rem_lists, node_feature, node_type, node_time, edge_index, edge_type, node_dict, ori_lists
def mt_sample(seed, time_range, task_set, sampled_depth = 3, sampled_number = 100, batch_size = batch_size):
    np.random.seed(seed)
    train_feature, train_time, train_edge_list, _ = \
            sample_subgraph(graph, time_range=train_range, sampled_depth = sampled_depth, sampled_number = sampled_number)
    return multi_mask(task_set, train_feature, train_time, train_edge_list, batch_size)

def prepare_data(pool, process_ids, task_set):
    jobs = []
    for process_id in process_ids[:-1]:
        p = pool.apply_async(mt_sample, args=(np.random.randint(2**32 - 1), train_range, task_set))
        jobs.append(p)
    p = pool.apply_async(mt_sample, args=(np.random.randint(2**32 - 1), valid_range, task_set))
    jobs.append(p)
    return jobs

In [7]:
types = graph.get_types()
gnn = GNN(in_dim = len(graph.node_feature['paper']['emb'][0]) + 401, n_hid = 256, num_types = len(types), \
          num_relations = len(graph.get_meta_graph()) + 1, n_heads = 8, n_layers = 3).to(device)
fp_predictor = Matcher(256, 8).to(device)
vp_predictor = Matcher(256, 8).to(device)
pp_predictor = Matcher(256, 8).to(device)
ap_predictor = Matcher(256, 8).to(device)
ai_predictor = Matcher(256, 8).to(device)
model = nn.Sequential(gnn, fp_predictor, vp_predictor, pp_predictor, ap_predictor, ai_predictor)

In [8]:
optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000, eta_min=1e-6)

In [None]:
stats = []
models = [fp_predictor, vp_predictor, pp_predictor, ap_predictor, ai_predictor]
task_set = [['field', 'paper', 'PF_in_L2'],\
            ['venue', 'paper', 'PV_Conference'],\
            ['paper', 'paper', 'PP_cite'],\
            ['paper', 'author', 'AP_write_first'],\
            ['affiliation', 'author', 'in']]

pool = mp.Pool(4)
process_ids = np.arange(batch_num // 4)
st = time.time()
jobs = prepare_data(pool, process_ids, task_set)
train_step = 1500
best_val   = 0

for epoch in np.arange(epoch_num)+1:
    train_data = [job.get() for job in jobs[:-1]]
    valid_data = jobs[-1].get()
    pool.close()
    pool.join()
    pool = mp.Pool(4)
    jobs = prepare_data(pool, process_ids, task_set)
    et = time.time()
    print('Data Preparation: %.1fs' % (et - st))
    
    train_losses = []
    model.train()
    torch.cuda.empty_cache()
    for neg_nums, rem_lists, node_feature, node_type, node_time, edge_index, edge_type, node_dict, ori_lists in train_data:
        for target_size, predictor, rem_edges, ori_edges, (target_type, source_type, rel_type) in \
                zip(neg_nums, models, rem_lists, ori_lists, task_set):
            '''
                Train
            '''
            sn = np.min([target_size-1, samp_num])
            positive_target_ids, source_ids = rem_edges[:,0].reshape(-1, 1), rem_edges[:,1]
            negative_target_ids = np.array([neg_sample(target_size, sn, \
                edges[0][edges[1] == s_id].tolist()) for edges, s_id in zip(ori_edges, source_ids)])
            target_ids = np.concatenate((positive_target_ids, negative_target_ids), axis=-1) + node_dict[target_type][0]

            node_emb = gnn.forward(node_feature.to(device), node_type.to(device), \
                        node_time.to(device), edge_index.to(device), edge_type.to(device))
            source_emb = node_emb[source_ids].repeat(1, target_ids.shape[1]\
                                              ).view(target_ids.shape[0] * target_ids.shape[1], -1)

            target_emb = node_emb[target_ids].view(target_ids.shape[0] * target_ids.shape[1], -1)
            res = predictor.forward(source_emb, target_emb, pair=True).view(-1, target_ids.shape[1])
            loss = -torch.log_softmax(res, dim=-1)[:,0].mean()
            optimizer.zero_grad() 
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.2)
            optimizer.step()
            torch.cuda.empty_cache()
            train_losses += [loss.cpu().detach().tolist()]
            train_step += 1
            scheduler.step(train_step)
    '''
        Valid
    '''
    model.eval()
    with torch.no_grad():
        neg_nums, rem_lists, node_feature, node_type, node_time, edge_index, edge_type, node_dict, ori_lists = valid_data
        valid_losses = []
        valid_accs   = []
        for target_size, predictor, rem_edges, ori_edges, (target_type, source_type, rel_type) in \
                zip(neg_nums, models, rem_lists, ori_lists, task_set):
            '''
                Valid
            '''
            positive_target_ids, source_ids = rem_edges[:,0].reshape(-1, 1), rem_edges[:,1]
            negative_target_ids = np.array([neg_sample(target_size, sn, \
                edges[0][edges[1] == s_id].tolist()) for edges, s_id in zip(ori_edges, source_ids)])
            target_ids = np.concatenate((positive_target_ids, negative_target_ids), axis=-1) + node_dict[target_type][0]

            node_emb = gnn.forward(node_feature.to(device), node_type.to(device), \
                        node_time.to(device), edge_index.to(device), edge_type.to(device))
            source_emb = node_emb[source_ids].repeat(1, target_ids.shape[1]\
                                              ).view(target_ids.shape[0] * target_ids.shape[1], -1)

            target_emb = node_emb[target_ids].view(target_ids.shape[0] * target_ids.shape[1], -1)
            res = predictor.forward(source_emb, target_emb, pair=True).view(-1, target_ids.shape[1])
            valid_losses +=  [-torch.log_softmax(res, dim=-1)[:,0]]
            s = (res.argmax(dim=1) == 0).sum()
            valid_accs += [s.tolist() / batch_size]
        valid_losses = torch.cat(valid_losses).cpu().detach().tolist()
        s = (res.argmax(dim=1) == 0).sum()
        st = time.time()
        print(("Epoch: %d (%.1fs)  LR: %.5f Train Loss: %f  Valid Loss: %f  Valid Acc: %f") % \
              (epoch, (st-et), optimizer.param_groups[0]['lr'], np.average(train_losses), np.average(valid_losses), np.average(valid_accs)))
        if np.average(valid_accs) > best_val:
            best_val = np.average(valid_accs)
            torch.save(gnn, './save/mt_model.pt')
        stats += [[train_losses, valid_losses]]

Data Preparation: 156.5s
Epoch: 1 (571.6s)  LR: 0.00073 Train Loss: 7.193311  Valid Loss: 4.212434  Valid Acc: 0.037500


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Data Preparation: 3.1s
Epoch: 2 (584.3s)  LR: 0.00091 Train Loss: 4.982474  Valid Loss: 3.974452  Valid Acc: 0.068750


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Data Preparation: 3.0s
Epoch: 3 (569.8s)  LR: 0.00100 Train Loss: 4.332902  Valid Loss: 3.829248  Valid Acc: 0.028125
Data Preparation: 2.7s
Epoch: 4 (599.3s)  LR: 0.00096 Train Loss: 4.035325  Valid Loss: 3.711625  Valid Acc: 0.078125


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Data Preparation: 3.7s
Epoch: 5 (557.2s)  LR: 0.00082 Train Loss: 3.832633  Valid Loss: 3.623439  Valid Acc: 0.059375
Data Preparation: 2.8s
Epoch: 6 (572.6s)  LR: 0.00061 Train Loss: 3.657501  Valid Loss: 3.664063  Valid Acc: 0.059375
Data Preparation: 3.0s
Epoch: 7 (583.7s)  LR: 0.00037 Train Loss: 3.621540  Valid Loss: 3.524964  Valid Acc: 0.053125
Data Preparation: 2.8s
Epoch: 8 (576.9s)  LR: 0.00016 Train Loss: 3.543070  Valid Loss: 3.515449  Valid Acc: 0.062500
Data Preparation: 2.8s
Epoch: 9 (589.3s)  LR: 0.00003 Train Loss: 3.500578  Valid Loss: 3.551407  Valid Acc: 0.040625
Data Preparation: 3.0s
Epoch: 10 (571.6s)  LR: 0.00001 Train Loss: 3.501643  Valid Loss: 3.506605  Valid Acc: 0.071875
Data Preparation: 3.2s
Epoch: 11 (564.4s)  LR: 0.00010 Train Loss: 3.553417  Valid Loss: 3.338795  Valid Acc: 0.109375


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Data Preparation: 3.2s
Epoch: 12 (580.0s)  LR: 0.00029 Train Loss: 3.489405  Valid Loss: 3.462789  Valid Acc: 0.071875
Data Preparation: 2.8s
Epoch: 13 (587.8s)  LR: 0.00052 Train Loss: 3.506860  Valid Loss: 3.220992  Valid Acc: 0.100000
Data Preparation: 2.8s
Epoch: 14 (564.2s)  LR: 0.00075 Train Loss: 3.522716  Valid Loss: 3.394563  Valid Acc: 0.075000
Data Preparation: 3.3s
Epoch: 15 (579.2s)  LR: 0.00093 Train Loss: 3.512809  Valid Loss: 3.296485  Valid Acc: 0.068750
Data Preparation: 3.3s
Epoch: 16 (572.0s)  LR: 0.00100 Train Loss: 3.525861  Valid Loss: 3.317275  Valid Acc: 0.062500
Data Preparation: 2.9s
Epoch: 17 (584.1s)  LR: 0.00096 Train Loss: 3.501419  Valid Loss: 3.494350  Valid Acc: 0.078125
Data Preparation: 2.9s
Epoch: 18 (582.4s)  LR: 0.00081 Train Loss: 3.526415  Valid Loss: 3.309427  Valid Acc: 0.056250
Data Preparation: 3.0s
Epoch: 19 (571.8s)  LR: 0.00059 Train Loss: 3.494969  Valid Loss: 3.510880  Valid Acc: 0.062500
Data Preparation: 3.0s
Epoch: 20 (582.2s)  LR: 0

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Data Preparation: 3.1s
Epoch: 58 (572.3s)  LR: 0.00052 Train Loss: 3.425837  Valid Loss: 3.308850  Valid Acc: 0.084375
Data Preparation: 3.0s
Epoch: 59 (568.0s)  LR: 0.00028 Train Loss: 3.357271  Valid Loss: 3.157411  Valid Acc: 0.137500


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Data Preparation: 3.3s
Epoch: 60 (551.5s)  LR: 0.00010 Train Loss: 3.356827  Valid Loss: 3.056056  Valid Acc: 0.115625
Data Preparation: 3.0s
Epoch: 61 (560.8s)  LR: 0.00001 Train Loss: 3.342914  Valid Loss: 3.184962  Valid Acc: 0.053125
Data Preparation: 3.1s
Epoch: 62 (578.4s)  LR: 0.00003 Train Loss: 3.331947  Valid Loss: 3.301248  Valid Acc: 0.071875
Data Preparation: 2.9s
Epoch: 63 (582.9s)  LR: 0.00016 Train Loss: 3.324354  Valid Loss: 3.068801  Valid Acc: 0.109375
Data Preparation: 3.0s
Epoch: 64 (568.9s)  LR: 0.00038 Train Loss: 3.311912  Valid Loss: 3.339999  Valid Acc: 0.065625
Data Preparation: 3.0s
Epoch: 65 (588.7s)  LR: 0.00062 Train Loss: 3.293715  Valid Loss: 3.308818  Valid Acc: 0.106250
Data Preparation: 3.1s
Epoch: 66 (565.0s)  LR: 0.00083 Train Loss: 3.367169  Valid Loss: 3.395196  Valid Acc: 0.100000
Data Preparation: 3.1s
Epoch: 67 (551.4s)  LR: 0.00097 Train Loss: 3.409278  Valid Loss: 3.317049  Valid Acc: 0.090625
Data Preparation: 3.0s
Epoch: 68 (580.0s)  LR: 0

Data Preparation: 3.4s
Epoch: 129 (546.6s)  LR: 0.00049 Train Loss: 3.412929  Valid Loss: 3.390638  Valid Acc: 0.078125
Data Preparation: 3.1s
Epoch: 130 (529.5s)  LR: 0.00073 Train Loss: 3.420853  Valid Loss: 3.507815  Valid Acc: 0.087500
Data Preparation: 3.0s
Epoch: 131 (533.1s)  LR: 0.00091 Train Loss: 3.427823  Valid Loss: 3.274329  Valid Acc: 0.068750
Data Preparation: 3.0s
Epoch: 132 (540.2s)  LR: 0.00100 Train Loss: 3.424181  Valid Loss: 3.505622  Valid Acc: 0.081250
Data Preparation: 3.1s
Epoch: 133 (545.4s)  LR: 0.00097 Train Loss: 3.429288  Valid Loss: 3.520567  Valid Acc: 0.065625
Data Preparation: 3.5s
Epoch: 134 (533.6s)  LR: 0.00083 Train Loss: 3.412040  Valid Loss: 3.335769  Valid Acc: 0.106250
Data Preparation: 3.0s
Epoch: 135 (550.2s)  LR: 0.00062 Train Loss: 3.386748  Valid Loss: 3.329660  Valid Acc: 0.062500
Data Preparation: 3.0s
Epoch: 136 (538.6s)  LR: 0.00038 Train Loss: 3.403004  Valid Loss: 3.120133  Valid Acc: 0.093750
Data Preparation: 3.1s
Epoch: 137 (518.8

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Data Preparation: 3.3s
Epoch: 197 (579.0s)  LR: 0.00100 Train Loss: 3.386912  Valid Loss: 3.348182  Valid Acc: 0.078125
Data Preparation: 3.0s
Epoch: 198 (580.4s)  LR: 0.00091 Train Loss: 3.376766  Valid Loss: 3.404176  Valid Acc: 0.090625
Data Preparation: 3.2s
Epoch: 199 (566.9s)  LR: 0.00073 Train Loss: 3.392063  Valid Loss: 3.289126  Valid Acc: 0.090625
Data Preparation: 3.1s
Epoch: 200 (560.5s)  LR: 0.00050 Train Loss: 3.378901  Valid Loss: 3.106820  Valid Acc: 0.100000
Data Preparation: 3.4s
Epoch: 201 (567.6s)  LR: 0.00027 Train Loss: 3.341902  Valid Loss: 3.436722  Valid Acc: 0.093750
Data Preparation: 3.0s
Epoch: 202 (569.3s)  LR: 0.00009 Train Loss: 3.348789  Valid Loss: 3.137346  Valid Acc: 0.115625
Data Preparation: 3.6s
Epoch: 203 (581.4s)  LR: 0.00000 Train Loss: 3.323877  Valid Loss: 3.181622  Valid Acc: 0.078125
Data Preparation: 3.2s
Epoch: 204 (554.3s)  LR: 0.00004 Train Loss: 3.390142  Valid Loss: 3.315120  Valid Acc: 0.087500
Data Preparation: 3.2s
Epoch: 205 (602.6

Data Preparation: 3.1s
Epoch: 266 (540.1s)  LR: 0.00017 Train Loss: 3.333254  Valid Loss: 3.247920  Valid Acc: 0.090625
Data Preparation: 3.2s
Epoch: 267 (541.1s)  LR: 0.00003 Train Loss: 3.294555  Valid Loss: 3.290626  Valid Acc: 0.096875
Data Preparation: 3.1s
Epoch: 268 (516.2s)  LR: 0.00000 Train Loss: 3.320540  Valid Loss: 3.061470  Valid Acc: 0.115625
Data Preparation: 3.5s
Epoch: 269 (545.4s)  LR: 0.00009 Train Loss: 3.317026  Valid Loss: 3.172194  Valid Acc: 0.109375
Data Preparation: 3.4s
Epoch: 270 (528.3s)  LR: 0.00027 Train Loss: 3.300713  Valid Loss: 3.121057  Valid Acc: 0.128125
Data Preparation: 3.6s
Epoch: 271 (549.9s)  LR: 0.00051 Train Loss: 3.314428  Valid Loss: 3.116534  Valid Acc: 0.103125
Data Preparation: 3.6s
Epoch: 272 (566.4s)  LR: 0.00074 Train Loss: 3.286019  Valid Loss: 3.257643  Valid Acc: 0.109375
Data Preparation: 3.6s
Epoch: 273 (538.6s)  LR: 0.00092 Train Loss: 3.338220  Valid Loss: 3.050644  Valid Acc: 0.121875
Data Preparation: 3.5s
Epoch: 274 (528.9

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


Data Preparation: 3.4s
Epoch: 281 (546.5s)  LR: 0.00001 Train Loss: 3.296891  Valid Loss: 3.284798  Valid Acc: 0.131250
Data Preparation: 3.4s
Epoch: 282 (555.1s)  LR: 0.00011 Train Loss: 3.266341  Valid Loss: 3.208198  Valid Acc: 0.121875
Data Preparation: 3.4s
Epoch: 283 (532.1s)  LR: 0.00029 Train Loss: 3.293627  Valid Loss: 3.152039  Valid Acc: 0.087500
Data Preparation: 3.5s
Epoch: 284 (534.8s)  LR: 0.00053 Train Loss: 3.285497  Valid Loss: 3.104111  Valid Acc: 0.096875
Data Preparation: 3.5s
Epoch: 285 (543.7s)  LR: 0.00076 Train Loss: 3.293271  Valid Loss: 3.422727  Valid Acc: 0.090625
Data Preparation: 3.5s
Epoch: 286 (538.1s)  LR: 0.00093 Train Loss: 3.319184  Valid Loss: 3.382679  Valid Acc: 0.100000
Data Preparation: 3.5s
Epoch: 287 (538.5s)  LR: 0.00100 Train Loss: 3.307029  Valid Loss: 3.329804  Valid Acc: 0.087500
Data Preparation: 3.5s
Epoch: 288 (543.8s)  LR: 0.00095 Train Loss: 3.332467  Valid Loss: 3.124009  Valid Acc: 0.140625
Data Preparation: 3.1s
Epoch: 289 (551.0

In [135]:
def neg_sample(size, num, pos_id):
    s = np.arange(size)
    np.random.shuffle(s)
    return [si for si in s if si not in pos_id]

In [136]:
model.eval()
with torch.no_grad():
    neg_nums, rem_lists, node_feature, node_type, node_time, edge_index, edge_type, node_dict, ori_lists = valid_data
    valid_losses = []
    valid_accs   = []
    for target_size, predictor, rem_edges, ori_edges, (target_type, source_type, rel_type) in \
            zip(neg_nums, models, rem_lists, ori_lists, task_set):
        '''
            Valid
        '''
        positive_target_ids, source_ids = rem_edges[:,0].reshape(-1, 1), rem_edges[:,1]
        negative_target_ids = np.array([neg_sample(target_size, sn, \
            ori_edges[:, 0][ori_edges[:, 1] == s_id].tolist()) for  s_id in source_ids])
        
        
        sn = np.min([len(_id) for _id in negative_target_ids] + [samp_num])
        source_ids = source_ids + node_dict[source_type][0]
        target_ids = np.concatenate((positive_target_ids, negative_target_ids), axis=-1) + node_dict[target_type][0]
        
        node_emb = gnn.forward(node_feature.to(device), node_type.to(device), \
                    node_time.to(device), edge_index.to(device), edge_type.to(device))
        source_emb = node_emb[source_ids].repeat(1, target_ids.shape[1]\
                                          ).view(target_ids.shape[0] * target_ids.shape[1], -1)

        target_emb = node_emb[target_ids].view(target_ids.shape[0] * target_ids.shape[1], -1)
        res = predictor.forward(source_emb, target_emb, pair=True).view(-1, target_ids.shape[1])

ValueError: all the input arrays must have same number of dimensions

In [139]:
samp_num

63

In [None]:
gnn = torch.load('./save/mt_model.pt')

In [None]:
size = 0
for s in graph.edge_list:
    for t in graph.edge_list[s]:
        for e in graph.edge_list[s][t]:
            for i in graph.edge_list[s][t][e]:
                size += len(graph.edge_list[s][t][e][i])
size
size = 0
for s in graph.node_feature:
    size += len(graph.node_feature[s])
size

In [None]:
size = 0
for s in graph.node_feature:
    size += len(graph.node_feature[s])
size

In [12]:
epoch

470

In [None]:
np.unique(list(graph.node_feature['venue']['attr']))

In [None]:
graph.node_feature['venue'][graph.node_feature['venue']['attr']=='Patent']

In [None]:
1116163, 5574903, 

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics.pairwise import euclidean_distances

In [None]:
topconf = graph.node_feature['venue'].nlargest(256, 'citation')

In [None]:
res = list(graph.node_feature['venue'][graph.node_feature['venue']['name'] == 'WWW']['emb'])

In [None]:
cos = euclidean_distances(res, list(topconf['emb']))[0]
for li in np.argsort(cos)[:50]:
    print(list(topconf['name'])[li])

In [None]:
gnn.to(device)
gnn.eval()
topconf = graph.node_feature['venue'].nlargest(100, 'citation')
with torch.no_grad():
    _time = 2000
    vids = np.stack([graph.node_feature['venue'].nlargest(100, 'citation').index.values, np.repeat([_time], 100)]).T
    conf_emb1 = []
    conf_emb2 = []
    for i in range(10):
        print(i)
        feature, times, edge_list, _ = sample_subgraph(graph, {t: True for t in graph.times if t != None}, \
                        inp = {'venue': vids}, sampled_depth = 4, sampled_number = 256)
        node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = \
                    to_torch(feature, times, edge_list, graph)
        venue_ids = np.arange(len(vids)) + node_dict['venue'][0]

        node_feature, node_type, edge_time, edge_index, edge_type = node_feature.to(device), node_type.to(device), \
                              edge_time.to(device), edge_index.to(device), edge_type.to(device)
        res = torch.zeros(node_feature.size(0), gnn.n_hid).to(node_feature.device)
        for t_id in range(gnn.num_types):
            aggregat_w = gnn.aggregat_ws[t_id]
            idx = (node_type == t_id)
            if idx.sum() == 0:
                continue
            res[idx] = torch.tanh(aggregat_w(node_feature[idx]))
        meta_xs = gnn.drop(res)
        del res
        meta_xs = gnn.gcs[0](meta_xs, node_type, edge_index, edge_type, edge_time)
        emb = meta_xs[venue_ids].cpu().detach().numpy()
        conf_emb1 += [emb]
        cos = cosine_similarity(emb)
        for li in np.argsort(-cos[list(topconf['name']).index('KDD')])[:50]:
            print(list(topconf['name'])[li])
        print('-' * 100)
        dis = euclidean_distances(emb)
        for li in np.argsort(dis[list(topconf['name']).index('KDD')])[:50]:
            print(list(topconf['name'])[li])
        print('=' * 100)
        
        
        meta_xs = gnn.gcs[1](meta_xs, node_type, edge_index, edge_type, edge_time)
        emb = meta_xs[venue_ids].cpu().detach().numpy()
        conf_emb2 += [emb]
        cos = cosine_similarity(emb)
        for li in np.argsort(-cos[list(topconf['name']).index('KDD')])[:50]:
            print(list(topconf['name'])[li])
        print('-' * 100)
        dis = euclidean_distances(emb)
        for li in np.argsort(dis[list(topconf['name']).index('KDD')])[:50]:
            print(list(topconf['name'])[li])
        print('=' * 100)

In [None]:
coss = [euclidean_distances(emb) for emb in conf_emb1]
cos = np.average(coss, axis=0)
for li in np.argsort(cos[list(topconf['name']).index('ACL')])[:50]:
    print(list(topconf['name'])[li])
print('-' * 100)

In [None]:
coss = [euclidean_distances(emb) for emb in conf_emb2]
cos = np.average(coss, axis=0)
for li in np.argsort(cos[list(topconf['name']).index('ACL')])[:50]:
    print(list(topconf['name'])[li])
print('-' * 100)

In [None]:
time_emb = {2010: [conf_emb1, conf_emb2]}

In [None]:
time_emb[1990] = [conf_emb1, conf_emb2]

In [None]:
time_emb[2030] = [conf_emb1, conf_emb2]

In [None]:
time_emb[1990][1]

In [None]:
node_dict = {j: i for i, j in enumerate(graph.get_types())}

In [None]:
edge_dict = {e[2]: i for i, e in enumerate(graph.get_meta_graph())}
edge_dict['self'] = len(edge_dict)

In [None]:
_s = 0
for i in graph.edge_list['author']['affiliation']['rev_in']:
    _s += len(graph.edge_list['author']['affiliation']['rev_in'][i])

In [None]:
_s

In [None]:
len(graph.edge_list['author']['affiliation']['rev_in'])

In [None]:
gnn, _ = torch.load('../paper-field-L2/save/rgt_1.pt')

In [None]:

for i, j, k in graph.get_meta_graph():
    if i == 'affiliation':
        print(i,j,k)
        print(gnn.gcs[0].relation_ws[node_dict[i]][edge_dict[k]][node_dict[j]].mean())

In [None]:
paper venue rev_PV_Conference
tensor(1.0509, device='cuda:1', grad_fn=<MeanBackward0>)

