This notebook reproduces the GAT implementation in PyG, based on https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gat.py 

In [1]:
!pip install torch_geometric

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch_geometric
  Downloading torch_geometric-2.3.0.tar.gz (616 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m616.2/616.2 KB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: torch_geometric
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone
  Created wheel for torch_geometric: filename=torch_geometric-2.3.0-py3-none-any.whl size=909897 sha256=518765f867c9c0b37620f4bc1ed8a555ec2a1b06517ee94ab0ddda08a2cc6ec4
  Stored in directory: /root/.cache/pip/wheels/cd/7d/6b/17150450b80b4a3656a84330e22709ccd8dc0f8f4773ba4133
Successfully built torch_geometric
Installing collected packages: torch_geometric
Successfully installed torch_geomet

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid

In [6]:
from torch_geometric.transforms import NormalizeFeatures

In [3]:
from torch_geometric.logging import init_wandb, log

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
init_wandb(name=f'GAT-planetoid_v1', heads=1, epochs=10,
           hidden_channels=100, lr=1e-3, device=device)

In [16]:
from torch_geometric.utils.convert import torch_geometric
dataset = Planetoid("./", name="Cora", transform = NormalizeFeatures())

In [17]:
data = dataset[0].to(device)

In [13]:
from torch_geometric.nn import GATConv

In [14]:
class GAT(nn.Module):
  def __init__(self, in_channels, out_channels, hidden_channels, heads):
    super().__init__()
    self.l1 = GATConv(in_channels, hidden_channels, heads, dropout=0.6)
    self.l2 = GATConv(hidden_channels, out_channels, heads, concat=False, dropout=0.6)

  def forward(self, x, edge_index):
    x = F.relu(self.l1(x, edge_index))
    x = self.l2(x, edge_index)
    return x

In [18]:
model = GAT(dataset.num_features, 100, dataset.num_classes,
            1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)


In [22]:
def train():
  optimizer.zero_grad()
  model.train()
  out = model(data.x, data.edge_index)
  loss = nn.CrossEntropyLoss()(out[data.train_mask], data.y[data.train_mask])
  loss.backward()
  optimizer.step()
  return loss

In [23]:
@torch.no_grad()
def test():
    model.eval()
    pred = model(data.x, data.edge_index).argmax(dim=-1)

    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return accs

In [24]:
best_val_acc = final_test_acc = 0
for epoch in range(1, 10 + 1):
    loss = train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
    log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc)

  return src.new_zeros(size).scatter_reduce_(


Epoch: 001, Loss: 4.604935646057129, Train: 0.1357, Val: 0.1020, Test: 0.1040
Epoch: 002, Loss: 4.594199180603027, Train: 0.1429, Val: 0.1020, Test: 0.1040
Epoch: 003, Loss: 4.582528114318848, Train: 0.1571, Val: 0.0920, Test: 0.1040
Epoch: 004, Loss: 4.569808483123779, Train: 0.1500, Val: 0.0800, Test: 0.1040
Epoch: 005, Loss: 4.555157661437988, Train: 0.1500, Val: 0.0780, Test: 0.1040
Epoch: 006, Loss: 4.541922569274902, Train: 0.1500, Val: 0.0780, Test: 0.1040
Epoch: 007, Loss: 4.528451442718506, Train: 0.1500, Val: 0.0760, Test: 0.1040
Epoch: 008, Loss: 4.510128021240234, Train: 0.1500, Val: 0.0760, Test: 0.1040
Epoch: 009, Loss: 4.4966230392456055, Train: 0.1500, Val: 0.0760, Test: 0.1040
Epoch: 010, Loss: 4.476688385009766, Train: 0.1500, Val: 0.0760, Test: 0.1040
