In [1]:
import dgl
import dgl.function as fn
import torch
import torch.nn as nn

### 图

![image.png](attachment:ce26b6b7-0ec4-4575-9f52-2e31d49352b6.png)

In [2]:
user_feat = torch.FloatTensor([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]])
item_feat = torch.FloatTensor([[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]])
firm_feat = torch.FloatTensor([[0, 0, 0], [1, 1, 1]])

follow_feat = torch.ones(5)  # [,1]
like_feat = torch.arange(6)  # [,1]
belong_feat = torch.eye(5)  # [,5]

follow_edge = [
    [0, 0, 1, 2, 2],
    [1, 2, 2, 3, 4],
]
like_edge = [
    [0, 0, 0, 2, 3, 4],
    [0, 1, 2, 2, 3, 2],
]
belong_edge = [
    [0, 1, 2, 3, 4],
    [1, 1, 0, 1, 0],
]
fit_edge = [[0, 1, 3, 4], [1, 0, 4, 3]]


firm_label = torch.LongTensor([0, 1])
item_label = torch.LongTensor([1, 1, 0, 1, 0])
user_label = torch.LongTensor([3, 0, 1, 1, 1])

In [3]:
hg = dgl.heterograph(
    {
        ("item", "belong", "firm"): (belong_edge[0], belong_edge[1]),
        ("user", "like", "item"): (like_edge[0], like_edge[1]),
        ("item", "fit", "item"): (fit_edge[0], fit_edge[1]),
        ("user", "follow", "user"): (follow_edge[0], follow_edge[1]),
    }
)


hg.nodes["user"].data["feat"] = user_feat
hg.nodes["item"].data["feat"] = item_feat
hg.nodes["firm"].data["feat"] = firm_feat
hg.nodes["user"].data["label"] = user_label
hg.nodes["item"].data["label"] = item_label
hg.nodes["firm"].data["label"] = firm_label

hg.edges["follow"].data["feat"] = follow_feat
hg.edges["like"].data["feat"] = like_feat
hg.edges["belong"].data["feat"] = belong_feat

# add self loop需要指定边，可以选择fill_data方法，sum mean ...
hg = hg.add_self_loop(etype=("user", "follow", "user"), fill_data="sum")
# hg = hg.add_self_loop(etype=("user", "follow", "user"), fill_data="sum")

mods针对不同边类型定义模型
```python
 mods : dict[str, nn.Module]
        Modules associated with every edge types. The forward function of each
        module must have a `DGLGraph` object as the first argument, and
        its second argument is either a tensor object representing the node
        features or a pair of tensor object representing the source and destination
        node features.
```

### 使用现成的层

In [4]:
# weight=False就是简单相加
mods = {
    ("user", "follow", "user"): dgl.nn.GraphConv(
        3,
        1,
        norm="none",
        bias=False,
        weight=False,
        allow_zero_in_degree=True,
    ),
    ("user", "like", "item"): dgl.nn.GraphConv(
        3, 2, norm="none", bias=False, weight=False, allow_zero_in_degree=True
    ),
    ("item", "fit", "item"): dgl.nn.GraphConv(
        3, 2, norm="none", bias=False, weight=False, allow_zero_in_degree=True
    ),
    ("item", "belong", "firm"): dgl.nn.GraphConv(
        3, 3, norm="none", bias=False, weight=False, allow_zero_in_degree=True
    ),
}

根据边关系来传递消息

In [5]:
outputs = {nty: [] for nty in hg.dsttypes}

with torch.no_grad():
    for relation in hg.canonical_etypes:
        srctype, etype, dsttype = relation
        rel_mod = mods[relation]
        print("in", relation)

        out = rel_mod(hg[relation], hg.nodes[srctype].data["feat"])
        print(out)
        outputs[dsttype] = out

in ('item', 'belong', 'firm')
tensor([[6., 6., 6.],
        [4., 4., 4.]])
in ('item', 'fit', 'item')
tensor([[1., 1., 1.],
        [0., 0., 0.],
        [0., 0., 0.],
        [4., 4., 4.],
        [3., 3., 3.]])
in ('user', 'follow', 'user')
tensor([[0., 0., 0.],
        [1., 1., 1.],
        [3., 3., 3.],
        [5., 5., 5.],
        [6., 6., 6.]])
in ('user', 'like', 'item')
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [6., 6., 6.],
        [3., 3., 3.],
        [0., 0., 0.]])


### 自定义层、fn

In [6]:
etypes = ["follow", "like", "fit", "belong"]

In [7]:
in_feats = 3
out_feats = 4

In [8]:
# 定义每个关系类型的权重
weight = nn.ModuleDict(
    {name: nn.Linear(in_feats, out_feats, bias=False) for name in etypes}
)

