In [1]:
from __future__ import division
from __future__ import print_function

import argparse
import time

import numpy as np
import scipy.sparse as sp
import torch
from torch import optim

from gae.model import GCNModelVAE
from gae.optimizer import loss_function
from gae.utils import load_data, mask_test_edges, preprocess_graph, get_roc_score

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='gcn_vae', help="models used")
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.')
parser.add_argument('--hidden1', type=int, default=32, help='Number of units in hidden layer 1.')
parser.add_argument('--hidden2', type=int, default=16, help='Number of units in hidden layer 2.')
parser.add_argument('--lr', type=float, default=0.01, help='Initial learning rate.')
parser.add_argument('--dropout', type=float, default=0., help='Dropout rate (1 - keep probability).')
parser.add_argument('--dataset-str', type=str, default='cora', help='type of dataset.')

args,_ = parser.parse_known_args()

In [3]:
def gae_for(args):
    print("Using {} dataset".format(args.dataset_str))
    adj, features = load_data(args.dataset_str)
    features = torch.eye(2708)
    
    n_nodes, feat_dim = features.shape

    # Store original adjacency matrix (without diagonal entries) for later
    adj_orig = adj
    adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
    adj_orig.eliminate_zeros()

    #adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges(adj)
    adj_train, train_edges, test_edges, test_edges_false = mask_test_edges(adj)
    adj = adj_train

    # Some preprocessing
    adj_norm = preprocess_graph(adj)
    adj_label = adj_train + sp.eye(adj_train.shape[0])
    # adj_label = sparse_to_tuple(adj_label)
    adj_label = torch.FloatTensor(adj_label.toarray())

    pos_weight = torch.tensor(float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum())
    norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)

    model = GCNModelVAE(feat_dim, args.hidden1, args.hidden2, args.dropout)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    hidden_emb = None
    for epoch in range(args.epochs):
        t = time.time()
        model.train()
        optimizer.zero_grad()
        recovered, mu, logvar = model(features, adj_norm)
        
        loss = loss_function(preds=recovered, labels=adj_label,
                             mu=mu, logvar=logvar, n_nodes=n_nodes,
                             norm=norm, pos_weight=pos_weight)
        loss.backward()
        cur_loss = loss.item()
        optimizer.step()

        hidden_emb = mu.data.numpy()
        #roc_curr, ap_curr = get_roc_score(hidden_emb, adj_orig, val_edges, val_edges_false)

        print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(cur_loss),
              "time=", "{:.5f}".format(time.time() - t)
              )

    print("Optimization Finished!")

    roc_score, ap_score = get_roc_score(hidden_emb, adj_orig, test_edges, test_edges_false)
    print('Test ROC score: ' + str(roc_score))
    print('Test AP score: ' + str(ap_score))
    
    return ap_score, roc_score

In [4]:
if __name__ == '__main__':
    all_ap = []
    all_auc = []
    for i in range(10):
        ap, auc = gae_for(args)
        all_ap.append(ap)
        all_auc.append(auc)
    
    print('AP MEAN : ', np.array(all_ap).mean())
    print('AP STD : ', np.array(all_ap).std())

    print('AUC MEAN : ', np.array(all_auc).mean())
    print('AUC STD : ', np.array(all_auc).std())

Using cora dataset
Epoch: 0001 train_loss= 1.72535 time= 0.19400
Epoch: 0002 train_loss= 1.72248 time= 0.16700
Epoch: 0003 train_loss= 1.68331 time= 0.16400
Epoch: 0004 train_loss= 1.66292 time= 0.16568
Epoch: 0005 train_loss= 1.66416 time= 0.17123
Epoch: 0006 train_loss= 1.58223 time= 0.16600
Epoch: 0007 train_loss= 1.54247 time= 0.16971
Epoch: 0008 train_loss= 1.51270 time= 0.17800
Epoch: 0009 train_loss= 1.46347 time= 0.16214
Epoch: 0010 train_loss= 1.42418 time= 0.14886
Epoch: 0011 train_loss= 1.36112 time= 0.17664
Epoch: 0012 train_loss= 1.28314 time= 0.16800
Epoch: 0013 train_loss= 1.23214 time= 0.17323
Epoch: 0014 train_loss= 1.16794 time= 0.18173
Epoch: 0015 train_loss= 1.09833 time= 0.18628
Epoch: 0016 train_loss= 1.07019 time= 0.17029
Epoch: 0017 train_loss= 1.02287 time= 0.17700
Epoch: 0018 train_loss= 0.95538 time= 0.17700
Epoch: 0019 train_loss= 0.90043 time= 0.17868
Epoch: 0020 train_loss= 0.87012 time= 0.17000
Epoch: 0021 train_loss= 0.83241 time= 0.17661
Epoch: 0022 tra

