<h4>The GAT layer is defined as:</h4>

1$$ e_{ij}^{l} = LeakyReLU(a(Wh_{i}^{(l)} || Wh_{j}^{(l)}))$$ 
2$$ a_{ij}^{l} = softmax_{j}(e_{ij}^{l})= \dfrac{\exp(e_{ij}^{l})}{\sum_{k \in N(i)} \exp(e_{ik}^{l})}$$ 
3$$h_{i}^{(l+1)} = \sum_{j \in N(i)} a_{ij} W^{(l)} h_{j}^{(l)}$$ 

<h4>add Multi-Head Attention Mechanism, the last aggregate equation can be rewritten as follows:</h4> the dimensions is num_head $*$ hidden layer dimensions

$$h_{i}^{l+1} = ||_{k=1}^{K} \sigma (\sum_{j \in N(i)} a_{ij}^{k} W^{k} h_{j}^{l})$$

<h4>the last layer uses average instead of cat</h4>

$$h_{i}^{l+1} = \sigma (\dfrac{1}{K}\sum_{k=1}^{K}\sum_{j \in N(i)} a_{ij}^{k} W^{k} h_{j}^{l})$$

In [53]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax, add_remaining_self_loops

class GATConv(MessagePassing):
    def __init__(self, in_feats, out_feats, alpha, drop_prob, num_heads):
        super().__init__(aggr="add")
        self.drop_prob = drop_prob
        self.num_heads = num_heads
        self.out_feats = out_feats // num_heads
        self.lin = nn.Linear(in_feats, self.out_feats*self.num_heads, bias=False)
        self.a = nn.Linear(2*self.out_feats, 1)
        self.leakrelu = nn.LeakyReLU(alpha)
    def forward(self, x, edge_index):
        edge_index, _ = add_remaining_self_loops(edge_index)
#         W H
        h = self.lin(x)
        h_p = self.propagate(edge_index, x=h)
        return h_p
    def message(self, x_i, x_j, edge_index_i):
        x_i = x_i.view(-1, self.num_heads, self.out_feats)
        x_j = x_j.view(-1, self.num_heads, self.out_feats)
#         a(wh_i || wh_j)
        e = self.a(torch.cat([x_i, x_j], dim=-1))
#         LeakReLU(a(Wh_i, Wh_j))

        e = self.leakrelu(e)
#         softmax(e_{ij})

        alpha = softmax(e, edge_index_i)

        alpha = F.dropout(alpha, self.drop_prob, self.training)
        return (x_j * alpha).view(x_j.size(0), -1)

    
class GAT(nn.Module):
    def __init__(self, in_feats, hidden_feats, y_num, alpha=0.2, drop_prob=0., num_heads=[1,1]):
        super().__init__()
        self.drop_prob = drop_prob
        self.gatconv1 = GATConv(in_feats, hidden_feats, alpha, drop_prob, num_heads[0])
        
        self.gatconv2 = GATConv(hidden_feats, y_num, alpha, drop_prob, num_heads[1])
        
    
    def forward(self, x, edge_index):
        x = self.gatconv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, self.drop_prob, self.training)
        out = self.gatconv2(x, edge_index)
        return F.log_softmax(out, dim=1)

In [67]:
# testing with a sample dataset
conv = GAT(in_feats=8,
            hidden_feats=64,
            y_num=4,
            drop_prob=0.2,
            num_heads=[8, 1])
x = torch.rand(4, 8)
edge_index = torch.tensor(
    [[0, 1, 1, 2, 0, 2, 0, 3], [1, 0, 2, 1, 2, 0, 3, 0]], dtype=torch.long)
x1 = conv(x, edge_index)
print(x1.shape)

torch.Size([4, 4])


In [56]:
# apply in a big sample dataset
from torch_geometric.datasets import Planetoid
from copy import deepcopy
from sklearn.metrics import accuracy_score

In [69]:
# load data
dataset = Planetoid(root='temp/', name='Cora')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = dataset[0].to(device)
h = data.x
feature_dim = h.shape[1]
class_num = dataset.num_classes

# define model
model = GAT(in_feats=feature_dim,
            hidden_feats=64,
            y_num=class_num,
            drop_prob=0.6,
            num_heads=[8,1]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
best_acc, best_model = 0. , None

# training
model.train()
for epoch in range(300):
    out = model(h, data.edge_index)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    valid_acc = accuracy_score(data.y[data.val_mask].cpu(),
                               out[data.val_mask].argmax(dim=1).cpu())
    if valid_acc > best_acc:
        best_acc = valid_acc
        best_model = deepcopy(model)
    if (epoch+1) % 25 == 0:
        print(f"Epoch {epoch + 1}: train_loss: {loss.item():.8f}")
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
# evaluation
best_model.eval()
pred = best_model(h, data.edge_index)
test_acc = accuracy_score(data.y[data.test_mask].cpu(),
                          pred[data.test_mask].argmax(dim=1).cpu())
print(f'testset accuracy:{test_acc:.4f}')

Epoch 25: train_loss: 0.43038774
Epoch 50: train_loss: 0.35651967
Epoch 75: train_loss: 0.29772645
Epoch 100: train_loss: 0.36352384
Epoch 125: train_loss: 0.28289831
Epoch 150: train_loss: 0.33076385
Epoch 175: train_loss: 0.30606464
Epoch 200: train_loss: 0.41505975
Epoch 225: train_loss: 0.38383773
Epoch 250: train_loss: 0.32141551
Epoch 275: train_loss: 0.32326403
Epoch 300: train_loss: 0.37980521
testset accuracy:0.8120
