In [1]:
import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.datasets import Amazon
from torch_geometric.data import DataLoader
from torch_geometric.nn import VGAE
from torch_geometric.utils import train_test_split_edges
from sklearn.metrics import f1_score as sk_f1_score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = Amazon(root='data/Amazon', name='Computers')
data = dataset[0]

data = train_test_split_edges(data)



In [3]:
data

Data(x=[13752, 767], y=[13752], val_pos_edge_index=[2, 12293], test_pos_edge_index=[2, 24586], train_pos_edge_index=[2, 417964], train_neg_adj_mask=[13752, 13752], val_neg_edge_index=[2, 12293], test_neg_edge_index=[2, 24586])

In [4]:
class Encoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.5):
        super(Encoder, self).__init__()
        self.conv1 = torch_geometric.nn.GCNConv(in_channels, 2 * out_channels, cached=True)
        self.conv_mu = torch_geometric.nn.GCNConv(2 * out_channels, out_channels, cached=True)
        self.conv_logvar = torch_geometric.nn.GCNConv(2 * out_channels, out_channels, cached=True)
        
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.bn1 = torch.nn.BatchNorm1d(2 * out_channels)
        self.bn_mu = torch.nn.BatchNorm1d(out_channels)
        self.bn_logvar = torch.nn.BatchNorm1d(out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        mu = self.conv_mu(x, edge_index)
        mu = self.bn_mu(mu)
        
        logvar = self.conv_logvar(x, edge_index)
        logvar = self.bn_logvar(logvar)
        
        return mu, logvar

In [None]:
out_channels = 16
encoder = Encoder(dataset.num_features, out_channels)
model = VGAE(encoder)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [5]:
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(data.x, data.train_pos_edge_index)
    loss = model.recon_loss(z, data.train_pos_edge_index) + (1 / data.num_nodes) * model.kl_loss()
    loss.backward()
    optimizer.step()
    return float(loss)

In [None]:
def test(pos_edge_index, neg_edge_index):
    model.eval()
    with torch.no_grad():
        z = model.encode(data.x, data.train_pos_edge_index)
    auc, ap = model.test(z, pos_edge_index, neg_edge_index)
    
    pos_pred = model.decode(z, pos_edge_index).view(-1)
    neg_pred = model.decode(z, neg_edge_index).view(-1)
    preds = torch.cat([pos_pred, neg_pred])
    
    pos_label = torch.ones(pos_edge_index.size(1), )
    neg_label = torch.zeros(neg_edge_index.size(1), )
    labels = torch.cat([pos_label, neg_label])
    
    preds = (preds > 0.5).float().cpu().numpy()
    labels = labels.cpu().numpy()

    f1 = sk_f1_score(labels, preds)
    return auc, ap, f1

In [6]:
for epoch in range(1, 201):
    loss = train()
    auc, ap, f1 = test(data.test_pos_edge_index, data.test_neg_edge_index)
    print(f'Epoch: {epoch}, Loss: {loss:.4f}, AUC: {auc:.4f}, AP: {ap:.4f}, F1: {f1:.4f}')

Epoch: 1, Loss: 128.2774, AUC: 0.6646, AP: 0.6034, F1: 0.6667
Epoch: 2, Loss: 20.7686, AUC: 0.5819, AP: 0.5454, F1: 0.6670
Epoch: 3, Loss: 17.5908, AUC: 0.5719, AP: 0.5393, F1: 0.6681
Epoch: 4, Loss: 14.2298, AUC: 0.5798, AP: 0.5442, F1: 0.6695
Epoch: 5, Loss: 18.0012, AUC: 0.5785, AP: 0.5434, F1: 0.6712
Epoch: 6, Loss: 16.4956, AUC: 0.5834, AP: 0.5464, F1: 0.6728
Epoch: 7, Loss: 14.1026, AUC: 0.5872, AP: 0.5488, F1: 0.6747
Epoch: 8, Loss: 11.4815, AUC: 0.5929, AP: 0.5524, F1: 0.6769
Epoch: 9, Loss: 11.2207, AUC: 0.5992, AP: 0.5564, F1: 0.6800
Epoch: 10, Loss: 9.9730, AUC: 0.6070, AP: 0.5616, F1: 0.6836
Epoch: 11, Loss: 11.9367, AUC: 0.6128, AP: 0.5654, F1: 0.6869
Epoch: 12, Loss: 9.1552, AUC: 0.6196, AP: 0.5699, F1: 0.6906
Epoch: 13, Loss: 10.0544, AUC: 0.6259, AP: 0.5743, F1: 0.6930
Epoch: 14, Loss: 10.4375, AUC: 0.6354, AP: 0.5810, F1: 0.6960
Epoch: 15, Loss: 8.7407, AUC: 0.6448, AP: 0.5878, F1: 0.6986
Epoch: 16, Loss: 8.5230, AUC: 0.6560, AP: 0.5962, F1: 0.7016
Epoch: 17, Loss: 10.