In [6]:
from __future__ import division
from __future__ import print_function

import argparse
import time

import numpy as np
import pandas as pd
import scipy.sparse as sp
import torch
from torch import optim
import networkx as nx

from gae.model import GCNModelVAE
from gae.optimizer import loss_function
from gae.utils import mask_test_edges, preprocess_graph, get_roc_score

In [7]:
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='gcn_vae', help="models used")
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.')
parser.add_argument('--hidden1', type=int, default=32, help='Number of units in hidden layer 1.')
parser.add_argument('--hidden2', type=int, default=16, help='Number of units in hidden layer 2.')
parser.add_argument('--lr', type=float, default=0.01, help='Initial learning rate.')
parser.add_argument('--dropout', type=float, default=0., help='Dropout rate (1 - keep probability).')
parser.add_argument('--dataset-str', type=str, default='wiki', help='type of dataset.')

args,_ = parser.parse_known_args()

In [8]:
def load_data(adj_name):
    if adj_name == 'wiki':
        nodes_numbers = 2405
        raw_edges = pd.read_csv('datasets/wiki.txt', header=None, sep='\t')
    elif adj_name == 'Citeseer':
        nodes_numbers = 3327
        datasets = Planetoid('./datasets', adj_name)
        edges = datasets[0].edge_index
        raw_edges = pd.DataFrame([[edges[0,i].item(), edges[1,i].item()] for i in range(edges.shape[1])])
    else:
        print("Dataset is not exist!")
    
    drop_self_loop = raw_edges[raw_edges[0]!=raw_edges[1]]
    
    graph_np = np.zeros((nodes_numbers, nodes_numbers))
    
    for i in range(drop_self_loop.shape[0]):
        graph_np[drop_self_loop.iloc[i,0], drop_self_loop.iloc[i,1]]=1
        graph_np[drop_self_loop.iloc[i,1], drop_self_loop.iloc[i,0]]=1
    
    adj = nx.adjacency_matrix(nx.from_numpy_matrix(graph_np))
    
    features = torch.eye(nodes_numbers)
    
    return adj, features

In [9]:
def gae_for(args):
    print("Using {} dataset".format(args.dataset_str))
    adj, features = load_data(args.dataset_str)
    
    n_nodes, feat_dim = features.shape

    # Store original adjacency matrix (without diagonal entries) for later
    adj_orig = adj
    adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
    adj_orig.eliminate_zeros()

    adj_train, train_edges, test_edges, test_edges_false = mask_test_edges(adj)
    adj = adj_train

    # Some preprocessing
    adj_norm = preprocess_graph(adj)
    adj_label = adj_train + sp.eye(adj_train.shape[0])
    # adj_label = sparse_to_tuple(adj_label)
    adj_label = torch.FloatTensor(adj_label.toarray())

    pos_weight = torch.tensor(float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum())
    norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)

    model = GCNModelVAE(feat_dim, args.hidden1, args.hidden2, args.dropout)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    hidden_emb = None
    for epoch in range(args.epochs):
        t = time.time()
        model.train()
        optimizer.zero_grad()
        recovered, mu, logvar = model(features, adj_norm)
        
        loss = loss_function(preds=recovered, labels=adj_label,
                             mu=mu, logvar=logvar, n_nodes=n_nodes,
                             norm=norm, pos_weight=pos_weight)
        loss.backward()
        cur_loss = loss.item()
        optimizer.step()

        hidden_emb = mu.data.numpy()
        #roc_curr, ap_curr = get_roc_score(hidden_emb, adj_orig, val_edges, val_edges_false)

        print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(cur_loss),
              "time=", "{:.5f}".format(time.time() - t)
              )

    print("Optimization Finished!")

    roc_score, ap_score = get_roc_score(hidden_emb, adj_orig, test_edges, test_edges_false)
    print('Test ROC score: ' + str(roc_score))
    print('Test AP score: ' + str(ap_score))
    
    return ap_score, roc_score

In [10]:
if __name__ == '__main__':
    all_ap = []
    all_auc = []
    for i in range(10):
        ap, auc = gae_for(args)
        all_ap.append(ap)
        all_auc.append(auc)
    
    print('AP MEAN : ', np.array(all_ap).mean())
    print('AP STD : ', np.array(all_ap).std())

    print('AUC MEAN : ', np.array(all_auc).mean())
    print('AUC STD : ', np.array(all_auc).std())

