In [1]:
import argparse
import time

import numpy as np
import scipy.sparse as sp
import torch
import pandas as pd
import random

from model import GNP_Encoder, GNP_Decoder, InnerProductDecoder
from optimizer import loss_function3
from utils import load_data, preprocess_graph, get_roc_score, train_val_test_split, ct_split_nodes

import matplotlib.pyplot as plt

In [2]:
#%%

torch.set_default_dtype(torch.float32)

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--epochs', type=int, default=400, help='Number of epochs to train.')
parser.add_argument('--hiddenEnc', type=int, default=32, help='Number of units in hidden layer of Encoder.')
parser.add_argument('--z_dim', type=int, default=16, help='Dimension of latent code Z.')
parser.add_argument('--hiddenDec', type=int, default=64, help='Number of units in hidden layer of Decoder.')
parser.add_argument('--outDimDec', type=int, default=16, help='Output Dimension of the Decoder.')
parser.add_argument('--lr', type=float, default=0.01, help='Initial learning rate.')
parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate (1 - keep probability).')
parser.add_argument('--dataset_str', type=str, default='citeseer', help='type of dataset.')
parser.add_argument('--n_z_samples', type=int, default=1, help='Number of Z samples')


args,unknown = parser.parse_known_args()

torch.manual_seed(args.seed)

<torch._C.Generator at 0x17ebafd3b30>

In [3]:
def sample_z(mu, std, n):
    """Reparameterisation trick."""
    eps = torch.autograd.Variable(std.data.new(n,args.z_dim).normal_())
    return mu + std * eps 

def KLD_gaussian(mu_q, std_q, mu_p, std_p):
    """Analytical KLD between 2 Gaussians."""
    qs2 = std_q**2 + 1e-16
    ps2 = std_p**2 + 1e-16
    
    return (qs2/ps2 + ((mu_q-mu_p)**2)/ps2 + torch.log(ps2/qs2) - 1.0).sum()*0.5
    
def NPGNN(args,if_plot):
    adj, features = load_data(args.dataset_str)
    n_nodes, feat_dim, = features.shape
    features = features.to(device)
    print("Using {} dataset".format(args.dataset_str))   
    
    # Store original adjacency matrix (without diagonal entries) for later
    adj_orig = adj
    adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
    adj_orig.eliminate_zeros()

    np.random.seed(args.seed)
    
    # You can change the proportion of nodes for training and testing for different few-shot inductive tasks.
    # Change testing nodes proportion in train_val_test_split from utils.py 
    # This notebook use an example of training with 30% nodes and around 10% links, 
    # to predict the rest 90% links in testing.
    features_train, adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false \
    = train_val_test_split(features,adj)
    adj = adj_train
    adj_deep_copy = adj
    
    # Some preprocessing
    adj_norm = preprocess_graph(adj)
    adj_norm = adj_norm.to(device)
    
    adj_label = adj + sp.eye(adj.shape[0])
    adj_label = torch.FloatTensor(adj_label.toarray())
    adj_label = adj_label.to(device)

    pos_weight = float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()
    norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)
    
    encoder = GNP_Encoder(feat_dim, args.hiddenEnc, args.z_dim, args.dropout)
    decoder = GNP_Decoder(feat_dim + args.z_dim, args.hiddenDec, args.outDimDec, args.dropout)
    innerDecoder = InnerProductDecoder(args.dropout, act=lambda x: x)
    encoder.to(device)
    decoder.to(device)
    innerDecoder.to(device)
    
    optimizer = torch.optim.Adam(list(decoder.parameters())+list(encoder.parameters()), args.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=1)
    
    train_loss = []
    val_ap = []
    t = time.time()
    for epoch in range(args.epochs):
        
        encoder.train()
        decoder.train()
        innerDecoder.train()
        
        
        np.random.seed(epoch)
        features_context, adj_context  = ct_split_nodes(features_train, adj_deep_copy )
        adj_context_norm = preprocess_graph(adj_context)
        adj_context_norm = adj_context_norm.to(device) 
        
        c_z_mu, c_z_logvar = encoder(features_context, adj_context_norm)
        ct_z_mu, ct_z_logvar = encoder(features_train, adj_norm)
        
        #Sample a batch of zs using reparam trick.
        zs = sample_z(ct_z_mu, torch.exp(ct_z_logvar), args.n_z_samples)
        zs = zs.to(device)
        
        # Get the predictive distribution of y*
        mu, std = decoder(features_train, zs)

        emb = torch.mean(mu, dim = 1)
        pred_adj = innerDecoder(emb)
        
        #Compute loss and backprop
        loss = loss_function3(preds=pred_adj, labels=adj_label, norm=norm, pos_weight=torch.tensor(pos_weight)) + KLD_gaussian(ct_z_mu, torch.exp(ct_z_logvar), c_z_mu, torch.exp(c_z_logvar))
        
        optimizer.zero_grad()
        loss.backward()
        cur_loss = loss.item()
        optimizer.step()        
        scheduler.step()
        
        if epoch%19 == 0:        
            decoder.eval()
            mu, std = decoder(features, zs)
            emb = torch.mean(mu, dim = 1)
            hidden_emb = emb.cpu().data.numpy()
            roc_curr, ap_curr = get_roc_score(hidden_emb, adj_orig, val_edges, val_edges_false)
            train_loss.append(cur_loss)
            val_ap.append(ap_curr)
            print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(cur_loss),
                  "val_ap=", "{:.5f}".format(ap_curr),
                  "time=", "{:.5f}".format(time.time() - t))
            t = time.time()
        
        if (if_plot == True) and (epoch%10 == 0):
            fig = plt.figure(figsize=(30,10))
            ax1 = fig.add_subplot(1,2,1)
            ax2 = fig.add_subplot(1,2,2)
            
            ax1.plot(train_loss, label='Training loss')
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('Loss')
            ax1.legend(frameon=False)
    
            ax2.plot(val_ap, label='Validation Average Precision Score',color='Red')
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('AP')
            ax2.legend(frameon=False)
            
            plt.show()    
    print("Optimization Finished!")

    roc_score, ap_score = get_roc_score(hidden_emb, adj_orig, test_edges, test_edges_false)
    print('Test ROC score: ' + str(roc_score))
    print('Test AP score: ' + str(ap_score))
    return roc_score, ap_score


