In [1]:
import sys
sys.path.append('../')
from utils import *
from model import *

In [2]:
data_dir = '/datadrive/data_cs/'
batch_size = 128
batch_num  = 128
epoch_num  = 1000
samp_num   = 128 - 1

device = torch.device("cuda:2")
graph = dill.load(open(data_dir + 'graph.pk', 'rb'))

In [3]:
train_range = {t: True for t in graph.times if t != None and t < 2015}
valid_range = {t: True for t in graph.times if t != None and t >= 2015  and t <= 2016}
test_range  = {t: True for t in graph.times if t != None and t > 2016}

In [4]:
class GNN(nn.Module):
    def __init__(self, in_dim, n_hid, num_types, num_relations, n_heads, n_layers, dropout = 0.3):
        super(GNN, self).__init__()
        self.gcs = nn.ModuleList()
        self.num_types = num_types
        self.in_dim    = in_dim
        self.n_hid     = n_hid
        self.aggregat_ws   = nn.ModuleList()
        self.drop          = nn.Dropout(dropout)
        for t in range(num_types):
            self.aggregat_ws.append(nn.Linear(in_dim, n_hid))
        for l in range(n_layers):
            self.gcs.append(RAGCNConv(n_hid, n_hid, num_types, num_relations, n_heads, dropout))
    def set_device(self, device):
        self.device = device
        for gc in self.gcs:
            gc.device = device
    def forward(self, node_feature, node_type, edge_time, edge_index, edge_type):
        res = torch.zeros(node_feature.size(0), self.n_hid).to(node_feature.device)
        for t_id in range(self.num_types):
            aggregat_w = self.aggregat_ws[t_id]
            idx = (node_type == t_id)
            if idx.sum() == 0:
                continue
            res[idx] = torch.tanh(aggregat_w(node_feature[idx]))
        meta_xs = self.drop(res)
        del res
        for gc in self.gcs:
            meta_xs = gc(meta_xs, node_type, edge_index, edge_type, edge_time)
        return meta_xs

In [5]:
def neg_sample(size, num, pos_id):
    res = {}
    while len(res) != num:
        s = np.random.choice(size)
        if s in res or s in pos_id:
            continue
        res[s] = True
    return list(res.keys())

In [6]:
def edge_loss(gnn, predictor, rem_edges, node_feature, node_type, node_time, edge_index, edge_type, node_dict, samp_num, target_size):
    positive_target_ids, source_ids = rem_edges[:,0].reshape(-1, 1), rem_edges[:,1] + node_dict[source_type][0]
    negative_target_ids = np.array([neg_sample(target_size, pos_id, samp_num) for pos_id in positive_target_ids])
    target_ids = np.concatenate((positive_target_ids, negative_target_ids), axis=-1) + node_dict[target_type][0]

    node_emb = gnn.forward(node_feature.to(device), node_type.to(device), \
                node_time.to(device), edge_index.to(device), edge_type.to(device))
    source_emb = node_emb[source_ids].repeat(1, target_ids.shape[1]\
                                      ).view(target_ids.shape[0] * target_ids.shape[1], -1)

    target_emb = node_emb[target_ids].view(target_ids.shape[0] * target_ids.shape[1], -1)
    res = predictor.forward(source_emb, target_emb).view(-1, target_ids.shape[1])
    return -torch.log_softmax(res, dim=-1)[:,0]

In [7]:
def multi_mask(task_set, feature, time, edge_list, batch_size):
    store_edges = copy.deepcopy(edge_list)
    rem_lists   = []
    neg_nums    = []
    for target_type, source_type, rel_type in task_set:
        edges = np.array(edge_list[target_type][source_type][rel_type])
        remn  = min(len(edges)-1, batch_size)
        rem_ids = np.random.choice(np.arange(len(edges)), remn, replace = False)
        ori_ids = np.array([i for i in range(len(edges)) if i not in rem_ids])
        ori_edges = edges[ori_ids]
        rem_lists += [edges[rem_ids]]
        neg_nums  += [len(time[target_type])]
        edge_list[target_type][source_type][rel_type] = list(ori_edges)
        edge_list[source_type][target_type]['rev_' + rel_type] = list(np.stack((ori_edges[:,1], ori_edges[:,0])).T)
    node_feature, node_type, node_time, edge_index, edge_type, node_dict, _ = to_torch(feature, time, edge_list, graph)
    del edge_list
    edge_list = store_edges
    return neg_nums, rem_lists, node_feature, node_type, node_time, edge_index, edge_type, node_dict