Using wiki dataset
Epoch: 0001 train_loss= 1.75667 time= 0.16700
Epoch: 0002 train_loss= 1.71062 time= 0.12495
Epoch: 0003 train_loss= 1.68586 time= 0.13063
Epoch: 0004 train_loss= 1.63165 time= 0.13504
Epoch: 0005 train_loss= 1.60422 time= 0.13850
Epoch: 0006 train_loss= 1.55333 time= 0.13500
Epoch: 0007 train_loss= 1.52418 time= 0.13367
Epoch: 0008 train_loss= 1.45356 time= 0.13900
Epoch: 0009 train_loss= 1.39788 time= 0.14300
Epoch: 0010 train_loss= 1.35814 time= 0.13800
Epoch: 0011 train_loss= 1.27232 time= 0.14000
Epoch: 0012 train_loss= 1.19995 time= 0.14258
Epoch: 0013 train_loss= 1.14696 time= 0.13542
Epoch: 0014 train_loss= 1.08224 time= 0.13574
Epoch: 0015 train_loss= 1.02097 time= 0.13557
Epoch: 0016 train_loss= 0.98534 time= 0.13700
Epoch: 0017 train_loss= 0.93892 time= 0.13177
Epoch: 0018 train_loss= 0.90351 time= 0.12394
Epoch: 0019 train_loss= 0.85272 time= 0.13985
Epoch: 0020 train_loss= 0.82014 time= 0.13604
Epoch: 0021 train_loss= 0.78925 time= 0.13857
Epoch: 0022 tra

Epoch: 0179 train_loss= 0.44699 time= 0.14496
Epoch: 0180 train_loss= 0.44632 time= 0.15600
Epoch: 0181 train_loss= 0.44584 time= 0.14400
Epoch: 0182 train_loss= 0.44596 time= 0.14724
Epoch: 0183 train_loss= 0.44528 time= 0.13331
Epoch: 0184 train_loss= 0.44507 time= 0.15152
Epoch: 0185 train_loss= 0.44458 time= 0.14507
Epoch: 0186 train_loss= 0.44420 time= 0.13861
Epoch: 0187 train_loss= 0.44403 time= 0.13689
Epoch: 0188 train_loss= 0.44358 time= 0.15300
Epoch: 0189 train_loss= 0.44350 time= 0.14700
Epoch: 0190 train_loss= 0.44270 time= 0.14600
Epoch: 0191 train_loss= 0.44262 time= 0.14600
Epoch: 0192 train_loss= 0.44252 time= 0.14608
Epoch: 0193 train_loss= 0.44170 time= 0.14600
Epoch: 0194 train_loss= 0.44169 time= 0.14253
Epoch: 0195 train_loss= 0.44102 time= 0.14142
Epoch: 0196 train_loss= 0.44085 time= 0.15100
Epoch: 0197 train_loss= 0.44082 time= 0.15800
Epoch: 0198 train_loss= 0.44041 time= 0.15053
Epoch: 0199 train_loss= 0.44056 time= 0.14239
Epoch: 0200 train_loss= 0.43985 ti

Epoch: 0155 train_loss= 0.45420 time= 0.14050
Epoch: 0156 train_loss= 0.45358 time= 0.13650
Epoch: 0157 train_loss= 0.45343 time= 0.13975
Epoch: 0158 train_loss= 0.45314 time= 0.14800
Epoch: 0159 train_loss= 0.45309 time= 0.14802
Epoch: 0160 train_loss= 0.45279 time= 0.14440
Epoch: 0161 train_loss= 0.45239 time= 0.13858
Epoch: 0162 train_loss= 0.45261 time= 0.13500
Epoch: 0163 train_loss= 0.45198 time= 0.14100
Epoch: 0164 train_loss= 0.45163 time= 0.13544
Epoch: 0165 train_loss= 0.45186 time= 0.13200
Epoch: 0166 train_loss= 0.45159 time= 0.14860
Epoch: 0167 train_loss= 0.45111 time= 0.14386
Epoch: 0168 train_loss= 0.45113 time= 0.14700
Epoch: 0169 train_loss= 0.45087 time= 0.14466
Epoch: 0170 train_loss= 0.45075 time= 0.12807
Epoch: 0171 train_loss= 0.45044 time= 0.13026
Epoch: 0172 train_loss= 0.45028 time= 0.13300
Epoch: 0173 train_loss= 0.45033 time= 0.13600
Epoch: 0174 train_loss= 0.44997 time= 0.14000
Epoch: 0175 train_loss= 0.44980 time= 0.13886
Epoch: 0176 train_loss= 0.44967 ti