Epoch: 0180 train_loss= 0.42887 time= 0.16988
Epoch: 0181 train_loss= 0.42842 time= 0.17402
Epoch: 0182 train_loss= 0.42862 time= 0.18700
Epoch: 0183 train_loss= 0.42840 time= 0.17100
Epoch: 0184 train_loss= 0.42801 time= 0.15731
Epoch: 0185 train_loss= 0.42806 time= 0.15738
Epoch: 0186 train_loss= 0.42757 time= 0.18231
Epoch: 0187 train_loss= 0.42746 time= 0.17100
Epoch: 0188 train_loss= 0.42751 time= 0.17900
Epoch: 0189 train_loss= 0.42737 time= 0.17200
Epoch: 0190 train_loss= 0.42718 time= 0.17300
Epoch: 0191 train_loss= 0.42705 time= 0.16833
Epoch: 0192 train_loss= 0.42690 time= 0.16901
Epoch: 0193 train_loss= 0.42646 time= 0.18212
Epoch: 0194 train_loss= 0.42657 time= 0.18788
Epoch: 0195 train_loss= 0.42635 time= 0.17023
Epoch: 0196 train_loss= 0.42631 time= 0.17037
Epoch: 0197 train_loss= 0.42613 time= 0.17392
Epoch: 0198 train_loss= 0.42583 time= 0.18500
Epoch: 0199 train_loss= 0.42573 time= 0.17100
Epoch: 0200 train_loss= 0.42571 time= 0.17200
Optimization Finished!
Test ROC sc

Epoch: 0157 train_loss= 0.43448 time= 0.14749
Epoch: 0158 train_loss= 0.43407 time= 0.17699
Epoch: 0159 train_loss= 0.43402 time= 0.16200
Epoch: 0160 train_loss= 0.43361 time= 0.16600
Epoch: 0161 train_loss= 0.43348 time= 0.16400
Epoch: 0162 train_loss= 0.43332 time= 0.16012
Epoch: 0163 train_loss= 0.43312 time= 0.16900
Epoch: 0164 train_loss= 0.43297 time= 0.16200
Epoch: 0165 train_loss= 0.43282 time= 0.16300
Epoch: 0166 train_loss= 0.43262 time= 0.16200
Epoch: 0167 train_loss= 0.43239 time= 0.16600
Epoch: 0168 train_loss= 0.43218 time= 0.15654
Epoch: 0169 train_loss= 0.43199 time= 0.16650
Epoch: 0170 train_loss= 0.43191 time= 0.16175
Epoch: 0171 train_loss= 0.43174 time= 0.16500
Epoch: 0172 train_loss= 0.43156 time= 0.16228
Epoch: 0173 train_loss= 0.43141 time= 0.16772
Epoch: 0174 train_loss= 0.43115 time= 0.16600
Epoch: 0175 train_loss= 0.43105 time= 0.16500
Epoch: 0176 train_loss= 0.43079 time= 0.15902
Epoch: 0177 train_loss= 0.43069 time= 0.16453
Epoch: 0178 train_loss= 0.43070 ti

Epoch: 0133 train_loss= 0.44154 time= 0.16500
Epoch: 0134 train_loss= 0.44108 time= 0.16801
Epoch: 0135 train_loss= 0.44073 time= 0.16850
Epoch: 0136 train_loss= 0.44053 time= 0.14682
Epoch: 0137 train_loss= 0.44016 time= 0.16616
Epoch: 0138 train_loss= 0.44004 time= 0.16509
Epoch: 0139 train_loss= 0.43955 time= 0.16400
Epoch: 0140 train_loss= 0.43936 time= 0.16600
Epoch: 0141 train_loss= 0.43918 time= 0.16371
Epoch: 0142 train_loss= 0.43857 time= 0.16951
Epoch: 0143 train_loss= 0.43843 time= 0.16975
Epoch: 0144 train_loss= 0.43808 time= 0.16199
Epoch: 0145 train_loss= 0.43784 time= 0.16400
Epoch: 0146 train_loss= 0.43778 time= 0.16300
Epoch: 0147 train_loss= 0.43741 time= 0.16500
Epoch: 0148 train_loss= 0.43740 time= 0.14607
Epoch: 0149 train_loss= 0.43693 time= 0.15809
Epoch: 0150 train_loss= 0.43657 time= 0.16152
Epoch: 0151 train_loss= 0.43634 time= 0.16100
Epoch: 0152 train_loss= 0.43617 time= 0.16558
Epoch: 0153 train_loss= 0.43603 time= 0.16539
Epoch: 0154 train_loss= 0.43581 ti