def mt_sample(seed, time_range, task_set, sampled_depth = 3, sampled_number = 100, batch_size = batch_size):
    np.random.seed(seed)
    train_feature, train_time, train_edge_list, _ = \
            sample_subgraph(graph, time_range=train_range, sampled_depth = sampled_depth, sampled_number = sampled_number)
    neg_nums, rem_lists, node_feature, node_type, node_time, edge_index, edge_type, node_dict = \
        multi_mask(task_set, train_feature, train_time, train_edge_list, batch_size)
    return neg_nums, rem_lists, node_feature, node_type, node_time, edge_index, edge_type, node_dict

def prepare_data(pool, process_ids, task_set):
    jobs = []
    for process_id in process_ids[:-1]:
        p = pool.apply_async(mt_sample, args=(np.random.randint(2**32 - 1), train_range, task_set))
        jobs.append(p)
    p = pool.apply_async(mt_sample, args=(np.random.randint(2**32 - 1), valid_range, task_set))
    jobs.append(p)
    return jobs

In [None]:
types = graph.get_types()
cand_list = list(graph.edge_list['field']['paper']['PF_in_L1'])
gnn = GNN(in_dim = len(graph.node_feature['paper']['emb'][0]) + 401, n_hid = 256, num_types = len(types), \
          num_relations = len(graph.get_meta_graph()) + 1, n_heads = 8, n_layers = 3).to(device)
fp_predictor = Matcher(256, 8).to(device)
vp_predictor = Matcher(256, 8).to(device)
pp_predictor = Matcher(256, 8).to(device)
ap_predictor = Matcher(256, 8).to(device)
ai_predictor = Matcher(256, 8).to(device)
model = nn.Sequential(gnn, fp_predictor, vp_predictor, pp_predictor, ap_predictor, ai_predictor)

In [None]:
optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 1000, eta_min=1e-6)

In [None]:
stats = []
models = [fp_predictor, vp_predictor, pp_predictor, ap_predictor, ai_predictor]
task_set = [['field', 'paper', 'PF_in_L2'],\
            ['venue', 'paper', 'PV_Conference'],\
            ['paper', 'paper', 'PP_cite'],\
            ['paper', 'author', 'AP_write_first'],\
            ['affiliation', 'author', 'in']]

