In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GraphAutoencoder(nn.Module):
    def __init__(self, num_features, hidden_dim=32, latent_dim=16):
        super(GraphAutoencoder, self).__init__()
        
        self.encoder_conv1 = GCNConv(num_features, hidden_dim)
        self.encoder_conv2 = GCNConv(hidden_dim, latent_dim)
        
        self.attr_decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_features)
        )
        
        self.struct_decoder = nn.Linear(latent_dim, latent_dim)
        
    def encode(self, x, edge_index):
        x = F.relu(self.encoder_conv1(x, edge_index))
        z = self.encoder_conv2(x, edge_index)
        return z
    
    def decode_attributes(self, z):
        return self.attr_decoder(z)
    
    def decode_structure(self, z):
        z_transformed = self.struct_decoder(z)
        adj_reconstructed = torch.mm(z_transformed, z_transformed.t())
        return adj_reconstructed
    
    def forward(self, x, edge_index):
        z = self.encode(x, edge_index)
        
        x_reconstructed = self.decode_attributes(z)
        adj_reconstructed = self.decode_structure(z)
        
        return x_reconstructed, adj_reconstructed

ModuleNotFoundError: No module named 'torch'