In [None]:
mport torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv

class FuzzyRuleModule(nn.Module):
    """3.5. Trainable Fuzzy Rule Reasoning [cite: 196]"""
    def __init__(self, input_dim=3, num_rules=3):
        super(FuzzyRuleModule, self).__init__()
        # 3个拓扑描述符：Clustering Coefficient, Node Degree, 2-Hop Label Agreement [cite: 116, 119]
        # 初始化高斯隶属函数的可训练中心(c)和宽度(sigma) [cite: 208, 212]
        self.centers = nn.Parameter(torch.rand(num_rules, input_dim))
        self.sigmas = nn.Parameter(torch.ones(num_rules, input_dim))
        # 学习规则重要性权重 alpha_k [cite: 205]
        self.rule_weights = nn.Parameter(torch.ones(num_rules) * 0.5)

    def forward(self, f_u):
        # 计算高斯隶属度 [cite: 209, 210]
        f_u_expanded = f_u.unsqueeze(1) # [N, 1, 3]
        diff = - (f_u_expanded - self.centers)**2 / (2 * self.sigmas**2 + 1e-6)
        membership_degrees = torch.exp(diff).prod(dim=-1) # [N, num_rules]
        
        # 规则聚合 r(u) [cite: 202, 203]
        rule_activation = torch.matmul(membership_degrees, self.rule_weights)
        return rule_activation.unsqueeze(-1)

class GAFRNet(nn.Module):
    """GAFR-Net Overall Architecture [cite: 96, 101]"""
    def __init__(self, feature_dim, hidden_dim, num_classes):
        super(GAFRNet, self).__init__()
        # 3.3. Graph Attention Message Passing: 使用4个注意力头 [cite: 175, 347]
        self.gat1 = GATConv(feature_dim, hidden_dim, heads=4, concat=True)
        self.gat2 = GATConv(hidden_dim * 4, hidden_dim, heads=1, concat=False)
        
        self.fuzzy_module = FuzzyRuleModule(input_dim=3)
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index, f_u):
        # 提取图特征 h_u [cite: 176, 177]
        h = F.elu(self.gat1(x, edge_index))
        h_u = self.gat2(h, edge_index)
        
        # 提取模糊规则激活 r(u) [cite: 216]
        r_u = self.fuzzy_module(f_u)
        
        # 3.6. Gating Fusion: h_u' = h_u + r(u) [cite: 231, 232]
        h_prime = h_u + r_u 
        
        # 最终预测 [cite: 233, 234]
        out = self.classifier(h_prime)
        return out # 通常在训练脚本中使用 CrossEntropyLoss