In [None]:
import os

os.environ["DGLBACKEND"] = "pytorch" 
import dgl
from dgl.data import CoraGraphDataset
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [None]:
def load_data_dgl():
    """Load citation network dataset using DGL"""
    print('Loading Cora dataset...')

    # 使用 DGL 加载 Cora 数据集
    dataset = CoraGraphDataset()
    graph = dataset[0]
    N = graph.number_of_nodes()  # 节点数

    # 获取节点特征和标签
    features = graph.ndata['feat']
    # features = normalize_features(features.numpy())  # 归一化特征
    labels = graph.ndata['label']

    # 获取训练、验证和测试集的索引
    train_mask = graph.ndata['train_mask']
    val_mask = graph.ndata['val_mask']
    test_mask = graph.ndata['test_mask']

    # 转换为 PyTorch 张量
    features = torch.FloatTensor(features)
    labels = torch.LongTensor(labels)

    # 获取边属性（邻接矩阵）
    adj = graph.adjacency_matrix().to_dense()
    adj = torch.FloatTensor(adj.numpy())
    #NOTE: 为图的边赋予新的三维属性
    edge_attr = [adj, adj.t(), adj + adj.t()] # 这里人为构建一个3维的边属性 [3 * N * N]
    P = len(edge_attr)  # 边属性的维度
    edge_attr = torch.stack(edge_attr, dim=0)
    edge_attr = DSN(edge_attr) # 双随机归一化
    edge_attr_reshaped = edge_attr[:, graph.edges()[0], graph.edges()[1]]
    graph.edata['feat'] = edge_attr_reshaped.t()

    return graph, edge_attr, features, labels, train_mask, val_mask, test_mask

def DSN2(t):
    a=t.sum(dim=1,keepdim=True)
    b=t.sum(dim=0,keepdim=True)
    lamb=torch.cat([a.squeeze(),b.squeeze()],dim=0).max()
    r=t.shape[0]*lamb-t.sum(dim=0).sum(dim=0)
    
    a=a.expand(-1,t.shape[1])
    b=b.expand(t.shape[0],-1)
    tt=t+(lamb**2-lamb*(a+b)+a*b)/r

    ttmatrix=tt/tt.sum(dim=0)[0]
    ttmatrix=torch.where(t>0,ttmatrix,t)
    return ttmatrix


def DSN(x):
    """Doubly stochastic normalization"""
    p=x.shape[0]
    y1=[]
    for i in range(p):
        y1.append(DSN2(x[i]))
    y1=torch.stack(y1,dim=0)
    return y1

