In [1]:
from collections import defaultdict
import random

import math
import networkx as nx
import numpy as np
from six import iteritems
from gensim.models.keyedvectors import Vocab
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import tqdm
import dgl
import dgl.function as fn
from sklearn.metrics import (auc, f1_score, precision_recall_curve,
                             roc_auc_score)

In [2]:
class NSLoss(nn.Module):
    #                   511         5           200
    def __init__(self, num_nodes, num_sampled, embedding_size):
        super(NSLoss, self).__init__()
        self.num_nodes = num_nodes
        self.num_sampled = num_sampled
        self.embedding_size = embedding_size
        self.weights = Parameter(torch.FloatTensor(num_nodes, embedding_size))
        self.sample_weights = F.normalize(
            torch.Tensor(
                [
                    (math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1)
                    for k in range(num_nodes)
                ]
            ),
            dim=0,
        )

        self.reset_parameters()

    def reset_parameters(self):
        self.weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))

    def forward(self, input, embs, label):
        n = input.shape[0]
        log_target = torch.log(
            torch.sigmoid(torch.sum(torch.mul(embs, self.weights[label]), 1))
        )
        negs = torch.multinomial(
            self.sample_weights, self.num_sampled * n, replacement=True
        ).view(n, self.num_sampled)
        noise = torch.neg(self.weights[negs])
        sum_log_sampled = torch.sum(
            torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1
        ).squeeze()

        loss = log_target + sum_log_sampled
        return -loss.sum() / n

In [3]:
def load_testing_data(f_name):
    print('We are loading data from:', f_name)
    true_edge_data_by_type = dict()
    false_edge_data_by_type = dict()
    all_edges = list()
    all_nodes = list()
    with open(f_name, 'r') as f:
        for line in f:
            words = line[:-1].split(' ')
            x, y = words[1], words[2]
            if int(words[3]) == 1:
                if words[0] not in true_edge_data_by_type:
                    true_edge_data_by_type[words[0]] = list()
                true_edge_data_by_type[words[0]].append((x, y))
            else:
                if words[0] not in false_edge_data_by_type:
                    false_edge_data_by_type[words[0]] = list()
                false_edge_data_by_type[words[0]].append((x, y))
            all_nodes.append(x)
            all_nodes.append(y)
    all_nodes = list(set(all_nodes))
    return true_edge_data_by_type, false_edge_data_by_type

In [4]:
edge_data_by_type = dict()
all_nodes = list()
with open("data/example/train.txt", 'r') as f:
    for line in f:
        words = line[:-1].split(' ')  # 
        if words[0] not in edge_data_by_type:
            edge_data_by_type[words[0]] = list()
        x, y = words[1], words[2]
        edge_data_by_type[words[0]].append((x, y))
        all_nodes.append(x)
        all_nodes.append(y)
all_nodes = list(set(all_nodes))
print('Total training nodes: ' + str(len(all_nodes)))
training_data_by_type = edge_data_by_type


valid_true_data_by_edge, valid_false_data_by_edge = load_testing_data(
    "data/example/valid.txt"
)
testing_true_data_by_edge, testing_false_data_by_edge = load_testing_data(
    "data/example/test.txt"
)
    

Total training nodes: 511
We are loading data from: data/example/valid.txt
We are loading data from: data/example/test.txt


In [5]:
true_edge_data_by_type = dict()
false_edge_data_by_type = dict()
all_edges = list()
all_nodes = list()
with open("data/example/test.txt", 'r') as f:
    for line in f:
        words = line[:-1].split(' ')
        x, y = words[1], words[2]
        if int(words[3]) == 1:
            if words[0] not in true_edge_data_by_type:
                true_edge_data_by_type[words[0]] = list()
            true_edge_data_by_type[words[0]].append((x, y))
        else:
            if words[0] not in false_edge_data_by_type:
                false_edge_data_by_type[words[0]] = list()
            false_edge_data_by_type[words[0]].append((x, y))
        all_nodes.append(x)
        all_nodes.append(y)
all_nodes = list(set(all_nodes))

In [6]:
# get_G_from_edges(tmp_data)  tmp_data: [(node1, node2), (node, node)]
def get_G_from_edges(edges):
    edge_dict = dict()  # store how many edges between node1 and node2
    for edge in edges:  # (node1, node2)
        edge_key = str(edge[0]) + '_' + str(edge[1])
        if edge_key not in edge_dict:
            edge_dict[edge_key] = 1
        else:
            edge_dict[edge_key] += 1
    tmp_G = nx.Graph()
    for edge_key in edge_dict:
        weight = edge_dict[edge_key]
        x = edge_key.split('_')[0]
        y = edge_key.split('_')[1]
        tmp_G.add_edge(x, y)
        tmp_G[x][y]['weight'] = weight
    return tmp_G

In [7]:
class RWGraph():
    # layer_walker = RWGraph(get_G_from_edges(tmp_data))
    def __init__(self, nx_G, node_type=None):
        self.G = nx_G
        self.node_type = node_type
