In [3]:
import dgl
import dgl.function as fn
import dgl.nn as dglnn
import openhgnn
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph

$$
h_i^{(l+1)} = \sigma(\sum_{r\in\mathcal{R}}
\sum_{j\in\mathcal{N}^r(i)}e_{j,i}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)})
$$

In [4]:
class RGCNLayer(nn.Module):
    def __init__(
        self,
        in_feat,
        out_feat,
        num_rels,
        regularizer=None,
        num_bases=None,
        bias=True,
        activation=None,
        self_loop=True,
        dropout=0.0,
        layer_norm=False,
    ):
        super(RGCNLayer, self).__init__()
        if self.num_bases <= 0 or self.num_bases > self.num_rels:
            self.num_bases = self.num_rels

        self.linear_r = dglnn.TypedLinear(
            in_feat, out_feat, num_rels, regularizer, num_bases
        )
        self.bias = bias
        self.activation = activation
        self.self_loop = self_loop
        self.layer_norm = layer_norm

        # bias
        if self.bias:
            self.h_bias = nn.Parameter(torch.Tensor(out_feat))
            nn.init.zeros_(self.h_bias)

        # layer norm
        if self.layer_norm:
            self.layer_norm_weight = nn.LayerNorm(out_feat, elementwise_affine=True)

        # weight for self loop
        if self.self_loop:
            self.loop_weight = nn.Parameter(torch.Tensor(in_feat, out_feat))
            nn.init.xavier_uniform_(
                self.loop_weight, gain=nn.init.calculate_gain("relu")
            )

        self.dropout = nn.Dropout(dropout)

        def message(self, edges: dgl.udf.EdgeBatch):
            m = self.linear_r(edges.src["h"], edges.data["etype"])
            if "norm" in edges.data:
                m = m * edges.data["norm"]
            return {"m": m}

        def forward(self, g: dgl.DGLGraph, feat, etypes, norm=None):
            with g.local_scope():
                g.srcdata["h"] = feat
                if norm is not None:
                    g.edata["norm"] = norm
                g.update_all(self.message, fn.sum("m", "h"))
                h = g.dstdata["h"]
                if self.layer_norm:
                    h = self.layer_norm_weight(h)
                if self.bias:
                    h = h + self.h_bias
                if self.self_loop:
                    h = h + feat[: g.num_dst_nodes()] @ self.loop_weight
                if self.activation:
                    h = self.activation(h)
                h = self.dropout(h)
                return h