# Message Passing

本章節會介紹如何使在 DGL 建立 Message Passing 和 GNN 層

同樣地，我們先檢查 cuda 版本和安裝 dgl 庫

In [1]:
!nvidia-smi

Thu Jul 13 19:37:48 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   54C    P8    10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!pip install  dgl -q -f https://data.dgl.ai/wheels/cu118/repo.html
!pip install  dglgo -q -f https://data.dgl.ai/wheels-test/repo.html

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 MB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.5/63.5 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.2/45.2 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.4/52.4 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m112.2/112.2 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.8/78.8 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.4/29.4 MB[0m [31m52.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
import os
os.environ["DGLBACKEND"] = "pytorch"

import dgl
import torch
import torch.nn.functional as F

## What is Message Passing?

Message Passing 是一種神經網路的信息傳遞機制，透過信息傳遞，我們可以在網路中進行信息的共享，在 CV 或 NLP 中，我們可以常常見到類似的概念。Message Passing 主要包含兩步驟

1. 從周圍的點計算信息
2. 將計算的信息進行匯聚

在 CV 中的 CNN 模塊也可以看做是一個 Message Passing 機制，對於每一個 pixel，如果我們使用 3*3 的捲積核，我們會根據濾波器的權重將該 pixel 周圍八個相鄰點的特徵做加權平均，作為這個 pixel 新的特徵，而該特徵為聚合了周圍所有鄰居的信息

在圖上也一樣，我們可以透過相同的方式做到一樣的事情，我們把一個點的周圍相鄰點 (點向量) 信息做匯聚，但是在這裡，匯聚的權重通常使用 mean, sum, max 等方式。所以**圖上的信息傳遞其實可以看成 CNN 上的運算**，只是每個 pixel 考慮的權重一樣而已

![](https://hackmd.io/_uploads/HJR15lSSh.png)

### Graph Embeddings

在進入到真正的 Message Passsing 之前，我們先來講講怎麼做圖的 embeddings。回想我們對於點、邊和圖的信息使用矩陣的方式儲存，並用列表儲存圖的連接性，理所當然，我們可以用參數矩陣來對圖的信息做 embeddings。

我們將分為三個面向討論 GNN，我們假設第 $t$ 層的對於點、邊和圖的參數矩陣為 $(W^{(t)}_v, W^{(t)}_e, W^{(t)}_g)$，embeddings 向量為 $(v^{(t)}, e^{(t)}, g^{(t)})$，則我們可以通過一種簡單的現性變換得到圖的 embeddings，具體透過下面的方式可以得到

$$
\begin{align*}
v^{(t+1)}&=W^{(t)}_vv^{(t)} \\
e^{(t+1)}&=W^{(t)}_ee^{(t)} \\
g^{(t+1)}&=W^{(t)}_gg^{(t)} \\
\end{align*}
$$

其中 $(W^{(t)}_v, W^{(t)}_e, W^{(t)}_g)$ 可以用梯度下降找出。這就是最基本的 GNN 模塊，我們可以使用該模塊得到圖的潛在信息表徵，這類似於 CV 中的 CNN backbone 用來提取影像特徵，或者是 NLP 中的 BERT backbone 用來提取文字特徵。GNN layer 提取到圖上的特徵之後，我們可以根據提取到的特徵做我們感興趣的**下游任務**

#### Information sharing

注意到在我們剛剛的討論中完全沒有用到圖的優勢，換句話說，$(v^{(t)}, e^{(t)}, g^{(t)})$ 完全沒有進行信息交互。要怎麼進行信息交互呢? 這裡就可以提到 message passing 的精神了。假設我們想要得到 $v^{(t+1)}$ 的 embeddings，我們可以把與 $v$ 相鄰的點考慮進來，記做 $\mathcal{N}(v)$，我們通過一個匯聚函數 $\rho$ 把周圍信息匯聚起來，在通過更新函數 $\psi$ 將原本的 $v^{(t)}$ 做更新，具體如下

$$
\begin{align*}
m_v^{(t+1)}&=\{W^{(t)}_vv^{(t)}| v\in \mathcal{N}(v)\}\\
v^{(t+1)}&=\psi(v^{(t)}, \rho(m_v^{(t+1)}))
\end{align*}
$$

$\rho$ 函數通常會選擇 mean, sum 或 max 等方式，為了方便起見，我們把上述整個流程用 $\rho_{V\to V}$ 表示

是不是覺得上面這個公式很複雜，但這還算是基本的匯聚方式，我們大可以考慮更複雜的情況，上面的例子是匯聚點的鄰居，同樣地，也可以把邊和圖的信息匯聚過來。聽起來很容易吧，但是，當我的點向量 embeddings dimension 和邊向量 embeddings 不一致怎麼辦?

其實這個問題我們應該已經遇過很多次了，回想在 CV 或者 NLP 中我們是怎麼處理的，解決方法很簡單，一種方法是做向量拼接，不過這裡我們介紹另一種方法，只需要另外設定一個參數矩陣用來做現性變換即可，我們記做 $W_{e\to v}^{(t)}$，具體方法如下

$$
\begin{align*}
e^{(t)}&=W_{e\to v}^{(t)}W^{(t)}_ee^{(t)} \\
m_v^{(t+1)}&=\{W^{(t)}_vv^{(t)}| v\in \mathcal{N}(v)\}\cup\{e^{(t)}|e\in\mathcal{N}(v)\}\\
v^{(t+1)}&=\psi(v^{(t)}, \rho(m_v^{(t+1)}))
\end{align*}
$$

我們把整個過程記做 $\rho_{E\to V}$ 代表把邊的信息匯聚到點上，同樣的我們可以有 $\rho_{E\to E}, \rho_{E\to G}, \rho_{G\to G}, \rho_{V\to E}, \rho_{V\to G}, \rho_{E\to V}, \rho_{V\to V}, \rho_{G\to V}$ 等方式，更複雜的說，在一層 GNN 中，我們可以有好幾次的信息傳遞，或者傳遞有先後順序，如下圖

![](https://hackmd.io/_uploads/ry2BmGBSh.png)

我們可以隨意的傳遞信息

## Message Passing in DGL

我們接下來來實戰 DGL 中的 Message Passing。假設我們想知道點 $v$ 在 $t+1$ 層的 embeddings vector $x_v\in \mathbb{R}^{d_1}$，邊 $(u, v)$ 的特徵 $w_e\in\mathbb{R}^{d_2}$，則 DGL 中定義的消息傳遞泛式為

$$
m_e^{(t+1)}=\phi(x_v^{(t)}, x_u^{(t)}, w_e^{(t)}), (u,v,e)\in\mathcal{E}\\
x_v^{(t+1)}=\psi\{x_v^{(t)},\rho(\{m_e^{(t+1)}:(u,v,e)\in\mathcal{E}\})\}
$$

在上面的公式中

- $\phi:$ 定義在每條邊上的消息函數，通過邊兩端的端點來結合生成消息
  - 參數 `edges` 為 `EdgeBatch` 實例
  - `edges` 有 `src`, `dst`, `data` 三個屬性，`data` 表示邊的特徵
- $\rho:$ 聚合函數來聚合節點的消息
  - 參數 `nodes` 為 `NodeBatch` 實例
  - `nodes` 有 `mailbox` 儲存節點收到的信息
- $\psi:$ 更新函數結合更新後的消息和本身的信息
  - 參數 `nodes` 同上，作用在 $\rho$ 函數之後

當然我們可以自己定義我們的 $\phi$ 如下

In [4]:
def message_func(edges):
    return {'m': edges.src['hu'] + edges.dst['hv']}

該函數定義了將邊的 src 特徵 `hu` 和 dst 特徵 `hv` 做加總

同樣地，也可以定義 $\rho$

In [5]:
def reduce_func(nodes):
    return {'h': torch.sum(nodes.mailbox['m'], dim=1)}

有了上面的兩個函數 (dgl 不推薦使用者定義 $\psi$ 函數)，我們就可以把整個消息傳遞函數透過 `update_all` API 寫出來了

In [6]:
import dgl.function as fn
def update_all_example(graph):
    # 在graph.ndata['ft']中存储结果
    graph.update_all(message_func, reduce_func)
    # 在update_all外调用更新函数
    final_ft = graph.ndata['h'] * 2
    return final_ft

執行完上述過程後，dgl 會自動把消息 $m$ 歸零重置。當然如果我們不想使用 message passing，只想要讓資料在邊上傳遞，不進行與點或者圖的交互，我們可以使用 `apply_edges(func)`，其中 `func` 為 $\phi$ 函數，也就是要對每條邊做什麼變換，這個 API 會對所有的邊進行更新，用法如下

In [7]:
# g.apply_edges(fn.u_add_v('el', 'er', 'e'))
# It is equivalent to
# def message_func(edges):
#      return {'e': edges.src['el'] + edges.dst['er']}

在 dgl 中

- `u:` 即 src
- `v:` 即 dst
- `e:` 即 (src, dst)

`u_add_v` 接口表示把 src 特徵 `el` 的值和 dst 特徵 `er` 做和，並命名為 `e`

Note: 最好是使用 `u_func_v` 這種接口 API，會比自訂義函數來的高效

我們來看一段簡單的例子，幫我們更了解 dgl 消息傳遞機制

In [41]:
g = dgl.graph((
    [1, 3, 5, 0, 4, 2, 3, 3, 4, 5],
    [1, 1, 0, 0, 1, 2, 2, 0, 3, 3]
  ))
g.edata['eid'] = torch.arange(10)
def reducer(nodes):
    print(nodes.nodes())
    print(nodes.mailbox['eid'])
    return {'n': nodes.mailbox['eid'].sum(1)}
g.update_all(fn.copy_e('eid', 'eid'), reducer)
g.ndata

tensor([2, 3])
tensor([[5, 6],
        [8, 9]])
tensor([0, 1])
tensor([[2, 3, 7],
        [0, 1, 4]])


{'n': tensor([12,  5, 11, 17,  0,  0])}

In [39]:
g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
g.ndata['x'] = torch.ones(5, 2)
g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'h'))
g.ndata['h']

tensor([[0., 0.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])

綜合來講，dgl 把消息傳遞機制分為三個步驟

1. 定義 $\phi$
2. 定義 $\rho$
3. 定義 $\psi$

In [9]:
# 定義 ϕ 函數
def message_func(edges):
    return {'msg': edges.src['feat']}

# 定義 ρ 函數
def reduce_func(nodes):
    return {'agg_feat': torch.sum(nodes.mailbox['msg'], dim=1)}

# 定義 Message Passing 函數
def message_passing(g, node_feats):
    # 將節點特徵設置為圖的節點屬性
    g.ndata['feat'] = node_feats

    # 執行 Message Passing
    g.update_all(message_func, reduce_func)

    # 定義 ψ 函數
    updated_feats = g.ndata['agg_feat']

    return updated_feats


### Efficient way to pass message

我們現在已經知道怎麼在圖上進行消息傳遞了，但是這種方法通常不是最高效的

假設我現在有一個任務如下

- $\phi(u,v): \text{cat}(u,v)$，拼接 src 和 dst 特徵
- 只在邊上做計算
- 做 edge classification，目標為 3 類

這個任務可以用以下方式實現

In [10]:
import torch
import torch.nn as nn

node_feat_dim = 3
out_dim = 3
g = dgl.graph([(0,1), (1,2), (3,5)])
g.ndata["feat"] = torch.randn(6, 3)

linear = nn.Parameter(torch.FloatTensor(size=(node_feat_dim * 2, out_dim)))
def concat_message_function(edges):
     return {'cat_feat': torch.cat([edges.src['feat'], edges.dst['feat']], dim=1)}
g.apply_edges(concat_message_function)
g.edata['out'] = g.edata['cat_feat'] @ linear

但是這種方法不是高效的，比較好的方法是進行矩陣拆解的運算

$$
W\times\text{cat}(u,v) = W_l\times u+W_r\times v
$$

其中 $W_l$ 為 $W$ 的左半部分，$W_r$ 為右半部分

In [11]:
linear_src = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))
linear_dst = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))
out_src = g.ndata['feat'] @ linear_src
out_dst = g.ndata['feat'] @ linear_dst
g.srcdata.update({'out_src': out_src})
g.dstdata.update({'out_dst': out_dst})
g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out'))

