## 通用图神经网络堆栈
实现一个通用GNN模块的实现，该模块可以插入任何自定义组件，包括GraphSage、GAT等。

In [None]:
import torch
import torch_scatter
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

from torch import Tensor
from typing import Union, Tuple, Optional
from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,
                                    OptTensor)

from torch.nn import Parameter, Linear
#PRW：https://pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html
from torch_sparse import SparseTensor, set_diag
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax

class GNNStack(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, args, emb=False):
        super(GNNStack, self).__init__()
        conv_model = self.build_conv_model(args.model_type)  #GraphSage / GAT
        self.convs = nn.ModuleList()
        self.convs.append(conv_model(input_dim, hidden_dim))
        assert (args.num_layers >= 1), 'Number of layers is not >=1'
        
        for l in range(args.num_layers-1):
            self.convs.append(conv_model(args.heads * hidden_dim, hidden_dim))  

        # post-message-passing
        self.post_mp = nn.Sequential(
            nn.Linear(args.heads * hidden_dim, hidden_dim), nn.Dropout(args.dropout), 
            nn.Linear(hidden_dim, output_dim))

        self.dropout = args.dropout
        self.num_layers = args.num_layers

        self.emb = emb

    def build_conv_model(self, model_type):
        if model_type == 'GraphSage':
            return GraphSage
        elif model_type == 'GAT':
            '''
            当使用num heads>1应用GAT时，需要修改conv层的输入和输出维度（self.convs），确保下一层的输入dim为num_heads乘以上一层的输出尺寸。
提示：如果想实现多头（multi-heads），需要在建立self.convs时改变self.convs.append(conv_model(hidden_dim * num_heads, hidden_dim))和post-message-passing中第一个nn.Linear(hidden_dim * num_heads, hidden_dim) 
         '''
            return GAT

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
          
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout,training=self.training)
            

        x = self.post_mp(x)

        if self.emb == True:
            return x

        #Applies a softmax followed by a logarithm.
        #https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.log_softmax
        return F.log_softmax(x, dim=1)
       

    def loss(self, pred, label):
        #The negative log likelihood loss.
        #https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#nll_loss
        return F.nll_loss(pred, label)    

## 归纳式GCN GraphSAGE实现
通常，方法forward是进行实际消息传递的地方。每个迭代中的所有逻辑都发生在forward中，通过调用propagate方法将信息从相邻节点传播到中心节点。因此，一般的范式是pre-processing -> propagate -> post-processing。  
如前文介绍的消息传递过程。propagate通过调用message，该消息将邻居节点的信息转换为消息，aggregate将邻居节点的所有消息聚合为一个，update为下一次迭代中节点生成节点的嵌入。  
我们的实现与此略有不同，我们不会显式地实现update，而是将更新节点的逻辑放在forward方法中。更具体地说，信息传播后，可以进一步对propagate的输出进行一些操作。forward的输出就是当前迭代后的嵌入。  
此外，传递给**propagate()**的张量可以通过在变量名后面附加_i或_j来映射到相应的节点 i 和 j 。例如x_i和x_j。请注意，通常将i称为聚合信息的中心节点，并将j称为相邻节点，因为这是最常见的表示方法。  
<img src="../image/6.png">
<img src="../image/7.png">