#         print("node_type", node_type)

    def walk(self, walk_length, start, schema=None):
        # Simulate a random walk starting from start node.
        G = self.G

        rand = random.Random()

        if schema:
            schema_items = schema.split('-')
            assert schema_items[0] == schema_items[-1]

        walk = [start]
        while len(walk) < walk_length:
            cur = walk[-1]
            candidates = []
            for node in G[cur].keys():
                if schema == None or self.node_type[node] == schema_items[len(walk) % (len(schema_items) - 1)]:
                    candidates.append(node)
            if candidates:
                walk.append(rand.choice(candidates))
            else:
                break
        return [str(node) for node in walk]

    # layer_walker = RWGraph(get_G_from_edges(tmp_data))
    # layer_walks = layer_walker.simulate_walks(num_walks=20, walk_length=10, schema=schema)
    def simulate_walks(self, num_walks, walk_length, schema=None):
        G = self.G
        walks = []
        nodes = list(G.nodes())
        # print('Walk iteration:')
        if schema is not None:
            schema_list = schema.split(',')
        for walk_iter in range(num_walks):
            random.shuffle(nodes)
            for node in nodes:
                if schema is None:
                    walks.append(self.walk(walk_length=walk_length, start=node))
                else:
                    for schema_iter in schema_list:
                        if schema_iter.split('-')[0] == self.node_type[node]:
                            walks.append(self.walk(walk_length=walk_length, start=node, schema=schema_iter))

        return walks


In [8]:
# generate_walks(training_data_by_type, 20, 10, None, "data/example")
file_name = "data/example"
def generate_walks(network_data, num_walks, walk_length, schema, file_name):
    if schema is not None:
        node_type = load_node_type(file_name + '/node_type.txt')
    else:
        node_type = None

    all_walks = []
    for layer_id in network_data:  # edge_type
        tmp_data = network_data[layer_id]  # node of each layer (edge type)
        # start to do the random walk on a layer

        layer_walker = RWGraph(get_G_from_edges(tmp_data))
        layer_walks = layer_walker.simulate_walks(num_walks, walk_length, schema=schema)

        all_walks.append(layer_walks)

    print('Finish generating the walks')

    return all_walks

In [9]:
def generate_vocab(all_walks):  # (2, [8640, 4660], 10)
    index2word = []
    raw_vocab = defaultdict(int)

    for walks in all_walks:
        for walk in walks:
            for word in walk:
                raw_vocab[word] += 1

    vocab = {}
    for word, v in raw_vocab.items():  # no duplicate
        vocab[word] = Vocab(count=v, index=len(index2word))
        index2word.append(word)

    index2word.sort(key=lambda word: vocab[word].count, reverse=True)
    for i, word in enumerate(index2word):
        vocab[word].index = i
    
    return vocab, index2word

In [10]:
def generate_pairs(all_walks, vocab, window_size):
    pairs = []
    skip_window = window_size // 2
    for layer_id, walks in enumerate(all_walks):
        for walk in walks:
            for i in range(len(walk)):
                for j in range(1, skip_window + 1):
                    if i - j >= 0:
                        pairs.append((vocab[walk[i]].index, vocab[walk[i - j]].index, layer_id))
                    if i + j < len(walk):
                        pairs.append((vocab[walk[i]].index, vocab[walk[i + j]].index, layer_id))
    return pairs

In [11]:
window_size = 5
all_walks = generate_walks(training_data_by_type, 20, 10, None, file_name)
vocab, index2word = generate_vocab(all_walks)
train_pairs = generate_pairs(all_walks, vocab, window_size)

edge_types = list(training_data_by_type.keys())
eval_type = 'all'
num_nodes = len(index2word)
edge_type_count = len(edge_types)
epochs = 100
batch_size = 64
embedding_size = 200
embedding_u_size = 15
u_num = edge_type_count
num_sampled = 5
dim_a = 20
att_head = 1
neighbor_samples = 10

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Finish generating the walks


In [12]:
network_data=training_data_by_type
neighbors = [[[] for __ in range(edge_type_count)] for _ in range(num_nodes)]
for r in range(edge_type_count):
    g = network_data[edge_types[r]]
    for (x, y) in g:
        ix = vocab[x].index
        iy = vocab[y].index
        neighbors[ix][r].append(iy)
        neighbors[iy][r].append(ix)
    for i in range(num_nodes):
        if len(neighbors[i][r]) == 0:  # no neighbor
            neighbors[i][r] = [i] * neighbor_samples  # regard itself as its neighbor
        elif len(neighbors[i][r]) < neighbor_samples:  # randomly repeat neighbors to reach neighbor_samples
            neighbors[i][r].extend(
                list(
                    np.random.choice(
                        neighbors[i][r],
                        size=neighbor_samples - len(neighbors[i][r]),
                    )
                )
            )
        elif len(neighbors[i][r]) > neighbor_samples:  # random pick 10 and remove others
            neighbors[i][r] = list(
                np.random.choice(neighbors[i][r], size=neighbor_samples)
            )


In [42]:
# train_pairs: size: 452200, form: (node1, node2, layer_id)
# neighbors: [num_nodes=511, 2, 10]
def get_batches(pairs, neighbors, batch_size):
    n_batches = (len(pairs) + (batch_size - 1)) // batch_size

    for idx in range(n_batches):
        x, y, t, neigh = [], [], [], []
        for i in range(batch_size):
            index = idx * batch_size + i
            if index >= len(pairs):
                break
            x.append(pairs[index][0])
            y.append(pairs[index][1])
            t.append(pairs[index][2])
            neigh.append(neighbors[pairs[index][0]])
        yield torch.tensor(x), torch.tensor(y), torch.tensor(t), torch.tensor(neigh)