nn.init.ones_(weight.belong.weight)
nn.init.ones_(weight.like.weight)
nn.init.ones_(weight.follow.weight)
nn.init.ones_(weight.fit.weight)


print(weight)

ModuleDict(
  (follow): Linear(in_features=3, out_features=4, bias=False)
  (like): Linear(in_features=3, out_features=4, bias=False)
  (fit): Linear(in_features=3, out_features=4, bias=False)
  (belong): Linear(in_features=3, out_features=4, bias=False)
)


In [9]:
funcs = {}
for c_etype in hg.canonical_etypes:
    srctype, etype, dsttype = c_etype
    Wh = weight[etype](hg.nodes[srctype].data["feat"])
    # 把它存在图中用来做消息传递
    # hg.nodes[srctype].data['Wh_%s' % etype] = Wh
    hg.srcnodes[srctype].data["Wh_%s" % etype] = Wh
    funcs[etype] = (fn.copy_u("Wh_%s" % etype, "m"), fn.sum("m", "h"))

hg.multi_update_all(funcs, "stack")
# sum是整合方式，这个整合方式可以是 sum、 min、 max、 mean 和 stack 中的一个

srcdata
```python
user Wh_follow
 tensor([[ 0.,  0.,  0.,  0.],
        [ 3.,  3.,  3.,  3.],
        [ 6.,  6.,  6.,  6.],
        [ 9.,  9.,  9.,  9.],
        [12., 12., 12., 12.]])
user Wh_like
 tensor([[ 0.,  0.,  0.,  0.],
        [ 3.,  3.,  3.,  3.],
        [ 6.,  6.,  6.,  6.],
        [ 9.,  9.,  9.,  9.],
        [12., 12., 12., 12.]])
item Wh_fit
 tensor([[ 0.,  0.,  0.,  0.],
        [ 3.,  3.,  3.,  3.],
        [ 6.,  6.,  6.,  6.],
        [ 9.,  9.,  9.,  9.],
        [12., 12., 12., 12.]])
```

muti_update_all 使用 sum 整合
```python
user    torch.Size([5, 4])
tensor([[ 0.,  0.,  0.,  0.],
        [ 3.,  3.,  3.,  3.],
        [ 9.,  9.,  9.,  9.],
        [15., 15., 15., 15.],
        [18., 18., 18., 18.]])
item    torch.Size([5, 4])
tensor([[ 3.,  3.,  3.,  3.],
        [ 0.,  0.,  0.,  0.],
        [18., 18., 18., 18.],
        [21., 21., 21., 21.],
        [ 9.,  9.,  9.,  9.]])
firm    torch.Size([2, 4])
tensor([[18., 18., 18., 18.],
        [12., 12., 12., 12.]])
```

muti_update_all 使用 stack 整合, item 能接收两种节点的信息，五个节点，shape是[5,2,4]，.sum(dim=1)相当于sum整合
```python
user    torch.Size([5, 1, 4])
tensor([[[ 0.,  0.,  0.,  0.]],

        [[ 3.,  3.,  3.,  3.]],

        [[ 9.,  9.,  9.,  9.]],

        [[15., 15., 15., 15.]],

        [[18., 18., 18., 18.]]])
item    torch.Size([5, 2, 4])
tensor([[[ 3.,  3.,  3.,  3.],
         [ 0.,  0.,  0.,  0.]],

        [[ 0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.]],

        [[ 0.,  0.,  0.,  0.],
         [18., 18., 18., 18.]],

        [[12., 12., 12., 12.],
         [ 9.,  9.,  9.,  9.]],

        [[ 9.,  9.,  9.,  9.],
         [ 0.,  0.,  0.,  0.]]])
firm    torch.Size([2, 1, 4])
tensor([[[18., 18., 18., 18.]],

        [[12., 12., 12., 12.]]])
```

In [10]:
with torch.no_grad():
    for i in ["user", "item", "firm"]:
        print(i, "  ", hg.ndata["h"][i].shape)
        print(hg.ndata["h"][i].detach())

user    torch.Size([5, 1, 4])
tensor([[[ 0.,  0.,  0.,  0.]],

        [[ 3.,  3.,  3.,  3.]],

        [[ 9.,  9.,  9.,  9.]],

        [[15., 15., 15., 15.]],

        [[18., 18., 18., 18.]]])
item    torch.Size([5, 2, 4])
tensor([[[ 3.,  3.,  3.,  3.],
         [ 0.,  0.,  0.,  0.]],

        [[ 0.,  0.,  0.,  0.],
         [ 0.,  0.,  0.,  0.]],

        [[ 0.,  0.,  0.,  0.],
         [18., 18., 18., 18.]],

        [[12., 12., 12., 12.],
         [ 9.,  9.,  9.,  9.]],

        [[ 9.,  9.,  9.,  9.],
         [ 0.,  0.,  0.,  0.]]])
