1. 超参数设置

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lr = 0.001
epochs = 20
hidden_dim = 128
dataset_name = 'cora-seeds'

2. 数据集处理与加载

In [None]:
dataset = gb.BuiltinDataset(dataset_name).load()
graph = dataset.graph.to(device)
feature = dataset.feature.to(device)
train_set = dataset.tasks[1].train_set
test_set = dataset.tasks[1].test_set
task_name = dataset.tasks[1].metadata["name"]
 
datapipe = gb.ItemSampler(train_set, batch_size=256, shuffle=True)
datapipe = datapipe.copy_to(device)
datapipe = datapipe.sample_uniform_negative(graph, 5)
datapipe = datapipe.sample_neighbor(graph, [5, 5])
datapipe = datapipe.transform(partial(gb.exclude_seed_edges, include_reverse_edges=True))
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)

3. 构建卷积层和模型

In [None]:
class SAGEConv(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(SAGEConv, self).__init__()
        # 定义GraphSAGE参数
        self.fc_neigh = nn.Linear(in_feats, out_feats)
        self.fc_self = nn.Linear(in_feats, out_feats)
        # 参数初始化
        gain = nn.init.calculate_gain("relu")
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
    def forward(self, graph, feat):
        with graph.local_scope():
            # 获取源节点和目标节点的特征
            feat_src = feat_dst = feat
            feat_dst = feat_src[: graph.number_of_dst_nodes()]
            # 定义消息传递函数
            msg_fn = fn.copy_u("h", "m")
            # 定义聚合函数
            reduce_fn = fn.mean("m", "neigh")
            # 自身特征
            h_self = feat_dst
            # 消息传递
            graph.srcdata["h"] = self.fc_neigh(feat_src)
            graph.update_all(msg_fn, reduce_fn)
            h_neigh = graph.dstdata["neigh"]
            # 聚合自身特征和邻居特征
            rst = self.fc_self(h_self) + h_neigh
            return rst

# 定义 GraphSAGE 模型
class SAGE(nn.Module):
    def __init__(self, in_size, hidden_size):
        super().__init__()
        self.layers = nn.ModuleList()
        # 增加graphsage层
        self.layers.append(SAGEConv(in_size, hidden_size))
        self.layers.append(SAGEConv(hidden_size, hidden_size))
        self.layers.append(SAGEConv(hidden_size, hidden_size))
        self.hidden_size = hidden_size
        # 定义预测器
        self.predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )
    def forward(self, blocks, x):
        # 前向传播获得节点嵌入
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
        return hidden_x


4. 定义损失函数

In [None]:
criterion = nn.BCEWithLogitsLoss()

5. 设置优化器

In [None]:
optimizer = optim.Adam(model.parameters(), lr=lr)

6. 模型评估函数

In [None]:
def evaluate(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for step, data in enumerate(dataloader):
            compacted_seeds = data.compacted_seeds.T
            labels = data.labels
            node_feature = data.node_features["feat"]
            blocks = data.blocks
            # 获得节点embedding
            y = model(blocks, node_feature)
            # 使用预测器获得预测值
            logits = model.predictor(
                y[compacted_seeds[0]] * y[compacted_seeds[1]]
            ).squeeze()
            # 获得预测结果
            preds = torch.round(torch.sigmoid(logits))
            # 预测准确的个数
            correct += (preds == labels).sum().item()
            # 样本总数
            total += labels.size(0)
    # 计算准确率
    accuracy = correct / total
    return accuracy


7. 模型训练流程

In [None]:
for epoch in range(epochs):
    model.train()
    total_loss = 0
    start_epoch_time = time.time()
    for step, data in enumerate(dataloader):
        # mini-batch中的节点ID
        compacted_seeds = data.compacted_seeds.T
        # 获取mini-batch中的标签、特征、block
        labels = data.labels
        node_feature = data.node_features["feat"]
        blocks = data.blocks
        # 获取输入节点的表示
        y = model(blocks, node_feature)
        logits = model.predictor(
            y[compacted_seeds[0]] * y[compacted_seeds[1]]
        ).squeeze()
        # 计算损失并反向传播
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    end_epoch_time = time.time()
    print(f"Epoch {epoch} | Loss: {total_loss:.4f} | Time: {end_epoch_time - start_epoch_time:.2f}s")
    # 评估训练结果
    if (epoch+1) % 5 == 0:
        accuracy = evaluate(model, dataloader)
        print(f"Accuracy: {accuracy:.4f}")
