In [2]:
!pip install spektral -qq
!pip install --upgrade keras -qq
!pip install ogb -qq
!git clone https://github.com/anas-rz/k3-node.git

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m140.1/140.1 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.0.4 which is incompatible.[0m[31m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.8/78.8 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for littleutils (setup.py) ... [?25l[?25hdone


In [4]:
import os, sys
os.environ['KERAS_BACKEND'] = 'tensorflow'
sys.path.append('/content/k3-node')

In [5]:
import numpy as np
from ogb.nodeproppred import NodePropPredDataset
from keras.layers import BatchNormalization, Dropout, Input
from keras.losses import SparseCategoricalCrossentropy
from keras.models import Model
from keras.optimizers import Adam

from spektral.datasets.ogb import OGB
from spektral.transforms import AdjToSpTensor, GCNFilter

from k3_node.layers import ARMAConv

In [6]:
# Load data
dataset_name = "ogbn-arxiv"
ogb_dataset = NodePropPredDataset(dataset_name)
dataset = OGB(ogb_dataset, transforms=[GCNFilter(), AdjToSpTensor()])
graph = dataset[0]
x, adj, y = graph.x, graph.a, graph.y

Downloading http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip


Downloaded 0.08 GB: 100%|██████████| 81/81 [00:07<00:00, 10.39it/s]


Extracting dataset/arxiv.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:00<00:00, 1729.61it/s]

Saving...





In [7]:
# Parameters
channels = 256  # Number of channels for GCN layers
dropout = 0.5  # Dropout rate for the features
learning_rate = 1e-2  # Learning rate
epochs = 10  # Number of training epochs
N = dataset.n_nodes  # Number of nodes in the graph
F = dataset.n_node_features  # Original size of node features
n_out = ogb_dataset.num_classes  # OGB labels are sparse indices

In [8]:
# Data splits
idx = ogb_dataset.get_idx_split()
idx_tr, idx_va, idx_te = idx["train"], idx["valid"], idx["test"]
mask_tr = np.zeros(N, dtype=bool)
mask_va = np.zeros(N, dtype=bool)
mask_te = np.zeros(N, dtype=bool)
mask_tr[idx_tr] = True
mask_va[idx_va] = True
mask_te[idx_te] = True
masks = [mask_tr, mask_va, mask_te]

In [9]:
# Model definition
x_in = Input(shape=(F,))
a_in = Input((N,), sparse=True)
x_1 = ARMAConv(channels, activation="relu")([x_in, a_in])
x_1 = BatchNormalization()(x_1)
x_1 = Dropout(dropout)(x_1)
x_2 = ARMAConv(channels, activation="relu")([x_1, a_in])
x_2 = BatchNormalization()(x_2)
x_2 = Dropout(dropout)(x_2)
x_3 = ARMAConv(n_out, activation="softmax")([x_2, a_in])

In [10]:
# Build model
model = Model(inputs=[x_in, a_in], outputs=x_3)
optimizer = Adam(learning_rate=learning_rate)
loss_fn = SparseCategoricalCrossentropy()
model.summary()

In [11]:
import tensorflow as tf
# Training function
@tf.function
def train(inputs, target, mask):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        loss = loss_fn(target[mask], predictions[mask]) + sum(model.losses)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

In [14]:
@tf.function
def evaluate(inputs, target, mask):
    predictions = model(inputs, training=True)
    loss = loss_fn(target[mask], predictions[mask]) + sum(model.losses)
    return loss

In [15]:
# Train model
for i in range(1, 1 + epochs):
    tr_loss = train([x, adj], y, mask_tr)
    eval_loss = evaluate([x, adj], y, mask_va) # TODO Add more metrics
    print(f"EPOCH {i}: Training Loss {tr_loss.numpy()}, Validation Loss: {eval_loss}")
test_loss = evaluate([x, adj], y, mask_te)
print(f"Test Loss: {test_loss}")

EPOCH 1: Training Loss 1.343968152999878, Evaluation Loss: 1.2462656497955322
EPOCH 2: Training Loss 1.3407444953918457, Evaluation Loss: 1.241466999053955
EPOCH 3: Training Loss 1.338334083557129, Evaluation Loss: 1.2379639148712158
EPOCH 4: Training Loss 1.3329962491989136, Evaluation Loss: 1.2300032377243042
EPOCH 5: Training Loss 1.3284363746643066, Evaluation Loss: 1.237213373184204
EPOCH 6: Training Loss 1.3322268724441528, Evaluation Loss: 1.2309643030166626
EPOCH 7: Training Loss 1.3203994035720825, Evaluation Loss: 1.2292847633361816
EPOCH 8: Training Loss 1.3217833042144775, Evaluation Loss: 1.224158525466919
EPOCH 9: Training Loss 1.3212112188339233, Evaluation Loss: 1.2227007150650024
EPOCH 10: Training Loss 1.3170732259750366, Evaluation Loss: 1.2200335264205933
EPOCH 11: Training Loss 1.3163007497787476, Evaluation Loss: 1.2167840003967285
EPOCH 12: Training Loss 1.3103758096694946, Evaluation Loss: 1.2180887460708618
EPOCH 13: Training Loss 1.3075199127197266, Evaluation