## 教程5:聚合
在本教程中，将重写Pytorch Geometric 中 GIN卷积模块的聚合方法，实现以下方法：  
* 主要领域聚合（PNA）  
* 学习聚合函数（LAF）

In [4]:
import os
import torch
import torch_scatter
import torch.nn.functional as F
torch.manual_seed(42)

OSError: dlopen(/Volumes/Data/software/conda/miniconda3/envs/dl_env/lib/python3.10/site-packages/torch_scatter/_version_cpu.so, 0x0006): symbol not found in flat namespace '__ZN5torch3jit17parseSchemaOrNameERKNSt3__112basic_stringIcNS1_11char_traitsIcEENS1_9allocatorIcEEEEb'

### 消息传递类

In [2]:
from torch_geometric.nn import MessagePassing

In [3]:
dir(MessagePassing)

['SUPPORTS_FUSED_EDGE_INDEX',
 'T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_call_impl',
 '_check_input',
 '_collect',
 '_compiled_call_impl',
 '_get_backward_hooks',
 '_get_backward_pre_hooks',
 '_get_edge_updater_signature',
 '_get_name',
 '_get_propagate_signature',
 '_index_select',
 '_index_select_safe',
 '_lift',
 '_load_from_state_dict',
 '_maybe_warn_non_full_backward_hook',
 '_named_members',
 '_register_load_state_dict_pre_hook',
 '_register_state_dict_hook',
 '_replicate_for_data_parallel',
 '_save_to_state_dict',
 '_set_jittable_templates',
 '_

对聚合方法感兴趣或者如果使用稀疏邻接矩阵，对消息和聚合方法感兴趣，我们构建了自定义的卷积类，该类扩展了GINConv

In [1]:
from torch.nn import Parameter, Module, Sigmoid
import torch    # 用于图数据的散列聚合操作
import torch.nn.functional as F

### 定义LAFLayer的基础结构，处理权重初始化、设备管理、关键张量的预定义
class AbstractLAFLayer(Module):
    def __init__(self, **kwargs):
        super(AbstractLAFLayer, self).__init__()

        # --- 设备管理模块 ---
        assert 'units' in kwargs or 'weights' in kwargs # 必须提供units或weights参数
        if 'device' in kwargs.keys():
            self.device = kwargs['device']  # 指定设备(CPU/GPU)
        else:                               # 自动选择可用设备
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.ngpus = torch.cuda.device_count()  # 可用GPU数量

        # --- 权重初始化模块 ---
        if 'kernel_initializer' in kwargs.keys():
            # 检测初始化方法是否合法
            assert kwargs['kernel_initializer'] in [
                'random_normal', 'glorot_normal', 'he_normal', 
                'random_uniform', 'glorot_uniform', 'he_uniform'
            ]
            self.kernel_initializer = kwargs['kernel_initializer']
        else:
            self.kernel_initializer = 'random_normal'  # 默认初始化方法
        
        # --- 权重创建逻辑 ---
        if 'weights' in kwargs.keys():
            # 直接使用外部提供的权重
            self.weights = Parameter(kwargs['weights'].to(self.device), requires_grad=True)
            self.units = self.weights.shape[1]  # 根据权重形状确定units
        else:
            # 指定units创建随机权重
            self.units = kwargs['units']
            params = torch.empty(12, self.units, device=self.device)   # 12xN权重矩阵

            # 根据初始化策略填充权重
            if self.kernel_initializer == 'random_normal':      # 标准正态分布初始化
                torch.nn.init.normal_(params)
            elif self.kernel_initializer == 'glorot_normal':    # Glorot正态分布初始化
                torch.nn.init.xavier_normal_(params)
            elif self.kernel_initializer == 'he_normal':        # He正态分布初始化
                torch.nn.init.kaiming_normal_(params)
            elif self.kernel_initializer == 'random_uniform':   # 均匀分布初始化
                torch.nn.init.uniform_(params)
            elif self.kernel_initializer == 'glorot_uniform':   # Glorot均匀分布初始化
                torch.nn.init.xavier_uniform_(params)
            elif self.kernel_initializer == 'he_uniform':       # He均匀分布初始化
                torch.nn.init.kaiming_uniform_(params)
            self.weights = Parameter(params, requires_grad=True)  # 注册为可训练参数
        
        # --- 预定义关键张量 ---
        e = torch.tensor([1, -1, 1, -1], dtype=torch.float32, device=self.device)   
        self.e = Parameter(e, requires_grad=False) # 用于符号变换的张量

        num_idx = torch.tensor([1,1,0,0], dtype=torch.float32, device=self.device).view(1,1,-1,1)
        self.num_idx = Parameter(num_idx, requires_grad=False) # 分子索引掩码

        den_idx = torch.tensor([0,0,1,1], dtype=torch.float32, device=self.device).view(1,1,-1,1)
        self.den_idx = Parameter(den_idx, requires_grad=False) # 分母索引掩码


In [5]:
### 实现核心的LAF操作，用于图数据的特征聚合
class LAFLayer(AbstractLAFLayer):
    def __init__(self, eps=1e-7,**kwargs):
        super(LAFLayer, self).__init__(**kwargs)  # 调用父类初始化
        self.eps = eps                            # 防止除零的小常数

    def forward(self, data, index, dim=0, **kwargs):
        """
        前向传播实现：
        Args:
            data: 输入特征张量 [N, F]
            index: 图节点索引 [E] （表示边的连接关系）
            dim: 聚合维度，默认0
        """
        eps = self.eps
        sup = 1.0 - eps   # 上界，防止数值不稳定

        # --- 数据预处理 ---
        x = torch.clamp(data, min=-sup, max=sup)  # 限制输入范围
        x = torch.unsqueeze(x, -1)  # 扩展维度 [N, F, 1]
        e = self.e.view(1,1,-1)  # 符号变换张量 [1, 1, 4]

        # --- 指数变换阶段 ---
        #  计算： (1-e)/2 * x*e -> [0.5, 1.5, 0.5, 1.5] 区间
        exps = (1. - e)/2 + x*e
        exps = torch.unsqueeze(exps, -1)  # [N, F, 4, 1]
        # 应用科学系参数权重
        exps = torch.pow(exps, torch.relu(self.weights[0:4]))  # [N, F, 4, units]

        # --- 聚合阶段 ---
        # 按照index对边特征求和聚合（图卷积核心）
        scatter = torch_scatter.scatter_add(exps, index.view(-1), dim=dim)
        scatter = torch.clamp(scatter, eps)  # 防止数值不稳定   

        # --- 特征变换阶段 ---
        # 应用第二指数权重（5-8个参数）
        sqrt = torch.pow(scatter, torch.relu(self.weights[4:8]))  # [N, F, 4, units]
        # 重新第三组权重
        alpha_beta = self.weights[8:12].view(1,1,4,-1)
        terms = sqrt * alpha_beta  # 加权特征 [N, F, 4, units]

        # --- 分子分母计算 ---
        num = torch.sum(terms * self.num_idx, dim=2)  # 分子：索引0+1位置求和
        den = torch.sum(terms * self.den_idx, dim=2)  # 分母：索引2+3位置求和

        # 防止分母为0的技巧
        multiplier = 2.0*torch.clamp(torch.sign(den), min=0.0) - 1.0 # 生成+1/-1掩码
        den = torch.where((den<eps) & (den > -eps), multiplier*eps, den)  # 使用+/-eps替换0值

        # 最终输出 = 分子/分母
        res = num / den
        return res
    
        

In [6]:
### 定义GINLAFConv 图卷积层，结合GIN和LAF机制
from torch_geometric.nn import GINConv
from torch.nn import Linear

class GINLAFConv(GINConv):
    def __init__(self, nn, units=1, node_dim=32, **kwargs):
        super(GINLAFConv, self).__init__(nn, **kwargs)
        self.laf = LAFLayer(units=units, kernel_initializer='random_normal')
        self.mlp = Linear(node_dim*units, node_dim)
        self.dim = node_dim
        self.units = units

    def aggregate(self, inputs, index):
        x = torch.sigmoid(inputs)                   # 归一化
        x = self.laf(x, index)                      # LAF聚合
        x = x.view((-1, self.dim*self.units))       # 特征重塑
        x = self.mlp(x)                             # 线性变换
        return x
    
    

In [7]:
# PNA聚合
class GINPNAConv(GINConv):
    def __init__(self, nn, node_dim=32, **kwargs):
        super(GINPNAConv, self).__init__(nn, **kwargs)
        self.mlp = torch.nn.Linear(node_dim*12, node_dim)
        self.delta = 2.5749

    def aggregate(self, inputs, index):
        ## --- 步骤1： 基础聚合函数计算 ---
        sums = torch_scatter.scatter_add(inputs, index, dim=0)      # 1. 求和聚合
        maxs = torch_scatter.scatter_max(inputs, index, dim=0)[0]   # 2. 最大聚合
        means = torch_scatter.scatter_mean(inputs, index, dim=0)    # 3. 均值聚合
        var = torch.relu(torch_scatter.scatter_mean(inputs ** 2, index, dim=0) - means ** 2) # 4. 方差聚合

        # 存储4种基础聚合结果
        aggrs = [sums, maxs, means, var]

        ## --- 步骤2： 节点度信息计算 ---
        # 计算每个节点的度
        c_idx = index.bincount().float().view(-1, 1) # [num_nodes, 1]
        l_idx = torch.log(c_idx+1.)     # 度的对数平滑

        ## --- 步骤3： 度依赖归一化 ---
        amplification_scaler = [c_idx/self.delta*a for a in aggrs]  # 放大归一化：强化高度节点的聚合信号
        attenuation_scaler = [self.delta/c_idx*a for a in aggrs]    # 衰减归一化：抑制高度节点的聚合信号

        ## --- 步骤4： 特征组合 ---
        # 拼接所有特征：4基础 + 4放大 + 4衰减 = 12种特征视角
        combinations = torch.cat(
            aggrs + amplification_scaler + attenuation_scaler,
            dim=1
        )

        ## --- 步骤5： 特征压缩 ---
        x = self.mlp(combinations)
        return x

## 数据导入

In [None]:
from torch_geometric.nn import MessagePassing, SAGEConv, GINConv, global_add_pool
import torch_scatter
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
import os.path as osp


#### LAFNET模型
核心功能：  
一个基于GINLAFConv的5层图神经网络，专为图分类任务设计：  
* 使用LAF聚合替代传统求和操作  
* 通过BatchNorm和Dropout防止过拟合  
* 最终用全局加和池化将节点特征聚合为图级表示  
____  
核心设计原理：  
* 层次化特征提取：  
每层捕获不同跳数的邻居信息（1层=1跳邻居，5层=5跳邻居）
* 全局池化的选择-global_add_pool  
1. 保持置换不变形：图结构无序，加操作满足对称性
2. 比 global_mean_pool 保留更多信息  
3. 比 global_max_pool 保留更多细节(避免弱信号丢失)  

In [9]:
class LAFNET(torch.nn.Module):
    def __init__(self, num_features, num_class, **kargus):
        super(LAFNET, self).__init__()

        ## --- 模型超参数定义 ---
        self.num_features = num_features        # 输入特征维度
        self.num_class = num_class              # 输出分类种类
        self.dim = 32                           # 隐藏层维度
        self.units =3                           # LAF 扩展单元数（每个特征生成3个视角）

        ## --- 第1层： 输入特征处理 ---
        nn1 = Sequential(Linear(self.num_features, self.dim), ReLU(), Linear(self.dim, self.dim))
        self.conv1 = GINLAFConv(nn1, units=self.units, node_dim=self.num_features)
        self.bn1 = torch.nn.BatchNorm1d(self.dim)       # 确保 LAF 输入维度匹配
        
        ## --- 第2层： 隐层处理（结构相同，维度传递）---
        nn2 = Sequential(Linear(self.dim, self.dim), ReLU(), Linear(self.dim, self.dim))
        self.conv2 = GINLAFConv(nn2, units=self.units, node_dim=self.dim)
        self.bn2 = torch.nn.BatchNorm1d(self.dim)

        nn3 = Sequential(Linear(self.dim, self.dim), ReLU(), Linear(self.dim, self.dim))
        self.conv3 = GINLAFConv(nn3, units=self.units, node_dim=self.dim)
        self.bn3 = torch.nn.BatchNorm1d(self.dim) 

        nn4 = Sequential(Linear(self.dim, self.dim), ReLU(), Linear(self.dim, self.dim))
        self.conv4 = GINLAFConv(nn4, units=self.units, node_dim=self.dim)
        self.bn4 = torch.nn.BatchNorm1d(self.dim)

        nn5 = Sequential(Linear(self.dim, self.dim), ReLU(), Linear(self.dim, self.dim))
        self.conv5 = GINLAFConv(nn5, units=self.units, node_dim=self.dim)
        self.bn5 = torch.nn.BatchNorm1d(self.dim)
        
        ## --- 全连接层：图级分类 ---
        self.fc1 = Linear(self.dim, self.dim)       # 隐藏层
        self.fc2 = Linear(self.dim, self.num_class)      # 输出层
    
    def forward(self, x, edge_index, batch):
        """
        前向传递
        1. x: 节点特征：[num_nodes, num_features]
        2. edge_index: 边索引 [2, num_edges]
        3. batch： 节点所属图索引 [num_nodes]
        """
        ## --- 第1层：GINLAFConv ---
        # 输入：x=[N,F],edge_index=[2, E] -> 输出：[N, dim]; 作用：初始特征转换+LAF聚合
        x = F.relu(self.conv1(x, edge_index))
        x = self.bn1(x)     # [N, dim] - 维度归一化
        ## --- 第2-5层：GINLAFConv ---  : 作用 多层特征提取 + 聚合， 捕获高阶结构信息
        x = F.relu(self.conv2(x, edge_index))            # 作用：对conv1特征转化 + LAF聚合
        x = self.bn2(x)     # [N, dim] - 维度归一化
        x = F.relu(self.conv3(x, edge_index))            # 作用：对conv2特征转化 + LAF聚合
        x = self.bn3(x)     # [N, dim] - 维度归一化
        x = F.relu(self.conv4(x, edge_index))            # 作用：对conv3特征转化 + LAF聚合
        x = self.bn4(x)     # [N, dim] - 维度归一化
        x = F.relu(self.conv5(x, edge_index))            # 作用：对conv4特征转化 + LAF聚合
        x = self.bn5(x)     # [N, dim] - 维度归一化
        ## --- 全局图池化 ---
        # 输入：x=[N,dim], batch=[N] -> 输出：[num_graphs, dim]; 作用：将节点特征聚合为图级表示
        x = global_add_pool(x, batch)
        ## --- 全连接分类器 ---
        x = F.relu(self.fc1(x))     # [G, dim] -> [G, dim]
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)             # [G, dim] -> [G, num_class]
        return F.log_softmax(x, dim=-1) # [G, num_class] -> [G, num_class] 对每个节点的种类的概率

#### PNANET模型
GINPNA卷积提取的特征:  
* 主邻域聚合：可能结合多种聚合函数(mean, max, sum等)  
* 度感知放缩：考虑节点度进行特征放缩  
* 多尺度信息：捕获不同范围的领域特征  

In [None]:

class PNANet(torch.nn.Module):
    def __init__(self, num_features, num_class, **kargus):
        super(PNANet, self).__init__()
        ## --- 定义模型参数 ---
        self.num_features = num_features    # 节点特征维度
        self.num_class = num_class          # 节点的种类
        self.dim = 32                       # 隐藏层维度
        ## --- 第1层： GINPNA ---
        nn1 = Sequential(Linear(self.num_features, self.dim), ReLU(), Linear(self.dim, self.dim))
        self.conv1 = GINPNAConv()
        ## --- 第2-5层： GINPNA ---
        ## --- 全连接层 ---

