# Benchmark DTNN model on the QM7 dataset

In [7]:
import sys
sys.path.append('../../../../')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras


from molgraph.chemistry.benchmark import configs
from molgraph.chemistry.benchmark import tf_records
from molgraph.chemistry import datasets

### 1. Build **MolecularGraphEncoder3D**

In [2]:
from molgraph.chemistry import features
from molgraph.chemistry import Featurizer
from molgraph.chemistry import MolecularGraphEncoder3D

atom_encoder = Featurizer([
    features.Symbol(),
    features.Hybridization(),
    features.FormalCharge(),
    features.TotalNumHs(),
    features.TotalValence(),
    features.NumRadicalElectrons(),
    features.Degree(),
    features.ChiralCenter(),
    features.Aromatic(),
    features.Ring(),
    features.Hetero(),
    features.HydrogenDonor(),
    features.HydrogenAcceptor(),
    features.CIPCode(),
    features.ChiralCenter(),
    features.RingSize(),
    features.Ring(),
    features.CrippenLogPContribution(),
    features.CrippenMolarRefractivityContribution(),
    features.TPSAContribution(),
    features.LabuteASAContribution(),
    features.GasteigerCharge(),
])

encoder = MolecularGraphEncoder3D(
    atom_encoder,
    conformer_generator=None, # qm7 encodes conformers
    edge_radius=None, # max radius
    coulomb=True,
)

### 2. Build **TF dataset** from **MolecularGraphEncoder3D**

In [3]:
qm7 = datasets.get('qm7')

x_train = encoder(qm7['train']['x'])
y_train = qm7['train']['y']

x_val = encoder(qm7['validation']['x'])
y_val = qm7['validation']['y']

x_test = encoder(qm7['test']['x'])
y_test = qm7['test']['y']

type_spec = x_train.merge().unspecific_spec

In [4]:
train_ds = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(1024)
    .batch(32)
    .map(lambda x, *args: (x.merge(), *args), -1)
    .prefetch(-1)
)

val_ds = (
    tf.data.Dataset.from_tensor_slices((x_val, y_val))
    .batch(32)
    .map(lambda x, *args: (x.merge(), *args), -1)
    .prefetch(-1)
)

test_ds = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .batch(32)
    .map(lambda x, *args: (x.merge(), *args), -1)
    .prefetch(-1)
)

### 3. Modeling

In [6]:
from molgraph.layers import DTNNConv
from molgraph.layers import Readout
from molgraph.layers import MinMaxScaling

node_preprocessing = MinMaxScaling(
    feature='node_feature', feature_range=(0, 1), threshold=True)
edge_preprocessing = MinMaxScaling(
    feature='edge_feature', feature_range=(0, 1), threshold=True)

node_preprocessing.adapt(train_ds.map(lambda x, *args: x))
edge_preprocessing.adapt(train_ds.map(lambda x, *args: x))

model = tf.keras.Sequential([
    keras.layers.Input(type_spec=type_spec),
    node_preprocessing,
    edge_preprocessing,
    DTNNConv(),
    DTNNConv(),
    DTNNConv(),
    Readout(),
    keras.layers.Dense(1024, 'relu'),
    keras.layers.Dense(1024, 'relu'),
    keras.layers.Dense(y_train.shape[-1])
])


optimizer = keras.optimizers.Adam(1e-4)
loss = keras.losses.MeanAbsoluteError(name='mae')

callbacks = [
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.1,
        patience=10,
        min_lr=1e-6,
        mode='min',
    ),
    keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=20,
        mode='min',
        restore_best_weights=True,
    )
]

model.compile(optimizer, loss)
history = model.fit(
    train_ds, 
    callbacks=callbacks, 
    validation_data=val_ds, 
    epochs=100,
    verbose=2,
)
score = model.evaluate(test_ds)
print(score)

Epoch 1/100
180/180 - 3s - loss: 1257.5872 - val_loss: 903.2003 - lr: 1.0000e-04 - 3s/epoch - 18ms/step
Epoch 2/100
180/180 - 1s - loss: 545.0237 - val_loss: 442.4483 - lr: 1.0000e-04 - 937ms/epoch - 5ms/step
Epoch 3/100
180/180 - 1s - loss: 163.0737 - val_loss: 219.5774 - lr: 1.0000e-04 - 961ms/epoch - 5ms/step
Epoch 4/100
180/180 - 1s - loss: 127.5755 - val_loss: 100.0004 - lr: 1.0000e-04 - 952ms/epoch - 5ms/step
Epoch 5/100
180/180 - 1s - loss: 120.0276 - val_loss: 72.0875 - lr: 1.0000e-04 - 924ms/epoch - 5ms/step
Epoch 6/100
180/180 - 1s - loss: 108.8816 - val_loss: 179.9970 - lr: 1.0000e-04 - 940ms/epoch - 5ms/step
Epoch 7/100
180/180 - 1s - loss: 101.8370 - val_loss: 80.3074 - lr: 1.0000e-04 - 950ms/epoch - 5ms/step
Epoch 8/100
180/180 - 1s - loss: 97.8691 - val_loss: 66.1778 - lr: 1.0000e-04 - 920ms/epoch - 5ms/step
Epoch 9/100
180/180 - 1s - loss: 92.5363 - val_loss: 60.9984 - lr: 1.0000e-04 - 939ms/epoch - 5ms/step
Epoch 10/100
180/180 - 1s - loss: 87.6198 - val_loss: 69.3062 

Epoch 80/100
180/180 - 1s - loss: 49.7845 - val_loss: 27.6885 - lr: 1.0000e-06 - 927ms/epoch - 5ms/step
Epoch 81/100
180/180 - 1s - loss: 50.6684 - val_loss: 28.2240 - lr: 1.0000e-06 - 961ms/epoch - 5ms/step
Epoch 82/100
180/180 - 1s - loss: 48.3892 - val_loss: 27.5653 - lr: 1.0000e-06 - 932ms/epoch - 5ms/step
Epoch 83/100
180/180 - 1s - loss: 50.5615 - val_loss: 27.0092 - lr: 1.0000e-06 - 919ms/epoch - 5ms/step
Epoch 84/100
180/180 - 1s - loss: 51.8766 - val_loss: 27.6167 - lr: 1.0000e-06 - 959ms/epoch - 5ms/step
Epoch 85/100
180/180 - 1s - loss: 49.8121 - val_loss: 27.5435 - lr: 1.0000e-06 - 918ms/epoch - 5ms/step
Epoch 86/100
180/180 - 1s - loss: 49.2227 - val_loss: 27.5950 - lr: 1.0000e-06 - 927ms/epoch - 5ms/step
Epoch 87/100
180/180 - 1s - loss: 48.8999 - val_loss: 27.2340 - lr: 1.0000e-06 - 953ms/epoch - 5ms/step
Epoch 88/100
180/180 - 1s - loss: 52.5623 - val_loss: 27.4560 - lr: 1.0000e-06 - 926ms/epoch - 5ms/step
Epoch 89/100
180/180 - 1s - loss: 48.5449 - val_loss: 27.5492 - 