# Benchmark GAT model on the ESOL dataset

In [6]:
import sys
sys.path.append('../../../../')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras


from molgraph.chemistry.benchmark import configs
from molgraph.chemistry.benchmark import tf_records
from molgraph.chemistry import datasets
from molgraph.losses import masked_losses
from molgraph.metrics import masked_metrics


### 1. Build **MolecularGraphEncoder**

In [2]:
from molgraph.chemistry import features
from molgraph.chemistry import AtomicFeaturizer
from molgraph.chemistry import MolecularGraphEncoder

atom_encoder = AtomicFeaturizer([
    features.Symbol(),
    features.Hybridization(),
    features.FormalCharge(),
    features.TotalNumHs(),
    features.TotalValence(),
    features.NumRadicalElectrons(),
    features.Degree(),
    features.ChiralCenter(),
    features.Aromatic(),
    features.Ring(),
    features.Hetero(),
    features.HydrogenDonor(),
    features.HydrogenAcceptor(),
    features.CIPCode(),
    features.ChiralCenter(),
    features.RingSize(),
    features.Ring(),
    features.CrippenLogPContribution(),
    features.CrippenMolarRefractivityContribution(),
    features.TPSAContribution(),
    features.LabuteASAContribution(),
    features.GasteigerCharge(),
])

bond_encoder = AtomicFeaturizer([
    features.BondType(),
    features.Conjugated(),
    features.Rotatable(),
    features.Ring(),
    features.Stereo(),
])

encoder = MolecularGraphEncoder(
    atom_encoder,
    bond_encoder,
    positional_encoding_dim=16,
    self_loops=False
)

### 2. Build **TF dataset** from **MolecularGraphEncoder**

In [3]:
esol = datasets.get('esol')

x_train = encoder(esol['train']['x'])
y_train = esol['train']['y']

x_val = encoder(esol['validation']['x'])
y_val = esol['validation']['y']

x_test = encoder(esol['test']['x'])
y_test = esol['test']['y']

type_spec = x_train.merge().unspecific_spec

In [4]:
train_ds = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(1024)
    .batch(32)
    .map(lambda x, *args: (x.merge(), *args), -1)
    .prefetch(-1)
)

val_ds = (
    tf.data.Dataset.from_tensor_slices((x_val, y_val))
    .batch(32)
    .map(lambda x, *args: (x.merge(), *args), -1)
    .prefetch(-1)
)

test_ds = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .batch(32)
    .map(lambda x, *args: (x.merge(), *args), -1)
    .prefetch(-1)
)

### 3. Modeling

In [5]:
from molgraph.layers import GATConv
from molgraph.layers import LaplacianPositionalEncoding
from molgraph.layers import Readout
from molgraph.layers import MinMaxScaling

node_preprocessing = MinMaxScaling(
    feature='node_feature', feature_range=(0, 1), threshold=True)
edge_preprocessing = MinMaxScaling(
    feature='edge_feature', feature_range=(0, 1), threshold=True)

node_preprocessing.adapt(train_ds.map(lambda x, *args: x))
edge_preprocessing.adapt(train_ds.map(lambda x, *args: x))

model = tf.keras.Sequential([
    keras.layers.Input(type_spec=type_spec),
    node_preprocessing,
    edge_preprocessing,
    LaplacianPositionalEncoding(),
    GATConv(),
    GATConv(),
    GATConv(),
    Readout(),
    keras.layers.Dense(1024, 'relu'),
    keras.layers.Dense(1024, 'relu'),
    keras.layers.Dense(y_train.shape[-1])
])


optimizer = keras.optimizers.Adam(1e-4)
loss = keras.losses.MeanAbsoluteError(name='mae')
callbacks = [
    keras.callbacks.ReduceLROnPlateau(
        monitor=f'val_loss',
        factor=0.1,
        patience=10,
        min_lr=1e-6,
        mode='min',
    ),
    keras.callbacks.EarlyStopping(
        monitor=f'val_loss',
        patience=20,
        mode='min',
        restore_best_weights=True,
    )
]

model.compile(optimizer, loss)
history = model.fit(
    train_ds, 
    callbacks=callbacks, 
    validation_data=val_ds, 
    epochs=100,
    verbose=2,
)
score = model.evaluate(test_ds)
print(score)

Epoch 1/100
29/29 - 6s - loss: 1.6104 - val_loss: 2.8727 - lr: 1.0000e-04 - 6s/epoch - 209ms/step
Epoch 2/100
29/29 - 0s - loss: 1.0733 - val_loss: 2.8306 - lr: 1.0000e-04 - 388ms/epoch - 13ms/step
Epoch 3/100
29/29 - 0s - loss: 0.9292 - val_loss: 2.7933 - lr: 1.0000e-04 - 380ms/epoch - 13ms/step
Epoch 4/100
29/29 - 0s - loss: 0.7980 - val_loss: 2.6992 - lr: 1.0000e-04 - 383ms/epoch - 13ms/step
Epoch 5/100
29/29 - 0s - loss: 0.7400 - val_loss: 2.6170 - lr: 1.0000e-04 - 386ms/epoch - 13ms/step
Epoch 6/100
29/29 - 0s - loss: 0.6782 - val_loss: 2.5222 - lr: 1.0000e-04 - 390ms/epoch - 13ms/step
Epoch 7/100
29/29 - 0s - loss: 0.6756 - val_loss: 2.3882 - lr: 1.0000e-04 - 392ms/epoch - 14ms/step
Epoch 8/100
29/29 - 0s - loss: 0.6457 - val_loss: 2.2180 - lr: 1.0000e-04 - 380ms/epoch - 13ms/step
Epoch 9/100
29/29 - 0s - loss: 0.6356 - val_loss: 2.1035 - lr: 1.0000e-04 - 379ms/epoch - 13ms/step
Epoch 10/100
29/29 - 0s - loss: 0.5921 - val_loss: 2.0029 - lr: 1.0000e-04 - 397ms/epoch - 14ms/step
E