In [None]:
import torch_geometric
from torch_geometric._datasets import  Planetoid

In [None]:
#########################
## Dataset preparation ##
#########################

root_path = '/home/longdpt/Documents/Long_AISDL/DeepLearning_PyTorch/05_GNN/data'

#---------
## Load the dataset
#---------

dataset = Planetoid(root=root_path, name="Cora")

'''
Planetoid is not a single dataset, but rather a collection of three citation network datasets, 
commonly used for benchmarking Graph Neural Networks (GNNs).
'''

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [None]:
#---------
## Dataset properties
#---------

print(dataset) # Cora()
print("number of graphs:\t\t", len(dataset))                      # 1 (has only one graph)
print("number of classes:\t\t", dataset.num_classes)              # 7 (has 7 different features)
print("number of node features:\t", dataset.num_node_features)    # 1433 (each node has 1433 features)
print("number of edge features:\t", dataset.num_edge_features)    # 0

Cora()
number of graphs:		 1
number of classes:		 7
number of node features:	 1433
number of edge features:	 0


In [None]:
#---------
## dataset._data
#---------

print(dataset._data)
print("\n")
# Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
'''
2,708 nodes (papers)
10,556 edges (citations)
1,433 features (words)
7 classes (research topics)
'''

print("edge_index:\t\t", dataset._data.edge_index.shape)
print(dataset._data.edge_index)
print("\n")
# edge_index:		 torch.Size([2, 10556])              10556 edges (citation relationships)
# tensor([[ 633, 1862, 2582,  ...,  598, 1473, 2706],    SOURCE NODES
#         [   0,    0,    0,  ..., 2707, 2707, 2707]])   TARGET NODES
# Example: 633 -> 0, 1862 -> 0

print("train_mask:\t\t", dataset._data.train_mask.shape)
print(dataset._data.train_mask)
print("\n")
# train_mask:		 torch.Size([2708])
# tensor([ True,  True,  True,  ..., False, False, False])
# True: this is from the training set
# False: this is NOT in the training set

print("X:\t\t", dataset._data.x.shape)
print(dataset._data.x)
print("\n")
# X:		 torch.Size([2708, 1433])
# tensor([[0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         ...,
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.],
#         [0., 0., 0.,  ..., 0., 0., 0.]])
# Each row represents one node
# Each column represents one feature of a node

print("y:\t\t", dataset._data.y.shape)
print(dataset._data.y)
print("\n")
# y:		 torch.Size([2708])
# tensor([3, 4, 4,  ..., 3, 3, 3])
# The output label of each node (in this example, we have 7 different classes)

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])


edge_index:		 torch.Size([2, 10556])
tensor([[ 633, 1862, 2582,  ...,  598, 1473, 2706],
        [   0,    0,    0,  ..., 2707, 2707, 2707]])


train_mask:		 torch.Size([2708])
tensor([ True,  True,  True,  ..., False, False, False])


X:		 torch.Size([2708, 1433])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


y:		 torch.Size([2708])
tensor([3, 4, 4,  ..., 3, 3, 3])




In [11]:
#----
## get the data
#----

data = dataset[0] # Since we have only one dataset, use [0] to get it out

In [None]:
################
## Simple GNN ##
################

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv # GraphSAGE Convolution layer from PyTorch Geometric

'''
What is SAGEConv?
GraphSAGE (SAmple and aggreGatE) is a type of graph convolution that:
# Samples neighbors of each node
# Aggregates their features
# Combines with the node's own features

What happens inside SAGEConv:
For each node i:
# Gather features from neighbors: {x_j : j ∈ Neighbors(i)}
# Aggregate: h_neighbors = max(x_j for all neighbors j) (element-wise max)
# Combine: h_i = W * concat([x_i, h_neighbors])

############ Example #############

x_i = [1, 0, 1]

neighbor_1 = [1, 2, 1]
neighbor_2 = [3, 4, 2]
neighbor_3 = [8, 1, 0]

aggregated = [max(1, 3, 8),   # position 0
              max(2, 4, 1),   # position 1
              max(1, 2, 0)]   # position 2

aggregated = [8, 4, 2]  ✅

=> concat([x_i, aggregated]) = [1, 0, 1, 8, 4, 2]
```

---

## Visual Representation
```
Position:       0   1   2
              ┌───┬───┬───┐
Neighbor 1:   │ 1 │ 2 │ 1 │
              ├───┼───┼───┤
Neighbor 2:   │ 3 │ 4 │ 2 │
              ├───┼───┼───┤
Neighbor 3:   │ 8 │ 1 │ 0 │
              ├───┼───┼───┤
              │ ↓ │ ↓ │ ↓ │
              ├───┼───┼───┤
Max:          │ 8 │ 4 │ 2 │
              └───┴───┴───┘
              
Node's own features:     [1, 0, 1]
                             +
Aggregated neighbors:    [8, 4, 2]
                              ↓
Concatenated:            [1, 0, 1, 8, 4, 2]
                         └─────┘ └───────┘
                          self   neighbors
'''

#----
## build model
#----

class SimpleGNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.cnn = SAGEConv(
            in_channels=dataset.num_features, # 1433 (Cora)
            out_channels=dataset.num_classes, # 7 (Cora)
            aggr="max" # could be max, mean, add, ...
        )
        
    def forward(self):
        out = self.cnn(data.x, data.edge_index)
        out = F.log_softmax(out, dim=1)
        return out

In [None]:
###############
## Optimizer ##
###############

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

torch.manual_seed(42)
model, data = SimpleGNN().to(device), data.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [None]:
##############
## Training ##
##############

from loguru import logger

best_val_acc = test_acc = 0

for epoch in range(1, 101, 1):
    #----TRAIN
    _ = model.train()
    optimizer.zero_grad()
    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
    optimizer.step()
    
    #-----VAL - TEST
    _ = model.eval()
    logits, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
        
    _, val_acc, tmp_test_acc = accs
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
        
    if (epoch % 10 == 0) or (epoch == 1):
        logger.info("+"*50)
        logger.info(f"Epoch: {epoch}")
        logger.info(f"Val: {best_val_acc:.4f}")
        logger.info(f"Test: {test_acc:.4f}")

[32m2026-01-09 13:56:53.950[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1m++++++++++++++++++++++++++++++++++++++++++++++++++[0m
[32m2026-01-09 13:56:53.950[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m27[0m - [1mEpoch: 1[0m
[32m2026-01-09 13:56:53.950[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m28[0m - [1mVal: 0.4380[0m
[32m2026-01-09 13:56:53.951[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m29[0m - [1mTest: 0.4330[0m
[32m2026-01-09 13:56:53.969[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1m++++++++++++++++++++++++++++++++++++++++++++++++++[0m
[32m2026-01-09 13:56:53.969[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m27[0m - [1mEpoch: 10[0m
[32m2026-01-09 13:56:53.969[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m28[0m - [1mVal: 0.7260[0m
[32m2026-01-09 13:56:53.969[0m | [1mINFO    [0m | [36