這兩種方法是等價的，後面這種由於內存占用較小，因此是比較高效的

### DGL on Heterograph

接下來我們介紹怎麼在異構圖中做消息傳遞，主要分為兩個步驟

1. 對不同類型的 edge 計算聚合消息
2. 對每個節點具和不同類型 edge 的消息

在異構圖中，我們使用 API `multi_update_all`，該方法實現邏輯與 `update_all` 近乎相同，唯獨輸入為一個字典，必須對於不同類型的點或邊給與不同的 message passing 方法

In [12]:
import dgl
import dgl.function as fn
import torch

def message_passing(g, node_feats, weight):
    funcs = {}
    for srctype, etype, dsttype in g.canonical_etypes:
        Wh = weight[etype](node_feats[srctype])
        g.nodes[srctype].data['Wh_%s' % etype] = Wh
        funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))

    g.multi_update_all(funcs, 'sum')

    updated_feats = {ntype: g.nodes[ntype].data['h'] for ntype in g.ntypes}
    return updated_feats

# 創建異質圖
g = dgl.heterograph({
    ('user', 'follows', 'user'): ([0, 1, 2, 3], [1, 2, 3, 4]),
    ('user', 'likes', 'item'): ([0, 1, 2, 3], [0, 1, 2, 3])
})

