In [1]:
# Install core library (Fastest way)
!pip install -q torch-geometric

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Amazon
from torch_geometric.nn import SAGEConv
from torch_geometric.transforms import RandomNodeSplit
from sklearn.metrics import f1_score

#  Load Dataset (Amazon Computers)
# RandomNodeSplit creates train/val/test masks for us automatically
dataset = Amazon(root='/tmp/Amazon', name='Computers', transform=RandomNodeSplit(num_val=0.1, num_test=0.2))
data = dataset[0]

# Define GraphSAGE Model
class SageNet(torch.nn.Module):
    def __init__(self):
        super(SageNet, self).__init__()
        self.conv1 = SAGEConv(dataset.num_features, 64)
        self.conv2 = SAGEConv(64, dataset.num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = F.dropout(x, p=0.4, training=self.training)
        x = self.conv2(x, edge_index)
        return x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SageNet().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# Fast Training Loop (50 Epochs is plenty)
for epoch in range(1, 51):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f'Epoch {epoch:02d} | Loss: {loss.item():.4f}')

#Advanced Evaluation (F1-Score)
model.eval()
with torch.no_grad():
    logits = model(data.x, data.edge_index)
    preds = logits.argmax(dim=1)

    # Accuracy
    acc = (preds[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()

    # Macro F1-Score (Better for multi-class where some categories are rare)
    y_true = data.y[data.test_mask].cpu()
    y_pred = preds[data.test_mask].cpu()
    f1 = f1_score(y_true, y_pred, average='macro')

print(f'\n--- Amazon Computers Results ---')
print(f'Test Accuracy: {acc*100:.2f}%')
print(f'Macro F1-Score: {f1:.4f}')

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m41.0 MB/s[0m eta [36m0:00:00[0m
[?25h

Downloading https://github.com/shchur/gnn-benchmark/raw/master/data/npz/amazon_electronics_computers.npz
Processing...
Done!


Epoch 10 | Loss: 1.9235
Epoch 20 | Loss: 1.4563
Epoch 30 | Loss: 1.1847
Epoch 40 | Loss: 0.9721
Epoch 50 | Loss: 0.8131

--- Amazon Computers Results ---
Test Accuracy: 76.33%
Macro F1-Score: 0.5128
