In [None]:
!pip install torch_geometric
!pip install torch_sparse
!pip install torch_scatter

In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch_geometric.utils import dense_to_sparse
from torch_geometric.nn.conv import TransformerConv

#print(torch.cuda.is_available())

In [None]:
#device = torch.device('cuda')

In [2]:
# ground truth
A = torch.tensor([
[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,1,1,0,0,1,0,1,0,0,0,0,0],
[0,0,0,0,0,0,1,1,0,0,1,0,0,0,0,0],
[0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0],
[0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,0],
[0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0],
[0,0,0,1,1,0,1,0,0,0,1,0,0,0,0,0],
[0,0,1,0,0,1,0,0,0,1,0,0,0,1,0,1],
[0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0],
[0,0,1,1,1,0,0,1,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0]
],dtype=torch.float32)

NUM_NODE = 16

# convert to input (edge indices)
data = dense_to_sparse(A)

#print(data) # edge indices and edge attributes
#print(data[0])
#print(data[1])

# node feature (does not work)
#B = torch.tensor([[1.0],[2.0],[3.0],[4.0],[5.0],[6.0],[7.0],[8.0],[9.0],[10.0],[11.0],[12.0],[13.0],[14.0],[15.0],[16.0]])
#print(B)
#B = torch.arange(0,16)
#print(B)

# test case for MSE loss
#C = torch.clone(A)
#C[0,1] += 0.5
#print(C[0,1])

#criterion = nn.BCELoss(reduction='mean')
#loss = criterion(C, A)

#print(loss)
#print(0.5**2)

In [3]:
B = torch.eye(16) # Featureless

A = A+B # ground truth

In [4]:
def network_stat(network_y):
  n = network_y.shape[0]
  holder = torch.zeros(2,dtype=torch.float32)
  for i in range(n):
    for j in range(n):
      if i < j:
        holder[0] += network_y[i,j]
        for k in range(n):
          if j < k:
            holder[1] += network_y[i,k] * network_y[j,k]
  return(holder)

observed_network_stat = network_stat(A)
print(observed_network_stat)

tensor([15., 10.])


In [5]:
import torch
from torch_geometric.data import DataLoader
from tqdm import tqdm
import numpy as np
#import mlflow.pytorch
#from utils import (count_parameters, gvae_loss, slice_edge_type_from_edge_feats, slice_atom_type_from_node_feats)
#from gvae import GVAE
#from config import DEVICE as device

import torch
import torch.nn as nn
from torch.nn import Linear
from torch_geometric.nn.conv import TransformerConv
from torch_geometric.nn import Set2Set
from torch_geometric.nn import BatchNorm
from tqdm import tqdm

from torch_geometric.nn import GAE, VGAE, GCNConv

In [6]:
class GVAE(nn.Module):
    def __init__(self, feature_size):
        super(GVAE, self).__init__()
        self.encoder_embedding_size = 128
        self.latent_size = 10
        self.decoder_embedding_size = 128

        # Encoder layers (1 is dimenstion of X)
        self.conv1 =       GCNConv(feature_size                   , 2 * self.encoder_embedding_size)
      
        self.conv_mean =   GCNConv(2 * self.encoder_embedding_size, self.latent_size)
        self.conv_logstd = GCNConv(2 * self.encoder_embedding_size, self.latent_size)

        self.decode_conv1 = GCNConv(self.latent_size, 2 * self.decoder_embedding_size)
        self.decode_conv2 = GCNConv(2 * self.decoder_embedding_size, 2 * self.decoder_embedding_size)
        
    def reparameterize(self, mu, logstd):
        # should return n by encoder_embedding_size
        gaussian_noise = torch.randn(NUM_NODE, self.latent_size)
        output = mu + gaussian_noise*torch.exp(logstd)
        return output

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).tanh()
        mu = self.conv_mean(x, edge_index)
        logstd = self.conv_logstd(x, edge_index)
        return mu, logstd

    def decode(self, Z, edge_index):
        #x = self.decode_conv1(Z, edge_index).tanh()
        A_pred = torch.sigmoid(torch.matmul(Z,Z.t()))
        return A_pred

    def forward(self, x, edge_index):
        mu, logstd = self.encode(x, edge_index)
        Z = self.reparameterize(mu, logstd)
        A_pred = self.decode(Z, edge_index)
        return A_pred, mu, logstd

 

