1. 超参数设置：

In [None]:
lr = 0.01
n_epoch = 200
hidden_dim = 16
l2_coef = 5e-4
dataset = 'cora'
dataset_path = './examples/gcn/'
best_model_path = './'
self_loops = 1
gpu = -1
if gpu >= 0:
    tlx.set_device("GPU", gpu)
else:
    tlx.set_device("CPU")

2. 数据集处理与加载

In [None]:
from gammagl.loader import DataLoader
from gammagl.datasets import TUDataset

dataset = TUDataset(path, name=args.dataset)
dataset_unit = len(dataset) // 10
train_dataset = dataset[2 * dataset_unit:]
val_dataset = dataset[:dataset_unit]
test_dataset = dataset[dataset_unit: 2 * dataset_unit]
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)


3. 构建卷积层

In [None]:
import tensorlayerx as tlx
from .mlp import MLP
from gammagl.layers.conv import MessagePassing
from gammagl.layers.pool.glob import global_sum_pool

class GINConv(MessagePassing):
    def __init__(self, nn, **kwargs):
        super().__init__(**kwargs)
        self.nn = nn

    def forward(self, x, edge_index, size=None):
        if not isinstance(x, (list, tuple)):
            x = (x, x)
        out = self.propagate(x=x[0], edge_index=edge_index, size=size) # 消息传递
        x_r = x[1]
        out += x_r # 跳跃连接部分加到输出上，缓解梯度消失或梯度爆炸问题
        return self.nn(out)

    def message(self, x, edge_index):
        return tlx.gather(x, edge_index[0, :])

class GINModel(tlx.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, name="GIN"):
        super(GINModel, self).__init__(name=name)
        self.convs = tlx.nn.ModuleList()
        mlp = MLP([in_channels, hidden_channels, hidden_channels]) #构建多层感知机
        self.convs.append(GINConv(nn=mlp, train_eps=False))
        in_channels = hidden_channels
        self.mlp = MLP([hidden_channels, hidden_channels, out_channels],
        norm=None, dropout=0.5)
        self.relu = tlx.ReLU()

    def forward(self, x, edge_index, batch):
        x = tlx.random_normal((batch.shape[0], 1), dtype=tlx.float32)
        for conv in self.convs:
            x = self.relu(conv(x, edge_index))
        x = global_sum_pool(x, batch)
        return self.mlp(x)

net = GINModel(in_channels=max(dataset.num_features, 1),
    hidden_channels=args.hidden_dim,
    out_channels=dataset.num_classes,
    name="GIN")


4. 定义损失函数：

In [None]:
from tensorlayerx.model import WithLoss, TrainOneStep 
class SemiSpvzLoss(WithLoss):
    def __init__(self, net, loss_fn):
        super(SemiSpvzLoss,   self).__init__(backbone=net, loss_fn=loss_fn)
    def forward(self, data, y):
        logits = self.backbone_network(data['x'],
        data['edge_index'],
        None,
        data['num_nodes']
        )
        # 根据输入的节点特征、边的连接信息等数据计算出模型的输出（logits）
        train_logits = tlx.gather(logits, data['train_idx'])  
        # 通过tlx.gather从标签中选择出训练集的真实标签 
        train_y = tlx.gather(data['y'], data['train_idx']) 
        loss = self._loss_fn(train_logits, train_y)
        return loss

train_weights = net.trainable_weights
loss_func = SemiSpvzLoss(net,   tlx.losses.softmax_cross_entropy_with_logits)

5. 设置优化器

In [None]:
optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.l2_coef)

6. 设置模型评测指标

In [None]:
def calculate_acc(logits, y, metrics):
    metrics.update(logits, y)
    rst = metrics.result()
    metrics.reset()
    return rst
metrics = tlx.metrics.Accuracy()


7. 定义模型训练、推理流程

In [None]:
train_one_step = TrainOneStep(loss_func, optimizer, train_weights)
best_val_acc = 0
for epoch in range(args.n_epoch):
    net.set_train()
    train_loss = train_one_step(data, graph.y) # 进行每轮训练
    net.set_eval()
    logits = net(data['x'], data['edge_index'], None, data['num_nodes']) # 执行模型的前向传播计算，生成预测的输出（logits）
    val_logits = tlx.gather(logits, data['val_idx'])
    val_y = tlx.gather(data['y'], data['val_idx'])
    val_acc = calculate_acc(val_logits, val_y, metrics)# 计算验证集上的准确率。val_logits是验证集的预测值，val_y是验证集的真实标签，metrics是评估标准
    print("Epoch [{:0>3d}] ".format(epoch+1)\
        + " train loss: {:.4f}".format(train_loss.item())\
        + " val acc: {:.4f}".format(val_acc))
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        # 保留验证集上表现最好的模型参数,作为测试集采用的模型参数
        net.save_weights(args.best_model_path+net.name+".npz", format='npz_dict')

net.load_weights(args.best_model_path+net.name+".npz", format='npz_dict')
net.set_eval()
logits = net(data['x'], data['edge_index'], None, data['num_nodes'])
test_logits = tlx.gather(logits, data['test_idx'])
test_y = tlx.gather(data['y'], data['test_idx'])
test_acc = calculate_acc(test_logits, test_y, metrics)
print("Test acc: {:.4f}".format(test_acc))
