In [1]:
from __future__ import division
from __future__ import print_function

import argparse
import time

import numpy as np
import pandas as pd
import scipy.sparse as sp
import torch
from torch import optim
import networkx as nx
from torch_geometric.datasets import Planetoid

from gae.model import GCNModelVAE
from gae.optimizer import loss_function
from gae.utils import mask_test_edges, preprocess_graph, get_roc_score

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='gcn_vae', help="models used")
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.')
parser.add_argument('--hidden1', type=int, default=32, help='Number of units in hidden layer 1.')
parser.add_argument('--hidden2', type=int, default=16, help='Number of units in hidden layer 2.')
parser.add_argument('--lr', type=float, default=0.01, help='Initial learning rate.')
parser.add_argument('--dropout', type=float, default=0., help='Dropout rate (1 - keep probability).')
parser.add_argument('--dataset-str', type=str, default='Citeseer', help='type of dataset.')

args,_ = parser.parse_known_args()

In [3]:
def load_data(adj_name):
    if adj_name == 'Citeseer':
        nodes_numbers = 3327
        datasets = Planetoid('./datasets', adj_name)
        edges = datasets[0].edge_index
        raw_edges = pd.DataFrame([[edges[0,i].item(), edges[1,i].item()] for i in range(edges.shape[1])])
    elif adj_name == 'wiki':
        nodes_numbers = 2405
        raw_edges = pd.read_csv('datasets/graph.txt', header=None, sep='\t')
    else:
        print("Dataset is not exist!")
    
    drop_self_loop = raw_edges[raw_edges[0]!=raw_edges[1]]
    
    graph_np = np.zeros((nodes_numbers, nodes_numbers))
    
    for i in range(drop_self_loop.shape[0]):
        graph_np[drop_self_loop.iloc[i,0], drop_self_loop.iloc[i,1]]=1
        graph_np[drop_self_loop.iloc[i,1], drop_self_loop.iloc[i,0]]=1
    
    adj = nx.adjacency_matrix(nx.from_numpy_matrix(graph_np))
    
    features = torch.eye(nodes_numbers)
    
    return adj, features

In [4]:
def gae_for(args):
    print("Using {} dataset".format(args.dataset_str))
    adj, features = load_data(args.dataset_str)
    
    n_nodes, feat_dim = features.shape

    # Store original adjacency matrix (without diagonal entries) for later
    adj_orig = adj
    adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
    adj_orig.eliminate_zeros()

    adj_train, train_edges, test_edges, test_edges_false = mask_test_edges(adj)
    adj = adj_train

    # Some preprocessing
    adj_norm = preprocess_graph(adj)
    adj_label = adj_train + sp.eye(adj_train.shape[0])
    # adj_label = sparse_to_tuple(adj_label)
    adj_label = torch.FloatTensor(adj_label.toarray())

    pos_weight = torch.tensor(float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum())
    norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)

    model = GCNModelVAE(feat_dim, args.hidden1, args.hidden2, args.dropout)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    hidden_emb = None
    for epoch in range(args.epochs):
        t = time.time()
        model.train()
        optimizer.zero_grad()
        recovered, mu, logvar = model(features, adj_norm)
        
        loss = loss_function(preds=recovered, labels=adj_label,
                             mu=mu, logvar=logvar, n_nodes=n_nodes,
                             norm=norm, pos_weight=pos_weight)
        loss.backward()
        cur_loss = loss.item()
        optimizer.step()

        hidden_emb = mu.data.numpy()
        #roc_curr, ap_curr = get_roc_score(hidden_emb, adj_orig, val_edges, val_edges_false)

        print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(cur_loss),
              "time=", "{:.5f}".format(time.time() - t)
              )

    print("Optimization Finished!")

    roc_score, ap_score = get_roc_score(hidden_emb, adj_orig, test_edges, test_edges_false)
    print('Test ROC score: ' + str(roc_score))
    print('Test AP score: ' + str(ap_score))
    
    return ap_score, roc_score

In [5]:
if __name__ == '__main__':
    all_ap = []
    all_auc = []
    for i in range(10):
        ap, auc = gae_for(args)
        all_ap.append(ap)
        all_auc.append(auc)

