In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.5.3-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.5.3


In [2]:
import math
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.nn import GINConv, GCNConv, VGAE
from torch.nn import Linear, Sequential, ReLU

In [3]:
import torch_geometric.transforms as T

transform = T.Compose([
    T.NormalizeFeatures(),
    #T.ToDevice(device),
    T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,
                      split_labels=True, add_negative_train_samples=False),
])

dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform)
train_data, val_data, test_data = dataset[0]

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [19]:
PARAMS = {
    'hidden_dim': 16,
    'batch_size': 256,
    'epochs': 500,
    'lr': 0.01,
    'weight_decay': 0,
    'seed': 69
}

class Params:
    def __init__(self, obj):
        for k, v in obj.items():
            setattr(self, k, v)

params = Params(PARAMS)

In [21]:
torch.manual_seed(params.seed)

class VariationalGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv_mu = GCNConv(2 * out_channels, out_channels)
        self.conv_logstd = GCNConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

class VariationalLinearEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_mu = GCNConv(in_channels, out_channels)
        self.conv_logstd = GCNConv(in_channels, out_channels)

    def forward(self, x, edge_index):
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

model = VGAE(VariationalGCNEncoder(dataset.num_features, params.hidden_dim))

In [22]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
train_data = train_data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=params.lr, weight_decay=params.weight_decay)

def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)
    loss = model.recon_loss(z, train_data.pos_edge_label_index)
    loss = loss + (1 / train_data.num_nodes) * model.kl_loss()
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test(data):
    model.eval()
    z = model.encode(data.x, data.edge_index)
    return model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)


times = []
best_auc = 0.0
for epoch in range(params.epochs):
    loss = train()
    auc, ap = test(val_data)
    if auc > best_auc:
      best_auc = auc
      test_auc, test_ap = test(test_data)
      print(f'Epoch (test): {epoch:03d}, AUC: {auc:.4f}, AP: {ap:.4f}')
    print(f'Epoch: {epoch:03d}, AUC: {auc:.4f}, AP: {ap:.4f}')

Epoch (test): 000, AUC: 0.7185, AP: 0.7385
Epoch: 000, AUC: 0.7185, AP: 0.7385
Epoch: 001, AUC: 0.7067, AP: 0.7233
Epoch: 002, AUC: 0.7020, AP: 0.7177
Epoch: 003, AUC: 0.6991, AP: 0.7157
Epoch: 004, AUC: 0.6968, AP: 0.7140
Epoch: 005, AUC: 0.6952, AP: 0.7126
Epoch: 006, AUC: 0.6946, AP: 0.7121
Epoch: 007, AUC: 0.6943, AP: 0.7115
Epoch: 008, AUC: 0.6942, AP: 0.7112
Epoch: 009, AUC: 0.6933, AP: 0.7105
Epoch: 010, AUC: 0.6931, AP: 0.7104
Epoch: 011, AUC: 0.6927, AP: 0.7101
Epoch: 012, AUC: 0.6923, AP: 0.7100
Epoch: 013, AUC: 0.6916, AP: 0.7095
Epoch: 014, AUC: 0.6915, AP: 0.7093
Epoch: 015, AUC: 0.6915, AP: 0.7091
Epoch: 016, AUC: 0.6914, AP: 0.7088
Epoch: 017, AUC: 0.6912, AP: 0.7087
Epoch: 018, AUC: 0.6915, AP: 0.7088
Epoch: 019, AUC: 0.6916, AP: 0.7088
Epoch: 020, AUC: 0.6914, AP: 0.7085
Epoch: 021, AUC: 0.6917, AP: 0.7088
Epoch: 022, AUC: 0.6916, AP: 0.7087
Epoch: 023, AUC: 0.6912, AP: 0.7085
Epoch: 024, AUC: 0.6911, AP: 0.7082
Epoch: 025, AUC: 0.6906, AP: 0.7081
Epoch: 026, AUC: 0.69