In [5]:
"""
This example shows how to perform regression of molecular properties with the
QM9 database, using a simple GNN in disjoint mode.
"""

import numpy as np
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.optimizers import Adam

from spektral.data import DisjointLoader
from spektral.datasets import QM9
from spektral.layers import ECCConv, GlobalSumPool

################################################################################
# Config
################################################################################
learning_rate = 1e-3  # Learning rate
epochs = 1000  # Number of training epochs
batch_size = 32  # Batch size

################################################################################
# Load data
################################################################################
dataset = QM9(amount=1000)  # Set amount=None to train on whole dataset

# Parameters
F = dataset.n_node_features  # Dimension of node features
S = dataset.n_edge_features  # Dimension of edge features
n_out = dataset.n_labels  # Dimension of the target

# Train/test split
idxs = np.random.permutation(len(dataset))
split = int(0.9 * len(dataset))
idx_tr, idx_te = np.split(idxs, [split])
dataset_tr, dataset_te = dataset[idx_tr], dataset[idx_te]

loader_tr = DisjointLoader(dataset_tr, batch_size=batch_size, epochs=epochs)
loader_te = DisjointLoader(dataset_te, batch_size=batch_size, epochs=1)

################################################################################
# Build model
################################################################################
class Net(Model):
    def __init__(self):
        super().__init__()
        self.conv1 = ECCConv(32, activation="relu")
        self.conv2 = ECCConv(32, activation="relu")
        self.global_pool = GlobalSumPool()
        self.dense = Dense(n_out)

    def call(self, inputs):
        x, a, e, i = inputs
        x = self.conv1([x, a, e])
        x = self.conv2([x, a, e])
        output = self.global_pool([x, i])
        output = self.dense(output)

        return output


model = Net()
optimizer = Adam(learning_rate)
loss_fn = MeanSquaredError()


################################################################################
# Fit model
################################################################################
@tf.function(input_signature=loader_tr.tf_signature(), experimental_relax_shapes=True)
def train_step(inputs, target):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        loss = loss_fn(target, predictions) + sum(model.losses)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss


step = loss = 0
for batch in loader_tr:
    step += 1
    loss += train_step(*batch)
    if step == loader_tr.steps_per_epoch:
        step = 0
        print("Loss: {}".format(loss / loader_tr.steps_per_epoch))
        loss = 0

################################################################################
# Evaluate model
################################################################################
print("Testing model")
loss = 0
for batch in loader_te:
    inputs, target = batch
    predictions = model(inputs, training=False)
    loss += loss_fn(target, predictions)
loss /= loader_te.steps_per_epoch
print("Done. Test loss: {}".format(loss))

Loading QM9 dataset.
Reading SDF


100%|█████████████████████████████████████| 1000/1000 [00:00<00:00, 1329.76it/s]
  return py_builtins.overload_of(f)(*args)


Loss: 22116324.0
Loss: 22056938.0
Loss: 21929188.0
Loss: 21826488.0
Loss: 21804524.0
Loss: 21799960.0
Loss: 21796658.0
Loss: 21793694.0
Loss: 21791300.0
Loss: 21789132.0
Loss: 21788770.0
Loss: 21785132.0
Loss: 21784742.0
Loss: 21781120.0
Loss: 21779498.0
Loss: 21776594.0
Loss: 21775530.0
Loss: 21772160.0
Loss: 21769722.0
Loss: 21765934.0
Loss: 21762802.0
Loss: 21760512.0
Loss: 21756050.0
Loss: 21753274.0
Loss: 21748452.0
Loss: 21744274.0
Loss: 21738692.0
Loss: 21734356.0
Loss: 21729190.0
Loss: 21721834.0
Loss: 21715632.0
Loss: 21710352.0
Loss: 21701684.0
Loss: 21694612.0
Loss: 21685052.0
Loss: 21675620.0
Loss: 21665428.0
Loss: 21653948.0
Loss: 21642970.0
Loss: 21630152.0
Loss: 21616880.0
Loss: 21607896.0
Loss: 21587878.0
Loss: 21575018.0
Loss: 21551318.0
Loss: 21535106.0
Loss: 21513436.0
Loss: 21492386.0
Loss: 21492268.0
Loss: 21451920.0
Loss: 21430900.0
Loss: 21415540.0
Loss: 21387972.0
Loss: 21379256.0
Loss: 21343576.0
Loss: 21329130.0
Loss: 21294850.0
Loss: 21272474.0
Loss: 21246632



In [38]:
print(dataset[0].a)

  (0, 1)	1
  (0, 2)	1
  (0, 3)	1
  (0, 4)	1
  (1, 0)	1
  (2, 0)	1
  (3, 0)	1
  (4, 0)	1