Epoch: 0110 train_loss= 0.45063 time= 0.16400
Epoch: 0111 train_loss= 0.45014 time= 0.16000
Epoch: 0112 train_loss= 0.44992 time= 0.16800
Epoch: 0113 train_loss= 0.44945 time= 0.16845
Epoch: 0114 train_loss= 0.44944 time= 0.16655
Epoch: 0115 train_loss= 0.44878 time= 0.16669
Epoch: 0116 train_loss= 0.44852 time= 0.15229
Epoch: 0117 train_loss= 0.44788 time= 0.16865
Epoch: 0118 train_loss= 0.44729 time= 0.16100
Epoch: 0119 train_loss= 0.44699 time= 0.16721
Epoch: 0120 train_loss= 0.44688 time= 0.16674
Epoch: 0121 train_loss= 0.44674 time= 0.16200
Epoch: 0122 train_loss= 0.44632 time= 0.15900
Epoch: 0123 train_loss= 0.44597 time= 0.16200
Epoch: 0124 train_loss= 0.44541 time= 0.16400
Epoch: 0125 train_loss= 0.44565 time= 0.16500
Epoch: 0126 train_loss= 0.44511 time= 0.16439
Epoch: 0127 train_loss= 0.44455 time= 0.16100
Epoch: 0128 train_loss= 0.44424 time= 0.14940
Epoch: 0129 train_loss= 0.44399 time= 0.13838
Epoch: 0130 train_loss= 0.44367 time= 0.16555
Epoch: 0131 train_loss= 0.44367 ti

Epoch: 0087 train_loss= 0.47352 time= 0.16600
Epoch: 0088 train_loss= 0.47275 time= 0.16700
Epoch: 0089 train_loss= 0.47061 time= 0.16700
Epoch: 0090 train_loss= 0.47013 time= 0.17800
Epoch: 0091 train_loss= 0.46805 time= 0.16800
Epoch: 0092 train_loss= 0.46775 time= 0.16500
Epoch: 0093 train_loss= 0.46646 time= 0.16500
Epoch: 0094 train_loss= 0.46517 time= 0.16400
Epoch: 0095 train_loss= 0.46414 time= 0.16451
Epoch: 0096 train_loss= 0.46322 time= 0.14646
Epoch: 0097 train_loss= 0.46208 time= 0.16773
Epoch: 0098 train_loss= 0.46150 time= 0.16701
Epoch: 0099 train_loss= 0.46047 time= 0.17080
Epoch: 0100 train_loss= 0.45965 time= 0.15692
Epoch: 0101 train_loss= 0.45929 time= 0.16200
Epoch: 0102 train_loss= 0.45792 time= 0.16200
Epoch: 0103 train_loss= 0.45696 time= 0.16500
Epoch: 0104 train_loss= 0.45683 time= 0.15900
Epoch: 0105 train_loss= 0.45567 time= 0.17400
Epoch: 0106 train_loss= 0.45490 time= 0.16300
Epoch: 0107 train_loss= 0.45412 time= 0.16200
Epoch: 0108 train_loss= 0.45371 ti

