In [1]:
import argparse
import time

import numpy as np
import scipy.sparse as sp
import torch
import pandas as pd
import random

from model import GNP_Encoder, GNP_Decoder, InnerProductDecoder
from optimizer import loss_function3
from utils import load_data, preprocess_graph, get_roc_score, train_val_test_split, ct_split_nodes

import matplotlib.pyplot as plt

In [2]:
#%%

torch.set_default_dtype(torch.float32)

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1, help='Random seed.')
parser.add_argument('--epochs', type=int, default=500, help='Number of epochs to train.')
parser.add_argument('--hiddenEnc', type=int, default=32, help='Number of units in hidden layer of Encoder.')
parser.add_argument('--z_dim', type=int, default=32, help='Dimension of latent code Z.')
parser.add_argument('--hiddenDec', type=int, default=64, help='Number of units in hidden layer of Decoder.')
parser.add_argument('--outDimDec', type=int, default=32, help='Output Dimension of the Decoder.')
parser.add_argument('--lr', type=float, default=0.01, help='Initial learning rate.')
parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate (1 - keep probability).')
parser.add_argument('--dataset_str', type=str, default='cora', help='type of dataset.')
parser.add_argument('--n_z_samples', type=int, default=10, help='Number of Z samples')


args,unknown = parser.parse_known_args()

torch.manual_seed(args.seed)

<torch._C.Generator at 0x1e9e410db10>

In [3]:
def sample_z(mu, std, n):
    """Reparameterisation trick."""
    eps = torch.autograd.Variable(std.data.new(n,args.z_dim).normal_())
    return mu + std * eps 

def KLD_gaussian(mu_q, std_q, mu_p, std_p):
    """Analytical KLD between 2 Gaussians."""
    qs2 = std_q**2 + 1e-16
    ps2 = std_p**2 + 1e-16
    
    return (qs2/ps2 + ((mu_q-mu_p)**2)/ps2 + torch.log(ps2/qs2) - 1.0).sum()*0.5
    
def GNP(args,if_plot):
    adj, features = load_data(args.dataset_str)
    n_nodes, feat_dim, = features.shape
    features = features.to(device)
    print("Using {} dataset".format(args.dataset_str))   
    
    # Store original adjacency matrix (without diagonal entries) for later
    adj_orig = adj
    adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
    adj_orig.eliminate_zeros()

    np.random.seed(args.seed)
    
    # You can change the proportion of nodes for training and testing for different few-shot inductive tasks.
    # Change testing nodes proportion in train_val_test_split from utils.py 
    # This notebook use an example of training with 30% nodes and around 10% links, 
    # to predict the rest 90% links in testing.
    features_train, adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false \
    = train_val_test_split(features,adj)
    adj = adj_train
    adj_deep_copy = adj
    
    # Some preprocessing
    adj_norm = preprocess_graph(adj)
    adj_norm = adj_norm.to(device)
    
    adj_label = adj + sp.eye(adj.shape[0])
    adj_label = torch.FloatTensor(adj_label.toarray())
    adj_label = adj_label.to(device)

    pos_weight = float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()
    norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)
    
    encoder = GNP_Encoder(feat_dim, args.hiddenEnc, args.z_dim, args.dropout)
    decoder = GNP_Decoder(feat_dim + args.z_dim, args.hiddenDec, args.outDimDec, args.dropout)
    innerDecoder = InnerProductDecoder(args.dropout, act=lambda x: x)
    encoder.to(device)
    decoder.to(device)
    innerDecoder.to(device)
    
    optimizer = torch.optim.Adam(list(decoder.parameters())+list(encoder.parameters()), args.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=1)
    
    train_loss = []
    val_ap = []
    for epoch in range(args.epochs):
        
        
        t = time.time()
        encoder.train()
        decoder.train()
        innerDecoder.train()
        
        
        np.random.seed(epoch)
        features_context, adj_context  = ct_split_nodes(features_train, adj_deep_copy )
        adj_context_norm = preprocess_graph(adj_context)
        adj_context_norm = adj_context_norm.to(device) 
        
        c_z_mu, c_z_logvar = encoder(features_context, adj_context_norm)
        ct_z_mu, ct_z_logvar = encoder(features_train, adj_norm)
        
        #Sample a batch of zs using reparam trick.
        zs = sample_z(ct_z_mu, torch.exp(ct_z_logvar), args.n_z_samples)
        zs = zs.to(device)
        
        # Get the predictive distribution of y*
        mu, std = decoder(features_train, zs)

        emb = torch.mean(mu, dim = 1)
        pred_adj = innerDecoder(emb)
        
        #Compute loss and backprop
        loss = loss_function3(preds=pred_adj, labels=adj_label, norm=norm, pos_weight=torch.tensor(pos_weight)) + KLD_gaussian(ct_z_mu, torch.exp(ct_z_logvar), c_z_mu, torch.exp(c_z_logvar))
        
        optimizer.zero_grad()
        loss.backward()
        cur_loss = loss.item()
        optimizer.step()        
        scheduler.step()
        
        decoder.eval()
        mu, std = decoder(features, zs)
        emb = torch.mean(mu, dim = 1)
        hidden_emb = emb.cpu().data.numpy()
        roc_curr, ap_curr = get_roc_score(hidden_emb, adj_orig, val_edges, val_edges_false)
        

        train_loss.append(cur_loss)
        val_ap.append(ap_curr)
        print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(cur_loss),
              "val_ap=", "{:.5f}".format(ap_curr),
              "time=", "{:.5f}".format(time.time() - t))
        
        if (if_plot == True) and (epoch%10 == 0):
            fig = plt.figure(figsize=(30,10))
            ax1 = fig.add_subplot(1,2,1)
            ax2 = fig.add_subplot(1,2,2)
            
            ax1.plot(train_loss, label='Training loss')
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('Loss')
            ax1.legend(frameon=False)
    
            ax2.plot(val_ap, label='Validation Average Precision Score',color='Red')
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('AP')
            ax2.legend(frameon=False)
            
            plt.show()    
    print("Optimization Finished!")

    roc_score, ap_score = get_roc_score(hidden_emb, adj_orig, test_edges, test_edges_false)
    print('Test ROC score: ' + str(roc_score))
    print('Test AP score: ' + str(ap_score))
    return roc_score, ap_score


