In [2]:
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv

In [3]:
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]

In [7]:
class GAT(nn.Module):
    def __init__(self):
        super(GAT, self).__init__()
        self.gat1 = GATConv(dataset.num_features, 16, heads=4, dropout=0.6)
        # 使用四个注意力头 首先将输入映射为16
        
        self.gat2 = GATConv(16*4, 8, heads=4, dropout=0.6)
        # 将新的输入维度和注意力头进行相乘 再进行一次维度映射
        
        self.gat3 = GATConv(8*4, dataset.num_classes, heads=1, dropout=0.6)
        # 最后将维度映射成为输出类别数量
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        x = F.elu(self.gat1(x, edge_index))
        x = F.elu(self.gat2(x, edge_index))
        x = self.gat3(x, edge_index)
        # ELU比LeakyReLU更适合 因为它的负数部分输出趋于平稳 输出均值更接近0
        
        return F.log_softmax(x,dim=1)

In [10]:
model = GAT()
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.3)

best_test_acc = 0.0
best_epoch = 0
model.train()
for epoch in range(0, 101):

    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    scheduler.step()
    # 训练模型
    
    model.eval()
    with torch.no_grad():

        _, pred = model(data).max(dim=1)
        correct = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()
        test_acc = correct / data.test_mask.sum().item()
    # 评估模型
    
    if test_acc > best_test_acc:
        best_test_acc = test_acc
        best_epoch = epoch
        torch.save(model.state_dict(), 'best_model.pth')
        print(f'Model Save: Epoch {epoch+1}, Loss: {loss:.4f}, Accuracy: {test_acc*100:.2f}%')
    # 记录准确率最高的模型
    
    model.train()
    # 重新开始训练

Model Save: Epoch 1, Loss: 1.9616, Accuracy: 64.40%
Model Save: Epoch 2, Loss: 1.6264, Accuracy: 75.30%
Model Save: Epoch 4, Loss: 1.0617, Accuracy: 75.80%
Model Save: Epoch 5, Loss: 0.8893, Accuracy: 77.90%
Model Save: Epoch 6, Loss: 0.7824, Accuracy: 79.10%
Model Save: Epoch 7, Loss: 0.6042, Accuracy: 80.20%
Model Save: Epoch 8, Loss: 0.5558, Accuracy: 81.80%
