In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GAE, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)
        self.fc1 = nn.Linear(output_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, input_dim)
    
    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x
    
    def decode(self, z):
        z = self.fc1(z)
        z = F.relu(z)
        z = self.fc2(z)
        return torch.sigmoid(torch.matmul(z, z.t()))
    
    def forward(self, x, edge_index):
        z = self.encode(x, edge_index)
        return self.decode(z), z


In [6]:
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T

dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=T.NormalizeFeatures())
data = dataset[0]


In [8]:
import torch
import torch.nn.functional as F
from torch_geometric.utils import to_dense_adj

def train():
    model.train()
    optimizer.zero_grad()
    reconstructed, z = model(data.x, data.edge_index)
    
    # Compute the dense adjacency matrix from edge_index
    adj_dense = to_dense_adj(data.edge_index)[0]
    
    loss = F.binary_cross_entropy(reconstructed, adj_dense)
    loss.backward()
    optimizer.step()
    return loss.item()

for epoch in range(200):
    loss = train()
    print(f'Epoch {epoch}, Loss: {loss:.4f}')


Epoch 0, Loss: 99.8560
Epoch 1, Loss: 99.8560
Epoch 2, Loss: 15.9194
Epoch 3, Loss: 10.1356
Epoch 4, Loss: 5.7721
Epoch 5, Loss: 3.2164
Epoch 6, Loss: 1.9674
Epoch 7, Loss: 1.5347
Epoch 8, Loss: 1.4059
Epoch 9, Loss: 1.2767
Epoch 10, Loss: 1.1390
Epoch 11, Loss: 1.1061
Epoch 12, Loss: 1.1123
Epoch 13, Loss: 1.0155
Epoch 14, Loss: 0.9572
Epoch 15, Loss: 0.9630
Epoch 16, Loss: 0.9239
Epoch 17, Loss: 0.8567
Epoch 18, Loss: 0.8363
Epoch 19, Loss: 0.8307
Epoch 20, Loss: 0.7883
Epoch 21, Loss: 0.7604
Epoch 22, Loss: 0.7631
Epoch 23, Loss: 0.7468
Epoch 24, Loss: 0.7241
Epoch 25, Loss: 0.7319
Epoch 26, Loss: 0.7343
Epoch 27, Loss: 0.7187
Epoch 28, Loss: 0.7154
Epoch 29, Loss: 0.7154
Epoch 30, Loss: 0.7050
Epoch 31, Loss: 0.7007
Epoch 32, Loss: 0.7048
Epoch 33, Loss: 0.7031
Epoch 34, Loss: 0.6988
Epoch 35, Loss: 0.7002
Epoch 36, Loss: 0.6997
Epoch 37, Loss: 0.6955
Epoch 38, Loss: 0.6953
Epoch 39, Loss: 0.6966
Epoch 40, Loss: 0.6952
Epoch 41, Loss: 0.6943
Epoch 42, Loss: 0.6955
Epoch 43, Loss: 0