In [60]:
!pip install spektral
import spektral
import numpy as np
import scipy.sparse as sp

In [None]:
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import Input, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2

from spektral.data.loaders import SingleLoader
from spektral.layers import GCNConv


dataset = [spektral.data.Graph(x=np.random.randint(0,1,size=(51,5)), a=sp.csr_matrix((51,51)), e=None)] # Placeholder
'''
x: np.array, the node features (shape (n_nodes, n_node_features));
a: np.array or scipy.sparse matrix, the adjacency matrix (shape (n_nodes, n_nodes));
e: np.array, the edge features (shape (n_nodes, n_nodes, n_edge_features) or (n_edges, n_edge_features));
'''

# Parameters
channels = 16          # Number of channels in the first layer
dropout = 0.5          # Dropout rate for the features
l2_reg = 5e-4 / 2      # L2 regularization rate
learning_rate = 1e-2   # Learning rate
epochs = 200           # Number of training epochs
patience = 10          # Patience for early stopping
a_dtype = dataset[0].a.dtype  # Only needed for TF 2.1

N = 51          # Number of nodes in the graph
F = 5           # Original size of node features
y = 5           # Label


# Model definition
x_in = Input(shape=(F,))
a_in = Input((N,), sparse=True, dtype=a_dtype)

do_1 = Dropout(dropout)(x_in)
gc_1 = GCNConv(channels,
               activation='relu',
               kernel_regularizer=l2(l2_reg),
               use_bias=False)([do_1, a_in])
do_2 = Dropout(dropout)(gc_1)
gc_2 = GCNConv(y,
               activation='softmax',
               use_bias=False)([do_2, a_in])

# Build model
model = Model(inputs=[x_in, a_in], outputs=gc_2)
optimizer = Adam(lr=learning_rate)
model.compile(optimizer=optimizer,
              loss='categorical_crossentropy',
              weighted_metrics=['acc'])
model.summary()

# Train model
loader = SingleLoader(dataset)
model.fit(loader.load(),
          steps_per_epoch=loader.steps_per_epoch,
          epochs=epochs,
          callbacks=[EarlyStopping(patience=patience, restore_best_weights=True)])
