In [None]:
from spektral.transforms import AdjToSpTensor, LayerPreprocess

from spektral.layers import GCNConv
from spektral.data import Dataset, DisjointLoader, Graph
from spektral.layers import GCSConv, GlobalAvgPool
from spektral.datasets import TUDataset
from spektral.models import GeneralGNN
from spektral.transforms.normalize_adj import NormalizeAdj
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import categorical_accuracy, binary_accuracy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy

In [None]:
class Customdataset(Dataset):
    def __init__(self, graph, labels,  **kwargs):
        self.graph = graph
        self.labels = labels
        super().__init__(**kwargs)

    def read(self):
        output = []
        for gr, lb in zip(self.graph, self.labels):
            num_nodes = len(list(gr.nodes))
            list_edges = list(gr.edges)
            x = np.zeros((num_nodes, 982))
            auxvar = 0
            initial_value = list_edges[0][0]
            for edge in list_edges:
                control = edge[0]
                if control == initial_value:
                    x[auxvar][edge[1]] += 1
                else:
                    initial_value = control
                    auxvar += 1
                    x[auxvar][edge[1]] += 1

            A = nx.to_numpy_matrix(gr)
            A = sp.csr_matrix(A)
            X = np.asarray(x)

            E = np.asarray(list(gr.edges))
            E = sp.csr_matrix(E)

            if lb == 1:
                Y = np.array([1, 0])
            else:
                Y = np.array([0, 1])
            output.append(
                Graph(x=X,
                      a=A,
                      e=None,
                      y=Y))
        return output

In [None]:
dataset = Customdataset(graph=listgraph, labels=listlabels, transforms=NormalizeAdj())
idx = np.random.permutation(len(dataset))
split_va = int(0.9*len(dataset))
idx_tr, idx_va = np.split(idx, [split_va])
data_tr = dataset[idx_tr]
data_va = dataset[idx_va]

loader_tr = DisjointLoader(data_tr, batch_size=32, epochs=25)
loader_va = DisjointLoader(data_va, batch_size=10)

In [None]:
# Config

learning_rate = 1e-3  # Learning rate
epochs = 10  # Number of training epochs
es_patience = 10  # Patience for early stopping
batch_size = 32  # Batch size

model = GeneralGNN(dataset.n_labels, activation="softmax")
optimizer = Adam(lr=learning_rate)
loss_fn = CategoricalCrossentropy()

In [None]:
# Fit model

@tf.function(input_signature=loader_tr.tf_signature(), experimental_relax_shapes=True)
def train_step(inputs, target):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        loss = loss_fn(target, predictions) + sum(model.losses)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    acc = tf.reduce_mean(categorical_accuracy(target, predictions))
    return loss, acc

def evaluate(loader):
    output = []
    step = 0
    while step < loader.steps_per_epoch:
        step += 1
        inputs, target = loader.__next__()
        pred = model(inputs, training=False)
        outs = (
            loss_fn(target, pred),
            tf.reduce_mean(binary_accuracy(target, pred)),
            len(target),  # Keep track of batch size
        )
        output.append(outs)
        if step == loader.steps_per_epoch:
            output = np.array(output)
            return np.average(output[:, :-1], 0, weights=output[:, -1])

def inference(loader):
    output = []
    step = 0
    while step < loader.steps_per_epoch:
        step += 1
        inputs, target = loader.__next__()
        pred = model(inputs, training=False)
        output.append(pred)

        if step == loader.steps_per_epoch:
            output = np.array(output)
            return np.average(output[:, :-1], 0, weights=output[:, -1])


epoch = step = 0
best_val_loss = np.inf
best_weights = None
patience = es_patience
results = []


In [None]:
# Train loop

for batch in loader_tr:
    step += 1
    loss, acc = train_step(*batch)
    results.append((loss, acc))
    if step == loader_tr.steps_per_epoch:
        step = 0
        epoch += 1

        # Compute validation loss and accuracy
        val_loss, val_acc = evaluate(loader_va)
        print(
            "Ep. {} - Loss: {:.3f} - Acc: {:.3f} - Val loss: {:.3f} - Val acc: {:.3f}".format(
                epoch, *np.mean(results, 0), val_loss, val_acc
            )
        )

        # Check if loss improved for early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience = es_patience
            print("New best val_loss {:.3f}".format(val_loss))
            best_weights = model.get_weights()
        else:
            patience -= 1
            if patience == 0:
                print("Early stopping (best val_loss: {})".format(best_val_loss))
                break
        results = []