Epoch: 0131 train_loss= 0.46134 time= 0.13300
Epoch: 0132 train_loss= 0.46035 time= 0.12010
Epoch: 0133 train_loss= 0.45954 time= 0.14947
Epoch: 0134 train_loss= 0.45930 time= 0.13564
Epoch: 0135 train_loss= 0.45892 time= 0.13900
Epoch: 0136 train_loss= 0.45833 time= 0.13400
Epoch: 0137 train_loss= 0.45802 time= 0.14557
Epoch: 0138 train_loss= 0.45714 time= 0.13167
Epoch: 0139 train_loss= 0.45755 time= 0.13500
Epoch: 0140 train_loss= 0.45661 time= 0.13001
Epoch: 0141 train_loss= 0.45632 time= 0.14699
Epoch: 0142 train_loss= 0.45561 time= 0.13400
Epoch: 0143 train_loss= 0.45511 time= 0.14000
Epoch: 0144 train_loss= 0.45446 time= 0.13613
Epoch: 0145 train_loss= 0.45390 time= 0.14100
Epoch: 0146 train_loss= 0.45359 time= 0.12955
Epoch: 0147 train_loss= 0.45300 time= 0.12948
Epoch: 0148 train_loss= 0.45244 time= 0.13479
Epoch: 0149 train_loss= 0.45228 time= 0.13699
Epoch: 0150 train_loss= 0.45137 time= 0.13100
Epoch: 0151 train_loss= 0.45094 time= 0.13701
Epoch: 0152 train_loss= 0.45071 ti

Epoch: 0107 train_loss= 0.50532 time= 0.13400
Epoch: 0108 train_loss= 0.50432 time= 0.13400
Epoch: 0109 train_loss= 0.50363 time= 0.13300
Epoch: 0110 train_loss= 0.50286 time= 0.12585
Epoch: 0111 train_loss= 0.50139 time= 0.13514
Epoch: 0112 train_loss= 0.50162 time= 0.13470
Epoch: 0113 train_loss= 0.49999 time= 0.13958
Epoch: 0114 train_loss= 0.49925 time= 0.13600
Epoch: 0115 train_loss= 0.49868 time= 0.13979
Epoch: 0116 train_loss= 0.49854 time= 0.13600
Epoch: 0117 train_loss= 0.49750 time= 0.13500
Epoch: 0118 train_loss= 0.49708 time= 0.13600
Epoch: 0119 train_loss= 0.49607 time= 0.13500
Epoch: 0120 train_loss= 0.49475 time= 0.13728
Epoch: 0121 train_loss= 0.49497 time= 0.13672
Epoch: 0122 train_loss= 0.49356 time= 0.14100
Epoch: 0123 train_loss= 0.49341 time= 0.13610
Epoch: 0124 train_loss= 0.49262 time= 0.12101
Epoch: 0125 train_loss= 0.49129 time= 0.12841
Epoch: 0126 train_loss= 0.49089 time= 0.11156
Epoch: 0127 train_loss= 0.48995 time= 0.15536
Epoch: 0128 train_loss= 0.48944 ti

Epoch: 0083 train_loss= 0.50317 time= 0.14217
Epoch: 0084 train_loss= 0.50221 time= 0.13600
Epoch: 0085 train_loss= 0.50111 time= 0.13689
Epoch: 0086 train_loss= 0.50029 time= 0.12442
Epoch: 0087 train_loss= 0.49858 time= 0.14360
Epoch: 0088 train_loss= 0.49740 time= 0.13698
Epoch: 0089 train_loss= 0.49666 time= 0.14110
Epoch: 0090 train_loss= 0.49551 time= 0.13272
Epoch: 0091 train_loss= 0.49379 time= 0.14300
Epoch: 0092 train_loss= 0.49316 time= 0.13595
Epoch: 0093 train_loss= 0.49208 time= 0.13546
Epoch: 0094 train_loss= 0.49089 time= 0.13554
Epoch: 0095 train_loss= 0.48999 time= 0.13800
Epoch: 0096 train_loss= 0.48902 time= 0.13456
Epoch: 0097 train_loss= 0.48738 time= 0.13541
Epoch: 0098 train_loss= 0.48606 time= 0.13600
Epoch: 0099 train_loss= 0.48557 time= 0.13800
Epoch: 0100 train_loss= 0.48419 time= 0.13055
Epoch: 0101 train_loss= 0.48289 time= 0.11650
Epoch: 0102 train_loss= 0.48143 time= 0.13585
Epoch: 0103 train_loss= 0.48083 time= 0.13681
Epoch: 0104 train_loss= 0.47960 ti

