In [1]:
import argparse
import time

import numpy as np
import scipy.sparse as sp
import torch
import pandas as pd
import random

from model import GNP_Encoder, GNP_Decoder, InnerProductDecoder
from optimizer import loss_function3
from utils import load_data, mask_test_edges, preprocess_graph, get_roc_score, ct_split

import matplotlib.pyplot as plt

In [2]:
#%%

torch.set_default_dtype(torch.float32)

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='gcn_vae', help="models used")
parser.add_argument('--seed', type=int, default=1, help='Random seed.')
parser.add_argument('--epochs', type=int, default=500, help='Number of epochs to train.')
parser.add_argument('--hiddenEnc', type=int, default=32, help='Number of units in hidden layer of Encoder.')
parser.add_argument('--z_dim', type=int, default=32, help='Dimension of latent code Z.')
parser.add_argument('--hiddenDec', type=int, default=64, help='Number of units in hidden layer of Decoder.')
parser.add_argument('--outDimDec', type=int, default=32, help='Output Dimension of the Decoder.')
parser.add_argument('--lr', type=float, default=0.01, help='Initial learning rate.')
parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate (1 - keep probability).')
parser.add_argument('--dataset_str', type=str, default='cora', help='type of dataset.')
parser.add_argument('--n_z_samples', type=int, default=10, help='Number of Z samples')


args,unknown = parser.parse_known_args()

torch.manual_seed(args.seed)

<torch._C.Generator at 0x1679c7abb10>

In [3]:
def sample_z(mu, std, n):
    """Reparameterisation trick."""
    eps = torch.autograd.Variable(std.data.new(n,args.z_dim).normal_())
    return mu + std * eps 

def KLD_gaussian(mu_q, std_q, mu_p, std_p):
    """Analytical KLD between 2 Gaussians."""
    qs2 = std_q**2 + 1e-16
    ps2 = std_p**2 + 1e-16
    
    return (qs2/ps2 + ((mu_q-mu_p)**2)/ps2 + torch.log(ps2/qs2) - 1.0).sum()*0.5
    
def gae_for(args,if_plot):
    adj, features = load_data(args.dataset_str)
    n_nodes, feat_dim, = features.shape
    features = features.to(device)
    print("Using {} dataset".format(args.dataset_str))
    
    
    # Store original adjacency matrix (without diagonal entries) for later
    adj_orig = adj
    adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
    adj_orig.eliminate_zeros()
    adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges(adj)
    adj = adj_train
    adj_deep_copy = adj
    
    # Some preprocessing
    adj_norm = preprocess_graph(adj)   
    adj_norm = adj_norm.to(device)
    
    adj_label = adj_train + sp.eye(adj_train.shape[0])
    adj_label = torch.FloatTensor(adj_label.toarray())
    adj_label = adj_label.to(device)

    pos_weight = float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()
    norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)
    
    encoder = GNP_Encoder(feat_dim, args.hiddenEnc, args.z_dim, args.dropout)
    decoder = GNP_Decoder(feat_dim + args.z_dim, args.hiddenDec, args.outDimDec, args.dropout)
    innerDecoder = InnerProductDecoder(args.dropout, act=lambda x: x)
    
    encoder.to(device)
    decoder.to(device)
    innerDecoder.to(device)
    
    optimizer = torch.optim.Adam(list(decoder.parameters())+list(encoder.parameters()), args.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=1)
    
    train_loss = []
    val_ap = []
    for epoch in range(args.epochs):
        
        
        t = time.time()
        encoder.train()
        decoder.train()
        innerDecoder.train()
        
        
        np.random.seed(epoch)
        adj_context  = ct_split(adj_deep_copy )
        adj_context_norm = preprocess_graph(adj_context)
        adj_context_norm = adj_context_norm.to(device)
        
        
        c_z_mu, c_z_logvar = encoder(features, adj_context_norm)
        ct_z_mu, ct_z_logvar = encoder(features, adj_norm)
        
        #Sample a batch of zs using reparam trick for MC estimation
        zs = sample_z(ct_z_mu, torch.exp(ct_z_logvar), args.n_z_samples)
        zs = zs.to(device)   

        # Get the predictive distribution of y*
        mu, std = decoder(features, zs)

        emb = torch.mean(mu, dim = 1)
        pred_adj = innerDecoder(emb)
        
        #Compute loss and backprop
        loss = loss_function3(preds=pred_adj, labels=adj_label, norm=norm, pos_weight=torch.tensor(pos_weight))  + KLD_gaussian(ct_z_mu, torch.exp(ct_z_logvar), c_z_mu, torch.exp(c_z_logvar))
        
        optimizer.zero_grad()
        loss.backward()
        cur_loss = loss.item()
        optimizer.step()        
        scheduler.step()
        
        hidden_emb = emb.cpu().data.numpy()
        roc_curr, ap_curr = get_roc_score(hidden_emb, adj_orig, val_edges, val_edges_false)
        
        
        train_loss.append(cur_loss)
        val_ap.append(ap_curr)
        print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(cur_loss),
              "val_ap=", "{:.5f}".format(ap_curr),
              "time=", "{:.5f}".format(time.time() - t)
              )

        if (if_plot == True) and (epoch%49 == 0):
            fig = plt.figure(figsize=(30,10))
            ax1 = fig.add_subplot(1,2,1)
            ax2 = fig.add_subplot(1,2,2)
            
            ax1.plot(train_loss, label='Training loss')
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('Loss')
            ax1.legend(frameon=False)
    
            ax2.plot(val_ap, label='Validation Average Precision Score',color='Red')
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('AP')
            ax2.legend(frameon=False)
            
            plt.show()
    print("Optimization Finished!")

    roc_score, ap_score = get_roc_score(hidden_emb, adj_orig, test_edges, test_edges_false)
    print('Test ROC score: ' + str(roc_score))
    print('Test AP score: ' + str(ap_score))
    return roc_score, ap_score

