## 教程9: 循环GNN
在这个教程中,将实现图神经网络模型（不强制执行收缩映射）并分析pytorch geometric的GatedGraph

In [1]:
import os.path as osp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
import torch_geometric
from torch_geometric.datasets import Planetoid, TUDataset
from torch_geometric.data import DataLoader
from torch_geometric.nn.inits import uniform
from torch.nn import Parameter as Param
from torch import Tensor

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'

from torch_geometric.nn.conv import MessagePassing

### 读取数据

In [2]:
dataset = 'Cora'
path = osp.join('data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]
data = data.to(device)

### 图神经网络
##### 多层感知机
关键功能说明：  
* 动态层构建：通过输入维度列表自动构建多层网络构建，避免硬编码  
* 激活函数策略：仅在隐藏层添加Tanh激活，适合回归、特征转换任务  
* 参数初始化：使用xavier_normal_初始化权重，缓解梯度消失、爆炸问题，提升训练稳定性  

In [9]:
class MLP(nn.Module):
    def __init__(self, input_dim, hid_dim, out_dim):
        super(MLP, self).__init__()
        # 构建层维度序列：[输入层, 隐藏层1, ... , 隐藏层n, 输出层]
        dims = [input_dim] + hid_dim + [out_dim]
        self.mlp = nn.Sequential()

        # 动态构建网络层
        for i in range(len(dims)-1):
            # 添加线性层（全链接层）
            self.mlp.add_module(f'lay_{i}', nn.Linear(dims[i], dims[i+1]))
            # 除最后一层外都添加Tanh激活函数
            if i+2 < len(dims):
                self.mlp.add_module(f'act_{i}', nn.Tanh())
    
    def reset_parameters(self):
        """重置所有线性层的参数（Xavier正态初始化）"""
        for layer in self.mlp:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)
    
    def forward(self, x):
        return self.mlp(x)

##### 图神经网络模型  
关键功能说明：  
* 节点状态管理： 
1. node_states是非可训练参数，用于存储节点特征迭代过程中的动态  
2. 初始状态为0向量，后续通过消息传递更新 
* 双MLP设计：  
1. transition: 将聚合后的消息转化为新节点状态（维度不变），实现图上的状态传播
2. readout: 将最终节点状态映射到分类输出空间（维度变化）
* 收敛机制：通过eps控制迭代终止条件，避免固定层数的冗余计算.  
$$
x_{v}^{t+1} = f_{w}(l_{v}, l_{co(v)}, x^{t}_{ne(v)}, l_{ne(v)})
$$
$$
o^{t}_{v} = g_w(x^{t}_{v}, l_{v})
$$

In [12]:
class GNNM(MessagePassing):
    def __init__(self, n_nodes, out_channels, features_dim, hid_dims,
                 num_layers=50, eps=1e-3, aggr='add', bias=True, **kwargs):
        super(GNNM, self).__init__(aggr=aggr, **kwargs)

        # 节点状态（非可训练参数）:存储每个节点的特征表示
        self.node_states = nn.Parameter(torch.zeros((n_nodes, features_dim)), requires_grad=False)
        self.out_channels = out_channels        # 分类类别
        self.eps = eps                          # 收敛闸值
        self.num_layers = num_layers            # 最大迭代次数

        # 2个关键MLP
        self.transition = MLP(features_dim, hid_dims, features_dim)     # 状态更新函数
        self.readout = MLP(features_dim, hid_dims, out_channels)        # 输出函数

        self.reset_parameters()
        print("================transition==================")
        print(self.transition)
        print("================readout==================")
        print(self.readout)
    
    def reset_parameters(self):
        """重置MLP参数"""
        self.transition.reset_parameters()
        self.readout.reset_parameters()

    def forward(self):
        edge_index = data.edge_index    # 图结构（2, E） 
        edge_weight = data.edge_attr    # 边权重 (E, )
        node_states = self.node_states

        for i in range(self.num_layers):
            # 1. 消息传递：聚合邻居消息
            m = self.propagate(edge_index, x=node_states, edge_weight=edge_weight)
            # 2. 状态更新：通过Transition MLP
            new_states = self.transition(m)
            # 3. 收敛检查
            with torch.no_grad():
                distance = torch.norm(new_states - node_states, dim=1)  # 计算：L2距离
                convergence = distance < self.eps
            node_states = new_states    # 更新节点状态
            if convergence.all():       # 所有节点收敛则提前终止
                break

        # 4. 获取分类结果
        out = self.readout(node_states)
        return F.log_softmax(out, dim=-1)
    
    ## 计算从邻居j到中心节点的消息
    def message(self, x_j, edge_weight):
        return x_j if edge_weight is None else edge_weight.view(-1, 1)*x_j
    
    ## 使用稀疏矩阵乘法替代循环
    def message_and_aggregate(self, adj_t, x):
        return matmul(adj_t, x, reduce=self.aggr)
    
    def __repr__(self):
        return '{}({}, num_layers={})'.format(self.__class__.__name__,
                                              self.out_channels, self.num_layers)
        

