# A Tutorial on Quantitative Structure-Activity Relationships using Fully Expressive Graph Neural Networks

The data set used for this tutorial is Tox21, which contains approximately 8k chemical compounds, tested on 12 different receptor types. Tox21 is in this tutorial considered a multi-label classification data set (with missing labels). 

The tutorial was performed using:

- Ubuntu 22.04
    - Python 3.10 
        - Keras 2
        - TensorFlow 2.13
        - MolGraph 0.6.4

The tutorial has also been performed on Windows 10/11 and MacOS 12.

The ROC-AUC score on the testing data set should mostly be in the range 0.84 to 0.85


## 1. Import modules

In [None]:
from molgraph import chemistry
from molgraph import layers
from molgraph import losses

import keras

import tensorflow as tf
import pandas as pd

## 2. Read in the Tox21 data

In [None]:
tox21 = chemistry.datasets.get("tox21")

x_train = tox21["train"]["x"]
y_train = tox21["train"]["y"]
m_train = tox21["train"]["y_mask"]

x_val = tox21["validation"]["x"]
y_val = tox21["validation"]["y"]
m_val = tox21["validation"]["y_mask"]

x_test = tox21["test"]["x"]
y_test = tox21["test"]["y"]
m_test = tox21["test"]["y_mask"]

## 3. Construct molecular graphs from SMILES

In [None]:
atom_encoder = chemistry.Featurizer([
    chemistry.features.Symbol(),
    chemistry.features.Hybridization(),
    chemistry.features.TotalValence(),
    chemistry.features.Hetero(),
    chemistry.features.HydrogenDonor(),
    chemistry.features.HydrogenAcceptor(),
])

bond_encoder = chemistry.Featurizer([
    chemistry.features.BondType(),
    chemistry.features.Rotatable(),
])

mol_encoder = chemistry.MolecularGraphEncoder(
    atom_encoder, bond_encoder, positional_encoding_dim=None,
)

train_graph = mol_encoder(x_train)
val_graph = mol_encoder(x_val)
test_graph = mol_encoder(x_test)

## 4. Set up input pipelines from data

In [None]:
train_data = (train_graph, y_train, m_train)
val_data = (val_graph, y_val, m_val)
test_data = (test_graph, y_test, m_test)

train_ds = (
    tf.data.Dataset.from_tensor_slices(train_data)
    .shuffle(1024)
    .batch(32)
    .prefetch(-1)
)

val_ds = (
    tf.data.Dataset.from_tensor_slices(val_data)
    .batch(32)
    .prefetch(-1)
)

test_ds = (
    tf.data.Dataset.from_tensor_slices(test_data)
    .batch(32)
    .prefetch(-1)
)

## 5. Build the QSAR model

In [None]:
inputs = keras.layers.Input(type_spec=train_graph.spec)

variance_threshold = layers.VarianceThreshold()
variance_threshold.adapt(train_graph)

h0 = variance_threshold(inputs)
h0 = layers.FeatureProjection(units=128)(h0)

# 1) Message passing (L = 4)
h1 = layers.GINConv(units=128, normalization="batch_norm")(h0)
h2 = layers.GINConv(units=128, normalization="batch_norm")(h1)
h3 = layers.GINConv(units=128, normalization="batch_norm")(h2)
h4 = layers.GINConv(units=128, normalization="batch_norm")(h3)

# 2) Readout 
z0 = layers.Readout()(h0)
z1 = layers.Readout()(h1)
z2 = layers.Readout()(h2)
z3 = layers.Readout()(h3)
z4 = layers.Readout()(h4)

z = keras.layers.Concatenate()([z0, z1, z2, z3, z4])

# 3) Prediction
z = keras.layers.Dense(units=1024, activation="relu")(z)
z = keras.layers.Dense(units=1024, activation="relu")(z)
outputs = keras.layers.Dense(units=12, activation="sigmoid")(z)

# Create model
qsar_model = keras.Model(inputs, outputs)

## 6. Compile the QSAR model

In [None]:
optimizer = keras.optimizers.SGD(
    learning_rate=0.005, momentum=0.5
)

loss = losses.MaskedBinaryCrossentropy()

metrics = [
    keras.metrics.AUC(multi_label=True, name="auc"),
]

qsar_model.compile(
    optimizer=optimizer, 
    loss=loss, 
    weighted_metrics=metrics
)

## 7. Train, validate, and evaluate the QSAR model on input

In [None]:
callbacks = [
    keras.callbacks.ReduceLROnPlateau(
        monitor="val_auc", patience=10, mode="max"
    ),
    keras.callbacks.EarlyStopping(
        monitor="val_auc", patience=20, mode="max",
        restore_best_weights=True
    ),
]

# May slow down training time
callbacks += [
    keras.callbacks.TensorBoard(
        log_dir="./logs", histogram_freq=1)
]

qsar_model.fit(
    train_ds,
    callbacks=callbacks,
    validation_data=val_ds,
    epochs=100, 
)

bce_loss, auc_score = qsar_model.evaluate(test_ds)


## 8. Predict probability of activity with the QSAR model

In [None]:
y_test_pred = qsar_model.predict(test_ds)

receptor_names = [
    "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", 
    "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", 
    "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"
]

def scores(receptor_names, trues, preds, masks):
    for name, true, pred, mask in zip(
        receptor_names, trues.T, preds.T, masks.T
    ):
        yield {
            'Receptor': name,
            'AUC': keras.metrics.AUC()(
                true, pred, mask).numpy(), 
            'TP': keras.metrics.TruePositives()(
                true, pred, mask).numpy().astype(int), 
            'FP': keras.metrics.FalsePositives()(
                true, pred, mask).numpy().astype(int), 
            'TN': keras.metrics.TrueNegatives()(
                true, pred, mask).numpy().astype(int), 
            'FN': keras.metrics.FalseNegatives()(
                true, pred, mask).numpy().astype(int)
        }
        
results_table = pd.DataFrame([
    score for score in scores(
        receptor_names, y_test, y_test_pred, m_test
    )
])

In [None]:
results_table