In [None]:
class GraphSage(MessagePassing):
    
    def __init__(self, in_channels, out_channels, normalize = True,
                 bias = False, **kwargs):  
        super(GraphSage, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize

        self.lin_l = None
        self.lin_r = None

        ############################################################################
        #定义下面的message和update 函数所需的层。
        # self.lin_l 是应用于中心节点嵌入的线性变换。
        # self.lin_r 是应用于来自邻居的聚合message的线性变换。
       
        self.lin_l=Linear(in_channels,out_channels)  #Wl
        self.lin_r=Linear(in_channels,out_channels)  #Wr

        ############################################################################

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()

    def forward(self, x, edge_index, size = None):
        """"""

        out = None

        ############################################################################
        # 实现消息传递以及任何后处理（更新规则）
        # 1. 首先调用propagate函数来进行消息传递。
        #    1.1 有关更多信息，请参见此处: 
        #        https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html
        #    1.2 对中心（x_central）和邻居（x_neighbor）节点使用相同的表示，即x=（x，x）进行传播。
        # 2. 使用skip connection更新节点嵌入。
        # 3. 如果需要归一化, 使用L-2 normalization (定义在torch.nn.functional)

        out=self.propagate(edge_index,x=(x,x),size=size)
        x=self.lin_l(x)
        out=self.lin_r(out)
        out=out+x
        if self.normalize:
            out=F.normalize(out)

        ############################################################################

        return out

    def message(self, x_j):

        out = None

        #propagte传入的中心节点和邻居节点的表示一样。
        out=x_j

        return out

    def aggregate(self, inputs, index, dim_size = None):

        out = None

        # 沿其索引节点数的维度.
        node_dim = self.node_dim

        ############################################################################
        # 实现平均聚合.
        # 请参见此处，了解如何使用torch_scatter.scatter: 
        # https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html#torch_scatter.scatter
       
        out=torch_scatter.scatter(inputs,index,node_dim,dim_size=dim_size,reduce='mean')

        ############################################################################

        return out

## 图注意力网络 GAT实现
<img src="../image/8.png">

In [None]:
class GAT(MessagePassing):

    def __init__(self, in_channels, out_channels, heads = 2,
                 negative_slope = 0.2, dropout = 0., **kwargs):
        super(GAT, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout

        self.lin_l = None
        self.lin_r = None
        self.att_l = None
        self.att_r = None

        ############################################################################
        # self.lin_l是在消息传递之前应用于嵌入的线性变换。
        # 注意线性层的尺寸，因为我们使用的是多头注意力。
       
        self.lin_l=Linear(in_channels,heads*out_channels) 

        ############################################################################

        self.lin_r = self.lin_l  #W_r

       

        ############################################################################
        # 定义注意力参数，需要考虑多头的情况
        self.att_l = Parameter(torch.Tensor(1, heads, out_channels))
        self.att_r = Parameter(torch.Tensor(1, heads, out_channels))

        ############################################################################

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.lin_l.weight)
        nn.init.xavier_uniform_(self.lin_r.weight)
        nn.init.xavier_uniform_(self.att_l)
        nn.init.xavier_uniform_(self.att_r)
        #https://pytorch.org/docs/stable/_modules/torch/nn/init.html#xavier_uniform_

    def forward(self, x, edge_index, size = None):
        
        H, C = self.heads, self.out_channels

        ############################################################################
        # 主要逻辑实现函数，实现消息传递、预处理、后处理
        # 1. 首先对节点嵌入应用线性变换，并将其拆分为多头。对源节点和目标节点使用相同的表示，但应用不同的线性权重（W_l和W_r）
        # 2. 计算中心节点（alpha_l）和相邻节点（alpha_r）的alpha向量
        # 3. 调用propagate函数进行消息传递。使得alpha = (alpha_l, alpha_r)传递参数。
        # 4. 将输出转换回N*d的形状。

        x_l=self.lin_l(x)
        x_r=self.lin_r(x)
        x_l=x_l.view(-1,H,C)
        x_r=x_r.view(-1,H,C)
        alpha_l = (x_l * self.att_l).sum(axis=1)  #*是逐元素相乘
        alpha_r = (x_r * self.att_r).sum(axis=1)
        out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r),size=size)
        out = out.view(-1, H * C)
        return out


    def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):

        ############################################################################
        # 实现message功能。将注意力放在message中
        # 1. 使用alpha_i和alpha_j计算最终注意力权重，并应用leaky Relu。
        # 2. 为所有节点计算邻居节点上的softmax。使用torch_geometric.utils.softmax而不是pytorch中的softmax。
        # 3. 对注意权重（alpha）应用dropout。
        # 4. 增加嵌入和注意力权重，输出应为形状 E * H * d。
        # 5. ptr (LongTensor, 可选):如果给定，则根据CSR表示中的排序输入计算softmax。
       
        
        #alpha：[E, C]
        alpha = alpha_i + alpha_j  #leaky_relu的对象
        alpha = F.leaky_relu(alpha,self.negative_slope)
        alpha = softmax(alpha, index, ptr, size_i)
        #https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch-geometric-utils
        #https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/utils/softmax.py
       
        
        alpha = F.dropout(alpha, p=self.dropout, training=self.training).unsqueeze(1)  #[E,1,C]
        out = x_j * alpha #[E,H,C]

        ############################################################################

        return out


    def aggregate(self, inputs, index, dim_size = None):

        ############################################################################
        # 实现聚合函数
        # 请参见此处，了解如何使用 torch_scatter.scatter: https://pytorch-scatter.readthedocs.io/en/latest/_modules/torch_scatter/scatter.html
        # 请注意“reduce”参数与GraphSAGE中的参数不同
       
        out = torch_scatter.scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce='sum')

        ############################################################################
    
        return out

