In [2]:
import dgl
import dgl.function as fn
import numpy as np
import torch
import torch.nn as nn


$$
  h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{e_{ji}}{c_{ji}}h_j^{(l)}W^{(l)})
$$

$$
(i.e.,  c_{ji} = \sqrt{|\mathcal{N}(j)|}\sqrt{|\mathcal{N}(i)|})
$$


In [3]:
src = torch.tensor([0, 1, 1, 2, 2, 4])
dst = torch.tensor([2, 0, 2, 3, 4, 3])
node_feat = torch.ones((5, 8))
edge_weight = 2 * torch.ones(6)
g = dgl.graph((src, dst))

g.edata["edge_weight"] = edge_weight
g.ndata["node_feat"] = node_feat


weight = torch.ones(8, 4)
g



Graph(num_nodes=5, num_edges=6,
      ndata_schemes={'node_feat': Scheme(shape=(8,), dtype=torch.float32)}
      edata_schemes={'edge_weight': Scheme(shape=(), dtype=torch.float32)})

### 先W再e

In [12]:
feat_src, feat_dst = dgl.utils.expand_as_pair(node_feat, g)

norm_src = g.out_degrees().clamp(min=1).view(-1, 1)
norm_src = torch.pow(norm_src, -0.5)
feat_src = feat_src * norm_src

norm_dst = g.in_degrees().clamp(min=1).view(-1, 1)
norm_dst = torch.pow(norm_dst, -0.5)

with g.local_scope():
    # 先w再e
    feat_src = torch.matmul(feat_src, weight)
    g.srcdata["u"] = feat_src
    g.update_all(
        message_func=fn.u_mul_e("u", "edge_weight", "ue"),
        reduce_func=fn.sum("ue", "ueN"),
    )
    ueN = g.dstdata["ueN"]
    rst = ueN * norm_dst
    print(rst)

tensor([[11.3137, 11.3137, 11.3137, 11.3137],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [19.3137, 19.3137, 19.3137, 19.3137],
        [19.3137, 19.3137, 19.3137, 19.3137],
        [11.3137, 11.3137, 11.3137, 11.3137]])


### 先e再W

In [13]:
feat_src, feat_dst = dgl.utils.expand_as_pair(node_feat, g)

norm_src = g.out_degrees().clamp(min=1).view(-1, 1)
norm_src = torch.pow(norm_src, -0.5)
feat_src = feat_src * norm_src

norm_dst = g.in_degrees().clamp(min=1).view(-1, 1)
norm_dst = torch.pow(norm_dst, -0.5)

with g.local_scope():
    # 先e再w
    g.srcdata["u"] = feat_src
    g.update_all(
        message_func=fn.u_mul_e("u", "edge_weight", "ue"),
        reduce_func=fn.sum("ue", "ueN"),
    )
    ueN = g.dstdata["ueN"]
    ueN = ueN @ weight
    rst = ueN * norm_dst
    print(rst)

tensor([[11.3137, 11.3137, 11.3137, 11.3137],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [19.3137, 19.3137, 19.3137, 19.3137],
        [19.3137, 19.3137, 19.3137, 19.3137],
        [11.3137, 11.3137, 11.3137, 11.3137]])



先w再e和先e再w的结果是一样的

如果out_feat更小的话可以先w，将维度缩小,减小update负担

边的个数和节点个数不一样相同，所以h和e不能相乘，对于edge_weight的计算只能放在message_func里




In [15]:
with g.local_scope():
    feat = g.ndata["node_feat"]
    norm_src = g.out_degrees().clamp(min=1).view(-1, 1)
    norm_src = torch.pow(norm_src, -0.5)
    norm_dst = g.in_degrees().clamp(min=1).view(-1, 1)
    norm_dst = torch.pow(norm_dst, -0.5)

    norm_dst = g.in_degrees().clamp(min=1).view(-1, 1)
    norm_dst = torch.pow(norm_dst, -0.5)

    feat = norm_src*feat * norm_dst   # 先左右norm
    feat = feat @ weight

    g.srcdata["h"] = feat
    g.update_all(fn.u_mul_e("h", "edge_weight", "e"), fn.sum("e", "rst"))
    print(g.dstdata["rst"])

tensor([[11.3137, 11.3137, 11.3137, 11.3137],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [27.3137, 27.3137, 27.3137, 27.3137],
        [24.0000, 24.0000, 24.0000, 24.0000],
        [ 8.0000,  8.0000,  8.0000,  8.0000]])


GraphConv源码中

norm='left'，则feat在update前就除以出度

norm='right' 则feat在update后，聚合后再除以入度


### 广播机制

In [9]:
import torch

w = torch.rand(5, 4)
l = torch.rand(5).reshape(-1, 1)
r = torch.rand(5).reshape(-1, 1)

print(l * w * r)
print(l * r * w)
print(r * w * l)
print(w * l * r)
print(w * r * l)

tensor([[0.3448, 0.1653, 0.1128, 0.3524],
        [0.0084, 0.0015, 0.0096, 0.0123],
        [0.2084, 0.2699, 0.1647, 0.0399],
        [0.0158, 0.0080, 0.0060, 0.0206],
        [0.0060, 0.0685, 0.0762, 0.0077]])
tensor([[0.3448, 0.1653, 0.1128, 0.3524],
        [0.0084, 0.0015, 0.0096, 0.0123],
        [0.2084, 0.2699, 0.1647, 0.0399],
        [0.0158, 0.0080, 0.0060, 0.0206],
        [0.0060, 0.0685, 0.0762, 0.0077]])
tensor([[0.3448, 0.1653, 0.1128, 0.3524],
        [0.0084, 0.0015, 0.0096, 0.0123],
        [0.2084, 0.2699, 0.1647, 0.0399],
        [0.0158, 0.0080, 0.0060, 0.0206],
        [0.0060, 0.0685, 0.0762, 0.0077]])
tensor([[0.3448, 0.1653, 0.1128, 0.3524],
        [0.0084, 0.0015, 0.0096, 0.0123],
        [0.2084, 0.2699, 0.1647, 0.0399],
        [0.0158, 0.0080, 0.0060, 0.0206],
        [0.0060, 0.0685, 0.0762, 0.0077]])
tensor([[0.3448, 0.1653, 0.1128, 0.3524],
        [0.0084, 0.0015, 0.0096, 0.0123],
        [0.2084, 0.2699, 0.1647, 0.0399],
        [0.0158, 0.0080, 0.006