In [61]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch.nn.modules.module import Module
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, GAE, VGAE
from torch_geometric.utils import train_test_split_edges
import torch.optim as optim
import networkx as nx
from pathlib import Path
import sys

project_path = Path().cwd().parent.parent
model_path = project_path / 'model'
data_path = project_path / 'data'

In [68]:
kwargs = {'GAE': GAE, 'VGAE': VGAE}

class Args():
    def __init__(self, model, dataset):
        self.model = model
        self.dataset = dataset

class Encoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Encoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True)
        if args.model in ['GAE']:
            self.conv2 = GCNConv(2 * out_channels, out_channels, cached=True)
        elif args.model in ['VGAE']:
            self.conv_mu = GCNConv(2 * out_channels, out_channels, cached=True)
            self.conv_logvar = GCNConv(2 * out_channels, out_channels,
                                       cached=True)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        if args.model in ['GAE']:
            return self.conv2(x, edge_index)
        elif args.model in ['VGAE']:
            return self.conv_mu(x, edge_index), self.conv_logvar(x, edge_index)

args = Args(model='VGAE', dataset='Cora')
dataset = Planetoid(str(data_path), args.dataset, transform=T.NormalizeFeatures())
data = dataset[0]

In [69]:
channels = 2
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = kwargs[args.model](Encoder(dataset.num_features, channels)).to(dev)
data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data)
x, train_pos_edge_index = data.x.to(dev), data.train_pos_edge_index.to(dev)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [70]:
model

VGAE(
  (encoder): Encoder(
    (conv1): GCNConv(1433, 4)
    (conv_mu): GCNConv(4, 2)
    (conv_logvar): GCNConv(4, 2)
  )
  (decoder): InnerProductDecoder()
)

In [71]:
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(x, train_pos_edge_index)
    loss = model.recon_loss(z, train_pos_edge_index)
    if args.model in ['VGAE']:
        loss = loss + (1 / data.num_nodes) * model.kl_loss()
    loss.backward()
    optimizer.step()


def test(pos_edge_index, neg_edge_index):
    model.eval()
    with torch.no_grad():
        z = model.encode(x, train_pos_edge_index)
    return model.test(z, pos_edge_index, neg_edge_index)


for epoch in range(1, 15):
    train()
    auc, ap = test(data.test_pos_edge_index, data.test_neg_edge_index)
    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

Epoch: 001, AUC: 0.6082, AP: 0.6339
Epoch: 002, AUC: 0.6492, AP: 0.6757
Epoch: 003, AUC: 0.6611, AP: 0.6850
Epoch: 004, AUC: 0.6667, AP: 0.6896
Epoch: 005, AUC: 0.6692, AP: 0.6915
Epoch: 006, AUC: 0.6706, AP: 0.6931
Epoch: 007, AUC: 0.6707, AP: 0.6939
Epoch: 008, AUC: 0.6692, AP: 0.6943
Epoch: 009, AUC: 0.6666, AP: 0.6940
Epoch: 010, AUC: 0.6625, AP: 0.6921
Epoch: 011, AUC: 0.6555, AP: 0.6888
Epoch: 012, AUC: 0.6495, AP: 0.6855
Epoch: 013, AUC: 0.6446, AP: 0.6822
Epoch: 014, AUC: 0.6427, AP: 0.6805