Epoch: 0059 train_loss= 0.56849 time= 0.14047
Epoch: 0060 train_loss= 0.56480 time= 0.13993
Epoch: 0061 train_loss= 0.56096 time= 0.14400
Epoch: 0062 train_loss= 0.56008 time= 0.13712
Epoch: 0063 train_loss= 0.55705 time= 0.13039
Epoch: 0064 train_loss= 0.55425 time= 0.13513
Epoch: 0065 train_loss= 0.55108 time= 0.14659
Epoch: 0066 train_loss= 0.54859 time= 0.13472
Epoch: 0067 train_loss= 0.54516 time= 0.13633
Epoch: 0068 train_loss= 0.54199 time= 0.14172
Epoch: 0069 train_loss= 0.53870 time= 0.13765
Epoch: 0070 train_loss= 0.53545 time= 0.13400
Epoch: 0071 train_loss= 0.53266 time= 0.13600
Epoch: 0072 train_loss= 0.52781 time= 0.13849
Epoch: 0073 train_loss= 0.52513 time= 0.13600
Epoch: 0074 train_loss= 0.52197 time= 0.13725
Epoch: 0075 train_loss= 0.51911 time= 0.13400
Epoch: 0076 train_loss= 0.51774 time= 0.14000
Epoch: 0077 train_loss= 0.51421 time= 0.13100
Epoch: 0078 train_loss= 0.51192 time= 0.13472
Epoch: 0079 train_loss= 0.50937 time= 0.13239
Epoch: 0080 train_loss= 0.50769 ti

Epoch: 0035 train_loss= 0.65050 time= 0.14200
Epoch: 0036 train_loss= 0.64529 time= 0.13836
Epoch: 0037 train_loss= 0.64125 time= 0.14000
Epoch: 0038 train_loss= 0.63700 time= 0.14100
Epoch: 0039 train_loss= 0.63214 time= 0.13414
Epoch: 0040 train_loss= 0.63031 time= 0.13132
Epoch: 0041 train_loss= 0.62458 time= 0.14716
Epoch: 0042 train_loss= 0.62096 time= 0.14200
Epoch: 0043 train_loss= 0.61532 time= 0.13800
Epoch: 0044 train_loss= 0.60929 time= 0.13401
Epoch: 0045 train_loss= 0.60208 time= 0.14199
Epoch: 0046 train_loss= 0.59843 time= 0.13342
Epoch: 0047 train_loss= 0.59243 time= 0.13658
Epoch: 0048 train_loss= 0.58694 time= 0.14400
Epoch: 0049 train_loss= 0.58351 time= 0.14500
Epoch: 0050 train_loss= 0.57810 time= 0.13500
Epoch: 0051 train_loss= 0.57494 time= 0.13900
Epoch: 0052 train_loss= 0.57203 time= 0.13800
Epoch: 0053 train_loss= 0.56974 time= 0.13338
Epoch: 0054 train_loss= 0.56780 time= 0.13072
Epoch: 0055 train_loss= 0.56591 time= 0.13033
Epoch: 0056 train_loss= 0.56346 ti

Epoch: 0011 train_loss= 1.27469 time= 0.14100
Epoch: 0012 train_loss= 1.20772 time= 0.13896
Epoch: 0013 train_loss= 1.14679 time= 0.15100
Epoch: 0014 train_loss= 1.08777 time= 0.13500
Epoch: 0015 train_loss= 1.03303 time= 0.13200
Epoch: 0016 train_loss= 0.99961 time= 0.12791
Epoch: 0017 train_loss= 0.94741 time= 0.13365
Epoch: 0018 train_loss= 0.90961 time= 0.13900
Epoch: 0019 train_loss= 0.87186 time= 0.13611
Epoch: 0020 train_loss= 0.83654 time= 0.13900
Epoch: 0021 train_loss= 0.80106 time= 0.13600
Epoch: 0022 train_loss= 0.78205 time= 0.13557
Epoch: 0023 train_loss= 0.76276 time= 0.13600
Epoch: 0024 train_loss= 0.74585 time= 0.13539
Epoch: 0025 train_loss= 0.72210 time= 0.13900
Epoch: 0026 train_loss= 0.72121 time= 0.13000
Epoch: 0027 train_loss= 0.71186 time= 0.15277
Epoch: 0028 train_loss= 0.70377 time= 0.14000
Epoch: 0029 train_loss= 0.69748 time= 0.13900
Epoch: 0030 train_loss= 0.69344 time= 0.12129
Epoch: 0031 train_loss= 0.68647 time= 0.13188
Epoch: 0032 train_loss= 0.68394 ti

