In [49]:
# 代码和说明来源：
# https://www.jiqizhixin.com/articles/2019-02-19-7

In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g
        # 公式 (1)
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        # 公式 (2)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)

    def edge_attention(self, edges):
        # 公式 (2) 所需，边上的用户定义函数
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e' : F.leaky_relu(a)}

    def message_func(self, edges):
        # 公式 (3), (4)所需，传递消息用的用户定义函数
        return {'z' : edges.src['z'], 'e' : edges.data['e']}

    def reduce_func(self, nodes):
        # 公式 (3), (4)所需, 归约用的用户定义函数
        # 公式 (3)
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        # 公式 (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h' : h}

    def forward(self, h):
        # 公式 (1)
        z = self.fc(h)
        self.g.ndata['z'] = z
        # 公式 (2)
        self.g.apply_edges(self.edge_attention)
        # 公式 (3) & (4)
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h')

## 多头注意力 (Multi-head attention)

神似卷积神经网络里的多通道，GAT 引入了多头注意力来丰富模型的能力和稳定训练的过程。每一个注意力的头都有它自己的参数。如何整合多个注意力机制的输出结果一般有两种方式：

In [33]:
class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge

    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            # 对输出特征维度（第1维）做拼接
            return torch.cat(head_outs, dim=1)
        else:
            # 用求平均整合多头结果
            return torch.mean(torch.stack(head_outs))

## 在 Cora 数据集上训练一个 GAT 模型

In [34]:
# 定义一个两层的 GAT 模型：
class GAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        # 注意输入的维度是 hidden_dim * num_heads 因为多头的结果都被拼接在了
        # 一起。 此外输出层只有一个头。
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)

    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        return h

In [50]:
# 我们使用 DGL 自带的数据模块加载 Cora 数据集。
from dgl import DGLGraph
from dgl.data import citation_graph as citegrh

def load_cora_data():
    data = citegrh.load_cora()                     # num_nodes=2708, num_edges=10556
    features = torch.FloatTensor(data.features)    #2708 , 1433
    labels = torch.LongTensor(data.labels)         #2708  [int]==》目标变量的取值
    mask = torch.ByteTensor(data.train_mask)       #2708  [1,0]===>标签？
    g = DGLGraph(data.graph)
    return g, features, labels, mask

In [59]:
g, features, labels, mask = load_cora_data()
mask.shape

torch.Size([2708])

## 模型训练的流程和 GCN 教程里的一样。（样例数据）

In [60]:
import requests
import time
import numpy as np
g, features, labels, mask = load_cora_data()

# 创建模型
net = GAT(g, 
          in_dim=features.size()[1], 
          hidden_dim=8, 
          out_dim=7, 
          num_heads=8)
#print(net)

# 创建优化器
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

# 主流程
dur = []
for epoch in range(30):
    if epoch >=3:
        t0 = time.time()

    logits = net(features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[mask], labels[mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >=3:
        dur.append(time.time() - t0)

    print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
            epoch, loss.item(), np.mean(dur)))

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


Epoch 00000 | Loss 1.9462 | Time(s) nan
Epoch 00001 | Loss 1.9445 | Time(s) nan
Epoch 00002 | Loss 1.9429 | Time(s) nan
Epoch 00003 | Loss 1.9412 | Time(s) 0.2891
Epoch 00004 | Loss 1.9395 | Time(s) 0.2896
Epoch 00005 | Loss 1.9377 | Time(s) 0.2900
Epoch 00006 | Loss 1.9359 | Time(s) 0.2925
Epoch 00007 | Loss 1.9340 | Time(s) 0.2944
Epoch 00008 | Loss 1.9321 | Time(s) 0.2937
Epoch 00009 | Loss 1.9302 | Time(s) 0.2929
Epoch 00010 | Loss 1.9282 | Time(s) 0.2939
Epoch 00011 | Loss 1.9261 | Time(s) 0.2947
Epoch 00012 | Loss 1.9240 | Time(s) 0.2969
Epoch 00013 | Loss 1.9218 | Time(s) 0.2967
Epoch 00014 | Loss 1.9196 | Time(s) 0.2961
Epoch 00015 | Loss 1.9173 | Time(s) 0.2957
Epoch 00016 | Loss 1.9149 | Time(s) 0.2956
Epoch 00017 | Loss 1.9125 | Time(s) 0.2957
Epoch 00018 | Loss 1.9100 | Time(s) 0.2955
Epoch 00019 | Loss 1.9074 | Time(s) 0.2951
Epoch 00020 | Loss 1.9048 | Time(s) 0.2946
Epoch 00021 | Loss 1.9020 | Time(s) 0.2946
Epoch 00022 | Loss 1.8993 | Time(s) 0.2951
Epoch 00023 | Loss 1

In [61]:
## 兰州数据