Epoch: 0063 train_loss= 0.48670 time= 0.15800
Epoch: 0064 train_loss= 0.48557 time= 0.15624
Epoch: 0065 train_loss= 0.48465 time= 0.15846
Epoch: 0066 train_loss= 0.48346 time= 0.17000
Epoch: 0067 train_loss= 0.48125 time= 0.16700
Epoch: 0068 train_loss= 0.48006 time= 0.16800
Epoch: 0069 train_loss= 0.47821 time= 0.16580
Epoch: 0070 train_loss= 0.47636 time= 0.16526
Epoch: 0071 train_loss= 0.47470 time= 0.16450
Epoch: 0072 train_loss= 0.47358 time= 0.16673
Epoch: 0073 train_loss= 0.47142 time= 0.16480
Epoch: 0074 train_loss= 0.47027 time= 0.16740
Epoch: 0075 train_loss= 0.46898 time= 0.16276
Epoch: 0076 train_loss= 0.46885 time= 0.14294
Epoch: 0077 train_loss= 0.46674 time= 0.15997
Epoch: 0078 train_loss= 0.46581 time= 0.17319
Epoch: 0079 train_loss= 0.46505 time= 0.16539
Epoch: 0080 train_loss= 0.46444 time= 0.16961
Epoch: 0081 train_loss= 0.46366 time= 0.16143
Epoch: 0082 train_loss= 0.46263 time= 0.16108
Epoch: 0083 train_loss= 0.46199 time= 0.16400
Epoch: 0084 train_loss= 0.46147 ti

Epoch: 0039 train_loss= 0.66843 time= 0.16600
Epoch: 0040 train_loss= 0.66328 time= 0.16500
Epoch: 0041 train_loss= 0.65584 time= 0.16500
Epoch: 0042 train_loss= 0.64724 time= 0.15661
Epoch: 0043 train_loss= 0.64087 time= 0.16243
Epoch: 0044 train_loss= 0.63057 time= 0.15662
Epoch: 0045 train_loss= 0.62239 time= 0.16846
Epoch: 0046 train_loss= 0.61040 time= 0.16300
Epoch: 0047 train_loss= 0.60018 time= 0.16900
Epoch: 0048 train_loss= 0.58657 time= 0.16100
Epoch: 0049 train_loss= 0.57737 time= 0.16800
Epoch: 0050 train_loss= 0.56500 time= 0.16418
Epoch: 0051 train_loss= 0.55372 time= 0.16200
Epoch: 0052 train_loss= 0.54667 time= 0.17376
Epoch: 0053 train_loss= 0.53888 time= 0.16600
Epoch: 0054 train_loss= 0.53324 time= 0.15983
Epoch: 0055 train_loss= 0.52968 time= 0.16641
Epoch: 0056 train_loss= 0.52465 time= 0.15338
Epoch: 0057 train_loss= 0.52323 time= 0.16126
Epoch: 0058 train_loss= 0.51967 time= 0.16400
Epoch: 0059 train_loss= 0.51868 time= 0.16626
Epoch: 0060 train_loss= 0.51654 ti

Epoch: 0015 train_loss= 1.23594 time= 0.15522
Epoch: 0016 train_loss= 1.18385 time= 0.16468
Epoch: 0017 train_loss= 1.13141 time= 0.15145
Epoch: 0018 train_loss= 1.08197 time= 0.16713
Epoch: 0019 train_loss= 1.02929 time= 0.14870
Epoch: 0020 train_loss= 0.98541 time= 0.16802
Epoch: 0021 train_loss= 0.93364 time= 0.16542
Epoch: 0022 train_loss= 0.89963 time= 0.15974
Epoch: 0023 train_loss= 0.86875 time= 0.15554
Epoch: 0024 train_loss= 0.82753 time= 0.15771
Epoch: 0025 train_loss= 0.80154 time= 0.15815
Epoch: 0026 train_loss= 0.78155 time= 0.14922
Epoch: 0027 train_loss= 0.75879 time= 0.15436
Epoch: 0028 train_loss= 0.74983 time= 0.14601
Epoch: 0029 train_loss= 0.73587 time= 0.17987
Epoch: 0030 train_loss= 0.72538 time= 0.15520
Epoch: 0031 train_loss= 0.71696 time= 0.14433
Epoch: 0032 train_loss= 0.71407 time= 0.16994
Epoch: 0033 train_loss= 0.70578 time= 0.14682
Epoch: 0034 train_loss= 0.69849 time= 0.15502
Epoch: 0035 train_loss= 0.69536 time= 0.15130
Epoch: 0036 train_loss= 0.68988 ti

