下面我们介绍一下 GraphSAGE 的代码实现，使用的是 DGL 框架。我们用 link prediction 作为模型的任务来举例。我们先简单的介绍一下链接预测这个任务。许多应用，如社交推荐、项目推荐、知识图谱补全等，都可以表述为链接预测，即预测两个特定节点之间是否存在边。

In [4]:
# 导入相关的库
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import numpy as np
import scipy.sparse as sp


In [8]:
# 导入Cora数据集
import dgl.data

dataset = dgl.data.CoraGraphDataset(raw_dir='./data/')
g = dataset[0]


Downloading ./data/cora_v2.zip from https://data.dgl.ai/dataset/cora_v2.zip...


./data/cora_v2.zip: 100%|████████████████████████████████████████████████████████████| 132k/132k [00:00<00:00, 404kB/s]


Extracting file to ./data/cora_v2_d697a464
Finished data loading and preprocessing.
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done saving data into cached files.


In [10]:
# 准备training set 和 testing set
u, v = g.edges()

eids = np.arange(g.number_of_edges())
eids = np.random.permutation(eids)
test_size = int(len(eids) * 0.1)
train_size = g.number_of_edges() - test_size
# 构造正样本
test_pos_u, test_pos_v = u[eids[:test_size]], v[eids[:test_size]]
train_pos_u, train_pos_v = u[eids[test_size:]], v[eids[test_size:]]

# 构造负样本
adj = sp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())))
adj_neg = 1 - adj.todense() - np.eye(g.number_of_nodes())
neg_u, neg_v = np.where(adj_neg != 0)

neg_eids = np.random.choice(len(neg_u), g.number_of_edges())
test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]]
train_neg_u, train_neg_v = neg_u[neg_eids[test_size:]], neg_v[neg_eids[test_size:]]


训练时，您需要从原始图中删除测试集中的边。您可以通过 dgl.remove_edges 来完成此操作。dgl.remove_edges 的工作原理是从原始图创建子图，从而生成副本，因此对于大型图来说可能会很慢。如果是这样，您可以将训练和测试图保存到磁盘，就像预处理一样。

In [11]:
train_g = dgl.remove_edges(g, eids[:test_size])

下面我们正式定义一个GraphSAGE模型：

In [20]:
from dgl.nn import SAGEConv

# 构建一个两层的 GraphSAGE 模型
class GraphSAGE(nn.Module):
    def __init__(self, in_feats, h_feats):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, 'mean')
        self.conv2 = SAGEConv(h_feats, h_feats, 'mean')

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h


然后，该模型通过计算两个节点的表示之间的得分来预测边缘存在的概率，通常通过一层MLP或者直接计算点积。

DGL recommends you to treat the pairs of nodes as another graph, since you can describe a pair of nodes with an edge. In link prediction, you will have a positive graph consisting of all the positive examples as edges, and a negative graph consisting of all the negative examples. The positive graph and the negative graph will contain the same set of nodes as the original graph. This makes it easier to pass node features among multiple graphs for computation. As you will see later, you can directly feed the node representations computed on the entire graph to the positive and the negative graphs for computing pair-wise scores.

In [13]:
# 构建正样本和负样本的图
train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.number_of_nodes())
train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.number_of_nodes())

test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.number_of_nodes())
test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.number_of_nodes())


构建上面提到的预测函数，如点积和MLP，即 DotPredictor 和 MLPPredictor:

In [14]:
import dgl.function as fn

class DotPredictor(nn.Module):
    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = h
            # 通过点积计算一个新的边的分数
            g.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            # u_dot_v 返回了一个 1-element 的向量，所以需要压平它
            return g.edata['score'][:, 0]

class MLPPredictor(nn.Module):
    def __init__(self, h_feats):
        super().__init__()
        self.W1 = nn.Linear(h_feats * 2, h_feats)
        self.W2 = nn.Linear(h_feats, 1)

    def apply_edges(self, edges):
        """
        Computes a scalar score for each edge of the given graph.

        Parameters
        ----------
        edges :
            Has three members ``src``, ``dst`` and ``data``, each of
            which is a dictionary representing the features of the
            source nodes, the destination nodes, and the edges
            themselves.

        Returns
        -------
        dict
            A dictionary of new edge features.
        """
        h = torch.cat([edges.src['h'], edges.dst['h']], 1)
        return {'score': self.W2(F.relu(self.W1(h))).squeeze(1)}

    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = h
            g.apply_edges(self.apply_edges)
            return g.edata['score']


下面展示整个任务的训练过程

In [61]:
model = GraphSAGE(train_g.ndata['feat'].shape[1], 16)
# You can replace DotPredictor with MLPPredictor.
# pred = MLPPredictor(16)
pred = DotPredictor()

def compute_loss(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score])
    labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])])
    return F.binary_cross_entropy_with_logits(scores, labels)

from sklearn.metrics import roc_auc_score
def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).numpy()
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).numpy()
    return roc_auc_score(labels, scores)


In [62]:
optimizer = torch.optim.Adam(itertools.chain(model.parameters(), pred.parameters()), lr=0.01)

# 训练
all_logits = []
for e in range(100):
    # 前向传播
    # 利用GraphSAGE作特征学习
    h = model(train_g, train_g.ndata['feat'])
    pos_score = pred(train_pos_g, h)
    neg_score = pred(train_neg_g, h)
    loss = compute_loss(pos_score, neg_score)

    # 更新参数
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if e % 5 == 0:
        print('In epoch {}, loss: {}'.format(e, loss))

# 计算AUC
with torch.no_grad():
    pos_score = pred(test_pos_g, h)
    neg_score = pred(test_neg_g, h)
    print('AUC', compute_auc(pos_score, neg_score))


In epoch 0, loss: 0.7080111503601074
In epoch 5, loss: 0.68918377161026
In epoch 10, loss: 0.6679327487945557
In epoch 15, loss: 0.6114687919616699
In epoch 20, loss: 0.5494130849838257
In epoch 25, loss: 0.5187898874282837
In epoch 30, loss: 0.4887141287326813
In epoch 35, loss: 0.4626503586769104
In epoch 40, loss: 0.43711423873901367
In epoch 45, loss: 0.41953739523887634
In epoch 50, loss: 0.40007540583610535
In epoch 55, loss: 0.3816624879837036
In epoch 60, loss: 0.3635600209236145
In epoch 65, loss: 0.345358282327652
In epoch 70, loss: 0.32713955640792847
In epoch 75, loss: 0.3090570867061615
In epoch 80, loss: 0.29081523418426514
In epoch 85, loss: 0.27249887585639954
In epoch 90, loss: 0.25416550040245056
In epoch 95, loss: 0.23578102886676788
AUC 0.8646912692886504
