In [3]:
import torch.nn as nn

In [7]:
class Generator(nn.Module):
    def __init__(self, latent_dim=32, N=9, T=5, Y=4):
        super().__init__()
        
        self.latent_dim = latent_dim 
        self.N = N  
        self.T = T  
        self.Y = Y  
        
        self.atom_mlp = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, N * T),
            nn.Softmax() 
        )
        
        self.adj_mlp = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, N * N * Y),
            nn.Softmax()  
        )
        
    def forward(self, z):
        batch_size = z.size(0)
        
        x = self.atom_mlp(z) 
        x = x.view(batch_size, self.N, self.T)  # batch_size × N × T
        
        a = self.adj_mlp(z)  # batch_size × (N*N*Y)
        a = a.view(batch_size, self.N, self.N, self.Y)  # batch_size × N × N × Y
        
        return x, a

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class RGCNLayer(nn.Module):
    def __init__(self, in_feat, out_feat, num_rels):
        super().__init__()
        self.num_rels = num_rels
        self.relation_weights = nn.ModuleList([nn.Linear(in_feat, out_feat) for _ in range(num_rels)])
        self.self_loop = nn.Linear(in_feat, out_feat)

    def forward(self, h, adj):
        # h: batch_size × N × in_feat
        # adj: batch_size × N × N × num_rels
        batch_size, N, _ = h.size()
        out_feat = self.relation_weights[0].out_features
        new_h = torch.zeros(batch_size, N, out_feat, device=h.device)

        for r in range(self.num_rels):
            adj_r = adj[..., r]  # batch_size × N × N
            neigh_h = torch.bmm(adj_r, h)  # batch_size × N × in_feat
            new_h += self.relation_weights[r](neigh_h)

        # Add self-loop
        new_h += self.self_loop(h)

        return F.relu(new_h)

class Discriminator(nn.Module):
    def __init__(self, node_dim=5, num_rels=4, hidden_dim=64, embed_dim=32, N=50):
        super().__init__()
        self.N = N
        self.rgcn_layer1 = RGCNLayer(node_dim, hidden_dim, num_rels)
        self.rgcn_layer2 = RGCNLayer(hidden_dim, embed_dim, num_rels)  # Output to 32-dim per node

        # Two-layer MLP: 32 -> 128 -> 1
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Tanh()  # Outputs between -1 and 1; adjust if strict 0-1 is needed
        )

    def forward(self, x, a):
        # x: batch_size × N × T (node features)
        # a: batch_size × N × N × Y (multi-relational adjacency)
        
        # Relational GCN layers
        h = self.rgcn_layer1(x, a)
        h = self.rgcn_layer2(h, a)
        
        # Graph aggregation: mean pooling over nodes to get 32-dim vector
        graph_embed = h.mean(dim=1)  # batch_size × 32
        
        # MLP to output scalar
        out = self.mlp(graph_embed)  # batch_size × 1
        
        return out