In [4]:
# if once = False, run the model for 10 times with different random seeds.
# if plot = True, then plot the learning curve every 10 epochs.
once = True
plot = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
if __name__ == '__main__':
    if once == True:
        NPGNN(args,plot)
    else:
        test_roc = []
        test_ap = []
        for seed in range(10):
            print('Seed',seed)
            args.seed = seed
            torch.manual_seed(args.seed)
            roc_score, ap_score = NPGNN(args,plot)
            test_roc.append(roc_score)
            test_ap.append(ap_score)
        print(test_roc)
        print('mean test AUC is',np.mean(test_roc),' std ', np.std(test_roc))
        print(test_ap)
        print('mean test AP is ',np.mean(test_ap), ' std ', np.std(test_ap))

cuda
Using citeseer dataset
number of training edges: 3855
number of total edges: 4552
training edges/total edges: 0.8468804920913884
Epoch: 0001 train_loss= 2.21067 val_ap= 0.61941 time= 2.61572
Epoch: 0020 train_loss= 0.77462 val_ap= 0.74007 time= 40.01886
Epoch: 0039 train_loss= 0.63013 val_ap= 0.84761 time= 39.92859
Epoch: 0058 train_loss= 0.52169 val_ap= 0.92229 time= 39.20084
Epoch: 0077 train_loss= 0.48297 val_ap= 0.92837 time= 41.23771
Epoch: 0096 train_loss= 0.46567 val_ap= 0.92489 time= 41.76829
Epoch: 0115 train_loss= 0.44809 val_ap= 0.92707 time= 42.32794
Epoch: 0134 train_loss= 0.44442 val_ap= 0.92787 time= 41.06981
Epoch: 0153 train_loss= 0.43625 val_ap= 0.92756 time= 41.25296
Epoch: 0172 train_loss= 0.43462 val_ap= 0.92883 time= 41.53292
Epoch: 0191 train_loss= 0.43038 val_ap= 0.93125 time= 41.91390
Epoch: 0210 train_loss= 0.43461 val_ap= 0.92568 time= 42.46313
Epoch: 0229 train_loss= 0.42687 val_ap= 0.93181 time= 41.08811
Epoch: 0248 train_loss= 0.42192 val_ap= 0.93036 