In [14]:
def get_graphs(layers, index2word, neighbors):
    graphs = []
    for layer in range(layers):
        g = dgl.DGLGraph()
        g.add_nodes(len(index2word))
        graphs.append(g)
    
    for n in range(len(neighbors)):
        for layer in range(layers):
            graphs[layer].add_edges(n, neighbors[n][layer])
    
    return graphs

In [51]:
class DGLGATNE(nn.Module):
    def __init__(
        self, graphs, num_nodes, embedding_size, embedding_u_size, edge_type_count, dim_a
    ):
        super(DGLGATNE, self).__init__()
        assert len(graphs) == edge_type_count
        
        self.graphs = graphs
        self.num_nodes = num_nodes
        self.embedding_size = embedding_size
        self.embedding_u_size = embedding_u_size
        self.edge_type_count = edge_type_count
        self.dim_a = dim_a
        
#         for g in self.graphs:
#             g.ndata['node_type_embeddings'] = -2 * torch.rand(num_nodes, embedding_u_size) + 1

        self.node_embeddings = Parameter(torch.FloatTensor(num_nodes, embedding_size))
        self.node_type_embeddings = Parameter(
            torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size)
        )
        self.trans_weights = Parameter(
            torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size)
        )
        self.trans_weights_s1 = Parameter(
            torch.FloatTensor(edge_type_count, embedding_u_size, dim_a)
        )
        self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, dim_a, 1))

        self.reset_parameters()

    def reset_parameters(self):
        self.node_embeddings.data.uniform_(-1.0, 1.0)
        self.node_type_embeddings.data.uniform_(-1.0, 1.0)
        self.trans_weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
        self.trans_weights_s1.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
        self.trans_weights_s2.data.normal_(std=1.0 / math.sqrt(self.embedding_size))

    # data: node1, node2, layer_id, 10 neighbors of node1. dimension: batch_size/batch_size*10
    # embs: [batch_size64, embedding_size200]
    # embs = model(data[0].to(device), data[2].to(device), data[3].to(device))
    def forward(self, train_inputs, train_types, node_neigh):
        for g in self.graphs:
            g.ndata['node_type_embeddings'] = self.node_type_embeddings
            
        sub_graphs = []
        for layer in range(self.edge_type_count):
            edges = self.graphs[layer].edge_ids(train_inputs[0], node_neigh[0][layer])
            for node in range(1, train_inputs.shape[0]):
                e = self.graphs[layer].edge_ids(train_inputs[node], node_neigh[node][layer])
                edges = torch.cat([edges, e], dim = 0)
            graph = self.graphs[layer].edge_subgraph(edges, preserve_nodes=True)
            graph.ndata['node_type_embeddings'] = self.node_type_embeddings[:, layer, :]
            sub_graphs.append(graph)
            
        node_embed = self.node_embeddings
        
        node_type_embed = []
        for layer in range(self.edge_type_count):
            graph = sub_graphs[layer]
            graph.update_all(fn.copy_src('node_type_embeddings', 'm'), fn.sum('m', 'neigh'))
            node_type_embed.append(graph.ndata['neigh'][train_inputs])
        node_type_embed = torch.stack(node_type_embed, 1)  # batch, layers, 10
        
        # [batch_size, embedding_u_size10, embedding_size200]
        trans_w = self.trans_weights[train_types]
        # [batch_size, embedding_u_size10, dim_a20]
        trans_w_s1 = self.trans_weights_s1[train_types]
        # [batch_size, dim_a20, 1]
        trans_w_s2 = self.trans_weights_s2[train_types]
        
        attention = F.softmax(
            torch.matmul(
                torch.tanh(torch.matmul(node_type_embed, trans_w_s1)), trans_w_s2
            ).squeeze(2),
            dim=1,
        ).unsqueeze(1)
                
        node_type_embed = torch.matmul(attention, node_type_embed)
        node_embed = node_embed[train_inputs] + torch.matmul(node_type_embed, trans_w).squeeze(1)
        last_node_embed = F.normalize(node_embed, dim=1)
        
        return last_node_embed
        
#     def forward(self, train_inputs, train_types, node_neigh):
#         graphs = self.graphs
#         node_embed = self.node_embeddings[train_inputs]
#         node_embed_neighbors = self.node_type_embeddings[node_neigh]
#         node_embed_tmp = torch.cat(
#             [
#                 node_embed_neighbors[:, i, :, i, :].unsqueeze(1)
#                 for i in range(self.edge_type_count)
#             ],
#             dim=1,
#         )
#         node_type_embed = torch.sum(node_embed_tmp, dim=2)

#         trans_w = self.trans_weights[train_types]
#         trans_w_s1 = self.trans_weights_s1[train_types]
#         trans_w_s2 = self.trans_weights_s2[train_types]

#         attention = F.softmax(
#             torch.matmul(
#                 torch.tanh(torch.matmul(node_type_embed, trans_w_s1)), trans_w_s2
#             ).squeeze(2),
#             dim=1,
#         ).unsqueeze(1)
#         node_type_embed = torch.matmul(attention, node_type_embed)
#         node_embed = node_embed + torch.matmul(node_type_embed, trans_w).squeeze(1)

#         last_node_embed = F.normalize(node_embed, dim=1)

#         return last_node_embed

In [52]:
graphs = get_graphs(edge_type_count, index2word, neighbors)

model = DGLGATNE(
    graphs, num_nodes, embedding_size, embedding_u_size, edge_type_count, dim_a
)

