In [None]:
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F

# ---------------------------
# FiLM-style Relational GIN
# ---------------------------
class RelationalFiLMGINConv(MessagePassing):
    """
    Relation-aware GIN with FiLM modulation:
      gamma_r, beta_r = MLP_rel(r_emb[r])
      m_{j->i} = gamma_r ⊙ (W x_j) + beta_r
      x_i' = MLP_node( (1 + eps) * x_i + sum_j m_{j->i} )
    """
    def __init__(
            self,
            emb_dim: int,
            num_relations: int,
            hidden_layers: int = 1,   # depth of relation MLP (>=1)
            train_eps: bool = True
    ):
        super().__init__(aggr="add")  # GIN uses sum aggregation

        # relation embeddings (one per edge type)
        self.rel_emb = nn.Embedding(num_relations, emb_dim)
        nn.init.xavier_uniform_(self.rel_emb.weight)

        # relation MLP that outputs [gamma | beta] of size 2 * emb_dim
        rel_layers = [nn.Linear(emb_dim, emb_dim * 2), nn.ReLU()]
        for _ in range(hidden_layers - 1):
            rel_layers += [nn.Linear(emb_dim * 2, emb_dim * 2), nn.ReLU()]
        rel_layers += [nn.Linear(emb_dim * 2, emb_dim * 2)]
        self.rel_mlp = nn.Sequential(*rel_layers)

        # shared linear transform on node features
        self.W = nn.Linear(emb_dim, emb_dim)

        # node-side GIN MLP (kept simple; you can deepen if you like)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, emb_dim * 2),
            nn.ReLU(),
            nn.Linear(emb_dim * 2, emb_dim),
        )

        # learnable epsilon like classic GIN
        if train_eps:
            self.eps = nn.Parameter(torch.zeros(1))
        else:
            self.register_buffer("eps", torch.zeros(1))

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_type: torch.Tensor) -> torch.Tensor:
        """
        x: [N, d]
        edge_index: [2, E]
        edge_type: [E]  (relation id for each edge)
        """
        # propagate will call message(...) then aggregate, then update(...)
        return self.propagate(edge_index, x=x, edge_type=edge_type)

    def message(self, x_j: torch.Tensor, edge_type: torch.Tensor) -> torch.Tensor:
        """
        x_j: [E, d] features of source nodes (neighbors)
        edge_type: [E] relation ids aligned with edges
        """
        r = self.rel_emb(edge_type)                 # [E, d]
        gamma_beta = self.rel_mlp(r)                # [E, 2d]
        gamma, beta = gamma_beta.chunk(2, dim=-1)   # each [E, d]
        return gamma * self.W(x_j) + beta           # FiLM message

    def update(self, aggr_out: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        """
        aggr_out: [N, d] summed messages
        x:         [N, d] original node features, passed from propagate via kwargs
        """
        out = (1.0 + self.eps) * x + aggr_out
        return self.mlp(out)


class RelationalGINEncoder(nn.Module):
    def __init__(
            self,
            num_nodes: int,
            num_relations: int,
            emb_dim: int = 128,
            num_layers: int = 3,
            hidden_layers: int = 1,   # depth for the relation FiLM MLP
            dropout: float = 0.1,
            train_eps: bool = True,
    ):
        super().__init__()
        self.embed = nn.Embedding(num_nodes, emb_dim)
        nn.init.xavier_uniform_(self.embed.weight)

        self.convs = nn.ModuleList([
            RelationalFiLMGINConv(
                emb_dim=emb_dim,
                num_relations=num_relations,
                hidden_layers=hidden_layers,
                train_eps=train_eps
            )
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout)

    def forward(self, edge_index: torch.Tensor, edge_type: torch.Tensor) -> torch.Tensor:
        x = self.embed.weight
        for conv in self.convs:
            x = conv(x, edge_index, edge_type)
            x = self.dropout(x)
        return x  # [N, emb_dim]