In [1]:
"""
This example shows how to perform regression of molecular properties with the
QM9 database, using a simple GNN in disjoint mode.
"""

import numpy as np
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.optimizers import Adam

from spektral.data import DisjointLoader
from spektral.datasets import QM9
from spektral.layers import ECCConv, GlobalSumPool

################################################################################
# Config
################################################################################
learning_rate = 1e-3  # Learning rate
epochs = 100  # Number of training epochs
batch_size = 32  # Batch size

################################################################################
# Load data
################################################################################
dataset = QM9(amount=1000)  # Set amount=None to train on whole dataset

# Parameters
F = dataset.n_node_features  # Dimension of node features
S = dataset.n_edge_features  # Dimension of edge features
n_out = dataset.n_labels  # Dimension of the target

# Train/test split
idxs = np.random.permutation(len(dataset))
split = int(0.9 * len(dataset))
idx_tr, idx_te = np.split(idxs, [split])
dataset_tr, dataset_te = dataset[idx_tr], dataset[idx_te]

loader_tr = DisjointLoader(dataset_tr, batch_size=batch_size, epochs=epochs)
loader_te = DisjointLoader(dataset_te, batch_size=batch_size, epochs=1)

################################################################################
# Build model
################################################################################
class Net(Model):
    def __init__(self):
        super().__init__()
        self.conv1 = ECCConv(32, activation="relu")
        self.conv2 = ECCConv(32, activation="relu")
        self.global_pool = GlobalSumPool()
        self.dense = Dense(n_out)

    def call(self, inputs):
        x, a, e, i = inputs
        x = self.conv1([x, a, e])
        x = self.conv2([x, a, e])
        output = self.global_pool([x, i])
        output = self.dense(output)

        return output


model = Net()
optimizer = Adam(learning_rate)
loss_fn = MeanSquaredError()


################################################################################
# Fit model
################################################################################
@tf.function(input_signature=loader_tr.tf_signature(), experimental_relax_shapes=True)
def train_step(inputs, target):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        loss = loss_fn(target, predictions) + sum(model.losses)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss


step = loss = 0
for batch in loader_tr:
    step += 1
    loss += train_step(*batch)
    if step == loader_tr.steps_per_epoch:
        step = 0
        print("Loss: {}".format(loss / loader_tr.steps_per_epoch))
        loss = 0

################################################################################
# Evaluate model
################################################################################
print("Testing model")
loss = 0
for batch in loader_te:
    inputs, target = batch
    predictions = model(inputs, training=False)
    loss += loss_fn(target, predictions)
loss /= loader_te.steps_per_epoch
print("Done. Test loss: {}".format(loss))

Loading QM9 dataset.
Reading SDF


100%|█████████████████████████████████████| 1000/1000 [00:00<00:00, 1356.29it/s]
2022-05-10 17:26:15.746034: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2022-05-10 17:26:15.747822: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
  return py_builtins.overload_of(f)(*args)
2022-05-10 17:26:16.593152: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)


Loss: 22110782.0
Loss: 22047232.0
Loss: 21904606.0
Loss: 21820092.0
Loss: 21806356.0
Loss: 21798554.0
Loss: 21795468.0
Loss: 21792614.0
Loss: 21790482.0
Loss: 21788376.0
Loss: 21786768.0
Loss: 21784460.0
Loss: 21783374.0
Loss: 21780754.0
Loss: 21778796.0
Loss: 21776766.0
Loss: 21774330.0
Loss: 21772702.0
Loss: 21769612.0
Loss: 21767686.0
Loss: 21764206.0
Loss: 21761484.0
Loss: 21758498.0
Loss: 21754880.0
Loss: 21750898.0
Loss: 21746402.0
Loss: 21742400.0
Loss: 21738032.0
Loss: 21732740.0
Loss: 21727154.0
Loss: 21722100.0
Loss: 21718780.0
Loss: 21708594.0
Loss: 21700858.0
Loss: 21693476.0
Loss: 21683352.0
Loss: 21678882.0
Loss: 21664588.0
Loss: 21653590.0
Loss: 21641578.0
Loss: 21628656.0
Loss: 21616408.0
Loss: 21602818.0
Loss: 21586540.0
Loss: 21572176.0
Loss: 21554942.0
Loss: 21537556.0
Loss: 21519446.0
Loss: 21501652.0
Loss: 21482152.0
Loss: 21468254.0
Loss: 21442914.0
Loss: 21422022.0
Loss: 21400422.0
Loss: 21376618.0
Loss: 21357556.0
Loss: 21333752.0
Loss: 21305578.0
Loss: 21282470



In [3]:
print(dataset[0].y)

[ 1.57711800e+02  1.57709970e+02  1.57706990e+02  0.00000000e+00
  1.32100000e+01 -3.87700000e-01  1.17100000e-01  5.04800000e-01
  3.53641000e+01  4.47490000e-02 -4.04789300e+01 -4.04760620e+01
 -4.04751170e+01 -4.04985970e+01  6.46900000e+00 -3.95999595e+02
 -3.98643290e+02 -4.01014647e+02 -3.72471772e+02]