pool = mp.Pool(4)
process_ids = np.arange(batch_num // 4)
st = time.time()
jobs = prepare_data(pool, process_ids, task_set)
train_step = 1500
best_val   = 0

for epoch in np.arange(epoch_num)+1:
    train_data = [job.get() for job in jobs[:-1]]
    valid_data = jobs[-1].get()
    pool.close()
    pool.join()
    pool = mp.Pool(4)
    jobs = prepare_data(pool, process_ids, task_set)
    et = time.time()
    print('Data Preparation: %.1fs' % (et - st))
    
    train_losses = []
    model.train()
    torch.cuda.empty_cache()
    for neg_nums, rem_lists, node_feature, node_type, node_time, edge_index, edge_type, node_dict in train_data:
        for target_size, predictor, rem_edges, (target_type, source_type, rel_type) in zip(neg_nums, models, rem_lists, task_set):
            '''
                Train
            '''
            positive_target_ids, source_ids = rem_edges[:,0].reshape(-1, 1), rem_edges[:,1] + node_dict[source_type][0]
            negative_target_ids = np.array([neg_sample(target_size, samp_num, \
                edge_index[1][edge_index[0] == s_id].tolist() + [pos_id]) for pos_id, s_id in zip(positive_target_ids, source_ids)])
            target_ids = np.concatenate((positive_target_ids, negative_target_ids), axis=-1) + node_dict[target_type][0]

            node_emb = gnn.forward(node_feature.to(device), node_type.to(device), \
                        node_time.to(device), edge_index.to(device), edge_type.to(device))
            source_emb = node_emb[source_ids].repeat(1, target_ids.shape[1]\
                                              ).view(target_ids.shape[0] * target_ids.shape[1], -1)

            target_emb = node_emb[target_ids].view(target_ids.shape[0] * target_ids.shape[1], -1)
            res = predictor.forward(source_emb, target_emb, pair=True).view(-1, target_ids.shape[1])
            loss = -torch.log_softmax(res, dim=-1)[:,0].mean()
            optimizer.zero_grad() 
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.2)
            optimizer.step()
            torch.cuda.empty_cache()
            train_losses += [loss.cpu().detach().tolist()]
            train_step += 1
            scheduler.step(train_step)
    '''
        Valid
    '''
    model.eval()
    with torch.no_grad():
        neg_nums, rem_lists, node_feature, node_type, node_time, edge_index, edge_type, node_dict = valid_data
        valid_losses = []
        valid_accs   = []
        for target_size, predictor, rem_edges, (target_type, source_type, rel_type) in zip(neg_nums, models, rem_lists, task_set):
            '''
                Valid
            '''
            positive_target_ids, source_ids = rem_edges[:,0].reshape(-1, 1), rem_edges[:,1] + node_dict[source_type][0]
            negative_target_ids = np.array([neg_sample(target_size, samp_num, \
                edge_index[1][edge_index[0] == s_id].tolist() + [pos_id]) for pos_id, s_id in zip(positive_target_ids, source_ids)])
            target_ids = np.concatenate((positive_target_ids, negative_target_ids), axis=-1) + node_dict[target_type][0]

            node_emb = gnn.forward(node_feature.to(device), node_type.to(device), \
                        node_time.to(device), edge_index.to(device), edge_type.to(device))
            source_emb = node_emb[source_ids].repeat(1, target_ids.shape[1]\
                                              ).view(target_ids.shape[0] * target_ids.shape[1], -1)

            target_emb = node_emb[target_ids].view(target_ids.shape[0] * target_ids.shape[1], -1)
            res = predictor.forward(source_emb, target_emb, pair=True).view(-1, target_ids.shape[1])
            valid_losses +=  [-torch.log_softmax(res, dim=-1)[:,0]]
            s = (res.argmax(dim=1) == 0).sum()
            valid_accs += [s.tolist() / batch_size]
        valid_losses = torch.cat(valid_losses).cpu().detach().tolist()
        s = (res.argmax(dim=1) == 0).sum()
        st = time.time()
        print(("Epoch: %d (%.1fs)  LR: %.5f Train Loss: %f  Valid Loss: %f  Valid Acc: %f") % \
              (epoch, (st-et), optimizer.param_groups[0]['lr'], np.average(train_losses), np.average(valid_losses), np.average(valid_accs)))
        if np.average(valid_accs) > best_val:
            best_val = np.average(valid_accs)
            torch.save(gnn, './save/mt_model.pt')
        stats += [[train_losses, valid_losses]]

In [215]:
gnn = torch.load('./save/mt_model.pt')

In [13]:
size = 0
for s in graph.edge_list:
    for t in graph.edge_list[s]:
        for e in graph.edge_list[s][t]:
            for i in graph.edge_list[s][t][e]:
                size += len(graph.edge_list[s][t][e][i])
size
size = 0
for s in graph.node_feature:
    size += len(graph.node_feature[s])
size

55749031

In [14]:
size = 0
for s in graph.node_feature:
    size += len(graph.node_feature[s])
size

1116163

In [196]:
np.unique(list(graph.node_feature['venue']['attr']))

array(['Conference', 'Journal', 'Patent', 'Repository'], dtype='<U10')

In [198]:
graph.node_feature['venue'][graph.node_feature['venue']['attr']=='Patent']

Unnamed: 0,attr,citation,id,name,node_emb,type,emb
1324,Patent,0,2764484637,international journal of chemical sciences,"[0.740164, -0.589690, -0.403322, 0.115058, 0.1...",venue,"[-0.1699792742729187, -1.8855005502700806, -0...."


In [None]:
1116163, 5574903, 

In [15]:
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics.pairwise import euclidean_distances

In [43]:
topconf = graph.node_feature['venue'].nlargest(256, 'citation')

In [42]:
res = list(graph.node_feature['venue'][graph.node_feature['venue']['name'] == 'WWW']['emb'])

In [55]:
cos = euclidean_distances(res, list(topconf['emb']))[0]
for li in np.argsort(cos)[:50]:
    print(list(topconf['name'])[li])

