In [5]:
import dgl
import dgl.nn as dglnn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from openhgnn.models import BaseModel, register_model

In [None]:
# @register_model('myRGCN')
# class RGCN(BaseModel):

In [3]:
class RGCNlayer(nn.Module):
    def __init__(
        self,
        in_feat,
        out_feat,
        rel_names,
        num_bases,  #  Number of bases. If is none, use number of relations. Default: None.
        weight=True,
        bias=True,
        activation=None,
        self_loop=False,
        dropout=0.0,
    ):
        super(RGCNlayer, self).__init__()
        self.in_feat = in_feat
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.rel_names = rel_names
        self.num_bases = num_bases
        self.bias = bias
        self.activation = activation
        self.self_loop = self_loop
        self.batchnorm = False

        self.conv = dglnn.HeteroGraphConv(
            {
                rel: dglnn.GraphConv(
                    in_feat, out_feat, norm="right", weight=False, bias=False
                )
                for rel in rel_names
            }
        )

        self.use_weight = weight
        self.use_basis = num_bases < len(self.rel_names) and weight  #
        if self.use_weight:
            if self.use_basis:
                self.basis = dglnn.WeightBasis(
                    (in_feat, out_feat), num_bases, len(self.rel_names)
                )
            else:
                self.weight = nn.Parameter(
                    th.Tensor(len(self.rel_names), in_feat, out_feat)
                )
                nn.init.xavier_uniform_(
                    self.weight, gain=nn.init.calculate_gain("relu")
                )

        # bias
        if bias:
            self.h_bias = nn.Parameter(th.Tensor(out_feat))
            nn.init.zeros_(self.h_bias)

        if self.self_loop:
            self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
            nn.init.xavier_uniform_(
                self.loop_weight, gain=nn.init.calculate_gain("relu")
            )

        if self.batchnorm:
            self.bn = nn.BatchNorm1d(out_feat)

        self.dropout = nn.Dropout(dropout)

        def forward(self, g: dgl.DGLGraph, inputs):
            g = g.local_var()
            if self.use_weight:
                weight = self.basis() if self.use_basis else self.weight
                wdict = {
                    self.rel_names[i]: {"weight": w.squeeze()}
                    for i, w in enumerate(th.split(weight, 1, dim=0))
                }
            else:
                wdict = {}

            if g.is_block:
                inputs_src = inputs
                inputs_dst = {
                    k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()
                }
            else:
                inputs_src = inputs_dst = inputs

            hs = self.conv(g, inputs_src, mod_kwargs=wdict)

            def _apply(ntype, h):
                if self.self_loop:
                    h = h + th.matmul(inputs_dst[ntype], self.loop_weight)
                if self.bias:
                    h = h + self.h_bias
                if self.activation:
                    h = self.activation(h)
                if self.batchnorm:
                    h = self.bn(h)
                return self.dropout(h)

            return {ntype: _apply(ntype, h) for ntype, h in hs.items()}

In [13]:
e = th.rand(5, 4)
print(e)

th.split(e, 2, dim=0)

tensor([[0.1403, 0.0395, 0.5835, 0.8673],
        [0.4261, 0.6162, 0.7110, 0.0673],
        [0.9705, 0.0353, 0.9909, 0.3213],
        [0.7699, 0.5226, 0.2869, 0.1250],
        [0.7725, 0.2516, 0.8530, 0.6372]])


(tensor([[0.1403, 0.0395, 0.5835, 0.8673],
         [0.4261, 0.6162, 0.7110, 0.0673]]),
 tensor([[0.9705, 0.0353, 0.9909, 0.3213],
         [0.7699, 0.5226, 0.2869, 0.1250]]),
 tensor([[0.7725, 0.2516, 0.8530, 0.6372]]))