## 📦 Packages and Basic Setup
---

In [None]:
%%capture
import os

import torch

torch_version = torch.__version__.split("+")
os.environ["TORCH"] = torch_version[0]
os.environ["CUDA"] = torch_version[1]

!pip install pyg-lib torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
!pip install torch-geometric
!pip install -q --upgrade wandb

import torch.nn.functional as F
import torch_geometric.transforms as T
import wandb
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# @title ⚙ Configuration

# Paste your api key here
os.environ["WANDB_API_KEY"] = "..."

wandb.init(project="GATv1", entity="graph-neural-networks")

config = wandb.config
config.lr = 0.01  # @param {type: "number"}
config.latent_dim = 16  # @param {type: "number"}
config.num_epochs = 50  # @param {type: "number"}
config.heads = 8  # @param {type: "number"}
wandb.config.update(config)

## 💿 The Dataset
---

In [None]:
dataset = Planetoid("data/", "Cora", transform=T.NormalizeFeatures())
data = dataset[0]

## ✍️ Model Architecture & Training
---

In [None]:
class GATv1(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=0.6)
        self.conv2 = GATConv(
            hidden_channels * heads, out_channels, heads=1, concat=False, dropout=0.6
        )

    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv1(x, edge_index).relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x


model = GATv1(
    in_channels=dataset.num_features,
    hidden_channels=config.latent_dim,
    out_channels=dataset.num_classes,
    heads=config.heads,
)

model, data = model.to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=5e-4)

## Training
---

In [None]:
def train():
    model.train()
    optimizer.zero_grad()
    out = model(x=data.x, edge_index=data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test():
    model.eval()
    pred = model(x=data.x, edge_index=data.edge_index).argmax(dim=-1)

    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
    return accs

In [None]:
best_val_acc = final_test_acc = 0

for epoch in range(1, config.num_epochs + 1):
    loss = train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
    wandb.log({"Training Loss": loss, "Training Accuracy": train_acc})

wandb.run.summary["Training Loss"] = loss
wandb.run.summary["Training Accuracy"] = train_acc
wandb.finish()