In [None]:
!pip install torch_geometric

In [2]:
import torch
import torch_geometric
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from sklearn.metrics import roc_auc_score

In [3]:
# VGAE
class Encoder(torch.nn.Module):
  def __init__(self, node_dim : int, latent_var_dim : int = 16, latent_dim : int = 16):
    super(Encoder, self).__init__()
    self.node_dim = node_dim
    self.latent_dim = latent_dim
    self.latent_var_dim = latent_var_dim
    self.sharedConv = torch_geometric.nn.GCNConv(self.node_dim, self.latent_dim)
    self.avgConv = torch_geometric.nn.GCNConv(self.latent_dim, self.latent_var_dim)
    self.varConv = torch_geometric.nn.GCNConv(self.latent_dim, self.latent_var_dim)

  def forward(self, x, edge_index):
    x = self.sharedConv(x, edge_index)
    x = torch.relu(x)
    avg = self.avgConv(x, edge_index)
    avg = torch_geometric.nn.global_mean_pool(avg, batch=None)
    log_var = self.varConv(x, edge_index)
    log_var = torch_geometric.nn.pool.global_mean_pool(log_var, batch=None)
    eps = torch.normal(mean=torch.zeros(avg.shape[0]), std=torch.ones(avg.shape[0]))
    embd = avg + (eps.reshape(log_var.shape[0], 1) * torch.exp(0.5 * log_var))
    return embd, avg, log_var

class Decoder(torch.nn.Module):
  def __init__(self, num_nodes : int, latent_var_dim : int, latent_dim = 16):
    super(Decoder, self).__init__()
    self.num_nodes = num_nodes
    self.fc1 = torch.nn.Linear(latent_var_dim, latent_dim)
    self.fc2 = torch.nn.Linear(latent_dim, num_nodes * num_nodes)

  def forward(self, latent):
    x = self.fc1(latent)
    x = torch.relu(x)
    x = self.fc2(x)
    x = x.reshape(self.num_nodes, self.num_nodes)
    x = F.sigmoid(x)
    return x

class VGAE(torch.nn.Module):
  def __init__(self, node_dim : int, num_nodes : int, latent_var_dim : int  = 16, latent_dim : int = 16):
    super(VGAE, self).__init__()
    self.encoder = Encoder(node_dim, latent_var_dim, latent_dim)
    self.decoder = Decoder(num_nodes, latent_var_dim, latent_dim)

  def forward(self, x, edge_index):
    embd, avg, log_var = self.encoder(x, edge_index)
    print(embd.shape)
    x = self.decoder(embd)
    return x, avg, log_var

In [4]:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]

In [None]:
model = VGAE(node_dim=data.x.shape[1], num_nodes=data.x.shape[0], latent_var_dim=128, latent_dim=16)
nodes = data.x.shape[0]
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
pos_weight = float(nodes * nodes - data.edge_index.shape[1]) / data.edge_index.shape[1]
lossfunc = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))

for i in range(200):
  optimizer.zero_grad()
  adj_model, avg, log_var = model(data.x, data.edge_index)
  adj_model = adj_model.flatten()
  adj = torch_geometric.utils.to_dense_adj(data.edge_index)[0]
  adj = adj.flatten()
  norm = nodes * nodes / float(2 * (nodes * nodes - adj.sum()))
  loss =  norm * lossfunc(adj_model, adj)
  KLD = (-0.5   * torch.sum(1 + log_var - avg.pow(2) - log_var.exp())) / (nodes)
  loss = loss + KLD
  loss.backward()
  optimizer.step()
  adj = adj.detach()
  adj_model = adj_model.detach()
  print('Epoch ', i, ' : ', KLD)
  print('ROC AUC : ', roc_auc_score(adj, adj_model))

torch.Size([1, 128])
Epoch  0  :  tensor(1.5029e-05, grad_fn=<DivBackward0>)
ROC AUC :  0.5047393142911228
torch.Size([1, 128])
Epoch  1  :  tensor(6.4246e-06, grad_fn=<DivBackward0>)
ROC AUC :  0.5046181828226488
torch.Size([1, 128])
Epoch  2  :  tensor(5.0501e-06, grad_fn=<DivBackward0>)
ROC AUC :  0.504362025968981
torch.Size([1, 128])
Epoch  3  :  tensor(6.9701e-06, grad_fn=<DivBackward0>)
ROC AUC :  0.5099879343724412
torch.Size([1, 128])
Epoch  4  :  tensor(7.9722e-06, grad_fn=<DivBackward0>)
ROC AUC :  0.510640523100737
torch.Size([1, 128])
Epoch  5  :  tensor(7.5232e-06, grad_fn=<DivBackward0>)
ROC AUC :  0.49930506697963084
torch.Size([1, 128])
Epoch  6  :  tensor(6.8346e-06, grad_fn=<DivBackward0>)
ROC AUC :  0.5140592644112534
torch.Size([1, 128])
Epoch  7  :  tensor(5.9771e-06, grad_fn=<DivBackward0>)
ROC AUC :  0.5115737247052865
torch.Size([1, 128])
Epoch  8  :  tensor(5.1543e-06, grad_fn=<DivBackward0>)
ROC AUC :  0.5145239388729249
torch.Size([1, 128])
Epoch  9  :  tens

In [None]:
out, avg, log_var = model(data.x, data.edge_index)