Using Citeseer dataset
Epoch: 0001 train_loss= 1.74101 time= 0.26613
Epoch: 0002 train_loss= 1.71212 time= 0.25900
Epoch: 0003 train_loss= 1.71588 time= 0.24676
Epoch: 0004 train_loss= 1.67775 time= 0.24304
Epoch: 0005 train_loss= 1.66441 time= 0.24448
Epoch: 0006 train_loss= 1.58269 time= 0.24700
Epoch: 0007 train_loss= 1.54663 time= 0.25199
Epoch: 0008 train_loss= 1.51801 time= 0.24500
Epoch: 0009 train_loss= 1.46603 time= 0.24775
Epoch: 0010 train_loss= 1.38706 time= 0.25076
Epoch: 0011 train_loss= 1.33222 time= 0.24328
Epoch: 0012 train_loss= 1.27194 time= 0.24834
Epoch: 0013 train_loss= 1.23723 time= 0.25322
Epoch: 0014 train_loss= 1.16156 time= 0.25312
Epoch: 0015 train_loss= 1.11263 time= 0.24600
Epoch: 0016 train_loss= 1.07372 time= 0.25400
Epoch: 0017 train_loss= 1.00034 time= 0.24900
Epoch: 0018 train_loss= 0.96299 time= 0.25578
Epoch: 0019 train_loss= 0.93171 time= 0.25100
Epoch: 0020 train_loss= 0.88694 time= 0.23879
Epoch: 0021 train_loss= 0.85013 time= 0.23876
Epoch: 0022

Epoch: 0179 train_loss= 0.42737 time= 0.23999
Epoch: 0180 train_loss= 0.42770 time= 0.25929
Epoch: 0181 train_loss= 0.42730 time= 0.24927
Epoch: 0182 train_loss= 0.42695 time= 0.24076
Epoch: 0183 train_loss= 0.42692 time= 0.24972
Epoch: 0184 train_loss= 0.42669 time= 0.25900
Epoch: 0185 train_loss= 0.42656 time= 0.24700
Epoch: 0186 train_loss= 0.42650 time= 0.24557
Epoch: 0187 train_loss= 0.42625 time= 0.24589
Epoch: 0188 train_loss= 0.42607 time= 0.24759
Epoch: 0189 train_loss= 0.42609 time= 0.24680
Epoch: 0190 train_loss= 0.42581 time= 0.25364
Epoch: 0191 train_loss= 0.42579 time= 0.25300
Epoch: 0192 train_loss= 0.42582 time= 0.25691
Epoch: 0193 train_loss= 0.42567 time= 0.25200
Epoch: 0194 train_loss= 0.42551 time= 0.25119
Epoch: 0195 train_loss= 0.42522 time= 0.24991
Epoch: 0196 train_loss= 0.42521 time= 0.25490
Epoch: 0197 train_loss= 0.42515 time= 0.24504
Epoch: 0198 train_loss= 0.42509 time= 0.24353
Epoch: 0199 train_loss= 0.42489 time= 0.25000
Epoch: 0200 train_loss= 0.42488 ti

Epoch: 0155 train_loss= 0.43119 time= 0.25600
Epoch: 0156 train_loss= 0.43084 time= 0.24773
Epoch: 0157 train_loss= 0.43080 time= 0.25700
Epoch: 0158 train_loss= 0.43045 time= 0.27787
Epoch: 0159 train_loss= 0.43016 time= 0.25400
Epoch: 0160 train_loss= 0.43000 time= 0.25100
Epoch: 0161 train_loss= 0.42999 time= 0.24420
Epoch: 0162 train_loss= 0.42988 time= 0.24617
Epoch: 0163 train_loss= 0.42969 time= 0.25269
Epoch: 0164 train_loss= 0.42969 time= 0.24870
Epoch: 0165 train_loss= 0.42961 time= 0.24400
Epoch: 0166 train_loss= 0.42925 time= 0.25100
Epoch: 0167 train_loss= 0.42937 time= 0.25200
Epoch: 0168 train_loss= 0.42900 time= 0.24529
Epoch: 0169 train_loss= 0.42905 time= 0.24814
Epoch: 0170 train_loss= 0.42883 time= 0.25020
Epoch: 0171 train_loss= 0.42884 time= 0.25699
Epoch: 0172 train_loss= 0.42842 time= 0.25600
Epoch: 0173 train_loss= 0.42818 time= 0.25500
Epoch: 0174 train_loss= 0.42818 time= 0.24700
Epoch: 0175 train_loss= 0.42829 time= 0.24900
Epoch: 0176 train_loss= 0.42788 ti