def kl_loss(mu, logstd):
    """
    Closed formula of the KL divergence for normal distributions
    """
    MAX_LOGSTD = 10
    logstd =  logstd.clamp(max=MAX_LOGSTD)
    kl_div = -0.5/NUM_NODE * torch.mean(torch.sum(1 + 2 * logstd - mu**2 - logstd.exp()**2, dim=1))
    # Limit numeric errors
    kl_div = kl_div.clamp(max=1000)
    return kl_div


def gvae_loss(A_pred, mu, logstd, kl_beta, num_nodes, ground_truth, criterion):
    """
    Calculates the loss for the graph variational autoencoder,
    consiting of a node loss, an edge loss and the KL divergence.
    """
    #recon = triu_to_dense(triu_logits, num_nodes)
    recon_loss = criterion(A_pred, ground_truth)

    # KL Divergence
    kl_divergence = kl_loss(mu, logstd)

    return recon_loss #+ kl_beta * kl_divergence, kl_divergence, 

In [7]:
model = GVAE(feature_size=16)#.to(device)
criterion = nn.BCELoss(reduction='sum')

loss_fn = gvae_loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=0.0001)
kl_beta = 0.2

input = data[0]#.to(device) ##
gt = A#.to(device) ##
feature = B#.to(device)

In [8]:
for epoch in range(10000): 
  
  A_pred, mu, logstd = model(feature, input) 
  loss = loss_fn(A_pred, mu, logstd, kl_beta, NUM_NODE, gt, criterion) 
  if epoch % 1000 == 0: 
    print(loss)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

tensor(330.4344, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(120.7233, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(40.4593, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(22.5825, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(12.2676, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(5.6960, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.0680, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.0171, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.0071, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.0039, grad_fn=<BinaryCrossEntropyBackward0>)


In [10]:
print(mu)
print(torch.exp(logstd))

tensor([[  2.4740,  -6.5603,   5.7162,  -2.5699,  -3.1541,  -4.7499,  -1.9895,
          14.3072,   9.1428,   9.9966],
        [ -2.6300, -15.0624,   0.0530,  -8.0496,   8.6574,   4.5892,  -4.9303,
          -7.6423,  -3.4009,   5.2160],
        [ -0.6463,  -2.8073,   1.7221,   4.9953,   1.9110,  -3.9735,   2.9545,
           1.8041,  -5.3203,  -1.1885],
        [ -4.6889,  -0.1911,   1.1308,  -4.3093, -12.3762,   2.1420,  -3.3471,
          -3.1099,   0.1136,  -1.4036],
        [ -4.1473,  -3.7976,   1.1838,   2.8799,   0.9668,  -6.4445,   7.7516,
           0.7787,  -4.0972,  -1.3141],
        [  2.6597,  -0.3638,   1.8245,   4.1688,   0.1132,   0.5494,  -3.9931,
           1.4450,  -4.2477,  -0.8104],
        [ -3.6965,   1.1316,   1.0813,  -5.6725, -14.5061,   4.6445,  -6.8686,
          -3.7793,   1.2154,  -1.2450],
        [ -5.5698,  -1.6489,   1.1464,  -2.6278,  -9.4131,  -0.8015,   0.7271,
          -2.2426,  -1.0886,  -1.5252],
        [  2.4965,   1.9766,   0.4488,   1.3810,

In [20]:
#sampled_Z = torch.randn(16, 10)
sampled_Z = GVAE(16).reparameterize(mu, logstd)
sampled_A = GVAE(16).decode(sampled_Z, input)
print(network_stat(sampled_A))

tensor([15.0008, 10.0009], grad_fn=<CopySlices>)


In [17]:
A_pred = (A_pred>0.5).float()
#A_pred.fill_diagonal_(0)   
print(network_stat(A_pred))
print(network_stat(A))

tensor([15., 10.])
tensor([15., 10.])