Epoch: 0191 train_loss= 0.44838 time= 0.13980
Epoch: 0192 train_loss= 0.44807 time= 0.12934
Epoch: 0193 train_loss= 0.44770 time= 0.13714
Epoch: 0194 train_loss= 0.44745 time= 0.13540
Epoch: 0195 train_loss= 0.44746 time= 0.13504
Epoch: 0196 train_loss= 0.44680 time= 0.13634
Epoch: 0197 train_loss= 0.44678 time= 0.13801
Epoch: 0198 train_loss= 0.44628 time= 0.13800
Epoch: 0199 train_loss= 0.44599 time= 0.13581
Epoch: 0200 train_loss= 0.44586 time= 0.14000
Optimization Finished!
Test ROC score: 0.9130985996228638
Test AP score: 0.9340347226963028
Using wiki dataset
Epoch: 0001 train_loss= 1.74965 time= 0.14000
Epoch: 0002 train_loss= 1.71054 time= 0.15000
Epoch: 0003 train_loss= 1.67907 time= 0.15500
Epoch: 0004 train_loss= 1.65861 time= 0.14599
Epoch: 0005 train_loss= 1.60465 time= 0.15200
Epoch: 0006 train_loss= 1.56539 time= 0.15100
Epoch: 0007 train_loss= 1.51364 time= 0.15100
Epoch: 0008 train_loss= 1.47338 time= 0.13918
Epoch: 0009 train_loss= 1.41252 time= 0.14800
Epoch: 0010 tra

Epoch: 0167 train_loss= 0.45181 time= 0.13800
Epoch: 0168 train_loss= 0.45121 time= 0.13100
Epoch: 0169 train_loss= 0.45082 time= 0.13400
Epoch: 0170 train_loss= 0.45042 time= 0.13582
Epoch: 0171 train_loss= 0.44992 time= 0.13800
Epoch: 0172 train_loss= 0.44955 time= 0.14023
Epoch: 0173 train_loss= 0.44915 time= 0.13500
Epoch: 0174 train_loss= 0.44889 time= 0.14198
Epoch: 0175 train_loss= 0.44853 time= 0.13996
Epoch: 0176 train_loss= 0.44834 time= 0.14429
Epoch: 0177 train_loss= 0.44751 time= 0.14814
Epoch: 0178 train_loss= 0.44722 time= 0.13886
Epoch: 0179 train_loss= 0.44696 time= 0.13700
Epoch: 0180 train_loss= 0.44652 time= 0.13400
Epoch: 0181 train_loss= 0.44623 time= 0.13691
Epoch: 0182 train_loss= 0.44586 time= 0.14048
Epoch: 0183 train_loss= 0.44561 time= 0.13652
Epoch: 0184 train_loss= 0.44527 time= 0.13423
Epoch: 0185 train_loss= 0.44482 time= 0.13800
Epoch: 0186 train_loss= 0.44446 time= 0.14098
Epoch: 0187 train_loss= 0.44437 time= 0.14020
Epoch: 0188 train_loss= 0.44424 ti

Epoch: 0143 train_loss= 0.46233 time= 0.13600
Epoch: 0144 train_loss= 0.46153 time= 0.14464
Epoch: 0145 train_loss= 0.46143 time= 0.13247
Epoch: 0146 train_loss= 0.46083 time= 0.13453
Epoch: 0147 train_loss= 0.46031 time= 0.13300
Epoch: 0148 train_loss= 0.45987 time= 0.13500
Epoch: 0149 train_loss= 0.45974 time= 0.13400
Epoch: 0150 train_loss= 0.45915 time= 0.13900
Epoch: 0151 train_loss= 0.45873 time= 0.13280
Epoch: 0152 train_loss= 0.45868 time= 0.12355
Epoch: 0153 train_loss= 0.45828 time= 0.13068
Epoch: 0154 train_loss= 0.45830 time= 0.13200
Epoch: 0155 train_loss= 0.45763 time= 0.13400
Epoch: 0156 train_loss= 0.45735 time= 0.14400
Epoch: 0157 train_loss= 0.45671 time= 0.13400
Epoch: 0158 train_loss= 0.45670 time= 0.13400
Epoch: 0159 train_loss= 0.45626 time= 0.13400
Epoch: 0160 train_loss= 0.45592 time= 0.13995
Epoch: 0161 train_loss= 0.45589 time= 0.13805
Epoch: 0162 train_loss= 0.45559 time= 0.14241
Epoch: 0163 train_loss= 0.45544 time= 0.13700
Epoch: 0164 train_loss= 0.45538 ti