In [4]:
!pip install torch torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv


Collecting torch-geometric
  Using cached torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
Collecting torch-scatter
  Using cached torch_scatter-2.1.2-cp310-cp310-linux_x86_64.whl
Collecting torch-sparse
  Using cached torch_sparse-0.6.18.tar.gz (209 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch-cluster
  Using cached torch_cluster-1.6.3.tar.gz (54 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch-spline-conv
  Using cached torch_spline_conv-1.2.2.tar.gz (25 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Using cached torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
Building wheels for collected packages: torch-sparse, torch-cluster, torch-spline-conv
  Building wheel for torch-sparse (setup.py) ... [?25l[?25hdone
  Created wheel for torch-sparse: filename=torch_sparse-0.6.18-cp310-cp310-linux_x86_64.whl size=1104576 sha256=363e91c7e9ade04a3856f0de85bacdf96bf685f169303766fe07c37e5fffd9bc
  Stored in directory: /root/.

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import GCNConv, global_mean_pool # This import will now work
from torch_geometric.data import Data, DataLoader

# Define the Generator model
class Generator(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.gcn = GCNConv(hidden_dim, output_dim)

    def forward(self, z, edge_index):
        x = torch.relu(self.fc1(z))
        return self.gcn(x, edge_index)

# Define the Discriminator model
class Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Discriminator, self).__init__()
        self.gcn1 = GCNConv(input_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, x, edge_index, batch):
        x = torch.relu(self.gcn1(x, edge_index))
        x = global_mean_pool(x, batch)  # Pool to get graph-level embedding
        return torch.sigmoid(self.fc(x))

# Hyperparameters
input_dim = 8  # Smaller latent dimension for faster training
hidden_dim = 16
output_dim = 8
epochs = 20  # Reduced epochs for quicker results
batch_size = 8  # Reduced batch size

# Initialize Generator and Discriminator
generator = Generator(input_dim, hidden_dim, output_dim)
discriminator = Discriminator(output_dim, hidden_dim)

# Optimizers
g_optimizer = optim.Adam(generator.parameters(), lr=0.001)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.001)

# Create synthetic graph data (small graphs for faster processing)
def generate_real_data(num_nodes=10, feature_dim=8):
    edge_index = torch.randint(0, num_nodes, (2, 10))  # Random edges for simplicity
    x = torch.randn((num_nodes, feature_dim))  # Random node features
    return Data(x=x, edge_index=edge_index)

# Training loop
for epoch in range(epochs):
    real_data = [generate_real_data() for _ in range(batch_size)]
    real_loader = DataLoader(real_data, batch_size=batch_size)

    for real_graph in real_loader:
        # Train Discriminator
        z = torch.randn((real_graph.num_nodes, input_dim))  # Random noise
        fake_graph = generator(z, real_graph.edge_index)

        d_optimizer.zero_grad()
        real_pred = discriminator(real_graph.x, real_graph.edge_index, real_graph.batch)
        fake_pred = discriminator(fake_graph.detach(), real_graph.edge_index, real_graph.batch)

        d_loss = -torch.mean(torch.log(real_pred + 1e-8) + torch.log(1 - fake_pred + 1e-8))
        d_loss.backward()
        d_optimizer.step()

        # Train Generator
        g_optimizer.zero_grad()
        fake_pred = discriminator(fake_graph, real_graph.edge_index, real_graph.batch)
        g_loss = -torch.mean(torch.log(fake_pred + 1e-8))
        g_loss.backward()
        g_optimizer.step()

    # Log progress every epoch
    print(f"Epoch {epoch+1}/{epochs}, D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")




Epoch 1/20, D Loss: 1.3847, G Loss: 0.5928
Epoch 2/20, D Loss: 1.3518, G Loss: 0.5767
Epoch 3/20, D Loss: 1.3922, G Loss: 0.5808
Epoch 4/20, D Loss: 1.3677, G Loss: 0.6020
Epoch 5/20, D Loss: 1.3434, G Loss: 0.5870
Epoch 6/20, D Loss: 1.3440, G Loss: 0.5864
Epoch 7/20, D Loss: 1.3936, G Loss: 0.5946
Epoch 8/20, D Loss: 1.3700, G Loss: 0.5951
Epoch 9/20, D Loss: 1.4053, G Loss: 0.5845
Epoch 10/20, D Loss: 1.3599, G Loss: 0.5953
Epoch 11/20, D Loss: 1.3347, G Loss: 0.5904
Epoch 12/20, D Loss: 1.3534, G Loss: 0.5961
Epoch 13/20, D Loss: 1.3510, G Loss: 0.5846
Epoch 14/20, D Loss: 1.3609, G Loss: 0.5894
Epoch 15/20, D Loss: 1.3301, G Loss: 0.5987
Epoch 16/20, D Loss: 1.3612, G Loss: 0.5604
Epoch 17/20, D Loss: 1.3727, G Loss: 0.5874
Epoch 18/20, D Loss: 1.3888, G Loss: 0.5946
Epoch 19/20, D Loss: 1.4372, G Loss: 0.5778
Epoch 20/20, D Loss: 1.3719, G Loss: 0.5967