In [None]:
import torch.optim as optim

def build_optimizer(args, params):
    weight_decay = args.weight_decay
    filter_fn = filter(lambda p : p.requires_grad, params)
    if args.opt == 'adam':
        optimizer = optim.Adam(filter_fn, lr=args.lr, weight_decay=weight_decay)
    elif args.opt == 'sgd':
        optimizer = optim.SGD(filter_fn, lr=args.lr, momentum=0.95, weight_decay=weight_decay)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(filter_fn, lr=args.lr, weight_decay=weight_decay)
    elif args.opt == 'adagrad':
        optimizer = optim.Adagrad(filter_fn, lr=args.lr, weight_decay=weight_decay)
    if args.opt_scheduler == 'none':
        return None, optimizer
    elif args.opt_scheduler == 'step':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.opt_decay_step, gamma=args.opt_decay_rate)
    elif args.opt_scheduler == 'cos':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.opt_restart)
    return scheduler, optimizer

In [None]:
import time

import networkx as nx
import numpy as np
import torch
import torch.optim as optim

from torch_geometric.datasets import TUDataset
from torch_geometric.datasets import Planetoid
from torch_geometric.data import DataLoader

import torch_geometric.nn as pyg_nn

import matplotlib.pyplot as plt


def train(dataset, args):
    
    print("节点分类任务，数据集大小:", np.sum(dataset[0]['train_mask'].numpy()))
    test_loader = loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    # build model
    model = GNNStack(dataset.num_node_features, args.hidden_dim, dataset.num_classes, 
                            args)
    scheduler, opt = build_optimizer(args, model.parameters())

    # train
    losses = []
    test_accs = []
    for epoch in range(args.epochs):
        total_loss = 0
        model.train()
        for batch in loader:
            opt.zero_grad()
            pred = model(batch)
            label = batch.y
            pred = pred[batch.train_mask]
            label = label[batch.train_mask]
            loss = model.loss(pred, label)
            loss.backward()
            opt.step()
            total_loss += loss.item() * batch.num_graphs
        total_loss /= len(loader.dataset)
        losses.append(total_loss)

        if epoch % 10 == 0:
          test_acc = test(test_loader, model)
          test_accs.append(test_acc)
        else:
          test_accs.append(test_accs[-1])
    return test_accs, losses

def test(loader, model, is_validation=True):
    model.eval()

    correct = 0
    for data in loader:
        with torch.no_grad():
            # max(dim=1) returns values, indices tuple; only need indices
            pred = model(data).max(dim=1)[1]
            label = data.y

        mask = data.val_mask if is_validation else data.test_mask
        # node classification: only evaluate on nodes in test set
        pred = pred[mask]
        label = data.y[mask]
            
        correct += pred.eq(label).sum().item()

    total = 0
    for data in loader.dataset:
        total += torch.sum(data.val_mask if is_validation else data.test_mask).item()
    return correct / total
  
class objectview(object):
    def __init__(self, d):
        self.__dict__ = d

In [None]:
%%time
def main():
    for args in [
        {'model_type': 'GraphSage', 'dataset': 'cora', 'num_layers': 2, 'heads': 1, 'batch_size': 32, 'hidden_dim': 32, 'dropout': 0.5, 'epochs': 500, 'opt': 'adam', 'opt_scheduler': 'none', 'opt_restart': 0, 'weight_decay': 5e-3, 'lr': 0.01},
    ]:
        args = objectview(args)
        for model in ['GraphSage', 'GAT']:
            args.model_type = model

            # Match the dimension.
            if model == 'GAT':
              args.heads = 2
            else:
              args.heads = 1

            if args.dataset == 'cora':
                dataset = Planetoid(root='../data/Cora', name='Cora')
            else:
                raise NotImplementedError("Unknown dataset") 
            test_accs, losses = train(dataset, args) 

            print("Maximum accuracy: {0}".format(max(test_accs)))
            print("Minimum loss: {0}".format(min(losses)))

            plt.title(dataset.name)
            plt.plot(losses, label="training loss" + " - " + args.model_type)
            plt.plot(test_accs, label="test accuracy" + " - " + args.model_type)
        plt.legend()
        plt.show()

if __name__ == '__main__':
    main()