In [4]:
# if once = False, run the model for 10 times with different random seeds.
# if plot = True, then plot the learning curve every 10 epochs.
once = True
plot = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
if __name__ == '__main__':
    if once == True:
        GNP(args,plot)
    else:
        test_roc = []
        test_ap = []
        for seed in range(10):
            print('Seed',seed)
            args.seed = seed
            torch.manual_seed(args.seed)
            roc_score, ap_score = GNP(args,plot)
            test_roc.append(roc_score)
            test_ap.append(ap_score)
        print(test_roc)
        print('mean test AUC is',np.mean(test_roc),' std ', np.std(test_roc))
        print(test_ap)
        print('mean test AP is ',np.mean(test_ap), ' std ', np.std(test_ap))

cuda
Using cora dataset
number of training edges: 415
number of total edges: 5278
training edges/total edges: 0.0786282682834407
Epoch: 0001 train_loss= 6.38263 val_ap= 0.57465 time= 0.55137
Epoch: 0002 train_loss= 1.88367 val_ap= 0.59217 time= 0.10671
Epoch: 0003 train_loss= 0.95612 val_ap= 0.55384 time= 0.13763
Epoch: 0004 train_loss= 1.05494 val_ap= 0.56176 time= 0.13863
Epoch: 0005 train_loss= 1.33812 val_ap= 0.53963 time= 0.10574
Epoch: 0006 train_loss= 1.30587 val_ap= 0.52352 time= 0.11669
Epoch: 0007 train_loss= 1.10163 val_ap= 0.49325 time= 0.12766
Epoch: 0008 train_loss= 1.04751 val_ap= 0.49824 time= 0.10173
Epoch: 0009 train_loss= 1.04071 val_ap= 0.49741 time= 0.11968
Epoch: 0010 train_loss= 0.95398 val_ap= 0.49138 time= 0.11569
Epoch: 0011 train_loss= 0.94534 val_ap= 0.50172 time= 0.11170
Epoch: 0012 train_loss= 0.93209 val_ap= 0.52793 time= 0.13564
Epoch: 0013 train_loss= 0.93427 val_ap= 0.53982 time= 0.12566
Epoch: 0014 train_loss= 0.93473 val_ap= 0.59143 time= 0.12965
Epo

