<a href="https://colab.research.google.com/github/Bulat27/topological-features-influence-GNNs/blob/master/GAT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

import torch.nn.functional as F

2.1.0+cu121
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


# Load and inspect the dataset


In [2]:
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset = Planetoid(root='data/Planetoid', name='Cora', transform=NormalizeFeatures())

In [3]:
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

Number of graphs: 1
Number of features: 1433
Number of classes: 7


In [4]:
data = dataset[0]
print(data)

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])


In [5]:
print(data.train_mask.sum().item())
print(data.val_mask.sum().item())
print(data.test_mask.sum().item())

140
500
1000


In [8]:
from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, heads):
        super().__init__()
        torch.manual_seed(42)
        self.conv1 = GATConv(in_channels=dataset.num_features, out_channels=8, heads=8, dropout=0.6)
        self.conv2 = GATConv(in_channels=64, out_channels=dataset.num_classes, heads=1, dropout=0.6)

    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return x

model = GAT(hidden_channels=8, heads=8)
print(model)

GAT(
  (conv1): GATConv(1433, 8, heads=8)
  (conv2): GATConv(64, 7, heads=1)
)


In [14]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

def train():
      model.train()
      optimizer.zero_grad()  # Clear gradients.
      out = model(data.x, data.edge_index)  # Perform a single forward pass.
      loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss solely based on the training nodes.
      loss.backward()  # Derive gradients.
      optimizer.step()  # Update parameters based on gradients.
      return loss

def test(mask):
      model.eval()
      out = model(data.x, data.edge_index)
      pred = out.argmax(dim=1)  # Use the class with highest probability.
      correct = pred[mask] == data.y[mask]  # Check against ground-truth labels.
      acc = int(correct.sum()) / int(mask.sum())  # Derive ratio of correct predictions.
      return acc


for epoch in range(1, 201):
    loss = train()
    val_acc = test(data.val_mask)
    test_acc = test(data.test_mask)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')

Epoch: 001, Loss: 0.5754, Val: 0.7920, Test: 0.8270
Epoch: 002, Loss: 0.5966, Val: 0.7980, Test: 0.8290
Epoch: 003, Loss: 0.5782, Val: 0.7940, Test: 0.8350
Epoch: 004, Loss: 0.5676, Val: 0.7920, Test: 0.8280
Epoch: 005, Loss: 0.5680, Val: 0.7920, Test: 0.8290
Epoch: 006, Loss: 0.6449, Val: 0.7920, Test: 0.8270
Epoch: 007, Loss: 0.5564, Val: 0.7920, Test: 0.8270
Epoch: 008, Loss: 0.5315, Val: 0.7860, Test: 0.8260
Epoch: 009, Loss: 0.5500, Val: 0.7840, Test: 0.8270
Epoch: 010, Loss: 0.5899, Val: 0.7880, Test: 0.8240
Epoch: 011, Loss: 0.6687, Val: 0.7920, Test: 0.8240
Epoch: 012, Loss: 0.5291, Val: 0.7980, Test: 0.8240
Epoch: 013, Loss: 0.6857, Val: 0.8000, Test: 0.8240
Epoch: 014, Loss: 0.6313, Val: 0.8060, Test: 0.8220
Epoch: 015, Loss: 0.6207, Val: 0.8080, Test: 0.8250
Epoch: 016, Loss: 0.4922, Val: 0.8100, Test: 0.8250
Epoch: 017, Loss: 0.5372, Val: 0.8080, Test: 0.8240
Epoch: 018, Loss: 0.6990, Val: 0.8060, Test: 0.8210
Epoch: 019, Loss: 0.6590, Val: 0.8000, Test: 0.8210
Epoch: 020, 

In [16]:
test_acc = test(data.test_mask)
print(f'Test Accuracy: {test_acc:.4f}')

Test Accuracy: 0.8330
