In [1]:
# Re-import necessary modules after code execution state reset

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import EdgeConv, global_mean_pool

# Custom EdgeConv wrapper for conditional GNN
class ConditionalEdgeConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * in_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels),
        )
        self.conv = EdgeConv(nn=self.edge_mlp)

    def forward(self, x, edge_index):
        return self.conv(x, edge_index)

# Generator: takes random noise and a condition (target energy), outputs rechits
class GNNGenerator(nn.Module):
    def __init__(self, noise_dim=16, condition_dim=1, hidden_dim=64, out_dim=4):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(noise_dim + condition_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.gnn1 = ConditionalEdgeConv(hidden_dim, hidden_dim)
        self.gnn2 = ConditionalEdgeConv(hidden_dim, hidden_dim)
        self.out_layer = nn.Linear(hidden_dim, out_dim)  # Output: [E, x, y, z]

    def forward(self, z, condition, edge_index):
        condition = condition.expand(z.size(0), 1)
        x = torch.cat([z, condition], dim=-1)
        x = self.fc(x)
        x = F.relu(self.gnn1(x, edge_index))
        x = F.relu(self.gnn2(x, edge_index))
        x = self.out_layer(x)
        return x

# Discriminator: classifies graphs (real or fake) conditioned on target energy
class GNNDiscriminator(nn.Module):
    def __init__(self, in_dim=4, condition_dim=1, hidden_dim=64):
        super().__init__()
        self.gnn1 = ConditionalEdgeConv(in_dim + condition_dim, hidden_dim)
        self.gnn2 = ConditionalEdgeConv(hidden_dim, hidden_dim)
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, edge_index, batch, condition):
        cond = condition.expand(x.size(0), 1)
        x = torch.cat([x, cond], dim=-1)
        x = F.relu(self.gnn1(x, edge_index))
        x = F.relu(self.gnn2(x, edge_index))
        x = global_mean_pool(x, batch)  # [batch_size, hidden_dim]
        out = self.fc(x)  # [batch_size, 1]
        return out.view(-1)  # [batch_size]



In [2]:
import torch
import torch.nn as nn
from torch_geometric.loader import DataLoader

def train_gnn_gan(generator, discriminator, dataloader, epochs=50, device='cuda'):

    g_opt = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
    d_opt = torch.optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))
    bce_loss = nn.BCEWithLogitsLoss()

    generator.to(device)
    discriminator.to(device)

    for epoch in range(epochs):
        for real_data in dataloader:
            real_data = real_data.to(device)
            real_x = real_data.x
            real_y = real_data.y.view(-1, 1)  # target energy
            real_edge_index = real_data.edge_index
            real_batch = real_data.batch

            # ============ Train Discriminator ============
            discriminator.train()
            d_opt.zero_grad()

            # Real
            real_pred = discriminator(real_x, real_edge_index, real_batch, real_y)
            real_label = torch.ones_like(real_pred)
            loss_real = bce_loss(real_pred, real_label)

            # Fake
            num_nodes = real_x.size(0)
            z = torch.randn(num_nodes, 16).to(device)
            fake_x = generator(z, real_y, real_edge_index)
            fake_pred = discriminator(fake_x.detach(), real_edge_index, real_batch, real_y)
            fake_label = torch.zeros_like(fake_pred)
            loss_fake = bce_loss(fake_pred, fake_label)

            d_loss = loss_real + loss_fake
            d_loss.backward()
            d_opt.step()

            # ============ Train Generator ============
            g_opt.zero_grad()
            z = torch.randn(num_nodes, 16).to(device)
            fake_x = generator(z, real_y, real_edge_index)
            gen_pred = discriminator(fake_x, real_edge_index, real_batch, real_y)
            g_loss = bce_loss(gen_pred, torch.ones_like(gen_pred))
            g_loss.backward()
            g_opt.step()

        print(f"Epoch {epoch+1}/{epochs} | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")