# 創建節點特徵和權重
node_feats = {
    'user': torch.tensor([[0.1], [0.2], [0.3], [0.4], [0.5]]),
    'item': torch.tensor([[1.0], [2.0], [3.0], [4.0]])
}

weight = {
    'follows': torch.nn.Linear(1, 1),
    'likes': torch.nn.Linear(1, 1)
}

# 執行消息傳遞
updated_feats = message_passing(g, node_feats, weight)

# 輸出更新後的節點特徵
for ntype, feats in updated_feats.items():
    print(f"{ntype} features:")
    print(feats)


item features:
tensor([[-0.8167],
        [-0.7368],
        [-0.6568],
        [-0.5768]], grad_fn=<DivBackward0>)
user features:
tensor([[0.0000],
        [0.4501],
        [0.3697],
        [0.2893],
        [0.2089]], grad_fn=<DivBackward0>)


接下來的章節，我們將會介紹怎麼使用 message passing 建立 GNN 模塊

## DGL GNN 模塊構造

在 DGL GNN 模塊中，我們可以用不同後端搭建模型，dgl 除了消息傳遞方式之外，其他對於圖的操作基本上與深度學習框架是一致的，在這篇文章中，我們主要使用 Pytorch 作為後端引擎

要構建自己的 GNN 模況，用戶必須指定以下幾個參數