In [4]:
# if once = False, run the model for 10 times with different random seeds.
# if plot = True, then plot the learning curve every 10 epochs.
once = True
plot = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
if __name__ == '__main__':
    if once == True:
        gae_for(args, plot)
    else:
        test_roc = []
        test_ap = []
        for seed in range(10):
            print('Seed',seed)
            args.seed = seed
            torch.manual_seed(args.seed)
            roc_score, ap_score = gae_for(args,plot)
            test_roc.append(roc_score)
            test_ap.append(ap_score)
        print(test_roc)
        print('mean test AUC is',np.mean(test_roc),' std ', np.std(test_roc))
        print(test_ap)
        print('mean test AP is ',np.mean(test_ap), ' std ', np.std(test_ap))

cuda
Using cora dataset
Epoch: 0001 train_loss= 6.38497 val_ap= 0.50947 time= 0.51023
Epoch: 0002 train_loss= 1.87491 val_ap= 0.52450 time= 0.08577
Epoch: 0003 train_loss= 0.85954 val_ap= 0.55594 time= 0.07679
Epoch: 0004 train_loss= 0.94885 val_ap= 0.54348 time= 0.09774
Epoch: 0005 train_loss= 1.22695 val_ap= 0.54390 time= 0.07480
Epoch: 0006 train_loss= 1.24099 val_ap= 0.54373 time= 0.07779
Epoch: 0007 train_loss= 1.07747 val_ap= 0.53466 time= 0.08577
Epoch: 0008 train_loss= 0.97753 val_ap= 0.51889 time= 0.07679
Epoch: 0009 train_loss= 0.92885 val_ap= 0.51553 time= 0.07979
Epoch: 0010 train_loss= 0.83266 val_ap= 0.51335 time= 0.07879
Epoch: 0011 train_loss= 0.81389 val_ap= 0.52249 time= 0.08677
Epoch: 0012 train_loss= 0.80165 val_ap= 0.54310 time= 0.08278
Epoch: 0013 train_loss= 0.79805 val_ap= 0.55640 time= 0.09375
Epoch: 0014 train_loss= 0.79643 val_ap= 0.58052 time= 0.08178
Epoch: 0015 train_loss= 0.79125 val_ap= 0.58110 time= 0.07081
Epoch: 0016 train_loss= 0.78987 val_ap= 0.5950