In [53]:
# model = GATNEModel(
#     num_nodes, embedding_size, embedding_u_size, edge_type_count, dim_a
# )
nsloss = NSLoss(num_nodes, num_sampled, embedding_size)

model.to(device)
nsloss.to(device)

optimizer = torch.optim.Adam(
    [{"params": model.parameters()}, {"params": nsloss.parameters()}], lr=1e-4
)

best_score = 0
patience = 0
# batches = get_batches(train_pairs, neighbors)

In [54]:
epochs = 1
for epoch in range(epochs):
    random.shuffle(train_pairs)
    batches = get_batches(train_pairs, neighbors, batch_size)  # 7066 batches

    data_iter = tqdm.tqdm(
        batches,
        desc="epoch %d" % (epoch),
        total=(len(train_pairs) + (batch_size - 1)) // batch_size,
        bar_format="{l_bar}{r_bar}",
    )
    avg_loss = 0.0
    
    for i, data in enumerate(data_iter):
        # batch by batch, 7066 batches in total
        optimizer.zero_grad()
        # data: node1, node2, layer_id, 10 neighbors of node1. dimension: batch_size/batch_size*10
        # embs: [batch_size64, embedding_size200]
        embs = model(data[0].to(device), data[2].to(device), data[3].to(device))
        loss = nsloss(data[0].to(device), embs, data[1].to(device))
        loss.backward()
        optimizer.step()

        avg_loss += loss.item()

        if i % 5000 == 0:
            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "loss": loss.item(),
            }
            data_iter.write(str(post_fix))





                                       [A[A[A[A
[A                                    


[A[A[A                              

[A[A                                 



epoch 0:   0%|| 0/7066 [30:33<?, ?it/s]
epoch 0:   0%|| 0/7066 [23:27<?, ?it/s][A


epoch 0:   0%|| 0/7066 [05:24<?, ?it/s][A[A[A

epoch 0:   0%|| 0/7066 [15:20<?, ?it/s][A[A



epoch 0:   0%|| 0/7066 [00:00<?, ?it/s][A[A[A[A



epoch 0:   0%|| 3/7066 [00:00<04:01, 29.30it/s][A[A[A[A



epoch 0:   0%|| 8/7066 [00:00<03:31, 33.31it/s][A[A[A[A

{'epoch': 0, 'iter': 0, 'avg_loss': 4.163259983062744, 'loss': 4.163259983062744}






epoch 0:   0%|| 13/7066 [00:00<03:11, 36.76it/s][A[A[A[A



epoch 0:   0%|| 19/7066 [00:00<02:56, 39.99it/s][A[A[A[A



epoch 0:   0%|| 24/7066 [00:00<02:45, 42.53it/s][A[A[A[A



epoch 0:   0%|| 30/7066 [00:00<02:37, 44.62it/s][A[A[A[A



epoch 0:   0%|| 35/7066 [00:00<02:34, 45.43it/s][A[A[A[A



epoch 0:   1%|| 40/7066 [00:00<02:34, 45.60it/s][A[A[A[A



epoch 0:   1%|| 46/7066 [00:00<02:28, 47.21it/s][A[A[A[A



epoch 0:   1%|| 52/7066 [00:01<02:25, 48.28it/s][A[A[A[A



epoch 0:   1%|| 57/7066 [00:01<02:24, 48.38it/s][A[A[A[A



epoch 0:   1%|| 62/7066 [00:01<02:24, 48.33it/s][A[A[A[A



epoch 0:   1%|| 67/7066 [00:01<02:25, 48.26it/s][A[A[A[A



epoch 0:   1%|| 73/7066 [00:01<02:22, 48.96it/s][A[A[A[A



epoch 0:   1%|| 78/7066 [00:01<02:22, 48.93it/s][A[A[A[A



epoch 0:   1%|| 83/7066 [00:01<02:22, 48.91it/s][A[A[A[A



epoch 0:   1%|| 88/7066 [00:01<02:22, 48.97it/s][A[A[A[A



epoch 0:   1%|| 94/7066 [00:01<02:21

epoch 0:  10%|| 710/7066 [00:14<02:17, 46.29it/s][A[A[A[A



epoch 0:  10%|| 715/7066 [00:14<02:15, 46.72it/s][A[A[A[A



epoch 0:  10%|| 720/7066 [00:14<02:14, 47.09it/s][A[A[A[A



epoch 0:  10%|| 725/7066 [00:15<02:13, 47.55it/s][A[A[A[A



epoch 0:  10%|| 730/7066 [00:15<02:14, 47.05it/s][A[A[A[A



epoch 0:  10%|| 735/7066 [00:15<02:14, 46.96it/s][A[A[A[A



epoch 0:  10%|| 740/7066 [00:15<02:16, 46.23it/s][A[A[A[A



epoch 0:  11%|| 745/7066 [00:15<02:17, 46.12it/s][A[A[A[A



epoch 0:  11%|| 750/7066 [00:15<02:15, 46.68it/s][A[A[A[A



epoch 0:  11%|| 755/7066 [00:15<02:15, 46.64it/s][A[A[A[A



epoch 0:  11%|| 760/7066 [00:15<02:16, 46.24it/s][A[A[A[A



epoch 0:  11%|| 765/7066 [00:15<02:15, 46.62it/s][A[A[A[A



epoch 0:  11%|| 770/7066 [00:16<02:14, 46.71it/s][A[A[A[A



epoch 0:  11%|| 775/7066 [00:16<02:15, 46.35it/s][A[A[A[A



epoch 0:  11%|| 780/7066 [00:16<02:15, 46.27it/s][A[A[A[A



epoch 0:  11%|| 785/7066 

epoch 0:  19%|| 1336/7066 [00:28<02:11, 43.55it/s][A[A[A[A



epoch 0:  19%|| 1341/7066 [00:28<02:11, 43.37it/s][A[A[A[A



epoch 0:  19%|| 1346/7066 [00:28<02:10, 43.98it/s][A[A[A[A



epoch 0:  19%|| 1351/7066 [00:28<02:10, 43.71it/s][A[A[A[A



epoch 0:  19%|| 1356/7066 [00:28<02:11, 43.36it/s][A[A[A[A



epoch 0:  19%|| 1361/7066 [00:29<02:11, 43.35it/s][A[A[A[A



epoch 0:  19%|| 1366/7066 [00:29<02:10, 43.79it/s][A[A[A[A



epoch 0:  19%|| 1371/7066 [00:29<02:09, 43.84it/s][A[A[A[A



epoch 0:  19%|| 1376/7066 [00:29<02:10, 43.55it/s][A[A[A[A



epoch 0:  20%|| 1381/7066 [00:29<02:10, 43.43it/s][A[A[A[A



epoch 0:  20%|| 1386/7066 [00:29<02:12, 42.97it/s][A[A[A[A



epoch 0:  20%|| 1391/7066 [00:29<02:09, 43.86it/s][A[A[A[A



epoch 0:  20%|| 1396/7066 [00:29<02:06, 44.73it/s][A[A[A[A



epoch 0:  20%|| 1401/7066 [00:29<02:06, 44.74it/s][A[A[A[A



epoch 0:  20%|| 1406/7066 [00:30<02:05, 45.15it/s][A[A[A[A



epoch 0:  

epoch 0:  28%|| 1957/7066 [00:41<01:48, 47.13it/s][A[A[A[A



epoch 0:  28%|| 1962/7066 [00:42<01:48, 46.88it/s][A[A[A[A



epoch 0:  28%|| 1967/7066 [00:42<01:47, 47.22it/s][A[A[A[A



epoch 0:  28%|| 1972/7066 [00:42<01:46, 47.66it/s][A[A[A[A



epoch 0:  28%|| 1977/7066 [00:42<01:46, 47.80it/s][A[A[A[A



epoch 0:  28%|| 1982/7066 [00:42<01:47, 47.12it/s][A[A[A[A



epoch 0:  28%|| 1987/7066 [00:42<01:48, 46.91it/s][A[A[A[A



epoch 0:  28%|| 1992/7066 [00:42<01:47, 47.05it/s][A[A[A[A



epoch 0:  28%|| 1997/7066 [00:42<01:48, 46.90it/s][A[A[A[A



epoch 0:  28%|| 2002/7066 [00:42<01:47, 47.17it/s][A[A[A[A



epoch 0:  28%|| 2007/7066 [00:42<01:47, 46.94it/s][A[A[A[A



epoch 0:  28%|| 2012/7066 [00:43<01:49, 46.05it/s][A[A[A[A



epoch 0:  29%|| 2017/7066 [00:43<01:47, 46.80it/s][A[A[A[A



epoch 0:  29%|| 2022/7066 [00:43<01:47, 46.81it/s][A[A[A[A



epoch 0:  29%|| 2027/7066 [00:43<01:49, 45.96it/s][A[A[A[A



epoch 0:  

epoch 0:  36%|| 2579/7066 [00:55<01:36, 46.64it/s][A[A[A[A



epoch 0:  37%|| 2584/7066 [00:55<01:35, 46.90it/s][A[A[A[A



epoch 0:  37%|| 2589/7066 [00:55<01:36, 46.23it/s][A[A[A[A



epoch 0:  37%|| 2594/7066 [00:55<01:37, 45.96it/s][A[A[A[A



epoch 0:  37%|| 2599/7066 [00:55<01:37, 45.97it/s][A[A[A[A



epoch 0:  37%|| 2604/7066 [00:55<01:36, 46.37it/s][A[A[A[A



epoch 0:  37%|| 2609/7066 [00:55<01:36, 46.18it/s][A[A[A[A



epoch 0:  37%|| 2614/7066 [00:56<01:37, 45.54it/s][A[A[A[A



epoch 0:  37%|| 2619/7066 [00:56<01:39, 44.71it/s][A[A[A[A



epoch 0:  37%|| 2624/7066 [00:56<01:39, 44.43it/s][A[A[A[A



epoch 0:  37%|| 2629/7066 [00:56<01:39, 44.69it/s][A[A[A[A



epoch 0:  37%|| 2634/7066 [00:56<01:39, 44.57it/s][A[A[A[A



epoch 0:  37%|| 2639/7066 [00:56<01:38, 44.95it/s][A[A[A[A



epoch 0:  37%|| 2644/7066 [00:56<01:38, 44.81it/s][A[A[A[A



epoch 0:  37%|| 2649/7066 [00:56<01:38, 44.90it/s][A[A[A[A



epoch 0:  

epoch 0:  45%|| 3201/7066 [01:09<01:20, 47.84it/s][A[A[A[A



epoch 0:  45%|| 3206/7066 [01:09<01:20, 48.02it/s][A[A[A[A



epoch 0:  45%|| 3211/7066 [01:09<01:20, 47.71it/s][A[A[A[A



epoch 0:  46%|| 3216/7066 [01:09<01:20, 47.71it/s][A[A[A[A



epoch 0:  46%|| 3221/7066 [01:09<01:20, 47.93it/s][A[A[A[A



epoch 0:  46%|| 3226/7066 [01:09<01:20, 47.70it/s][A[A[A[A



epoch 0:  46%|| 3231/7066 [01:09<01:20, 47.58it/s][A[A[A[A



epoch 0:  46%|| 3236/7066 [01:09<01:20, 47.82it/s][A[A[A[A



epoch 0:  46%|| 3241/7066 [01:10<01:20, 47.72it/s][A[A[A[A



epoch 0:  46%|| 3246/7066 [01:10<01:20, 47.40it/s][A[A[A[A



epoch 0:  46%|| 3251/7066 [01:10<01:21, 46.94it/s][A[A[A[A



epoch 0:  46%|| 3256/7066 [01:10<01:20, 47.10it/s][A[A[A[A



epoch 0:  46%|| 3261/7066 [01:10<01:21, 46.93it/s][A[A[A[A



epoch 0:  46%|| 3266/7066 [01:10<01:21, 46.62it/s][A[A[A[A



epoch 0:  46%|| 3271/7066 [01:10<01:21, 46.43it/s][A[A[A[A



epoch 0:  

epoch 0:  54%|| 3821/7066 [01:22<01:12, 44.54it/s][A[A[A[A



epoch 0:  54%|| 3826/7066 [01:22<01:16, 42.39it/s][A[A[A[A



epoch 0:  54%|| 3831/7066 [01:23<01:14, 43.25it/s][A[A[A[A



epoch 0:  54%|| 3836/7066 [01:23<01:14, 43.62it/s][A[A[A[A



epoch 0:  54%|| 3841/7066 [01:23<01:13, 43.90it/s][A[A[A[A



epoch 0:  54%|| 3846/7066 [01:23<01:13, 44.07it/s][A[A[A[A



epoch 0:  55%|| 3851/7066 [01:23<01:12, 44.50it/s][A[A[A[A



epoch 0:  55%|| 3856/7066 [01:23<01:12, 44.47it/s][A[A[A[A



epoch 0:  55%|| 3861/7066 [01:23<01:13, 43.46it/s][A[A[A[A



epoch 0:  55%|| 3866/7066 [01:23<01:13, 43.51it/s][A[A[A[A



epoch 0:  55%|| 3871/7066 [01:23<01:12, 43.92it/s][A[A[A[A



epoch 0:  55%|| 3876/7066 [01:24<01:14, 42.77it/s][A[A[A[A



epoch 0:  55%|| 3881/7066 [01:24<01:15, 42.35it/s][A[A[A[A



epoch 0:  55%|| 3886/7066 [01:24<01:15, 41.90it/s][A[A[A[A



epoch 0:  55%|| 3891/7066 [01:24<01:15, 42.07it/s][A[A[A[A



epoch 0:  

epoch 0:  63%|| 4442/7066 [01:36<00:58, 44.73it/s][A[A[A[A



epoch 0:  63%|| 4447/7066 [01:37<00:59, 44.18it/s][A[A[A[A



epoch 0:  63%|| 4452/7066 [01:37<00:58, 44.81it/s][A[A[A[A



epoch 0:  63%|| 4457/7066 [01:37<00:57, 45.46it/s][A[A[A[A



epoch 0:  63%|| 4462/7066 [01:37<00:56, 45.90it/s][A[A[A[A



epoch 0:  63%|| 4467/7066 [01:37<00:56, 45.73it/s][A[A[A[A



epoch 0:  63%|| 4472/7066 [01:37<00:56, 46.00it/s][A[A[A[A



epoch 0:  63%|| 4477/7066 [01:37<00:56, 45.44it/s][A[A[A[A



epoch 0:  63%|| 4482/7066 [01:37<00:57, 44.74it/s][A[A[A[A



epoch 0:  64%|| 4487/7066 [01:37<00:57, 44.67it/s][A[A[A[A



epoch 0:  64%|| 4492/7066 [01:38<00:58, 44.28it/s][A[A[A[A



epoch 0:  64%|| 4497/7066 [01:38<00:57, 44.35it/s][A[A[A[A



epoch 0:  64%|| 4502/7066 [01:38<00:58, 44.12it/s][A[A[A[A



epoch 0:  64%|| 4507/7066 [01:38<00:57, 44.42it/s][A[A[A[A



epoch 0:  64%|| 4512/7066 [01:38<00:57, 44.53it/s][A[A[A[A



epoch 0:  

{'epoch': 0, 'iter': 5000, 'avg_loss': 3.5635554121151327, 'loss': 2.8564629554748535}






epoch 0:  71%|| 5012/7066 [01:49<00:47, 42.89it/s][A[A[A[A



epoch 0:  71%|| 5017/7066 [01:49<00:46, 43.86it/s][A[A[A[A



epoch 0:  71%|| 5022/7066 [01:49<00:45, 44.94it/s][A[A[A[A



epoch 0:  71%|| 5027/7066 [01:50<00:44, 45.48it/s][A[A[A[A



epoch 0:  71%|| 5032/7066 [01:50<00:44, 45.53it/s][A[A[A[A



epoch 0:  71%|| 5037/7066 [01:50<00:44, 45.57it/s][A[A[A[A



epoch 0:  71%|| 5042/7066 [01:50<00:44, 45.62it/s][A[A[A[A



epoch 0:  71%|| 5047/7066 [01:50<00:43, 46.45it/s][A[A[A[A



epoch 0:  71%|| 5052/7066 [01:50<00:43, 46.33it/s][A[A[A[A



epoch 0:  72%|| 5057/7066 [01:50<00:43, 46.55it/s][A[A[A[A



epoch 0:  72%|| 5062/7066 [01:50<00:42, 46.66it/s][A[A[A[A



epoch 0:  72%|| 5067/7066 [01:50<00:42, 46.82it/s][A[A[A[A



epoch 0:  72%|| 5072/7066 [01:50<00:42, 46.71it/s][A[A[A[A



epoch 0:  72%|| 5077/7066 [01:51<00:42, 46.35it/s][A[A[A[A



epoch 0:  72%|| 5082/7066 [01:51<00:43, 46.13it/s][A[A[A[A



epoch 

epoch 0:  80%|| 5632/7066 [02:03<00:30, 46.60it/s][A[A[A[A



epoch 0:  80%|| 5637/7066 [02:03<00:30, 46.34it/s][A[A[A[A



epoch 0:  80%|| 5642/7066 [02:03<00:30, 46.51it/s][A[A[A[A



epoch 0:  80%|| 5647/7066 [02:03<00:30, 47.07it/s][A[A[A[A



epoch 0:  80%|| 5652/7066 [02:03<00:30, 46.52it/s][A[A[A[A



epoch 0:  80%|| 5657/7066 [02:03<00:30, 45.75it/s][A[A[A[A



epoch 0:  80%|| 5662/7066 [02:03<00:30, 46.03it/s][A[A[A[A



epoch 0:  80%|| 5667/7066 [02:03<00:30, 46.57it/s][A[A[A[A



epoch 0:  80%|| 5672/7066 [02:03<00:30, 46.19it/s][A[A[A[A



epoch 0:  80%|| 5677/7066 [02:04<00:30, 45.83it/s][A[A[A[A



epoch 0:  80%|| 5682/7066 [02:04<00:30, 46.05it/s][A[A[A[A



epoch 0:  80%|| 5687/7066 [02:04<00:30, 45.63it/s][A[A[A[A



epoch 0:  81%|| 5692/7066 [02:04<00:30, 45.19it/s][A[A[A[A



epoch 0:  81%|| 5697/7066 [02:04<00:29, 46.14it/s][A[A[A[A



epoch 0:  81%|| 5702/7066 [02:04<00:29, 46.52it/s][A[A[A[A



epoch 0:  

epoch 0:  88%|| 6252/7066 [02:16<00:17, 46.85it/s][A[A[A[A



epoch 0:  89%|| 6257/7066 [02:16<00:17, 47.31it/s][A[A[A[A



epoch 0:  89%|| 6262/7066 [02:16<00:17, 47.13it/s][A[A[A[A



epoch 0:  89%|| 6267/7066 [02:16<00:16, 47.35it/s][A[A[A[A



epoch 0:  89%|| 6272/7066 [02:16<00:16, 48.08it/s][A[A[A[A



epoch 0:  89%|| 6278/7066 [02:17<00:16, 49.09it/s][A[A[A[A



epoch 0:  89%|| 6283/7066 [02:17<00:16, 48.77it/s][A[A[A[A



epoch 0:  89%|| 6288/7066 [02:17<00:15, 48.92it/s][A[A[A[A



epoch 0:  89%|| 6293/7066 [02:17<00:15, 48.64it/s][A[A[A[A



epoch 0:  89%|| 6298/7066 [02:17<00:15, 48.81it/s][A[A[A[A



epoch 0:  89%|| 6303/7066 [02:17<00:15, 48.78it/s][A[A[A[A



epoch 0:  89%|| 6308/7066 [02:17<00:15, 48.90it/s][A[A[A[A



epoch 0:  89%|| 6313/7066 [02:17<00:15, 48.84it/s][A[A[A[A



epoch 0:  89%|| 6318/7066 [02:17<00:15, 48.91it/s][A[A[A[A



epoch 0:  89%|| 6323/7066 [02:17<00:15, 48.59it/s][A[A[A[A



epoch 0:  

epoch 0:  97%|| 6876/7066 [02:29<00:04, 45.59it/s][A[A[A[A



epoch 0:  97%|| 6881/7066 [02:30<00:04, 45.32it/s][A[A[A[A



epoch 0:  97%|| 6886/7066 [02:30<00:03, 45.17it/s][A[A[A[A



epoch 0:  98%|| 6891/7066 [02:30<00:03, 45.35it/s][A[A[A[A



epoch 0:  98%|| 6896/7066 [02:30<00:03, 45.18it/s][A[A[A[A



epoch 0:  98%|| 6901/7066 [02:30<00:03, 45.49it/s][A[A[A[A



epoch 0:  98%|| 6906/7066 [02:30<00:03, 45.86it/s][A[A[A[A



epoch 0:  98%|| 6911/7066 [02:30<00:03, 45.63it/s][A[A[A[A



epoch 0:  98%|| 6916/7066 [02:30<00:03, 45.04it/s][A[A[A[A



epoch 0:  98%|| 6921/7066 [02:30<00:03, 45.20it/s][A[A[A[A



epoch 0:  98%|| 6926/7066 [02:31<00:03, 45.01it/s][A[A[A[A



epoch 0:  98%|| 6931/7066 [02:31<00:03, 44.75it/s][A[A[A[A



epoch 0:  98%|| 6936/7066 [02:31<00:02, 44.59it/s][A[A[A[A



epoch 0:  98%|| 6941/7066 [02:31<00:02, 44.68it/s][A[A[A[A



epoch 0:  98%|| 6946/7066 [02:31<00:02, 44.74it/s][A[A[A[A



epoch 0:  

In [278]:
final_model = dict(zip(edge_types, [dict() for _ in range(edge_type_count)]))
for i in range(num_nodes):
    train_inputs = torch.tensor([i for _ in range(edge_type_count)]).to(device)  # [i, i]
    train_types = torch.tensor(list(range(edge_type_count))).to(device)  # [0, 1]
    node_neigh = torch.tensor(
        [neighbors[i] for _ in range(edge_type_count)]  # [2, 2, 10]
    ).to(device)
    node_emb = model(train_inputs, train_types, node_neigh, subgraph=True) #[2, 200]
    for j in range(edge_type_count):
        final_model[edge_types[j]][index2word[i]] = (
            node_emb[j].cpu().detach().numpy()
        )

In [290]:
def evaluate(model, true_edges, false_edges):
    true_list = list()
    prediction_list = list()
    true_num = 0
    for edge in true_edges:
        tmp_score = get_score(model, str(edge[0]), str(edge[1]))
        if tmp_score is not None:
            true_list.append(1)
            prediction_list.append(tmp_score)
            true_num += 1

    for edge in false_edges:
        tmp_score = get_score(model, str(edge[0]), str(edge[1]))
        if tmp_score is not None:
            true_list.append(0)
            prediction_list.append(tmp_score)

    sorted_pred = prediction_list[:]
    sorted_pred.sort()
    threshold = sorted_pred[-true_num]

    y_pred = np.zeros(len(prediction_list), dtype=np.int32)
    for i in range(len(prediction_list)):
        if prediction_list[i] >= threshold:
            y_pred[i] = 1

    y_true = np.array(true_list)
    y_scores = np.array(prediction_list)
    ps, rs, _ = precision_recall_curve(y_true, y_scores)
    return roc_auc_score(y_true, y_scores), f1_score(y_true, y_pred), auc(rs, ps)

def get_score(local_model, node1, node2):
    try:
        vector1 = local_model[node1]
        vector2 = local_model[node2]
        return np.dot(vector1, vector2) / (np.linalg.norm(vector1) * np.linalg.norm(vector2))
    except Exception as e:
        pass

In [197]:
# train_pairs: size: 452200, form: (node1, node2, layer_id)
# neighbors: [num_nodes=511, 2, 10]
def get_batches(pairs, neighbors):
    x, y, t, neigh = [], [], [], []
    for pair in pairs:
        x.append(pair[0])
        y.append(pair[1])
        t.append(pair[2])
        neigh.append(neighbors[pair[0]])
    return torch.tensor(x), torch.tensor(y), torch.tensor(t), torch.tensor(neigh)

In [293]:
final_model = dict(zip(edge_types, [dict() for _ in range(edge_type_count)]))
for i in range(num_nodes):
    train_inputs = torch.tensor([i for _ in range(edge_type_count)]).to(device)  # [i, i]
    train_types = torch.tensor(list(range(edge_type_count))).to(device)  # [0, 1]
    node_neigh = torch.tensor(
        [neighbors[i] for _ in range(edge_type_count)]  # [2, 2, 10]
    ).to(device)
    node_emb = model(train_inputs, train_types, node_neigh, subgraph=True) #[2, 200]
    for j in range(edge_type_count):
        final_model[edge_types[j]][index2word[i]] = (
            node_emb[j].cpu().detach().numpy()
        )

valid_aucs, valid_f1s, valid_prs = [], [], []
test_aucs, test_f1s, test_prs = [], [], []
for i in range(edge_type_count):
    if eval_type == "all" or edge_types[i] in eval_type.split(","):
        tmp_auc, tmp_f1, tmp_pr = evaluate(
            final_model[edge_types[i]],
            valid_true_data_by_edge[edge_types[i]],
            valid_false_data_by_edge[edge_types[i]],
        )
        valid_aucs.append(tmp_auc)
        valid_f1s.append(tmp_f1)
        valid_prs.append(tmp_pr)

        tmp_auc, tmp_f1, tmp_pr = evaluate(
            final_model[edge_types[i]],
            testing_true_data_by_edge[edge_types[i]],
            testing_false_data_by_edge[edge_types[i]],
        )
        test_aucs.append(tmp_auc)
        test_f1s.append(tmp_f1)
        test_prs.append(tmp_pr)
print("valid auc:", np.mean(valid_aucs))
print("valid pr:", np.mean(valid_prs))
print("valid f1:", np.mean(valid_f1s))

valid auc: 0.5426267135215557
valid pr: 0.7175621239469707
valid f1: 0.6383369137128536


In [296]:
average_auc = np.mean(test_aucs)
average_f1 = np.mean(test_f1s)
average_pr = np.mean(test_prs)

cur_score = np.mean(valid_aucs)
if cur_score > best_score:
    best_score = cur_score
    patience = 0
else:
    patience += 1
    if patience > args.patience:
        print("Early Stopping")