Epoch: 0133 train_loss= 0.46638 val_ap= 0.67260 time= 0.14262
Epoch: 0134 train_loss= 0.47267 val_ap= 0.68709 time= 0.14162
Epoch: 0135 train_loss= 0.46846 val_ap= 0.67940 time= 0.11968
Epoch: 0136 train_loss= 0.47996 val_ap= 0.67504 time= 0.14262
Epoch: 0137 train_loss= 0.46320 val_ap= 0.68959 time= 0.10871
Epoch: 0138 train_loss= 0.46209 val_ap= 0.68631 time= 0.11368
Epoch: 0139 train_loss= 0.46414 val_ap= 0.69383 time= 0.11868
Epoch: 0140 train_loss= 0.46129 val_ap= 0.69959 time= 0.12267
Epoch: 0141 train_loss= 0.46730 val_ap= 0.71059 time= 0.15459
Epoch: 0142 train_loss= 0.46279 val_ap= 0.71610 time= 0.11768
Epoch: 0143 train_loss= 0.46080 val_ap= 0.69592 time= 0.11270
Epoch: 0144 train_loss= 0.46093 val_ap= 0.69356 time= 0.10472
Epoch: 0145 train_loss= 0.45870 val_ap= 0.69042 time= 0.12367
Epoch: 0146 train_loss= 0.46244 val_ap= 0.67858 time= 0.11868
Epoch: 0147 train_loss= 0.46147 val_ap= 0.68068 time= 0.10472
Epoch: 0148 train_loss= 0.45646 val_ap= 0.68373 time= 0.11769
Epoch: 0

Epoch: 0266 train_loss= 0.42344 val_ap= 0.76745 time= 0.10572
Epoch: 0267 train_loss= 0.42132 val_ap= 0.74959 time= 0.12467
Epoch: 0268 train_loss= 0.42129 val_ap= 0.74540 time= 0.12666
Epoch: 0269 train_loss= 0.41981 val_ap= 0.74434 time= 0.14062
Epoch: 0270 train_loss= 0.42301 val_ap= 0.73446 time= 0.13564
Epoch: 0271 train_loss= 0.41941 val_ap= 0.73519 time= 0.13763
Epoch: 0272 train_loss= 0.42170 val_ap= 0.74738 time= 0.12666
Epoch: 0273 train_loss= 0.41920 val_ap= 0.76745 time= 0.11768
Epoch: 0274 train_loss= 0.42092 val_ap= 0.75413 time= 0.10671
Epoch: 0275 train_loss= 0.42313 val_ap= 0.76713 time= 0.09874
Epoch: 0276 train_loss= 0.41912 val_ap= 0.75952 time= 0.12167
Epoch: 0277 train_loss= 0.41951 val_ap= 0.74410 time= 0.12267
Epoch: 0278 train_loss= 0.41853 val_ap= 0.75710 time= 0.11569
Epoch: 0279 train_loss= 0.41925 val_ap= 0.75998 time= 0.14660
Epoch: 0280 train_loss= 0.41869 val_ap= 0.74410 time= 0.11968
Epoch: 0281 train_loss= 0.41785 val_ap= 0.74070 time= 0.11270
Epoch: 0

Epoch: 0400 train_loss= 0.40853 val_ap= 0.78711 time= 0.10971
Epoch: 0401 train_loss= 0.40780 val_ap= 0.79218 time= 0.10971
Epoch: 0402 train_loss= 0.40734 val_ap= 0.79157 time= 0.11569
Epoch: 0403 train_loss= 0.40717 val_ap= 0.79463 time= 0.10472
Epoch: 0404 train_loss= 0.40784 val_ap= 0.79938 time= 0.11868
Epoch: 0405 train_loss= 0.40798 val_ap= 0.79829 time= 0.10871
Epoch: 0406 train_loss= 0.40776 val_ap= 0.79829 time= 0.15060
Epoch: 0407 train_loss= 0.40702 val_ap= 0.79218 time= 0.10272
Epoch: 0408 train_loss= 0.40714 val_ap= 0.79463 time= 0.12965
Epoch: 0409 train_loss= 0.40800 val_ap= 0.79157 time= 0.14461
Epoch: 0410 train_loss= 0.40690 val_ap= 0.79463 time= 0.12666
Epoch: 0411 train_loss= 0.40695 val_ap= 0.79463 time= 0.12367
Epoch: 0412 train_loss= 0.40947 val_ap= 0.79077 time= 0.11669
Epoch: 0413 train_loss= 0.40979 val_ap= 0.79334 time= 0.12566
Epoch: 0414 train_loss= 0.40677 val_ap= 0.80483 time= 0.12965
Epoch: 0415 train_loss= 0.40712 val_ap= 0.80483 time= 0.13265
Epoch: 0