def normalize_features(mx):
    """Row-normalize sparse matrix"""
    """input is a numpy array""" 
    rowsum = mx.sum(axis=1)
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = np.diag(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx

In [None]:
g, edge_attr, features, labels, train_mask, val_mask, test_mask = load_data_dgl()

In [None]:
print("Node features")
print(g.ndata)
print("Edge features")
print(g.edata)

\begin{equation}
X^l=\sigma\left[\|_{p=1}^P\left(\alpha_{. \cdot p}^l\left(X^{l-1}, E_{\cdot \cdot p}^{l-1}\right) g^l\left(X^{l-1}\right)\right)\right] .
\end{equation}

\begin{align}
g^l\left(X^{l-1}\right)=W^l X^{l-1},\\
f^l\left(X_{i \cdot}^{l-1}, X_{j .}^{l-1}\right)=\exp \left\{\mathrm{L}\left(a^T\left[W X_{i \cdot}^{l-1} \| W X_{j .}^{l-1}\right]\right)\right\}\\
\alpha_{\cdot \cdot p}^l=\operatorname{DS}\left(\hat{\alpha}_{\cdot \cdot p}^l\right),\\
\hat{\alpha}_{i j p}^l=f^l\left(X_{i .}^{l-1}, X_j^{l-1}\right) E_{i j p}^{l-1},\\
E^l=\alpha^l
\end{align}

In [None]:
class EGNNA_Conv(nn.Module):
    def __init__(self, 
                 dim_in:int, # 输入h的特征维度
                 dim_h:int, # 各节点h的l+1层特征维度
                 dropout:float,
                 node_att_agger: bool = False # 最后节点特征是N*(F*P) 还是聚合为 N*F
                 ):
        super(EGNNA_Conv, self).__init__()
        self.dropout = dropout
        self.FC1 = nn.Linear(dim_in, dim_h) # W
        self.FC2 = nn.Linear(dim_h,1) # a1
        self.FC3 = nn.Linear(dim_h,1) # a2
        self.leakyrelu = nn.LeakyReLU(0.2) #NOTE: negative_slope=0.2
        self.node_att_agger = node_att_agger
    
    def forward(self, h, e):
        '''
        INPUT:
        g: graph,
        h: node_feature  shape = [N * dim_node_features]
        e: edge_feature shape = [dim_edge_features * N * N]
        
        OUTPUT:
        new_h : new node feature shape = [N * (dim_node_features*dim_edge_features)]
        alpha: new edge feature shape = [dim_edge_features * N * N]
        
        '''
        # 用FC1先算W @ h
        Wh = self.FC1(h) # N * dim_out_features
        # 计算注意力系数(N * N )
        fXX = self.leakyrelu(self.FC2(Wh) + self.FC3(Wh).t()) # N * N
        alpha = fXX * e # 广播点乘 [N * N]  .*  [P * N * N] .= [P * N * N]
        alpha = torch.where(e>0, alpha, -9e15*torch.ones_like(alpha)) # 将不存在的边的注意力系数设置为0
        alpha = F.softmax(alpha, dim=1) # 对每个节点的所有邻居进行softmax
        
        #NOTE: alpha := new edge feature (e)
        
        attention = F.dropout(alpha, self.dropout, training=self.training) # dropout
        
        new_h = torch.empty((e.shape[0], Wh.shape[0], Wh.shape[1]), device=Wh.device)
        for i in range(e.shape[0]):
            new_h[i] = torch.matmul(attention[i], Wh)
        
        if not self.node_att_agger:
            new_h = new_h.permute(1,2,0).reshape(new_h.shape[1], -1) # N * (P * F)
            
            return F.elu(new_h), alpha
        
        else:
            new_h = torch.sum(new_h, dim=0) # N * F
            
            return new_h

In [None]:
class MultiHead_EGNNA_Classifier(nn.Module):
    def __init__(self,
                 dim_nfeat:int, 
                 dim_efeat:int, # P
                 dim_hidden:int, # for [input_layer, hidden_layer]
                 dim_out:int,# for output_layer
                 dropout:float,
                 n_heads:int # for [input_layer, hidden_layer]
                 ):
        super(MultiHead_EGNNA_Classifier, self).__init__()
        self.dropout = dropout
        
        # Multi-head attention mechanism
        # 1. input layer
        self.attentions = [EGNNA_Conv(dim_nfeat, dim_hidden[0], dropout) for _ in range(n_heads[0])]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)
            
        # 2. Hidden layer
        self.hidden_atts = [EGNNA_Conv(n_heads[0]*dim_efeat*dim_hidden[0],
                                       dim_hidden[1], dropout) for _ in range(n_heads[1])]
        for i, hidden_att in enumerate(self.hidden_atts):
            self.add_module('hidden_att_{}'.format(i), hidden_att)
        
        # 3. Output layer
        self.out_att = EGNNA_Conv(n_heads[0]*dim_efeat*dim_hidden[0], dim_out, dropout, node_att_agger=True)
        
    def forward(self,  h, e):
        
        # Input layer
        h = F.dropout(h, self.dropout, training=self.training)
        temp_h = []
        for att in self.attentions:
            h_, e = att( h, e)
            temp_h.append(h_)
        h = torch.cat(temp_h, dim=1)
            
        # Hidden layer
        h = F.dropout(h, self.dropout, training=self.training)
        temp_h = []
        for att in self.hidden_atts:
            h_, e = att( h, e)
            temp_h.append(h_)
        h = torch.cat(temp_h, dim=1)
        
        # Output layer
        h = F.dropout(h, self.dropout, training=self.training)
        h = F.elu(self.out_att( h, e)) # 输出各分类的得分情况
        
        return h
        

In [None]:
model = MultiHead_EGNNA_Classifier(dim_nfeat=features.shape[1],
                                   dim_efeat=edge_attr.shape[0],
                                   dim_hidden=[64,8],
                                   dim_out=int(labels.max()) + 1,
                                   dropout=0.5,
                                   n_heads=[1,8]
)


if torch.cuda.is_available():
    model.cuda()
    features = features.cuda()
    edge_attr = edge_attr.cuda()
    labels = labels.cuda()
    train_mask = train_mask.cuda()
    val_mask = val_mask.cuda()
    test_mask = test_mask.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
best_val_acc = 0
best_test_acc = 0
for e in range(100):
    # set model to training mode  
    model.train()
    logits = model(features, edge_attr)
    pred = logits.argmax(1) # 返回得分最大的索引
    train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
    val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
    loss = F.cross_entropy(logits[train_mask], labels[train_mask])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    model.eval()
    logits = model(features, edge_attr)
    pred = logits.argmax(1)
    test_acc = (pred[test_mask] == labels[test_mask]).float().mean()
    

    if best_val_acc < val_acc:
          best_val_acc = val_acc
          best_test_acc = test_acc
          pass
    
    if e % 5 == 0:
            print(
                f"In epoch {e}, loss: {loss:.3f}, val acc: {val_acc:.3f} (best {best_val_acc:.3f}), test acc: {test_acc:.3f} (best {best_test_acc:.3f})"
            )