firm    torch.Size([2, 1, 4])
tensor([[[18., 18., 18., 18.]],

        [[12., 12., 12., 12.]]])


### 自定义message、reduce
#### Message

![image.png](attachment:18ec55a9-642a-4f4f-9978-55f31c03fd25.png)

In [67]:
def myMessage(edges: dgl.udf.EdgeBatch):
    srctype, etype, dsttype = edges.canonical_etype
    if etype == "belong":
        print("belong")
        return {"m": edges.src["Wh_%s" % etype]}
    if etype == "like":
        print("like")
        return {"m": edges.src["Wh_%s" % etype]}
    if etype == "follow":
        print("follow")
        src = edges.src["Wh_%s" % etype]
        dst = edges.dst["Wh_%s" % etype]
        # h=src+dst
        return {"m": src + dst}
    if etype == "fit":
        print("fit")
        src = edges.src["Wh_%s" % etype]
        dst = edges.dst["Wh_%s" % etype]
        # h=src+dst
        return {"m": src + dst}


def myReduce(nodes: dgl.udf.NodeBatch):
    return {'h':nodes.mailbox['m'].sum(1)}
    # print('node type:',nodes.ntype,'  id:',nodes.nodes().item(),'  batch size:',nodes.batch_size())
    # print(nodes.mailbox['m'].shape)
    # print(nodes.mailbox['m'].sum(1))
    # return {'h':nodes.mailbox['m'].sum(1)}


funcs = {}
for c_etype in hg.canonical_etypes:
    srctype, etype, dsttype = c_etype
    Wh = weight[etype](hg.nodes[srctype].data["feat"])
    # 把它存在图中用来做消息传递
    # hg.nodes[srctype].data["Wh_%s" % etype] = Wh
    hg.srcnodes[srctype].data["Wh_%s" % etype] = Wh

    # funcs[etype] = (fn.copy_u("Wh_%s" % etype, "m"), fn.sum("m", "h"))
    funcs[etype] = (myMessage, fn.sum("m", "h"))

hg.multi_update_all(funcs, "stack")

belong
fit
follow
like


In [68]:
print("user Wh_like\n", hg.nodes["user"].data["Wh_like"].detach())
print("item Wh_fit\n", hg.nodes["item"].data["Wh_fit"].detach())

user Wh_like
 tensor([[ 0.,  0.,  0.,  0.],
        [ 3.,  3.,  3.,  3.],
        [ 6.,  6.,  6.,  6.],
        [ 9.,  9.,  9.,  9.],
        [12., 12., 12., 12.]])
item Wh_fit
 tensor([[ 0.,  0.,  0.,  0.],
        [ 3.,  3.,  3.,  3.],
        [ 6.,  6.,  6.,  6.],
        [ 9.,  9.,  9.,  9.],
        [12., 12., 12., 12.]])


In [69]:
print('stack')
print('hg.nodes["item"].data["h"][:,0,:]， 即 fit 的 message策略，两端相加')
print(hg.nodes["item"].data["h"][:, 0, :].detach())
print('hg.nodes["item"].data["h"][:,1,:]， 即 like 的 message 策略，从源点相加')
print(hg.nodes["item"].data["h"][:, 1, :].detach())

stack
hg.nodes["item"].data["h"][:,0,:]， 即 fit 的 message策略，两端相加
tensor([[ 3.,  3.,  3.,  3.],
        [ 3.,  3.,  3.,  3.],
        [ 0.,  0.,  0.,  0.],
        [21., 21., 21., 21.],
        [21., 21., 21., 21.]])
