## PYG

In this work, you will use the APIs of PyG and DGL to implement some basic functions.

You need to run the following commands to install the GNN libraries (Only CPU version).

The most popular GNN models can be written as follows:

$$
h_i^{(l+1)}=\sigma(b^{(l)}+\sum_{j\in\mathcal{N}(i)}e_{ij}h_j^{(l)}W^{(l)})
$$

where $h_i^{(l+1)}$ is the output feature, $\sigma$ is the activation function, $e_{ij}$ is the edge weight, $W^{(l)}$ is the learnable parameters, $b^{(l)
}$ is the bias.

First, you will use the PyTorch-Geometric(PyG) to implement this convolution layer.

In [6]:
import torch
import torch.nn as nn
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import scatter


class PyG_conv(MessagePassing):
    def __init__(self, in_channel, out_channel):
        super(PyG_conv,self).__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.W = nn.Parameter(torch.ones((in_channel, out_channel)))
        self.b = nn.Parameter(torch.ones(out_channel))

    def forward(self, x, edge_index, edge_weight):
        num_nodes = x.shape[0]
        hw = x @ self.W
        out = self.propagate(edge_index, hw=hw, edge_weight=edge_weight, num_nodes=num_nodes)
        return out + self.b

    def message(self, x, edge_weight, hw_j):
        num_edge = hw_j.shape[0]
        ehw = edge_weight.view(num_edge, -1) * hw_j

        return ehw

    def aggregate(self, ehw,num_nodes):
        return scatter(ehw, edge_index[1], dim=0, dim_size=num_nodes, reduce="sum")

You may run the following code to check the correctness.

In [9]:
import numpy as np

edge_index = torch.tensor([[0, 1, 1, 2, 2, 4], [2, 0, 2, 3, 4, 3]])
x = torch.ones((5, 8))
edge_weight = 2 * torch.ones(6)
conv = PyG_conv(8, 4)
output = conv(x, edge_index, edge_weight)
print(output)

tensor([[17., 17., 17., 17.],
        [ 1.,  1.,  1.,  1.],
        [33., 33., 33., 33.],
        [33., 33., 33., 33.],
        [17., 17., 17., 17.]], grad_fn=<AddBackward0>)


In [17]:
assert np.allclose(
    output.detach().numpy(),
    [
        [17.0, 17.0, 17.0, 17.0],
        [1.0, 1.0, 1.0, 1.0],
        [33.0, 33.0, 33.0, 33.0],
        [33.0, 33.0, 33.0, 33.0],
        [17.0, 17.0, 17.0, 17.0],
    ],
),'不一样捏'

完全一样捏