Epoch: 0131 train_loss= 0.44192 time= 0.25432
Epoch: 0132 train_loss= 0.44155 time= 0.25243
Epoch: 0133 train_loss= 0.44107 time= 0.24200
Epoch: 0134 train_loss= 0.44079 time= 0.24269
Epoch: 0135 train_loss= 0.44079 time= 0.23415
Epoch: 0136 train_loss= 0.44006 time= 0.24015
Epoch: 0137 train_loss= 0.44004 time= 0.25500
Epoch: 0138 train_loss= 0.43962 time= 0.24100
Epoch: 0139 train_loss= 0.43938 time= 0.24860
Epoch: 0140 train_loss= 0.43934 time= 0.25400
Epoch: 0141 train_loss= 0.43886 time= 0.25041
Epoch: 0142 train_loss= 0.43877 time= 0.24700
Epoch: 0143 train_loss= 0.43850 time= 0.23602
Epoch: 0144 train_loss= 0.43825 time= 0.23227
Epoch: 0145 train_loss= 0.43775 time= 0.25861
Epoch: 0146 train_loss= 0.43747 time= 0.24100
Epoch: 0147 train_loss= 0.43728 time= 0.23700
Epoch: 0148 train_loss= 0.43720 time= 0.25200
Epoch: 0149 train_loss= 0.43678 time= 0.24500
Epoch: 0150 train_loss= 0.43649 time= 0.25000
Epoch: 0151 train_loss= 0.43632 time= 0.23573
Epoch: 0152 train_loss= 0.43582 ti

Epoch: 0107 train_loss= 0.44265 time= 0.24626
Epoch: 0108 train_loss= 0.44194 time= 0.24821
Epoch: 0109 train_loss= 0.44167 time= 0.24001
Epoch: 0110 train_loss= 0.44140 time= 0.23721
Epoch: 0111 train_loss= 0.44110 time= 0.24850
Epoch: 0112 train_loss= 0.44078 time= 0.24609
Epoch: 0113 train_loss= 0.44024 time= 0.25000
Epoch: 0114 train_loss= 0.43992 time= 0.24700
Epoch: 0115 train_loss= 0.43976 time= 0.24389
Epoch: 0116 train_loss= 0.43907 time= 0.23843
Epoch: 0117 train_loss= 0.43880 time= 0.24437
Epoch: 0118 train_loss= 0.43868 time= 0.24024
Epoch: 0119 train_loss= 0.43803 time= 0.24600
Epoch: 0120 train_loss= 0.43793 time= 0.24042
Epoch: 0121 train_loss= 0.43771 time= 0.24500
Epoch: 0122 train_loss= 0.43772 time= 0.24500
Epoch: 0123 train_loss= 0.43695 time= 0.24200
Epoch: 0124 train_loss= 0.43684 time= 0.25200
Epoch: 0125 train_loss= 0.43676 time= 0.25689
Epoch: 0126 train_loss= 0.43622 time= 0.24177
Epoch: 0127 train_loss= 0.43593 time= 0.24612
Epoch: 0128 train_loss= 0.43554 ti

Epoch: 0083 train_loss= 0.46213 time= 0.24168
Epoch: 0084 train_loss= 0.46072 time= 0.24652
Epoch: 0085 train_loss= 0.45944 time= 0.23982
Epoch: 0086 train_loss= 0.45891 time= 0.24600
Epoch: 0087 train_loss= 0.45742 time= 0.24600
Epoch: 0088 train_loss= 0.45692 time= 0.24600
Epoch: 0089 train_loss= 0.45591 time= 0.25200
Epoch: 0090 train_loss= 0.45505 time= 0.24195
Epoch: 0091 train_loss= 0.45448 time= 0.24500
Epoch: 0092 train_loss= 0.45399 time= 0.24555
Epoch: 0093 train_loss= 0.45237 time= 0.23262
Epoch: 0094 train_loss= 0.45187 time= 0.25008
Epoch: 0095 train_loss= 0.45179 time= 0.25600
Epoch: 0096 train_loss= 0.45077 time= 0.24700
Epoch: 0097 train_loss= 0.44977 time= 0.24251
Epoch: 0098 train_loss= 0.44976 time= 0.23648
Epoch: 0099 train_loss= 0.44865 time= 0.24800
Epoch: 0100 train_loss= 0.44808 time= 0.24900
Epoch: 0101 train_loss= 0.44725 time= 0.23177
Epoch: 0102 train_loss= 0.44713 time= 0.24569
Epoch: 0103 train_loss= 0.44613 time= 0.26000
Epoch: 0104 train_loss= 0.44517 ti

