This notebook creates a MessagePassing based implementation of GAT on PyG.

In [1]:
!pip install torch_geometric

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch_geometric
  Downloading torch_geometric-2.3.0.tar.gz (616 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m616.2/616.2 KB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: torch_geometric
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone
  Created wheel for torch_geometric: filename=torch_geometric-2.3.0-py3-none-any.whl size=909897 sha256=4c084297d054cf56a8a106dd9874c3eb61e6f9b695128dae87029c08615dca2f
  Stored in directory: /root/.cache/pip/wheels/cd/7d/6b/17150450b80b4a3656a84330e22709ccd8dc0f8f4773ba4133
Successfully built torch_geometric
Installing collected packages: torch_geometric
Successfully installed torch_geomet

In [52]:
!pip install torch_scatter

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch_scatter
  Downloading torch_scatter-2.1.1.tar.gz (107 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m107.6/107.6 KB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: torch_scatter
  Building wheel for torch_scatter (setup.py) ... [?25l[?25hdone
  Created wheel for torch_scatter: filename=torch_scatter-2.1.1-cp39-cp39-linux_x86_64.whl size=484407 sha256=595435b502cd918e5f0c95508f99286d770881f874d639a20647529c930f032d
  Stored in directory: /root/.cache/pip/wheels/d5/0c/18/11b4cf31446c5d460543b0fff930fcac3a3f8a785e5c73fb15
Successfully built torch_scatter
Installing collected packages: torch_scatter
Successfully installed torch_scatter-2.1.1


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.logging import init_wandb, log
from torch_geometric.utils.convert import torch_geometric


In [53]:
import torch_scatter

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
init_wandb(name=f'GAT-planetoid_mp', heads=1, epochs=10,
           hidden_channels=100, lr=1e-3, device=device)

In [5]:
dataset = Planetoid("./", name="Cora", transform = NormalizeFeatures())

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [6]:
data = dataset[0].to(device)

In [7]:
# Define GATConv - singlehead using MessagePassing based on https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_gnn.html
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree


In [12]:
from torch_geometric.utils import (
    add_self_loops,
    is_torch_sparse_tensor,
    remove_self_loops,
    softmax,
)

In [66]:
class MyGAT(MessagePassing):
  # Based on https://github.com/njmarko/machine-learning-with-graphs/blob/master/MyAttempts/CS224W_Colab4.ipynb
  def __init__(self, in_channels, out_channels, head=1):
    super().__init__(node_dim = 0)
    self.lin1 = nn.Linear(in_channels, out_channels*head)
    # self.lin2 = nn.Linear(2 * out_channels, 1)
    self.a = nn.Parameter(torch.Tensor(1, head, out_channels))
    self.head = head
    self.out_channels = out_channels

  def forward(self, x, edge_index):
    # x is of the shape (num_nodes X in_channels)
    # edge_index is of the shape (2 X E)
    # 1. We need to add self loops for the nodes
    # edge_index, _ = add_self_loops(edge_index, num_nodes = x.size(0))
    x_new = self.lin1(x).view(-1, self.head, self.out_channels)
    alpha = x_new * self.a
    return self.propagate(edge_index, x=x_new, alpha=alpha).view(-1, self.head * self.out_channels)

  def message(self, x_j, alpha_i, alpha_j, index, ptr):
    # tmp = torch.cat([self.lin1(x_i), self.lin1(x_j)], dim = 1)
    # alpha_ij = (self.lin1(x_i) * self.a) + (self.lin1(x_j) * self.a)
    # Based on https://github.com/pyg-team/pytorch_geometric/discussions/6163
    alpha = alpha_i + alpha_j
    e_ij = nn.LeakyReLU(negative_slope=0.2)(alpha)
    a_ij = softmax(e_ij, ptr if ptr else index)
    print(a_ij.shape)
    print(x_j.shape)
    # can add dropout here
    out = x_j * a_ij
    return out

  # def aggregate(self, inputs, index, dim_size = None):
  #       out = torch_scatter.scatter(inputs, index = index, dim = self.node_dim, dim_size = dim_size, reduce = "sum")
  #       return out


In [67]:
class GAT(nn.Module):
  def __init__(self, in_channels, out_channels, hidden_channels, heads):
    super().__init__()
    self.l1 = MyGAT(in_channels, hidden_channels)
    self.l2 = MyGAT(hidden_channels, out_channels)

  def forward(self, x, edge_index):
    x = F.relu(self.l1(x, edge_index))
    x = self.l2(x, edge_index)
    return x

In [72]:
model = GAT(dataset.num_features, 100, dataset.num_classes,
            1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.05, weight_decay=5e-4)


In [73]:
def train():
  optimizer.zero_grad()
  model.train()
  out = model(data.x, data.edge_index)
  loss = nn.CrossEntropyLoss()(out[data.train_mask], data.y[data.train_mask])
  loss.backward()
  optimizer.step()
  return loss

In [74]:
@torch.no_grad()
def test():
    model.eval()
    pred = model(data.x, data.edge_index).argmax(dim=-1)

    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return accs

In [75]:
best_val_acc = final_test_acc = 0
for epoch in range(1, 10 + 1):
    loss = train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
    log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc)

torch.Size([10556, 1, 7])
torch.Size([10556, 1, 7])
torch.Size([10556, 1, 100])
torch.Size([10556, 1, 100])
torch.Size([10556, 1, 7])
torch.Size([10556, 1, 7])
torch.Size([10556, 1, 100])
torch.Size([10556, 1, 100])
Epoch: 001, Loss: 4.756466388702393, Train: 0.0000, Val: 0.0000, Test: 0.1300
torch.Size([10556, 1, 7])
torch.Size([10556, 1, 7])
torch.Size([10556, 1, 100])
torch.Size([10556, 1, 100])
torch.Size([10556, 1, 7])
torch.Size([10556, 1, 7])
torch.Size([10556, 1, 100])
torch.Size([10556, 1, 100])
Epoch: 002, Loss: 4.670750141143799, Train: 0.0000, Val: 0.0000, Test: 0.1300
torch.Size([10556, 1, 7])
torch.Size([10556, 1, 7])
torch.Size([10556, 1, 100])
torch.Size([10556, 1, 100])
torch.Size([10556, 1, 7])
torch.Size([10556, 1, 7])
torch.Size([10556, 1, 100])
torch.Size([10556, 1, 100])
Epoch: 003, Loss: 4.56876277923584, Train: 0.1429, Val: 0.0580, Test: 0.0640
torch.Size([10556, 1, 7])
torch.Size([10556, 1, 7])
torch.Size([10556, 1, 100])
torch.Size([10556, 1, 100])
torch.Size(