In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [2]:
import torch
import torch_geometric

In [26]:
# VGAE
class distrGCN(torch.nn.Module):
  def __init__(self, node_dim : int, latent_dim : int = 16):
    super(distrGCN, self).__init__()
    self.node_dim = node_dim
    self.latent_dim = latent_dim
    self.sharedConv = torch_geometric.nn.GCNConv(self.node_dim, self.latent_dim)
    self.avgConv = torch_geometric.nn.GCNConv(self.latent_dim, self.node_dim)
    self.varConv = torch_geometric.nn.GCNConv(self.latent_dim, self.node_dim)

  def forward(self, x, edge_index):
    x = self.sharedConv(x, edge_index)
    x = torch.relu(x)
    avg = self.avgConv(x, edge_index)
    log_var = self.varConv(x, edge_index)
    eps = torch.normal(mean=torch.zeros(avg.shape[0]), std=torch.ones(avg.shape[0]))
    embd = avg + (eps.reshape(log_var.shape[0], 1) * torch.exp(0.5 * log_var))
    return embd, avg, log_var

In [4]:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [66]:
model = distrGCN(node_dim=data.x.shape[1])
nodes = data.x.shape[0]
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
pos_weight = float(nodes * nodes - data.edge_index.shape[1]) / data.edge_index.shape[1]
lossfunc = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
m = torch.nn.Sigmoid()
for i in range(200):
  optimizer.zero_grad()
  out, avg, log_var = model(data.x, data.edge_index)
  adj_model = out @ out.t()
  adj_model = adj_model.flatten()
  adj = torch_geometric.utils.to_dense_adj(data.edge_index)[0]
  adj = adj.flatten()
  norm = nodes * nodes / float(2 * (nodes * nodes - adj.sum()))
  loss =  lossfunc(adj_model, adj)
  KLD = (-0.5   * torch.sum(1 + log_var - avg.pow(2) - log_var.exp())) / (nodes * nodes)
  loss = loss + KLD
  loss.backward()
  optimizer.step()
  adj = adj.detach()
  adj_model = adj_model.detach()
  print('Epoch ', i, ' : ', KLD)
  print('ROC AUC : ', roc_auc_score(adj, adj_model))

Epoch  0  :  tensor(2.5555e-05, grad_fn=<DivBackward0>)
ROC AUC :  0.5040823860610839
Epoch  1  :  tensor(0.0005, grad_fn=<DivBackward0>)
ROC AUC :  0.4980659648236827
Epoch  2  :  tensor(0.0024, grad_fn=<DivBackward0>)
ROC AUC :  0.5040278869480066
Epoch  3  :  tensor(0.0055, grad_fn=<DivBackward0>)
ROC AUC :  0.5057491835226896
Epoch  4  :  tensor(0.0117, grad_fn=<DivBackward0>)
ROC AUC :  0.5058463457373666
Epoch  5  :  tensor(0.0248, grad_fn=<DivBackward0>)
ROC AUC :  0.5006766772987591
Epoch  6  :  tensor(0.0497, grad_fn=<DivBackward0>)
ROC AUC :  0.5161686561534755
Epoch  7  :  tensor(0.0867, grad_fn=<DivBackward0>)
ROC AUC :  0.5214554293462322
Epoch  8  :  tensor(0.1373, grad_fn=<DivBackward0>)
ROC AUC :  0.546985689135298
Epoch  9  :  tensor(0.1976, grad_fn=<DivBackward0>)
ROC AUC :  0.5456978051695544
Epoch  10  :  tensor(0.2684, grad_fn=<DivBackward0>)
ROC AUC :  0.5501992533422332
Epoch  11  :  tensor(0.3518, grad_fn=<DivBackward0>)
ROC AUC :  0.5326834216498901
Epoch  12  

In [28]:
out, avg, log_var = model(data.x, data.edge_index)

In [97]:
# GAE
# note the increased latent dimension.
from torch_geometric.nn import BatchNorm
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from sklearn.metrics import roc_auc_score
class GAE(torch.nn.Module):
    def __init__(self, node_dim: int, latent_dim: int = 64):
        super(GAE, self).__init__()
        self.conv1 = GCNConv(node_dim, latent_dim)
        self.conv2 = GCNConv(latent_dim, latent_dim)
        self.conv3 = GCNConv(latent_dim, node_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        ##x = self.conv2(x, edge_index)
        ##x = F.relu(x)
        embd = self.conv3(x, edge_index)
        x = F.relu(embd)
        return embd

In [None]:
model = GAE(node_dim=data.x.shape[1])
nodes = data.x.shape[0]
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
pos_weight = float(nodes * nodes - data.edge_index.shape[1]) / data.edge_index.shape[1]
lossfunc = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))
for i in range(200):
  optimizer.zero_grad()
  out = model(data.x, data.edge_index)
  adj_model = out @ out.t()
  adj_model = adj_model.flatten()
  adj = torch_geometric.utils.to_dense_adj(data.edge_index)[0]
  adj = adj.flatten()
  #norm = nodes * nodes / float(2 * (nodes * nodes - adj.sum()))
  loss =  lossfunc(adj_model, adj)
  loss.backward()
  optimizer.step()
  with torch.no_grad():
    print('ROC AUC : ', roc_auc_score(adj, adj_model))

ROC AUC :  0.880918436341878
ROC AUC :  0.8492859534173747
ROC AUC :  0.7688261576572621
ROC AUC :  0.8039718674122641
ROC AUC :  0.9540884197534049
ROC AUC :  0.8953222402403664
ROC AUC :  0.7946112417257215
ROC AUC :  0.7510015060750661
ROC AUC :  0.735035041475864
ROC AUC :  0.7288434271621355
ROC AUC :  0.7257834510343031
ROC AUC :  0.7230532124063916
ROC AUC :  0.719777158453251
ROC AUC :  0.7162996657676344
ROC AUC :  0.71347041526445
ROC AUC :  0.7126673904649456
ROC AUC :  0.7145762046621031
ROC AUC :  0.7198685926787999
ROC AUC :  0.7275064452291259
ROC AUC :  0.7355174428714334
ROC AUC :  0.742427222705111
ROC AUC :  0.7475982057939721
ROC AUC :  0.7507004339029081
ROC AUC :  0.7534725784767734
ROC AUC :  0.7587836192473348
ROC AUC :  0.7680855868335429
ROC AUC :  0.7811970068604089
ROC AUC :  0.7953463836799372
ROC AUC :  0.8048807261206058
ROC AUC :  0.8104016797590032
ROC AUC :  0.8134133313433187
ROC AUC :  0.8151487241930958
ROC AUC :  0.8168418988657862
ROC AUC :  0.819