## 训练数据

In [13]:
model = GNNM(data.num_nodes, dataset.num_classes, 32, [64,64,64,64,64], eps=0.01).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

MLP(
  (mlp): Sequential(
    (lay_0): Linear(in_features=32, out_features=64, bias=True)
    (act_0): Tanh()
    (lay_1): Linear(in_features=64, out_features=64, bias=True)
    (act_1): Tanh()
    (lay_2): Linear(in_features=64, out_features=64, bias=True)
    (act_2): Tanh()
    (lay_3): Linear(in_features=64, out_features=64, bias=True)
    (act_3): Tanh()
    (lay_4): Linear(in_features=64, out_features=64, bias=True)
    (act_4): Tanh()
    (lay_5): Linear(in_features=64, out_features=32, bias=True)
  )
)
MLP(
  (mlp): Sequential(
    (lay_0): Linear(in_features=32, out_features=64, bias=True)
    (act_0): Tanh()
    (lay_1): Linear(in_features=64, out_features=64, bias=True)
    (act_1): Tanh()
    (lay_2): Linear(in_features=64, out_features=64, bias=True)
    (act_2): Tanh()
    (lay_3): Linear(in_features=64, out_features=64, bias=True)
    (act_3): Tanh()
    (lay_4): Linear(in_features=64, out_features=64, bias=True)
    (act_4): Tanh()
    (lay_5): Linear(in_features=64, ou

In [14]:
test_dataset = dataset[:len(dataset)//10]
train_dataset = dataset[len(dataset)//10:]
test_loader = DataLoader(test_dataset)
train_loader = DataLoader(train_dataset)

def train():
    model.train()
    optimizer.zero_grad()
    loss_fn(model()[data.train_mask], data.y[data.train_mask]).backward()
    optimizer.step()

def test():
    model.eval()
    logits, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item()/mask.sum().item()
        accs.append(acc)
    return accs

for epoch in range(1, 51):
    train()
    accs = test()
    print(f'Epoch:{epoch:03d}, Train Acc:{accs[0]:.4f}, Val Acc:{accs[1]:.4f}, Val Acc:{accs[2]:.4f}')



Epoch:001, Train Acc:0.1286, Val Acc:0.1200, Val Acc:0.1290
Epoch:002, Train Acc:0.1571, Val Acc:0.1420, Val Acc:0.1380
Epoch:003, Train Acc:0.1429, Val Acc:0.1620, Val Acc:0.1490
Epoch:004, Train Acc:0.1571, Val Acc:0.1140, Val Acc:0.1050
Epoch:005, Train Acc:0.1643, Val Acc:0.0980, Val Acc:0.0950
Epoch:006, Train Acc:0.1429, Val Acc:0.0580, Val Acc:0.0640
Epoch:007, Train Acc:0.1429, Val Acc:0.0580, Val Acc:0.0640
Epoch:008, Train Acc:0.1429, Val Acc:0.0580, Val Acc:0.0640
Epoch:009, Train Acc:0.1500, Val Acc:0.0660, Val Acc:0.0690
Epoch:010, Train Acc:0.1143, Val Acc:0.0740, Val Acc:0.0900
Epoch:011, Train Acc:0.1357, Val Acc:0.0820, Val Acc:0.0920
Epoch:012, Train Acc:0.1429, Val Acc:0.0800, Val Acc:0.0940
Epoch:013, Train Acc:0.1214, Val Acc:0.0800, Val Acc:0.0990
Epoch:014, Train Acc:0.1143, Val Acc:0.1180, Val Acc:0.1040
Epoch:015, Train Acc:0.1571, Val Acc:0.1260, Val Acc:0.1340
Epoch:016, Train Acc:0.1429, Val Acc:0.1680, Val Acc:0.1860
Epoch:017, Train Acc:0.1571, Val Acc:0.2

### 门控制图神经网络

In [20]:
class GatedGraphConv(MessagePassing):
    def __init__(self, out_channels, num_layers, aggr='add', bias=True, **kwargs):
        # 调用父类MessagePassing的初始化，aggr='add'表示求和聚合
        super(GatedGraphConv, self).__init__(aggr=aggr, **kwargs)

        self.out_channels = out_channels    # 存储输出维度  
        self.num_layers = num_layers        # 存储迭代次数
        # 创建可学习参数：[迭代次数， 输出维度， 输出维度]
        self.weight = Param(Tensor(self.num_layers, self.out_channels, self.out_channels))
        # 创建GRUCell单元：输入维度=out_channels, 隐藏维度=out_channels
        self.rnn = torch.nn.GRUCell(self.out_channels, self.out_channels, bias=bias)
        self.reset_parameters()     # 初始化参数
    
    def reset_parameters(self):
        # weight: uniform函数，从[-limit, limit]均匀分布初始化
        uniform(self.out_channels, self.weight)
        self.rnn.reset_parameters()
    
    def forward(self, data):
        x = data.x                      # 获取节点特征，[num_nodes, in_channels]
        edge_index = data.edge_index    # 边索引，[2, num_edge]
        edge_weight = data.edge_attr    # 边权重，[num_edge]
        # 1. 检查输入维度合法性-若输入维度>输出维度报错（无法通过填充0补充维度）
        if x.size(-1) > self.out_channels:
            raise ValueError('输入通道数不能大于输出通道数')
            
        # 2. 特征维度处理
        if x.size(-1) < self.out_channels:
            zero = x.new_zeros(x.size(0), self.out_channels-x.size(-1))
            x = torch.cat([x, zero], dim=1)     # 补齐维度：[num_nodes, in_channels]->[num_nodes, out_channels]
        
        # 3. 迭代num_layers次（消息传递+GRU更新）
        for i in range(self.num_layers):
            # 3.1 线性变换： X @ W: [num_nodes, out_channels] @ [out_channels, out_channels] -> [num_nodes, out_cannels]
            m = torch.matmul(x, self.weight[i])
            # 3.2 消息传递, - message:生成边消息，-aggregate:聚合邻居消息
            m = self.propagate(edge_index, x=m, edge_weight=edge_weight, size=None)
            # 3.3 GRU更新:(输入消息m，隐藏状态x)-循环网络GRU-RNN [num_nodes, out_channels]
            x = self.rnn(m, x)
        # 4. 返回最终节点表示
        return x
    
    # 消息构造函数
    def message(self, x_j, edge_weight):
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
    
    # 消息聚合函数
    def message_and_aggregate(self, adj_t, x):
        return matmul(adj_t, x, reduce=self.aggr)
    
    def __repr__(self):
        """返回层的字符串表示"""
        return f'''
    {self.__class__.__name__}({self.out_channels}, num_layers={self.num_layers})
    '''

class GGNN(torch.nn.Module):
    def __init__(self):
        super(GGNN, self).__init__()

        self.conv = GatedGraphConv(1433, 3)
        self.mlp = MLP(1433, [32, 32, 32], dataset.num_classes)
    
    def forward(self):
        x = self.conv(data)
        x = self.mlp(x)
        return F.log_softmax(x, dim=-1)

#### 模型训练

In [21]:
model = GGNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

In [22]:
test_dataset = dataset[:len(dataset)//10]
train_dataset = dataset[len(dataset)//10:]
test_loader = DataLoader(test_dataset)
train_loader = DataLoader(train_dataset)

def train():
    model.train()
    optimizer.zero_grad()
    loss_fn(model()[data.train_mask], data.y[data.train_mask]).backward()
    optimizer.step()

def test():
    model.eval()
    logits, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item()/mask.sum().item()
        accs.append(acc)
    return accs

for epoch in range(1, 51):
    train()
    accs = test()
    print('Epoch: {:03d}, Train Acc:{:.5f}, Val Acc:{:.5f}, Test Acc:{:.5f}'.format(epoch, accs[0], accs[1], accs[2]))

Epoch: 001, Train Acc:0.15000, Val Acc:0.16600, Test Acc:0.15200
Epoch: 002, Train Acc:0.15000, Val Acc:0.14800, Test Acc:0.14300
Epoch: 003, Train Acc:0.22857, Val Acc:0.21400, Test Acc:0.20400
Epoch: 004, Train Acc:0.27857, Val Acc:0.23800, Test Acc:0.22200
Epoch: 005, Train Acc:0.25714, Val Acc:0.24600, Test Acc:0.25100
Epoch: 006, Train Acc:0.30000, Val Acc:0.27200, Test Acc:0.26300
Epoch: 007, Train Acc:0.37143, Val Acc:0.32000, Test Acc:0.31500
Epoch: 008, Train Acc:0.43571, Val Acc:0.36800, Test Acc:0.34200
Epoch: 009, Train Acc:0.45000, Val Acc:0.41800, Test Acc:0.40100
Epoch: 010, Train Acc:0.52143, Val Acc:0.46400, Test Acc:0.42600
Epoch: 011, Train Acc:0.54286, Val Acc:0.47000, Test Acc:0.47800
Epoch: 012, Train Acc:0.56429, Val Acc:0.47800, Test Acc:0.48900
Epoch: 013, Train Acc:0.56429, Val Acc:0.53200, Test Acc:0.49600
Epoch: 014, Train Acc:0.62857, Val Acc:0.51800, Test Acc:0.50900
Epoch: 015, Train Acc:0.65000, Val Acc:0.52000, Test Acc:0.53000
Epoch: 016, Train Acc:0.6

KeyboardInterrupt: 