In [1]:
!pip install torch-geometric
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric


Collecting torch-scatter
  Using cached torch_scatter-2.1.2-cp310-cp310-linux_x86_64.whl
Collecting torch-sparse
  Using cached torch_sparse-0.6.18.tar.gz (209 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch-cluster
  Using cached torch_cluster-1.6.3.tar.gz (54 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch-spline-conv
  Using cached torch_spline_conv-1.2.2.tar.gz (25 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: torch-sparse, torch-cluster, torch-spline-conv
  Building wheel for torch-sparse (setup.py) ... [?25l[?25hcanceled
[31mERROR: Operation cancelled by user[0m[31m
[0m

In [3]:
import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid

# Load the Cora dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=T.NormalizeFeatures())

# Check the dataset
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_node_features}')
print(f'Number of classes: {dataset.num_classes}')


Dataset: Cora():
Number of graphs: 1
Number of features: 1433
Number of classes: 7


In [4]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCN(nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

model = GCN()
print(model)


GCN(
  (conv1): GCNConv(1433, 16)
  (conv2): GCNConv(16, 7)
)


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
data = dataset[0].to(device)
criterion = nn.CrossEntropyLoss()

def train_model(model, criterion, optimizer, num_epochs=25):
    for epoch in range(num_epochs):
        model.train()

        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        model.eval()
        _, pred = model(data).max(dim=1)
        correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
        acc = correct / int(data.test_mask.sum())

        print(f'Epoch {epoch}/{num_epochs - 1}, Loss: {loss.item():.4f}, Test Accuracy: {acc:.4f}')

    return model

# Choose the optimizer you want to use
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# Train the model
model = train_model(model, criterion, optimizer, num_epochs=200)


Epoch 0/199, Loss: 1.9455, Test Accuracy: 0.3970
Epoch 1/199, Loss: 1.9404, Test Accuracy: 0.5070
Epoch 2/199, Loss: 1.9336, Test Accuracy: 0.3730
Epoch 3/199, Loss: 1.9245, Test Accuracy: 0.3980
Epoch 4/199, Loss: 1.9145, Test Accuracy: 0.4370
Epoch 5/199, Loss: 1.9039, Test Accuracy: 0.4670
Epoch 6/199, Loss: 1.8930, Test Accuracy: 0.4880
Epoch 7/199, Loss: 1.8811, Test Accuracy: 0.4920
Epoch 8/199, Loss: 1.8683, Test Accuracy: 0.5090
Epoch 9/199, Loss: 1.8545, Test Accuracy: 0.5230
Epoch 10/199, Loss: 1.8398, Test Accuracy: 0.5520
Epoch 11/199, Loss: 1.8245, Test Accuracy: 0.5850
Epoch 12/199, Loss: 1.8084, Test Accuracy: 0.6080
Epoch 13/199, Loss: 1.7917, Test Accuracy: 0.6270
Epoch 14/199, Loss: 1.7743, Test Accuracy: 0.6490
Epoch 15/199, Loss: 1.7563, Test Accuracy: 0.6610
Epoch 16/199, Loss: 1.7376, Test Accuracy: 0.6650
Epoch 17/199, Loss: 1.7184, Test Accuracy: 0.6720
Epoch 18/199, Loss: 1.6985, Test Accuracy: 0.6720
Epoch 19/199, Loss: 1.6781, Test Accuracy: 0.6720
Epoch 20/1

In [6]:
model.eval()
_, pred = model(data).max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / int(data.test_mask.sum())
print(f'Test Accuracy: {acc:.4f}')

Test Accuracy: 0.8080
