In [26]:
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv

In [27]:
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]

In [28]:
class GAT(nn.Module):
    def __init__(self):
        super(GAT, self).__init__()
        self.gat1 = GATConv(dataset.num_features, 16, heads=4, dropout=0.6)
        # Using four attention heads to first map the input to 16 dimensions

        self.gat2 = GATConv(16*4, 8, heads=4, dropout=0.6)
        # Multiply the new input dimension with the number of attention heads, then perform another dimensional mapping

        self.gat3 = GATConv(8*4, dataset.num_classes, heads=1, dropout=0.6)
        # Finally, map the dimension to match the number of output classes

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = F.elu(self.gat1(x, edge_index))
        x = F.elu(self.gat2(x, edge_index))
        x = self.gat3(x, edge_index)
        # ELU is preferred over LeakyReLU because its negative part stabilizes better, 
        # and its output mean is closer to 0

        return F.log_softmax(x, dim=1)

In [32]:
model = GAT()
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.3)

best_test_acc = 0.0
best_epoch = 0
model.train()
for epoch in range(0, 101):

    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    scheduler.step()
    # Train the model

    model.eval()
    with torch.no_grad():

        _, pred = model(data).max(dim=1)
        correct = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()
        test_acc = correct / data.test_mask.sum().item()
    # Evaluate the model

    if test_acc > best_test_acc:
        best_test_acc = test_acc
        best_epoch = epoch
        torch.save(model.state_dict(), 'best_model.pth')
        print(f'Model Save: Epoch {epoch+1}, Loss: {loss:.4f}, Accuracy: {test_acc*100:.2f}%')
    # Record the model with the highest accuracy

    model.train()
    # Resume training

Model Save: Epoch 1, Loss: 1.9733, Accuracy: 44.70%
Model Save: Epoch 2, Loss: 1.5667, Accuracy: 77.00%
Model Save: Epoch 3, Loss: 1.3158, Accuracy: 80.30%
Model Save: Epoch 4, Loss: 1.0460, Accuracy: 80.90%
