In [1]:
import os.path as osp

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import SplineConv
from torch_geometric.typing import WITH_TORCH_SPLINE_CONV

In [2]:
if not WITH_TORCH_SPLINE_CONV:
    quit("This example requires 'torch-spline-conv'")

In [9]:
dataset = 'Cora'
transform = T.Compose([
    T.RandomNodeSplit(num_val=1080, num_test=540),
    T.TargetIndegree(),
])
path = ('./dataset/cora')
dataset = Planetoid(path, dataset, transform=transform)
data = dataset[0]
print("Number of graphs in the dataset:", data)

Number of graphs in the dataset: Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708], edge_attr=[10556, 1])


In [4]:
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = SplineConv(dataset.num_features, 16, dim=1, kernel_size=2)
        self.conv2 = SplineConv(16, dataset.num_classes, dim=1, kernel_size=2)

    def forward(self):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = F.dropout(x, training=self.training)
        x = F.elu(self.conv1(x, edge_index, edge_attr))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index, edge_attr)
        return F.log_softmax(x, dim=1)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-3)

In [6]:
def train():
    model.train()
    optimizer.zero_grad()
    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
    optimizer.step()


@torch.no_grad()
def test():
    model.eval()
    log_probs, accs = model(), []
    for _, mask in data('train_mask', 'test_mask'):
        pred = log_probs[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs

In [7]:

for epoch in range(1, 201):
    train()
    train_acc, test_acc = test()
    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Test: {test_acc:.4f}')

Epoch: 001, Train: 0.4982, Test: 0.4130
Epoch: 002, Train: 0.5129, Test: 0.4296
Epoch: 003, Train: 0.5708, Test: 0.4852
Epoch: 004, Train: 0.6507, Test: 0.5407
Epoch: 005, Train: 0.7408, Test: 0.6148
Epoch: 006, Train: 0.8309, Test: 0.7037
Epoch: 007, Train: 0.8796, Test: 0.7537
Epoch: 008, Train: 0.8980, Test: 0.7833
Epoch: 009, Train: 0.9026, Test: 0.8111
Epoch: 010, Train: 0.9099, Test: 0.8204
Epoch: 011, Train: 0.9154, Test: 0.8315
Epoch: 012, Train: 0.9182, Test: 0.8315
Epoch: 013, Train: 0.9256, Test: 0.8370
Epoch: 014, Train: 0.9403, Test: 0.8444
Epoch: 015, Train: 0.9439, Test: 0.8519
Epoch: 016, Train: 0.9494, Test: 0.8611
Epoch: 017, Train: 0.9504, Test: 0.8704
Epoch: 018, Train: 0.9559, Test: 0.8704
Epoch: 019, Train: 0.9577, Test: 0.8759
Epoch: 020, Train: 0.9596, Test: 0.8741
Epoch: 021, Train: 0.9642, Test: 0.8741
Epoch: 022, Train: 0.9678, Test: 0.8759
Epoch: 023, Train: 0.9724, Test: 0.8704
Epoch: 024, Train: 0.9761, Test: 0.8685
Epoch: 025, Train: 0.9789, Test: 0.8722