Epoch: 0194 train_loss= 0.42929 time= 0.16197
Epoch: 0195 train_loss= 0.42894 time= 0.14842
Epoch: 0196 train_loss= 0.42902 time= 0.15956
Epoch: 0197 train_loss= 0.42864 time= 0.16385
Epoch: 0198 train_loss= 0.42846 time= 0.16218
Epoch: 0199 train_loss= 0.42838 time= 0.15663
Epoch: 0200 train_loss= 0.42824 time= 0.16491
Optimization Finished!
Test ROC score: 0.8590028408988619
Test AP score: 0.893500137890329
Using cora dataset
Epoch: 0001 train_loss= 1.73657 time= 0.15928
Epoch: 0002 train_loss= 1.71445 time= 0.15906
Epoch: 0003 train_loss= 1.67617 time= 0.14936
Epoch: 0004 train_loss= 1.64920 time= 0.14657
Epoch: 0005 train_loss= 1.59628 time= 0.15063
Epoch: 0006 train_loss= 1.57956 time= 0.13635
Epoch: 0007 train_loss= 1.48756 time= 0.15342
Epoch: 0008 train_loss= 1.49454 time= 0.15848
Epoch: 0009 train_loss= 1.43198 time= 0.14651
Epoch: 0010 train_loss= 1.35352 time= 0.15307
Epoch: 0011 train_loss= 1.29785 time= 0.15835
Epoch: 0012 train_loss= 1.23482 time= 0.14677
Epoch: 0013 trai

Epoch: 0171 train_loss= 0.42920 time= 0.16501
Epoch: 0172 train_loss= 0.42926 time= 0.16941
Epoch: 0173 train_loss= 0.42897 time= 0.16379
Epoch: 0174 train_loss= 0.42882 time= 0.16789
Epoch: 0175 train_loss= 0.42867 time= 0.16557
Epoch: 0176 train_loss= 0.42846 time= 0.15246
Epoch: 0177 train_loss= 0.42829 time= 0.16791
Epoch: 0178 train_loss= 0.42810 time= 0.16500
Epoch: 0179 train_loss= 0.42792 time= 0.16199
Epoch: 0180 train_loss= 0.42774 time= 0.16300
Epoch: 0181 train_loss= 0.42765 time= 0.17271
Epoch: 0182 train_loss= 0.42752 time= 0.16606
Epoch: 0183 train_loss= 0.42753 time= 0.16796
Epoch: 0184 train_loss= 0.42728 time= 0.16599
Epoch: 0185 train_loss= 0.42697 time= 0.15800
Epoch: 0186 train_loss= 0.42694 time= 0.17600
Epoch: 0187 train_loss= 0.42665 time= 0.16228
Epoch: 0188 train_loss= 0.42656 time= 0.15232
Epoch: 0189 train_loss= 0.42636 time= 0.15700
Epoch: 0190 train_loss= 0.42638 time= 0.16855
Epoch: 0191 train_loss= 0.42632 time= 0.16635
Epoch: 0192 train_loss= 0.42610 ti

Epoch: 0147 train_loss= 0.43898 time= 0.16500
Epoch: 0148 train_loss= 0.43835 time= 0.16600
Epoch: 0149 train_loss= 0.43834 time= 0.16200
Epoch: 0150 train_loss= 0.43782 time= 0.17700
Epoch: 0151 train_loss= 0.43749 time= 0.15600
Epoch: 0152 train_loss= 0.43720 time= 0.17900
Epoch: 0153 train_loss= 0.43699 time= 0.16000
Epoch: 0154 train_loss= 0.43657 time= 0.17600
Epoch: 0155 train_loss= 0.43638 time= 0.16210
Epoch: 0156 train_loss= 0.43614 time= 0.14906
Epoch: 0157 train_loss= 0.43596 time= 0.17043
Epoch: 0158 train_loss= 0.43574 time= 0.16278
Epoch: 0159 train_loss= 0.43523 time= 0.16282
Epoch: 0160 train_loss= 0.43490 time= 0.16219
Epoch: 0161 train_loss= 0.43480 time= 0.16647
Epoch: 0162 train_loss= 0.43459 time= 0.16592
Epoch: 0163 train_loss= 0.43429 time= 0.16525
Epoch: 0164 train_loss= 0.43433 time= 0.16323
Epoch: 0165 train_loss= 0.43374 time= 0.16722
Epoch: 0166 train_loss= 0.43389 time= 0.16082
Epoch: 0167 train_loss= 0.43323 time= 0.14597
Epoch: 0168 train_loss= 0.43338 ti