Epoch: 0059 train_loss= 0.50687 time= 0.23869
Epoch: 0060 train_loss= 0.50562 time= 0.23656
Epoch: 0061 train_loss= 0.50289 time= 0.24800
Epoch: 0062 train_loss= 0.50067 time= 0.24465
Epoch: 0063 train_loss= 0.49896 time= 0.24900
Epoch: 0064 train_loss= 0.49713 time= 0.24522
Epoch: 0065 train_loss= 0.49445 time= 0.24043
Epoch: 0066 train_loss= 0.49198 time= 0.24379
Epoch: 0067 train_loss= 0.49057 time= 0.23955
Epoch: 0068 train_loss= 0.48874 time= 0.23715
Epoch: 0069 train_loss= 0.48632 time= 0.25600
Epoch: 0070 train_loss= 0.48316 time= 0.24200
Epoch: 0071 train_loss= 0.48126 time= 0.24200
Epoch: 0072 train_loss= 0.48007 time= 0.24826
Epoch: 0073 train_loss= 0.47781 time= 0.24674
Epoch: 0074 train_loss= 0.47666 time= 0.24800
Epoch: 0075 train_loss= 0.47481 time= 0.24105
Epoch: 0076 train_loss= 0.47466 time= 0.23191
Epoch: 0077 train_loss= 0.47280 time= 0.24977
Epoch: 0078 train_loss= 0.47191 time= 0.24700
Epoch: 0079 train_loss= 0.46982 time= 0.24301
Epoch: 0080 train_loss= 0.46891 ti

Epoch: 0035 train_loss= 0.65619 time= 0.24554
Epoch: 0036 train_loss= 0.64555 time= 0.24600
Epoch: 0037 train_loss= 0.64030 time= 0.24015
Epoch: 0038 train_loss= 0.62907 time= 0.25100
Epoch: 0039 train_loss= 0.62356 time= 0.24054
Epoch: 0040 train_loss= 0.61485 time= 0.23838
Epoch: 0041 train_loss= 0.60432 time= 0.24400
Epoch: 0042 train_loss= 0.59640 time= 0.24200
Epoch: 0043 train_loss= 0.58751 time= 0.24396
Epoch: 0044 train_loss= 0.57976 time= 0.24735
Epoch: 0045 train_loss= 0.57104 time= 0.24438
Epoch: 0046 train_loss= 0.56387 time= 0.24136
Epoch: 0047 train_loss= 0.56093 time= 0.23593
Epoch: 0048 train_loss= 0.55375 time= 0.23502
Epoch: 0049 train_loss= 0.54806 time= 0.23998
Epoch: 0050 train_loss= 0.54503 time= 0.22789
Epoch: 0051 train_loss= 0.53877 time= 0.24194
Epoch: 0052 train_loss= 0.53441 time= 0.24153
Epoch: 0053 train_loss= 0.52983 time= 0.23980
Epoch: 0054 train_loss= 0.52334 time= 0.23603
Epoch: 0055 train_loss= 0.51720 time= 0.23807
Epoch: 0056 train_loss= 0.51524 ti

Epoch: 0011 train_loss= 1.47999 time= 0.23764
Epoch: 0012 train_loss= 1.38476 time= 0.24438
Epoch: 0013 train_loss= 1.34153 time= 0.23963
Epoch: 0014 train_loss= 1.28602 time= 0.24449
Epoch: 0015 train_loss= 1.23328 time= 0.23975
Epoch: 0016 train_loss= 1.16161 time= 0.23802
Epoch: 0017 train_loss= 1.12385 time= 0.23525
Epoch: 0018 train_loss= 1.04563 time= 0.23566
Epoch: 0019 train_loss= 1.01215 time= 0.24253
Epoch: 0020 train_loss= 0.94716 time= 0.24452
Epoch: 0021 train_loss= 0.91233 time= 0.23409
Epoch: 0022 train_loss= 0.87317 time= 0.23883
Epoch: 0023 train_loss= 0.83433 time= 0.23969
Epoch: 0024 train_loss= 0.81264 time= 0.23268
Epoch: 0025 train_loss= 0.78742 time= 0.23226
Epoch: 0026 train_loss= 0.76835 time= 0.22837
Epoch: 0027 train_loss= 0.74405 time= 0.23584
Epoch: 0028 train_loss= 0.73117 time= 0.23831
Epoch: 0029 train_loss= 0.71535 time= 0.24165
Epoch: 0030 train_loss= 0.70889 time= 0.24929
Epoch: 0031 train_loss= 0.69323 time= 0.23630
Epoch: 0032 train_loss= 0.68779 ti

