## DGL

The most popular GNN models can be written as follows:

$$
h_i^{(l+1)}=\sigma(b^{(l)}+\sum_{j\in\mathcal{N}(i)}e_{ij}h_j^{(l)}W^{(l)})
$$

where $h_i^{(l+1)}$ is the output feature, $\sigma$ is the activation function, $e_{ij}$ is the edge weight, $W^{(l)}$ is the learnable parameters, $b^{(l)
}$ is the bias.

First, you will use the PyTorch-Geometric(PyG) to implement this convolution layer.

Now, you will implement the same functions with DGL.

In [66]:
import dgl
import dgl.function as fn
import numpy as np
import torch
import torch.nn as nn


class DGL_conv(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(DGL_conv, self).__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.W = nn.Parameter(torch.ones(in_channel, out_channel))
        self.b = nn.Parameter(torch.ones(out_channel))

    def forward(self, g, h):
        # 输入为g和feat(h)
        with g.local_scope():
            hw = h @ self.W
            g.ndata["hw"] = hw
            g.update_all(fn.u_mul_e("hw", "e", "ehw"), fn.sum("ehw", "ehw_N"))
            ehw_N = g.ndata["ehw_N"]

            return ehw_N + self.b

Also, you can also run the code below to check the correctness.

In [67]:
src = torch.tensor([0, 1, 1, 2, 2, 4])
dst = torch.tensor([2, 0, 2, 3, 4, 3])
h = torch.ones((5, 8))
g = dgl.graph((src, dst))
edge_weight = 2 * torch.ones(6)

# 边赋值
g.edata["e"] = edge_weight

In [72]:
conv = DGL_conv(8, 4)
# 输入为h
output = conv(g, h)
print(output)

assert np.allclose(
    output.detach().numpy(),
    [
        [17.0, 17.0, 17.0, 17.0],
        [1.0, 1.0, 1.0, 1.0],
        [33.0, 33.0, 33.0, 33.0],
        [33.0, 33.0, 33.0, 33.0],
        [17.0, 17.0, 17.0, 17.0],
    ],
)

tensor([[17., 17., 17., 17.],
        [ 1.,  1.,  1.,  1.],
        [33., 33., 33., 33.],
        [33., 33., 33., 33.],
        [17., 17., 17., 17.]], grad_fn=<AddBackward0>)


In [None]:
# conv = DGL_conv(8, 4)
# output = conv(g, h, edge_weight)   ------------------------------<<<<<<<

# assert np.allclose(
#     output.detach().numpy(),
#     [
#         [17.0, 17.0, 17.0, 17.0],
#         [1.0, 1.0, 1.0, 1.0],
#         [33.0, 33.0, 33.0, 33.0],
#         [33.0, 33.0, 33.0, 33.0],
#         [17.0, 17.0, 17.0, 17.0],
#     ],
# )