hg.nodes["item"].data["h"][:,1,:]， 即 like 的 message 策略，从源点相加
tensor([[ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [18., 18., 18., 18.],
        [ 9.,  9.,  9.,  9.],
        [ 0.,  0.,  0.,  0.]])


In [70]:
print("stack")
print('user\n',hg.nodes["user"].data["h"].shape,' \n',hg.nodes["user"].data["h"].detach())
print('item\n',hg.nodes["item"].data["h"].shape,' \n',hg.nodes["item"].data["h"].detach())
print('firm\n',hg.nodes["firm"].data["h"].shape,' \n',hg.nodes["firm"].data["h"].detach())

stack
user
 torch.Size([5, 1, 4])  
 tensor([[[ 0.,  0.,  0.,  0.]],

        [[ 9.,  9.,  9.,  9.]],

        [[27., 27., 27., 27.]],

        [[33., 33., 33., 33.]],

        [[42., 42., 42., 42.]]])
item
 torch.Size([5, 2, 4])  
 tensor([[[ 3.,  3.,  3.,  3.],
         [ 0.,  0.,  0.,  0.]],

        [[ 3.,  3.,  3.,  3.],
         [ 0.,  0.,  0.,  0.]],

        [[ 0.,  0.,  0.,  0.],
         [18., 18., 18., 18.]],

        [[21., 21., 21., 21.],
         [ 9.,  9.,  9.,  9.]],

        [[21., 21., 21., 21.],
         [ 0.,  0.,  0.,  0.]]])
firm
 torch.Size([2, 1, 4])  
 tensor([[[18., 18., 18., 18.]],

        [[12., 12., 12., 12.]]])


#### Reduce

In [87]:
def myMessage(edges: dgl.udf.EdgeBatch):
    srctype, etype, dsttype = edges.canonical_etype
    if etype == "belong":
        print('======================================')
        print("belong")
        return {"m": edges.src["Wh_%s" % etype]}
    if etype == "like":
        print('======================================')
        print("like")
        return {"m": edges.src["Wh_%s" % etype]}
    if etype == "follow":
        print('======================================')
        print("follow")
        src = edges.src["Wh_%s" % etype]
        dst = edges.dst["Wh_%s" % etype]
        # h=src+dst
        return {"m": src + dst}
    if etype == "fit":
        print('======================================')
        print("fit")
        src = edges.src["Wh_%s" % etype]
        dst = edges.dst["Wh_%s" % etype]
        # h=src+dst
        return {"m": src + dst}


def myReduce(nodes: dgl.udf.NodeBatch):
    print('node type:',nodes.ntype,'   ids:',nodes.nodes().numpy(),'   batch size:',nodes.batch_size())
    print('shape:',nodes.mailbox['m'].shape)
    print('sum(1):\n',nodes.mailbox['m'].sum(1).detach())
    print('--------------------------------')
    return {'h':nodes.mailbox['m'].sum(1)}


funcs = {}
for c_etype in hg.canonical_etypes:
    srctype, etype, dsttype = c_etype
    Wh = weight[etype](hg.nodes[srctype].data["feat"])
    # 把它存在图中用来做消息传递
    # hg.nodes[srctype].data["Wh_%s" % etype] = Wh
    hg.srcnodes[srctype].data["Wh_%s" % etype] = Wh

    # funcs[etype] = (fn.copy_u("Wh_%s" % etype, "m"), fn.sum("m", "h"))
    # funcs[etype] = (myMessage, fn.sum("m", "h"))
    funcs[etype] = (myMessage, myReduce)

hg.multi_update_all(funcs, "sum")

belong
node type: firm    ids: [0]    batch size: 1
shape: torch.Size([1, 2, 4])
sum(1):
 tensor([[18., 18., 18., 18.]])
--------------------------------
node type: firm    ids: [1]    batch size: 1
shape: torch.Size([1, 3, 4])
sum(1):
 tensor([[12., 12., 12., 12.]])
--------------------------------
fit
node type: item    ids: [0 1 3 4]    batch size: 4
shape: torch.Size([4, 1, 4])
sum(1):
 tensor([[ 3.,  3.,  3.,  3.],
        [ 3.,  3.,  3.,  3.],
        [21., 21., 21., 21.],
        [21., 21., 21., 21.]])
--------------------------------
follow
node type: user    ids: [0]    batch size: 1
shape: torch.Size([1, 1, 4])
sum(1):
 tensor([[0., 0., 0., 0.]])
--------------------------------
node type: user    ids: [1 3 4]    batch size: 3
shape: torch.Size([3, 2, 4])
sum(1):
 tensor([[ 9.,  9.,  9.,  9.],
        [33., 33., 33., 33.],
        [42., 42., 42., 42.]])
--------------------------------
node type: user    ids: [2]    batch size: 1
shape: torch.Size([1, 3, 4])
sum(1):
 tensor([

In [83]:
print("sum")
print(hg.nodes["user"].data["h"].detach())
print(hg.nodes["item"].data["h"].detach())
print(hg.nodes["firm"].data["h"].detach())


sum
tensor([[ 0.,  0.,  0.,  0.],
        [ 9.,  9.,  9.,  9.],
        [27., 27., 27., 27.],
        [33., 33., 33., 33.],
        [42., 42., 42., 42.]])
tensor([[ 3.,  3.,  3.,  3.],
        [ 3.,  3.,  3.,  3.],
        [18., 18., 18., 18.],
        [30., 30., 30., 30.],
        [21., 21., 21., 21.]])
tensor([[18., 18., 18., 18.],
        [12., 12., 12., 12.]])