world wide web
WSDM
CIKM
ICWSM
KDD
ieee transactions on knowledge and data engineering
SIGIR
AAAI
proceedings of the vldb endowment
HICSS
ISWC
information processing and management
ieee intelligent systems
lecture notes in computer science
IUI
SIGMOD
knowledge and information systems
ICDM
journal of web semantics
CHI
IJCAI
SAC
acm transactions on information systems
ICDE
journal of the association for information science and technology
EDBT
arxiv cryptography and security
international journal of human computer studies international journal of man machine studies
acm transactions on computer human interaction
sigkdd explorations
scientific programming
NDSS
decision support systems
FSE
computer supported cooperative work
journal of artificial intelligence research
autonomous agents and multi agent systems
information sciences
arxiv distributed parallel and cluster computing
CCS
personal and ubiquitous computing
evolutionary computation
NAACL
ACSAC
international journal of geographical i

In [238]:
gnn.to(device)
gnn.eval()
topconf = graph.node_feature['venue'].nlargest(100, 'citation')
with torch.no_grad():
    _time = 2000
    vids = np.stack([graph.node_feature['venue'].nlargest(100, 'citation').index.values, np.repeat([_time], 100)]).T
    conf_emb1 = []
    conf_emb2 = []
    for i in range(10):
        print(i)
        feature, times, edge_list, _ = sample_subgraph(graph, {t: True for t in graph.times if t != None}, \
                        inp = {'venue': vids}, sampled_depth = 4, sampled_number = 256)
        node_feature, node_type, edge_time, edge_index, edge_type, node_dict, edge_dict = \
                    to_torch(feature, times, edge_list, graph)
        venue_ids = np.arange(len(vids)) + node_dict['venue'][0]

        node_feature, node_type, edge_time, edge_index, edge_type = node_feature.to(device), node_type.to(device), \
                              edge_time.to(device), edge_index.to(device), edge_type.to(device)
        res = torch.zeros(node_feature.size(0), gnn.n_hid).to(node_feature.device)
        for t_id in range(gnn.num_types):
            aggregat_w = gnn.aggregat_ws[t_id]
            idx = (node_type == t_id)
            if idx.sum() == 0:
                continue
            res[idx] = torch.tanh(aggregat_w(node_feature[idx]))
        meta_xs = gnn.drop(res)
        del res
        meta_xs = gnn.gcs[0](meta_xs, node_type, edge_index, edge_type, edge_time)
        emb = meta_xs[venue_ids].cpu().detach().numpy()
        conf_emb1 += [emb]
        cos = cosine_similarity(emb)
        for li in np.argsort(-cos[list(topconf['name']).index('KDD')])[:50]:
            print(list(topconf['name'])[li])
        print('-' * 100)
        dis = euclidean_distances(emb)
        for li in np.argsort(dis[list(topconf['name']).index('KDD')])[:50]:
            print(list(topconf['name'])[li])
        print('=' * 100)
        
        
        meta_xs = gnn.gcs[1](meta_xs, node_type, edge_index, edge_type, edge_time)
        emb = meta_xs[venue_ids].cpu().detach().numpy()
        conf_emb2 += [emb]
        cos = cosine_similarity(emb)
        for li in np.argsort(-cos[list(topconf['name']).index('KDD')])[:50]:
            print(list(topconf['name'])[li])
        print('-' * 100)
        dis = euclidean_distances(emb)
        for li in np.argsort(dis[list(topconf['name']).index('KDD')])[:50]:
            print(list(topconf['name'])[li])
        print('=' * 100)

0
KDD
ICCV
CIKM
NSDI
ICDM
MM
ICCAD
ICDCS
GLOBECOM
VTC
CHI
MOBICOM
world wide web
MobiHoc
HPCA
SIGMOD
ieee transactions on computer aided design of integrated circuits and systems
acm sigarch computer architecture news
UIST
DAC
CCS
MobiSys
autonomous agents and multi agent systems
SIGGRAPH
ICSE
ICDE
sigmetrics performance evaluation review
SenSys
ICIP
ICC
NAACL
SIGIR
INFOCOM
ieee wireless communications
HICSS
operating systems review
DATE
IJCAI
ICML
NeurIPS
ECCV
ICASSP
ACL
EMNLP
CVPR
AAAI
SECURITY
journal of the association for information science and technology
CRYPTO
IPDPS
----------------------------------------------------------------------------------------------------
KDD
ICCV
CIKM
NSDI
ICDM
MM
ICCAD
ICDCS
GLOBECOM
VTC
CHI
MOBICOM
world wide web
MobiHoc
HPCA
SIGMOD
ieee transactions on computer aided design of integrated circuits and systems
acm sigarch computer architecture news
UIST
DAC
CCS
MobiSys
SIGGRAPH
autonomous agents and multi agent systems
SenSys
sigmetrics performance 