- `__init__:` python 中物件的內置魔法方法，該方法會在類別被實例化時執行，通常用於定義 GNN 模型架構
- `forward(self, graph, feat):` 指定輸入 GNN 模塊的資料，預設有兩個輸入
  - `graph:` `dgl.DGLGraph` 實例，表示圖的結構信息
  - `feat:` `tensor`，表示點的特徵



In [13]:
import torch.nn as nn
from dgl.utils import expand_as_pair, check_eq_shape

class SAGEConv(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 feat_drop=0.,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()

        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.feat_drop = nn.Dropout(feat_drop)
        self.activation = activation
        # aggregator type: mean/pool/lstm/gcn
        if aggregator_type == 'pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type != 'gcn':
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
        # 初始化參數
        self.reset_parameters()

    def reset_parameters(self):
        """初始化参数
        gain 為不同 activation function 的建議起始值
        """
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

    def _lstm_reducer(self, nodes):
        """LSTM reducer
        NOTE(zihao): lstm reducer with default schedule (degree bucketing)
        is slow, we could accelerate this with degree padding in the future.
        """
        m = nodes.mailbox['m'] # (B, L, D)
        batch_size = m.shape[0]
        h = (m.new_zeros((1, batch_size, self._in_src_feats)),
             m.new_zeros((1, batch_size, self._in_src_feats)))
        _, (rst, _) = self.lstm(m, h)
        return {'neigh': rst.squeeze(0)}

    def forward(self, graph, feat):
        """ SAGE 模塊的前向傳播
        graph: DGLGraph 實例
        feat: pytorch tensor, 表示點的特徵
        """
        # local_var 會創建一個局部的圖實例，這個實例與原圖相同，但在圖上修信息不會影像到
        # 全局的圖
        # 該方法可以用 with graph.local_scope: 代替
        graph = graph.local_var()

        if isinstance(feat, tuple):
            # 若是包含 src, dst 信息的點，我們分別對點做 dropout
            feat_src = self.feat_drop(feat[0])
            feat_dst = self.feat_drop(feat[1])
        else:
            # 若是僅含點的信息，我們對點做 dropout
            feat_src = feat_dst = self.feat_drop(feat)

        h_self = feat_dst

        # copy_src 把源節點的特徵複製到目標節點上
        if self._aggre_type == 'mean':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_src('h', 'm'), fn.mean('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'gcn':
            check_eq_shape(feat)
            graph.srcdata['h'] = feat_src
            graph.dstdata['h'] = feat_dst     # same as above if homogeneous
            graph.update_all(fn.copy_src('h', 'm'), fn.sum('m', 'neigh'))
            # divide in_degrees
            degs = graph.in_degrees().to(feat_dst)
            h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
        elif self._aggre_type == 'pool':
            graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
            graph.update_all(fn.copy_src('h', 'm'), fn.max('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'lstm':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_src('h', 'm'), self._lstm_reducer)
            h_neigh = graph.dstdata['neigh']
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

        # GraphSAGE GCN does not require fc_self.
        if self._aggre_type == 'gcn':
            rst = self.fc_neigh(h_neigh)
        else:
            rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
        # activation
        if self.activation is not None:
            rst = self.activation(rst)
        # normalization
        if self.norm is not None:
            rst = self.norm(rst)
        return rst

In [14]:
class GraphSAGE(nn.Module):
    def __init__(self,
                 g,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout,
                 aggregator_type):
        super(GraphSAGE, self).__init__()
        self.layers = nn.ModuleList()
        self.g = g
        # input layer
        self.layers.append(SAGEConv(in_feats, n_hidden, aggregator_type,
                                    feat_drop=dropout, activation=activation))
        # hidden layers
        for i in range(n_layers - 1):
            self.layers.append(SAGEConv(n_hidden, n_hidden, aggregator_type,
                                        feat_drop=dropout, activation=activation))
        # output layer
        self.layers.append(SAGEConv(n_hidden, n_classes, aggregator_type,
                                    feat_drop=dropout, activation=None)) # activation None

    def forward(self, features):
        h = features
        for layer in self.layers:
            h = layer(self.g, h)
        return h

接下來我們看看在異構圖上的 GNN 模塊怎麼定義，我們具體會有以下幾種步驟

1. 為不同關係定義不同的 GNN 模塊
2. 聚合不同關係的結果

In [15]:
"""Heterograph NN modules"""
from functools import partial

import torch as th
import torch.nn as nn

from dgl.base import DGLError

__all__ = ["HeteroGraphConv", "HeteroLinear", "HeteroEmbedding"]

class HeteroGraphConv(nn.Module):
    r"""A generic module for computing convolution on heterogeneous graphs.

    The heterograph convolution applies sub-modules on their associating
    relation graphs, which reads the features from source nodes and writes the
    updated ones to destination nodes. If multiple relations have the same
    destination node types, their results are aggregated by the specified method.
    If the relation graph has no edge, the corresponding module will not be called.

    Pseudo-code:

    .. code::

        outputs = {nty : [] for nty in g.dsttypes}
        # Apply sub-modules on their associating relation graphs in parallel
        for relation in g.canonical_etypes:
            stype, etype, dtype = relation
            dstdata = relation_submodule(g[relation], ...)
            outputs[dtype].append(dstdata)

        # Aggregate the results for each destination node type
        rsts = {}
        for ntype, ntype_outputs in outputs.items():
            if len(ntype_outputs) != 0:
                rsts[ntype] = aggregate(ntype_outputs)
        return rsts

    Examples
    --------

    Create a heterograph with three types of relations and nodes.

    >>> import dgl
    >>> g = dgl.heterograph({
    ...     ('user', 'follows', 'user') : edges1,
    ...     ('user', 'plays', 'game') : edges2,
    ...     ('store', 'sells', 'game')  : edges3})

    Create a ``HeteroGraphConv`` that applies different convolution modules to
    different relations. Note that the modules for ``'follows'`` and ``'plays'``
    do not share weights.

    >>> import dgl.nn.pytorch as dglnn
    >>> conv = dglnn.HeteroGraphConv({
    ...     'follows' : dglnn.GraphConv(...),
    ...     'plays' : dglnn.GraphConv(...),
    ...     'sells' : dglnn.SAGEConv(...)},
    ...     aggregate='sum')

    Call forward with some ``'user'`` features. This computes new features for both
    ``'user'`` and ``'game'`` nodes.

    >>> import torch as th
    >>> h1 = {'user' : th.randn((g.num_nodes('user'), 5))}
    >>> h2 = conv(g, h1)
    >>> print(h2.keys())
    dict_keys(['user', 'game'])

    Call forward with both ``'user'`` and ``'store'`` features. Because both the
    ``'plays'`` and ``'sells'`` relations will update the ``'game'`` features,
    their results are aggregated by the specified method (i.e., summation here).

    >>> f1 = {'user' : ..., 'store' : ...}
    >>> f2 = conv(g, f1)
    >>> print(f2.keys())
    dict_keys(['user', 'game'])

    Call forward with some ``'store'`` features. This only computes new features
    for ``'game'`` nodes.

    >>> g1 = {'store' : ...}
    >>> g2 = conv(g, g1)
    >>> print(g2.keys())
    dict_keys(['game'])

    Call forward with a pair of inputs is allowed and each submodule will also
    be invoked with a pair of inputs.

    >>> x_src = {'user' : ..., 'store' : ...}
    >>> x_dst = {'user' : ..., 'game' : ...}
    >>> y_dst = conv(g, (x_src, x_dst))
    >>> print(y_dst.keys())
    dict_keys(['user', 'game'])

    Parameters
    ----------
    mods : dict[str, nn.Module]
        Modules associated with every edge types. The forward function of each
        module must have a `DGLGraph` object as the first argument, and
        its second argument is either a tensor object representing the node
        features or a pair of tensor object representing the source and destination
        node features.
    aggregate : str, callable, optional
        Method for aggregating node features generated by different relations.
        Allowed string values are 'sum', 'max', 'min', 'mean', 'stack'.
        The 'stack' aggregation is performed along the second dimension, whose order
        is deterministic.
        User can also customize the aggregator by providing a callable instance.
        For example, aggregation by summation is equivalent to the follows:

        .. code::

            def my_agg_func(tensors, dsttype):
                # tensors: is a list of tensors to aggregate
                # dsttype: string name of the destination node type for which the
                #          aggregation is performed
                stacked = torch.stack(tensors, dim=0)
                return torch.sum(stacked, dim=0)

    Attributes
    ----------
    mods : dict[str, nn.Module]
        Modules associated with every edge types.
    """

    def __init__(self, mods, aggregate="sum"):
        super(HeteroGraphConv, self).__init__()
        self.mod_dict = mods
        mods = {str(k): v for k, v in mods.items()}
        # Register as child modules
        self.mods = nn.ModuleDict(mods)
        # PyTorch ModuleDict doesn't have get() method, so I have to store two
        # dictionaries so that I can index with both canonical edge type and
        # edge type with the get() method.
        # Do not break if graph has 0-in-degree nodes.
        # Because there is no general rule to add self-loop for heterograph.
        for _, v in self.mods.items():
            set_allow_zero_in_degree_fn = getattr(
                v, "set_allow_zero_in_degree", None
            )
            if callable(set_allow_zero_in_degree_fn):
                set_allow_zero_in_degree_fn(True)
        if isinstance(aggregate, str):
            self.agg_fn = get_aggregate_fn(aggregate)
        else:
            self.agg_fn = aggregate

    def _get_module(self, etype):
        mod = self.mod_dict.get(etype, None)
        if mod is not None:
            return mod
        if isinstance(etype, tuple):
            # etype is canonical
            _, etype, _ = etype
            return self.mod_dict[etype]
        raise KeyError("Cannot find module with edge type %s" % etype)

    def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
        """Forward computation

        Invoke the forward function with each module and aggregate their results.

        Parameters
        ----------
        g : DGLGraph
            Graph data.
        inputs : dict[str, Tensor] or pair of dict[str, Tensor]
            Input node features.
        mod_args : dict[str, tuple[any]], optional
            Extra positional arguments for the sub-modules.
        mod_kwargs : dict[str, dict[str, any]], optional
            Extra key-word arguments for the sub-modules.

        Returns
        -------
        dict[str, Tensor]
            Output representations for every types of nodes.
        """
        if mod_args is None:
            mod_args = {}
        if mod_kwargs is None:
            mod_kwargs = {}
        outputs = {nty: [] for nty in g.dsttypes}
        if isinstance(inputs, tuple) or g.is_block:
            if isinstance(inputs, tuple):
                src_inputs, dst_inputs = inputs
            else:
                src_inputs = inputs
                dst_inputs = {
                    k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()
                }

            for stype, etype, dtype in g.canonical_etypes:
                rel_graph = g[stype, etype, dtype]
                if stype not in src_inputs or dtype not in dst_inputs:
                    continue
                dstdata = self._get_module((stype, etype, dtype))(
                    rel_graph,
                    (src_inputs[stype], dst_inputs[dtype]),
                    *mod_args.get(etype, ()),
                    **mod_kwargs.get(etype, {})
                )
                outputs[dtype].append(dstdata)
        else:
            for stype, etype, dtype in g.canonical_etypes:
                rel_graph = g[stype, etype, dtype]
                if stype not in inputs:
                    continue
                dstdata = self._get_module((stype, etype, dtype))(
                    rel_graph,
                    (inputs[stype], inputs[dtype]),
                    *mod_args.get(etype, ()),
                    **mod_kwargs.get(etype, {})
                )
                outputs[dtype].append(dstdata)
        rsts = {}
        for nty, alist in outputs.items():
            if len(alist) != 0:
                rsts[nty] = self.agg_fn(alist, nty)
        return rsts



def _max_reduce_func(inputs, dim):
    return th.max(inputs, dim=dim)[0]


def _min_reduce_func(inputs, dim):
    return th.min(inputs, dim=dim)[0]


def _sum_reduce_func(inputs, dim):
    return th.sum(inputs, dim=dim)


def _mean_reduce_func(inputs, dim):
    return th.mean(inputs, dim=dim)


def _stack_agg_func(inputs, dsttype):  # pylint: disable=unused-argument
    if len(inputs) == 0:
        return None
    return th.stack(inputs, dim=1)


def _agg_func(inputs, dsttype, fn):  # pylint: disable=unused-argument
    if len(inputs) == 0:
        return None
    stacked = th.stack(inputs, dim=0)
    return fn(stacked, dim=0)


def get_aggregate_fn(agg):
    """Internal function to get the aggregation function for node data
    generated from different relations.

    Parameters
    ----------
    agg : str
        Method for aggregating node features generated by different relations.
        Allowed values are 'sum', 'max', 'min', 'mean', 'stack'.

    Returns
    -------
    callable
        Aggregator function that takes a list of tensors to aggregate
        and returns one aggregated tensor.
    """
    if agg == "sum":
        fn = _sum_reduce_func
    elif agg == "max":
        fn = _max_reduce_func
    elif agg == "min":
        fn = _min_reduce_func
    elif agg == "mean":
        fn = _mean_reduce_func
    elif agg == "stack":
        fn = None  # will not be called
    else:
        raise DGLError(
            "Invalid cross type aggregator. Must be one of "
            '"sum", "max", "min", "mean" or "stack". But got "%s"' % agg
        )
    if agg == "stack":
        return _stack_agg_func
    else:
        return partial(_agg_func, fn=fn)

class HeteroLinear(nn.Module):
    """Apply linear transformations on heterogeneous inputs.

    Parameters
    ----------
    in_size : dict[key, int]
        Input feature size for heterogeneous inputs. A key can be a string or a tuple of strings.
    out_size : int
        Output feature size.
    bias : bool, optional
        If True, learns a bias term. Defaults: ``True``.

    Examples
    --------

    >>> import dgl
    >>> import torch
    >>> from dgl.nn import HeteroLinear

    >>> layer = HeteroLinear({'user': 1, ('user', 'follows', 'user'): 2}, 3)
    >>> in_feats = {'user': torch.randn(2, 1), ('user', 'follows', 'user'): torch.randn(3, 2)}
    >>> out_feats = layer(in_feats)
    >>> print(out_feats['user'].shape)
    torch.Size([2, 3])
    >>> print(out_feats[('user', 'follows', 'user')].shape)
    torch.Size([3, 3])
    """

    def __init__(self, in_size, out_size, bias=True):
        super(HeteroLinear, self).__init__()

        self.linears = nn.ModuleDict()
        for typ, typ_in_size in in_size.items():
            self.linears[str(typ)] = nn.Linear(typ_in_size, out_size, bias=bias)

    def forward(self, feat):
        """Forward function

        Parameters
        ----------
        feat : dict[key, Tensor]
            Heterogeneous input features. It maps keys to features.

        Returns
        -------
        dict[key, Tensor]
            Transformed features.
        """
        out_feat = dict()
        for typ, typ_feat in feat.items():
            out_feat[typ] = self.linears[str(typ)](typ_feat)

        return out_feat

class HeteroEmbedding(nn.Module):
    """Create a heterogeneous embedding table.

    It internally contains multiple ``torch.nn.Embedding`` with different dictionary sizes.

    Parameters
    ----------
    num_embeddings : dict[key, int]
        Size of the dictionaries. A key can be a string or a tuple of strings.
    embedding_dim : int
        Size of each embedding vector.

    Examples
    --------

    >>> import dgl
    >>> import torch
    >>> from dgl.nn import HeteroEmbedding

    >>> layer = HeteroEmbedding({'user': 2, ('user', 'follows', 'user'): 3}, 4)
    >>> # Get the heterogeneous embedding table
    >>> embeds = layer.weight
    >>> print(embeds['user'].shape)
    torch.Size([2, 4])
    >>> print(embeds[('user', 'follows', 'user')].shape)
    torch.Size([3, 4])

    >>> # Get the embeddings for a subset
    >>> input_ids = {'user': torch.LongTensor([0]),
    ...              ('user', 'follows', 'user'): torch.LongTensor([0, 2])}
    >>> embeds = layer(input_ids)
    >>> print(embeds['user'].shape)
    torch.Size([1, 4])
    >>> print(embeds[('user', 'follows', 'user')].shape)
    torch.Size([2, 4])
    """

    def __init__(self, num_embeddings, embedding_dim):
        super(HeteroEmbedding, self).__init__()

        self.embeds = nn.ModuleDict()
        self.raw_keys = dict()
        for typ, typ_num_rows in num_embeddings.items():
            self.embeds[str(typ)] = nn.Embedding(typ_num_rows, embedding_dim)
            self.raw_keys[str(typ)] = typ

    @property
    def weight(self):
        """Get the heterogeneous embedding table

        Returns
        -------
        dict[key, Tensor]
            Heterogeneous embedding table
        """
        return {
            self.raw_keys[typ]: emb.weight for typ, emb in self.embeds.items()
        }

    def reset_parameters(self):
        """
        Use the xavier method in nn.init module to make the parameters uniformly distributed
        """
        for typ in self.embeds.keys():
            nn.init.xavier_uniform_(self.embeds[typ].weight)

    def forward(self, input_ids):
        """Forward function

        Parameters
        ----------
        input_ids : dict[key, Tensor]
            The row IDs to retrieve embeddings. It maps a key to key-specific IDs.

        Returns
        -------
        dict[key, Tensor]
            The retrieved embeddings.
        """
        embeds = dict()
        for typ, typ_ids in input_ids.items():
            embeds[typ] = self.embeds[str(typ)](typ_ids)

        return embeds
