# Graph Auto Encoder with PyG

In [None]:
import argparse
import os
import time

import torch

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid

from torch_geometric.nn import GAE, GCNConv

In [None]:
device = torch.device('cpu')

In [None]:
DATASET_NAME="Cora"

In [None]:
transform = T.Compose([
    T.NormalizeFeatures(),
    T.RandomLinkSplit(num_val=0., num_test=0.1, is_undirected=True,
                      split_labels=True, add_negative_train_samples=False),
])
# path = os.path.join("/home/deusebio/Personal/graph_machine_learning/data", 'data')
path = os.path.join(os.getcwd(), 'data')
dataset = Planetoid(path, DATASET_NAME, transform=transform)
train_data, val_data, test_data = dataset[0]

In [None]:
print(f"Train edges (positive): {train_data.pos_edge_label_index.shape[1]}")
print(f"Test edges (positive): {test_data.pos_edge_label_index.shape[1]}")
print(f"Test edges (negative): {test_data.neg_edge_label_index.shape[1]}")

In [None]:
class GCNEncoder(torch.nn.Module):
    def __init__(self, num_node_features, num_embedding):
        super().__init__()
        self.conv1 = GCNConv(num_node_features, 2 * num_embedding)
        self.conv2 = GCNConv(2 * num_embedding, num_embedding)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

In [None]:
n_features = dataset.num_features
n_embeddings = 20

In [None]:
model = GAE(GCNEncoder(n_features, n_embeddings))

In [None]:
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
for epoch in range(20):  # loop over the dataset multiple times

    model.train()

    # zero the parameter gradients
    optimizer.zero_grad()

    z = model.encode(train_data.x, train_data.edge_index)
    loss = model.recon_loss(z, train_data.pos_edge_label_index)

    loss.backward()
    optimizer.step()
    
    # Test/Evaluate
    model.eval()
    z = model.encode(test_data.x, test_data.edge_index)
    auc, ap = model.test(z, test_data.pos_edge_label_index, test_data.neg_edge_label_index)
    
    print(f"Performance on validation set => AUC: {auc} AP: {ap}")