KeyboardInterrupt: 

In [None]:
coss = [euclidean_distances(emb) for emb in conf_emb1]
cos = np.average(coss, axis=0)
for li in np.argsort(cos[list(topconf['name']).index('ACL')])[:50]:
    print(list(topconf['name'])[li])
print('-' * 100)

In [232]:
coss = [euclidean_distances(emb) for emb in conf_emb2]
cos = np.average(coss, axis=0)
for li in np.argsort(cos[list(topconf['name']).index('ACL')])[:50]:
    print(list(topconf['name'])[li])
print('-' * 100)

ACL
INFOCOM
ECCV
GLOBECOM
ICSE
CVPR
ICIP
SIGMOD
EMNLP
NeurIPS
ICCV
NAACL
FSE
AAAI
VTC
CAV
SECURITY
SIGIR
ICDE
CCS
ICC
IJCAI
EDBT
MICCAI
CIKM
ECOOP
SIGGRAPH
COLING
ICRA
ICML
KDD
ICDCS
data and knowledge engineering
MM
GECCO
DAC
DATE
ICCAD
INTERSPEECH
ASILOMAR
MOBICOM
ICLR
IUI
NSDI
UAI
ICDM
HICSS
ACSAC
ISSTA
IPDPS
----------------------------------------------------------------------------------------------------


In [91]:
time_emb = {2010: [conf_emb1, conf_emb2]}

In [105]:
time_emb[1990] = [conf_emb1, conf_emb2]

In [116]:
time_emb[2030] = [conf_emb1, conf_emb2]

In [122]:
time_emb[1990][1]

[array([[-0.44682163,  1.4439737 ,  1.0461857 , ...,  1.0276611 ,
         -0.5696738 , -0.7640809 ],
        [-0.44682163,  1.2189132 ,  0.80924225, ...,  1.4106448 ,
         -0.56967384, -0.7640809 ],
        [-0.44682163,  1.7692353 ,  0.7775029 , ...,  0.61222756,
         -0.49836862, -0.7178318 ],
        ...,
        [-0.44682163, -0.10202482,  0.07401577, ...,  0.37724358,
         -0.51938236, -0.7640809 ],
        [-0.44682163, -0.00676589, -0.32257503, ...,  0.21618128,
         -0.50861377, -0.76408094],
        [-0.44682163, -0.5664511 , -0.03994929, ..., -0.2181938 ,
         -0.5206225 , -0.7640809 ]], dtype=float32),
 array([[-0.4534359 ,  1.0715052 ,  0.79899216, ...,  0.41144317,
         -0.49090075, -0.7148243 ],
        [-0.4534359 ,  0.7887458 ,  0.21565858, ...,  1.555398  ,
         -0.4794289 , -0.7148243 ],
        [-0.4534359 ,  1.0116353 ,  1.191432  , ...,  0.8996734 ,
         -0.43086213, -0.7148243 ],
        ...,
        [-0.4534359 , -0.29021427, -0.1

In [143]:
node_dict = {j: i for i, j in enumerate(graph.get_types())}

In [141]:
edge_dict = {e[2]: i for i, e in enumerate(graph.get_meta_graph())}
edge_dict['self'] = len(edge_dict)

In [222]:
_s = 0
for i in graph.edge_list['author']['affiliation']['rev_in']:
    _s += len(graph.edge_list['author']['affiliation']['rev_in'][i])

In [223]:
_s

612872

In [224]:
len(graph.edge_list['author']['affiliation']['rev_in'])

510189

In [202]:
gnn, _ = torch.load('../paper-field-L2/save/rgt_1.pt')

In [192]:

for i, j, k in graph.get_meta_graph():
    if i == 'affiliation':
        print(i,j,k)
        print(gnn.gcs[0].relation_ws[node_dict[i]][edge_dict[k]][node_dict[j]].mean())

affiliation author in
tensor(0.9698, device='cuda:1', grad_fn=<MeanBackward0>)
affiliation affiliation IPI_coauthor
tensor(0.9237, device='cuda:1', grad_fn=<MeanBackward0>)


In [None]:
paper venue rev_PV_Conference
tensor(1.0509, device='cuda:1', grad_fn=<MeanBackward0>)

