# Benchmark MPNN model on the Tox21 dataset (with Masked Loss)

In [4]:
import sys
sys.path.append('../../../../')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras


from molgraph.chemistry.benchmark import configs
from molgraph.chemistry.benchmark import tf_records
from molgraph.chemistry import datasets
from molgraph.losses import MaskedBinaryCrossentropy

### 1. Build **MolecularGraphEncoder**

In [2]:
from molgraph.chemistry import features
from molgraph.chemistry import Featurizer
from molgraph.chemistry import MolecularGraphEncoder

atom_encoder = Featurizer([
    features.Symbol(),
    features.Hybridization(),
    features.FormalCharge(),
    features.TotalNumHs(),
    features.TotalValence(),
    features.NumRadicalElectrons(),
    features.Degree(),
    features.ChiralCenter(),
    features.Aromatic(),
    features.Ring(),
    features.Hetero(),
    features.HydrogenDonor(),
    features.HydrogenAcceptor(),
    features.CIPCode(),
    features.ChiralCenter(),
    features.RingSize(),
    features.Ring(),
    features.CrippenLogPContribution(),
    features.CrippenMolarRefractivityContribution(),
    features.TPSAContribution(),
    features.LabuteASAContribution(),
    features.GasteigerCharge(),
])

bond_encoder = Featurizer([
    features.BondType(),
    features.Conjugated(),
    features.Rotatable(),
    features.Ring(),
    features.Stereo(),
])

encoder = MolecularGraphEncoder(
    atom_encoder,
    bond_encoder,
    positional_encoding_dim=16,
    self_loops=False
)

### 2. Build **TF dataset** from **MolecularGraphEncoder**

In [3]:
tox21 = datasets.get('tox21')

x_train = encoder(tox21['train']['x'])
y_train = tox21['train']['y']
y_mask_train = tox21['train']['y_mask']

x_val = encoder(tox21['validation']['x'])
y_val = tox21['validation']['y']
y_mask_val = tox21['validation']['y_mask']

x_test = encoder(tox21['test']['x'])
y_test = tox21['test']['y']
y_mask_test = tox21['test']['y_mask']

type_spec = x_train.spec

In [5]:
train_ds = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train, y_mask_train))
    .shuffle(1024)
    .batch(32)
    .prefetch(-1)
)

val_ds = (
    tf.data.Dataset.from_tensor_slices((x_val, y_val, y_mask_val))
    .batch(32)
    .prefetch(-1)
)

test_ds = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test, y_mask_test))
    .batch(32)
    .prefetch(-1)
)

### 3. Modeling

In [7]:
from molgraph.layers import MPNNConv
from molgraph.layers import LaplacianPositionalEncoding
from molgraph.layers import SetGatherReadout
from molgraph.layers import MinMaxScaling

node_preprocessing = MinMaxScaling(
    feature='node_feature', feature_range=(0, 1), threshold=True)
edge_preprocessing = MinMaxScaling(
    feature='edge_feature', feature_range=(0, 1), threshold=True)

node_preprocessing.adapt(train_ds.map(lambda x, *args: x))
edge_preprocessing.adapt(train_ds.map(lambda x, *args: x))

model = tf.keras.Sequential([
    keras.layers.Input(type_spec=type_spec),
    node_preprocessing,
    edge_preprocessing,
    LaplacianPositionalEncoding(),
    MPNNConv(normalization='batch_norm'),
    MPNNConv(normalization='batch_norm'),
    MPNNConv(normalization='batch_norm'),
    SetGatherReadout(),
    keras.layers.Dense(1024, 'relu'),
    keras.layers.Dense(1024, 'relu'),
    keras.layers.Dense(y_train.shape[-1], 'sigmoid')
])


optimizer = keras.optimizers.Adam(1e-4)
loss = MaskedBinaryCrossentropy(name='bce')
metrics = [
    # AUC deals with masks
    keras.metrics.AUC(name='roc_auc', multi_label=True) 
]

callbacks = [
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_roc_auc',
        factor=0.1,
        patience=5,
        min_lr=1e-6,
        mode='max',
    ),
    keras.callbacks.EarlyStopping(
        monitor='val_roc_auc',
        patience=10,
        mode='max',
        restore_best_weights=True,
    )
]

model.compile(optimizer, loss, weighted_metrics=metrics)
history = model.fit(
    train_ds, 
    callbacks=callbacks, 
    validation_data=val_ds, 
    epochs=100,
    verbose=2,
)
score = model.evaluate(test_ds)
print(score)

Epoch 1/100
196/196 - 20s - loss: 0.2953 - roc_auc: 0.5798 - val_loss: 0.2795 - val_roc_auc: 0.6407 - lr: 1.0000e-04 - 20s/epoch - 103ms/step
Epoch 2/100
196/196 - 13s - loss: 0.2739 - roc_auc: 0.6427 - val_loss: 0.2414 - val_roc_auc: 0.7301 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 3/100
196/196 - 13s - loss: 0.2616 - roc_auc: 0.6961 - val_loss: 0.2321 - val_roc_auc: 0.7534 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 4/100
196/196 - 13s - loss: 0.2544 - roc_auc: 0.7180 - val_loss: 0.2270 - val_roc_auc: 0.7654 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 5/100
196/196 - 13s - loss: 0.2513 - roc_auc: 0.7234 - val_loss: 0.2551 - val_roc_auc: 0.6889 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 6/100
196/196 - 13s - loss: 0.2476 - roc_auc: 0.7372 - val_loss: 0.2193 - val_roc_auc: 0.7653 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 7/100
196/196 - 13s - loss: 0.2396 - roc_auc: 0.7570 - val_loss: 0.2144 - val_roc_auc: 0.8030 - lr: 1.0000e-04 - 13s/epoch - 67ms/step
Epoch 8/100