Epoch: 0190 train_loss= 0.42672 time= 0.24117
Epoch: 0191 train_loss= 0.42665 time= 0.22410
Epoch: 0192 train_loss= 0.42642 time= 0.23150
Epoch: 0193 train_loss= 0.42632 time= 0.22407
Epoch: 0194 train_loss= 0.42612 time= 0.23778
Epoch: 0195 train_loss= 0.42602 time= 0.22299
Epoch: 0196 train_loss= 0.42592 time= 0.23282
Epoch: 0197 train_loss= 0.42565 time= 0.22971
Epoch: 0198 train_loss= 0.42566 time= 0.23220
Epoch: 0199 train_loss= 0.42541 time= 0.23133
Epoch: 0200 train_loss= 0.42531 time= 0.23890
Optimization Finished!
Test ROC score: 0.7990435937688685
Test AP score: 0.8509787490893921
Using Citeseer dataset
Epoch: 0001 train_loss= 1.71796 time= 0.23684
Epoch: 0002 train_loss= 1.73496 time= 0.24928
Epoch: 0003 train_loss= 1.69387 time= 0.23697
Epoch: 0004 train_loss= 1.67338 time= 0.22946
Epoch: 0005 train_loss= 1.65285 time= 0.23088
Epoch: 0006 train_loss= 1.61031 time= 0.24285
Epoch: 0007 train_loss= 1.60126 time= 0.24703
Epoch: 0008 train_loss= 1.53305 time= 0.23503
Epoch: 0009

Epoch: 0166 train_loss= 0.42875 time= 0.24892
Epoch: 0167 train_loss= 0.42885 time= 0.24482
Epoch: 0168 train_loss= 0.42851 time= 0.24217
Epoch: 0169 train_loss= 0.42823 time= 0.23883
Epoch: 0170 train_loss= 0.42790 time= 0.24252
Epoch: 0171 train_loss= 0.42772 time= 0.23655
Epoch: 0172 train_loss= 0.42752 time= 0.23714
Epoch: 0173 train_loss= 0.42748 time= 0.23445
Epoch: 0174 train_loss= 0.42720 time= 0.24800
Epoch: 0175 train_loss= 0.42706 time= 0.24600
Epoch: 0176 train_loss= 0.42676 time= 0.24345
Epoch: 0177 train_loss= 0.42667 time= 0.24500
Epoch: 0178 train_loss= 0.42634 time= 0.23962
Epoch: 0179 train_loss= 0.42621 time= 0.24194
Epoch: 0180 train_loss= 0.42601 time= 0.23186
Epoch: 0181 train_loss= 0.42599 time= 0.23328
Epoch: 0182 train_loss= 0.42580 time= 0.23911
Epoch: 0183 train_loss= 0.42536 time= 0.24189
Epoch: 0184 train_loss= 0.42545 time= 0.24100
Epoch: 0185 train_loss= 0.42515 time= 0.23272
Epoch: 0186 train_loss= 0.42506 time= 0.24148
Epoch: 0187 train_loss= 0.42483 ti

Epoch: 0142 train_loss= 0.43328 time= 0.23725
Epoch: 0143 train_loss= 0.43300 time= 0.24968
Epoch: 0144 train_loss= 0.43306 time= 0.23751
Epoch: 0145 train_loss= 0.43258 time= 0.24170
Epoch: 0146 train_loss= 0.43227 time= 0.23458
Epoch: 0147 train_loss= 0.43236 time= 0.24188
Epoch: 0148 train_loss= 0.43204 time= 0.23575
Epoch: 0149 train_loss= 0.43175 time= 0.23796
Epoch: 0150 train_loss= 0.43171 time= 0.22881
Epoch: 0151 train_loss= 0.43133 time= 0.23694
Epoch: 0152 train_loss= 0.43115 time= 0.25094
Epoch: 0153 train_loss= 0.43072 time= 0.24539
Epoch: 0154 train_loss= 0.43075 time= 0.24447
Epoch: 0155 train_loss= 0.43047 time= 0.24514
Epoch: 0156 train_loss= 0.43065 time= 0.24047
Epoch: 0157 train_loss= 0.43021 time= 0.24559
Epoch: 0158 train_loss= 0.43005 time= 0.25910
Epoch: 0159 train_loss= 0.42981 time= 0.24242
Epoch: 0160 train_loss= 0.42957 time= 0.23959
Epoch: 0161 train_loss= 0.42957 time= 0.24215
Epoch: 0162 train_loss= 0.42947 time= 0.25306
Epoch: 0163 train_loss= 0.42888 ti

In [6]:
print('AP MEAN : ', np.array(all_ap).mean())
print('AP STD : ', np.array(all_ap).std())

print('AUC MEAN : ', np.array(all_auc).mean())
print('AUC STD : ', np.array(all_auc).std())

AP MEAN :  0.8438632087684479
AP STD :  0.013407738950997339
AUC MEAN :  0.7939017026929115
AUC STD :  0.016032884118874362