Epoch: 0133 train_loss= 0.53687 val_ap= 0.82670 time= 0.07480
Epoch: 0134 train_loss= 0.53547 val_ap= 0.83052 time= 0.08677
Epoch: 0135 train_loss= 0.53554 val_ap= 0.83603 time= 0.08677
Epoch: 0136 train_loss= 0.54499 val_ap= 0.83852 time= 0.09176
Epoch: 0137 train_loss= 0.52824 val_ap= 0.83756 time= 0.07779
Epoch: 0138 train_loss= 0.52807 val_ap= 0.83614 time= 0.07081
Epoch: 0139 train_loss= 0.52605 val_ap= 0.83857 time= 0.07480
Epoch: 0140 train_loss= 0.52382 val_ap= 0.84039 time= 0.07879
Epoch: 0141 train_loss= 0.52255 val_ap= 0.84110 time= 0.08078
Epoch: 0142 train_loss= 0.51974 val_ap= 0.84525 time= 0.07879
Epoch: 0143 train_loss= 0.51805 val_ap= 0.84972 time= 0.08078
Epoch: 0144 train_loss= 0.51686 val_ap= 0.85296 time= 0.08477
Epoch: 0145 train_loss= 0.51417 val_ap= 0.85450 time= 0.08078
Epoch: 0146 train_loss= 0.51595 val_ap= 0.85793 time= 0.07679
Epoch: 0147 train_loss= 0.51131 val_ap= 0.85708 time= 0.07779
Epoch: 0148 train_loss= 0.51061 val_ap= 0.85465 time= 0.09475
Epoch: 0

Epoch: 0268 train_loss= 0.44698 val_ap= 0.91225 time= 0.08477
Epoch: 0269 train_loss= 0.44674 val_ap= 0.91263 time= 0.08677
Epoch: 0270 train_loss= 0.44746 val_ap= 0.91326 time= 0.07879
Epoch: 0271 train_loss= 0.44616 val_ap= 0.91292 time= 0.07779
Epoch: 0272 train_loss= 0.44673 val_ap= 0.91289 time= 0.07779
Epoch: 0273 train_loss= 0.44573 val_ap= 0.91345 time= 0.07779
Epoch: 0274 train_loss= 0.44585 val_ap= 0.91329 time= 0.09274
Epoch: 0275 train_loss= 0.44608 val_ap= 0.91426 time= 0.07580
Epoch: 0276 train_loss= 0.44524 val_ap= 0.91394 time= 0.07281
Epoch: 0277 train_loss= 0.44481 val_ap= 0.91379 time= 0.09275
Epoch: 0278 train_loss= 0.44455 val_ap= 0.91448 time= 0.07679
Epoch: 0279 train_loss= 0.44445 val_ap= 0.91473 time= 0.07380
Epoch: 0280 train_loss= 0.44463 val_ap= 0.91454 time= 0.08078
Epoch: 0281 train_loss= 0.44384 val_ap= 0.91499 time= 0.07480
Epoch: 0282 train_loss= 0.44376 val_ap= 0.91513 time= 0.07380
Epoch: 0283 train_loss= 0.44345 val_ap= 0.91512 time= 0.07979
Epoch: 0

Epoch: 0402 train_loss= 0.42615 val_ap= 0.92923 time= 0.09873
Epoch: 0403 train_loss= 0.42591 val_ap= 0.92928 time= 0.07380
Epoch: 0404 train_loss= 0.42586 val_ap= 0.92933 time= 0.07679
Epoch: 0405 train_loss= 0.42572 val_ap= 0.92933 time= 0.07480
Epoch: 0406 train_loss= 0.42566 val_ap= 0.92945 time= 0.07181
Epoch: 0407 train_loss= 0.42556 val_ap= 0.92961 time= 0.08078
Epoch: 0408 train_loss= 0.42547 val_ap= 0.92992 time= 0.08078
Epoch: 0409 train_loss= 0.42560 val_ap= 0.93020 time= 0.09275
Epoch: 0410 train_loss= 0.42523 val_ap= 0.92990 time= 0.08078
Epoch: 0411 train_loss= 0.42521 val_ap= 0.92968 time= 0.07480
Epoch: 0412 train_loss= 0.42562 val_ap= 0.92949 time= 0.07679
Epoch: 0413 train_loss= 0.42543 val_ap= 0.92931 time= 0.08278
Epoch: 0414 train_loss= 0.42510 val_ap= 0.93025 time= 0.08178
Epoch: 0415 train_loss= 0.42485 val_ap= 0.93034 time= 0.07380
Epoch: 0416 train_loss= 0.42521 val_ap= 0.93074 time= 0.07181
Epoch: 0417 train_loss= 0.42469 val_ap= 0.93027 time= 0.07580
Epoch: 0