In [1]:
import os

os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


一个图神经网络模型由三个重要的部分组成：
  - 消息传递机制
  - 聚合机制
  - 更新机制


### 消息函数 Message Function
消息传递是目标节点$u$通过边$e_{u \rightarrow v}^{(l-1)}$向源节点$v$发送消息，消息的内容是源节点$v$和目标节点$u$的特征表征，消息的传递是通过消息函数$M^{(l)}$来实现的。

\begin{equation}
m_{u \rightarrow v}^{(l)}=M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u \rightarrow v}^{(l-1)}\right)
\end{equation}
其中，$h_v^{(l-1)}$和$h_u^{(l-1)}$分别是源节点$v$和目标节点$u$在第$l-1$层的**节点特征**，$e_{u \rightarrow v}^{(l-1)}$是节点$u$到节点$v$的边在第$l-1$层的**边特征**。

### 聚合函数 Aggregation Function (Reduce Function)
聚合函数是目标节点$v$将所有邻居节点$u$发送过来的消息进行聚合，得到目标节点$v$在第$l$层的消息表示。
\begin{equation}
m_v^{(l)}=\sum_{u \in \mathcal{N}(v)} m_{u \rightarrow v}^{(l)}
\end{equation}
> 注意：这里用求和符号表示聚合，实际上聚合函数可以是任意的函数，比如求平均、求最大等。

### 更新函数 Update Function
用当前节点的特征表示$h_v^{(l-1)}$和聚合得到的消息表示$m_v^{(l)}$，通过函数$U^{(l)}(\cdot)$来更新节点的特征表示$h_v^{(l)}$。
\begin{equation}
h_v^{(l)}=U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right)
\end{equation}

### 我们来手动实现一下GraphSAGE模型
仍然是第一个例子中的节点分类问题。其中，我们将聚合函数设置为：

\begin{equation}
h_{\mathcal{N}(v)}^k \leftarrow \text { Average }\left\{h_u^{k-1}, \forall u \in \mathcal{N}(v)\right\}
\end{equation}

更新函数设置为：
\begin{equation}
h_v^k \leftarrow \operatorname{ReLU}\left(W^k \cdot \operatorname{CONCAT}\left(h_v^{k-1}, h_{\mathcal{N}(v)}^k\right)\right)
\end{equation}

其中$k$为图神经网络层数


In [None]:
'''
单层GraphSAGE模型
'''


