In [None]:
# TODO: Get Plots for ROC
# TODO: Step through Code
# TODO: Look at embeddings
# TODO: Graph kernel similarity in loss

In [3]:
import os.path as osp

import argparse

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
from vgae import GAE, VGAE

In [4]:
dataset = 'Cora'
path = '../data/geometric/' + dataset.upper()

dataset = Planetoid(path, dataset, T.NormalizeFeatures())
data = dataset[0]

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='VGAE')
args, unknown = parser.parse_known_args()

assert args.model in ['GAE', 'VGAE']
kwargs = {'GAE': GAE, 'VGAE': VGAE}

In [7]:
class Encoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Encoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True)

        self.conv_mu = GCNConv(2 * out_channels, out_channels, cached=True)
        self.conv_logvar = GCNConv(
            2 * out_channels, out_channels, cached=True)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        
        return self.conv_mu(x, edge_index), self.conv_logvar(x, edge_index)


channels = 16
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = kwargs[args.model](Encoder(dataset.num_features, channels)).to(device)

data.train_mask = data.val_mask = data.test_mask = data.y = None
data = model.split_edges(data)

x, edge_index = data.x.to(device), data.edge_index.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
    model.train()
    optimizer.zero_grad()
    
    z = model.encode(x, edge_index)
    
    loss = model.recon_loss(z, data.train_pos_edge_index)
    loss = loss + 0.001 * model.kl_loss()
    
    loss.backward()
    optimizer.step()


def test(pos_edge_index, neg_edge_index):
    model.eval()
    with torch.no_grad():
        z = model.encode(x, edge_index)
        
    return model.test(z, pos_edge_index, neg_edge_index)


for epoch in range(1, 201):
    train()
    auc, ap = test(data.val_pos_edge_index, data.val_neg_edge_index)
    if epoch % 10 == 0:
        print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

auc, ap = test(data.test_pos_edge_index, data.test_neg_edge_index)
print('Test AUC: {:.4f}, Test AP: {:.4f}'.format(auc, ap))

Epoch: 010, AUC: 0.7823, AP: 0.7837
Epoch: 020, AUC: 0.7806, AP: 0.7810
Epoch: 030, AUC: 0.7789, AP: 0.7774
Epoch: 040, AUC: 0.7751, AP: 0.7749
Epoch: 050, AUC: 0.7712, AP: 0.7758
Epoch: 060, AUC: 0.7642, AP: 0.7765
Epoch: 070, AUC: 0.7663, AP: 0.7801
Epoch: 080, AUC: 0.7957, AP: 0.8023
Epoch: 090, AUC: 0.8113, AP: 0.8103
Epoch: 100, AUC: 0.8171, AP: 0.8161
Epoch: 110, AUC: 0.8243, AP: 0.8305
Epoch: 120, AUC: 0.8254, AP: 0.8324
Epoch: 130, AUC: 0.8322, AP: 0.8366
Epoch: 140, AUC: 0.8507, AP: 0.8512
Epoch: 150, AUC: 0.8840, AP: 0.8892
Epoch: 160, AUC: 0.9040, AP: 0.9079
Epoch: 170, AUC: 0.9096, AP: 0.9131
Epoch: 180, AUC: 0.9131, AP: 0.9176
Epoch: 190, AUC: 0.9162, AP: 0.9212
Epoch: 200, AUC: 0.9177, AP: 0.9226
Test AUC: 0.9073, Test AP: 0.9021