class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.

    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """

    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat) #NOTE: 就是更新函数中的线性层
        # 因为最后由concat操作，所以输入维度是in_feat * 2
    def forward(self, g, h):
        """Forward computation

        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        """
        with g.local_scope(): #NOTE: 为了避免对原图的修改，这里使用了local_scope
            g.ndata["h"] = h #NOTE: 为节点添加隐藏特征属性
            
            # update_all is a message passing API.
            #NOTE: updata_all整合了消息函数以及聚合函数
            #NOTE: dgl中源节点是发送的节点，目标节点是接收的节点，在message_func中，被更新的是目标节点
            g.update_all(
                message_func=fn.copy_u("h", "m"), # 复制自身名为'h'的属性给有向图意义下的邻居，各节点接收到的消息为节点的'm'属性
                reduce_func=fn.mean("m", "h_N"), # 各节点将自身名为'm'的属性通过mean操作(求平均)，聚合到自身名为'h_N'的属性中
            )
            h_N = g.ndata["h_N"]
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

In [3]:
'''
将两层GraphSAGE模型组合成一个简单的完整的型

'''
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats)
        self.conv2 = SAGEConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h) #NOTE: 激活函数用于层之间
        h = self.conv2(g, h)
        return h

In [4]:
import dgl.data

dataset = dgl.data.CoraGraphDataset()
g = dataset[0]


def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    all_logits = []
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
    for e in range(200):
        # Forward
        logits = model(g, features)

        # Compute prediction
        pred = logits.argmax(1)

        # Compute loss
        # Note that we should only compute the losses of the nodes in the training set,
        # i.e. with train_mask 1.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        all_logits.append(logits.detach())

        if e % 5 == 0:
            print(
                "In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})".format(
                    e, loss, val_acc, best_val_acc, test_acc, best_test_acc
                )
            )


model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes)
train(g, model)

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
In epoch 0, loss: 1.949, val acc: 0.156 (best 0.156), test acc: 0.144 (best 0.144)
In epoch 5, loss: 1.868, val acc: 0.450 (best 0.450), test acc: 0.425 (best 0.425)
In epoch 10, loss: 1.706, val acc: 0.466 (best 0.466), test acc: 0.469 (best 0.469)
In epoch 15, loss: 1.457, val acc: 0.544 (best 0.544), test acc: 0.559 (best 0.559)
In epoch 20, loss: 1.135, val acc: 0.612 (best 0.612), test acc: 0.635 (best 0.635)
In epoch 25, loss: 0.788, val acc: 0.664 (best 0.664), test acc: 0.694 (best 0.694)
In epoch 30, loss: 0.483, val acc: 0.714 (best 0.714), test acc: 0.740 (best 0.740)
In epoch 35, loss: 0.269, val acc: 0.740 (best 0.740), test acc: 0.751 (best 0.751)
In epoch 40, loss: 0.145, val acc: 0.752 (best 0.752), test acc: 0.753 (best 0.753)
In epoch 45, loss: 0.080, val acc: 0.750 (best 0.754), test acc:

In [None]:
'''
定义一个在聚合函数中考虑边权重的GraphSAGE模型
'''

class WeightedSAGEConv(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(WeightedSAGEConv, self).__init__()
        self.linear = nn.Linear(in_feat*2, out_feat)
        
    def forward(self, g, h, w):
        # 读取的除了节点的特征h之外，还有边的特征（这里是w属性权重）
        with g.local_scope():
            g.ndata['h'] = h
            g.edata['w'] = w
            #NOTE: 为了实现加权平均，在传入时就把边的权重乘到消息上
            g.update_all(
                message_func = fn.u_mul_e('h', 'w', 'm'), # 传入节点的'h'属性和边的'w'属性的克洛克乘积
                reduce_func = fn.mean('m', 'h_N')
            )
            h_N = g.ndata['h_N']
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

In [6]:
class Model2(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model2, self).__init__()
        self.conv1 = WeightedSAGEConv(in_feats, h_feats)
        self.conv2 = WeightedSAGEConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        # 假设边的权重是1
        h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1))
        h = F.relu(h)
        h = self.conv2(g, h, torch.ones(g.num_edges(), 1))
        return h

In [7]:
model = Model2(g.ndata["feat"].shape[1], 16, dataset.num_classes)
train(g, model)

In epoch 0, loss: 1.951, val acc: 0.122 (best 0.122), test acc: 0.130 (best 0.130)
In epoch 5, loss: 1.888, val acc: 0.284 (best 0.284), test acc: 0.276 (best 0.276)
In epoch 10, loss: 1.772, val acc: 0.228 (best 0.284), test acc: 0.221 (best 0.276)
In epoch 15, loss: 1.601, val acc: 0.272 (best 0.284), test acc: 0.291 (best 0.276)
In epoch 20, loss: 1.385, val acc: 0.336 (best 0.336), test acc: 0.359 (best 0.359)
In epoch 25, loss: 1.141, val acc: 0.434 (best 0.434), test acc: 0.444 (best 0.444)
In epoch 30, loss: 0.892, val acc: 0.508 (best 0.508), test acc: 0.514 (best 0.514)
In epoch 35, loss: 0.662, val acc: 0.546 (best 0.546), test acc: 0.554 (best 0.554)
In epoch 40, loss: 0.470, val acc: 0.574 (best 0.574), test acc: 0.581 (best 0.581)
In epoch 45, loss: 0.324, val acc: 0.614 (best 0.614), test acc: 0.619 (best 0.619)
In epoch 50, loss: 0.220, val acc: 0.630 (best 0.630), test acc: 0.642 (best 0.641)
In epoch 55, loss: 0.149, val acc: 0.638 (best 0.638), test acc: 0.658 (best 0

In [None]:
'''
也可以自定义消息传递函数
- 消息传递函数总是接受edges作为输入，该输入包含src, dst, data三个属性.
- 分别对应边的源节点属性，目标节点属性，以及边的属性
- 返回的是一个字典，包含了要发送的消息

'''
def u_mul_e_udf(edges):
    # 这里是源节点的'h'属性乘以边的'w'属性，计算结果作为消息传递到目标节点
    return {"m": edges.src["h"] * edges.data["w"]}


'''
同样的，自定义聚合函数
'''

def mean_udf(nodes):
    return {"h_N": nodes.mailbox["m"].mean(1)}