# Molecule Property Prediction with Tox21 Dataset


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepchem/jaxchem//blob/master/notebooks/gcn_property_prediction.ipynb)


## Install packages

First, we need to install deepchem for using some useful functions about the Tox21 dataset. (Maybe it will take almost 3 minutes)


In [None]:
%tensorflow_version 1.x
!curl -Lo deepchem_installer.py https://raw.githubusercontent.com/deepchem/deepchem/master/scripts/colab_install.py
import deepchem_installer
%time deepchem_installer.install(version='2.3.0')

TensorFlow 1.x selected.
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  3477  100  3477    0     0  17922      0 --:--:-- --:--:-- --:--:-- 17922


add /root/miniconda/lib/python3.6/site-packages to PYTHONPATH
deepchem is already installed


CPU times: user 2.05 ms, sys: 9 µs, total: 2.06 ms
Wall time: 1.69 ms


And then, we install jaxchem with some dependencies   
**Caution** : After running the following commands, you need to restart the session. If you don't restart the session, maybe you will face an error.

In [None]:
!pip install -q dm-haiku==0.0.1 typing-extensions==3.7.4.2  git+https://github.com/deepchem/jaxchem

  Building wheel for jaxchem (setup.py) ... [?25l[?25hdone


## Import modules

If we face the error `ImportError: cannot import name 'Literal'
`, we should restart the session of this notebook.

In [None]:
import os
import time
import random
import pickle
import argparse
from typing import Any, Tuple, List

import jax
import numpy as np
import haiku as hk
import jax.numpy as jnp
from jax.experimental import optix
from sklearn.metrics import roc_auc_score


from deepchem.molnet import load_tox21
from jaxchem.models import PadGCNPredicator as GCNPredicator
from jaxchem.loss import binary_cross_entropy_with_logits as bce_with_logits
from jaxchem.utils import EarlyStopping


# type definition
Batch = Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray]
State, OptState = Any, Any



The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



## Download Tox21 dateset

we download the Tox21 dataset which were preprocessed.  In this example, we should use the `AdjacencyConv` featurizer because `PadGCNPredicator` depends on the pad pattern GCN which use the adjacency matrix to represent node connections



In [None]:
# load tox21 dataset
tox21_tasks, tox21_datasets, _ = load_tox21(featurizer='AdjacencyConv', reload=True)
train_dataset, valid_dataset, test_dataset = tox21_datasets

Loading dataset from disk.
Loading dataset from disk.
Loading dataset from disk.


In [None]:
print(tox21_tasks)

['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']


## Define some utilities

In [None]:
def seed_everything(seed: int = 42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)


def collate_fn(original_batch: Any) -> Batch:
    """Make batch data as PadGCN model inputs."""
    inputs, targets, _, _ = original_batch
    node_feats = np.array([inputs[i][1] for i in range(len(inputs))])
    adj = np.array([inputs[i][0] for i in range(len(inputs))])
    return (node_feats, adj), np.array(targets)


def multi_task_roc_auc_score(y_true: np.ndarray, y_score: np.ndarray) -> Tuple[float, List[float]]:
    """Calculate the roc_auc_score of all tasks for Tox21."""
    num_tasks = y_true.shape[1]
    scores = []
    for i in range(num_tasks):
        scores.append(roc_auc_score(y_true[:, i], y_score[:, i]))
    return np.mean(scores), scores

## Setup model and optimizer

We define the forward function using `GCNPredicator` which JAXChem provides. In this case, our task is a classification, so we modify the output of  `GCNPredicator` using a sigmoid function. After defining the forward function, we create the model instance by using `haiku.transform_with_state`.

In [None]:
rng_seq = hk.PRNGSequence(1234)

# model params
in_feats = train_dataset.X[0][1].shape[1]
hidden_feats = [64, 64, 32]
activation, batch_norm, dropout = None, None, None  # use default
predicator_hidden_feats = 32
pooling_method = 'mean'
predicator_dropout = 0.2
n_out = len(tox21_tasks)

def forward(node_feats: jnp.ndarray, adj: jnp.ndarray, is_training: bool) -> jnp.ndarray:
    """Forward application of the GCN."""
    model = GCNPredicator(in_feats=in_feats, hidden_feats=hidden_feats, activation=activation,
                          batch_norm=batch_norm, dropout=dropout, pooling_method=pooling_method,
                          predicator_hidden_feats=predicator_hidden_feats,
                          predicator_dropout=predicator_dropout, n_out=n_out)
    preds = model(node_feats, adj, is_training)
    return preds

# we use haiku
model = hk.transform_with_state(forward)

And then, we also create the optimizer instance. 

In [None]:
# optimizer params
lr = 0.001
optimizer = optix.adam(learning_rate=lr)

## Define loss, update and evaluate function

Using the model and optimizer instance, we define the following functions. These functions are used in a training loop, so we add `@jax.jit` of the decorator to improve the performance.

- The function which calculates a loss value
- The function which updates parameters
- The function which calculates metric values for the validation data

In [None]:
# define training loss
def train_loss(params: hk.Params, state: State, batch: Batch) -> Tuple[jnp.ndarray, State]:
    """Compute the loss."""
    inputs, targets = batch
    preds, new_state = model.apply(params, state, next(rng_seq), *inputs, True)
    loss = bce_with_logits(preds, targets)
    return loss, new_state

# define training update
@jax.jit
def update(params: hk.Params, state: State, opt_state: OptState,
           batch: Batch) -> Tuple[hk.Params, State, OptState]:
    """Update the params."""
    (_, new_state), grads = jax.value_and_grad(train_loss, has_aux=True)(params, state, batch)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optix.apply_updates(params, updates)
    return new_params, new_state, new_opt_state

# define evaluate metrics
@jax.jit
def evaluate(params: hk.Params, state: State, batch: Batch) -> jnp.ndarray:
    """Compute evaluate metrics."""
    inputs, targets = batch
    preds, _ = model.apply(params, state, next(rng_seq), *inputs, False)
    loss = bce_with_logits(preds, targets)
    return preds, loss, targets

## Training

We set up hyperparamter. 

In [None]:
# training params
seed = 42
batch_size = 32
early_stop_patience = 15
num_epochs = 100

# fix seed
seed_everything(seed)

And then, we train our model!

In [None]:
# initialize some values 
early_stop = EarlyStopping(patience=early_stop_patience)
batch_init_data = (
    jnp.zeros((batch_size, *train_dataset.X[0][1].shape)),
    jnp.zeros((batch_size, *train_dataset.X[0][0].shape)),
    True
)
params, state = model.init(next(rng_seq), *batch_init_data)
opt_state = optimizer.init(params)

In [None]:
print("Starting training...")
for epoch in range(num_epochs):
    # train
    start_time = time.time()
    for original_batch in train_dataset.iterbatches(batch_size=batch_size):
        batch = collate_fn(original_batch)
        params, state, opt_state = update(params, state, opt_state, batch)
    epoch_time = time.time() - start_time

    # valid
    y_score, y_true, valid_loss = [], [], []
    for original_batch in valid_dataset.iterbatches(batch_size=batch_size):
        batch = collate_fn(original_batch)
        logits, loss, targets = evaluate(params, state, batch)
        y_score.extend(logits), valid_loss.append(loss), y_true.extend(targets)
    score, _ = multi_task_roc_auc_score(np.array(y_true), np.array(y_score))

    # log
    print(f"Iter {epoch}/{num_epochs} ({epoch_time:.4f} s) \
            valid loss: {np.mean(valid_loss):.4f} \
            valid roc_auc score: {score:.4f}")
    # check early stopping
    early_stop.update(score, (params, state))
    if early_stop.is_train_stop:
        print("Early stopping...")
        break

Starting training...
Iter 0/200 (5.7399 s)             valid loss: 0.2196             valid roc_auc score: 0.5792
Iter 1/200 (1.9963 s)             valid loss: 0.2172             valid roc_auc score: 0.5976
Iter 2/200 (1.9389 s)             valid loss: 0.2121             valid roc_auc score: 0.5977
Iter 3/200 (1.9314 s)             valid loss: 0.2141             valid roc_auc score: 0.6130
Iter 4/200 (1.9399 s)             valid loss: 0.2131             valid roc_auc score: 0.6260
Iter 5/200 (1.9310 s)             valid loss: 0.2084             valid roc_auc score: 0.6314
Iter 6/200 (1.9151 s)             valid loss: 0.2082             valid roc_auc score: 0.6503
Iter 7/200 (1.9233 s)             valid loss: 0.2096             valid roc_auc score: 0.6572
Iter 8/200 (1.8927 s)             valid loss: 0.2070             valid roc_auc score: 0.6594
Iter 9/200 (1.9201 s)             valid loss: 0.2087             valid roc_auc score: 0.6571
Iter 10/200 (1.9437 s)             valid loss: 0.

## Testing

Finally, we evaluate the result of test dataset and save best model parametars and states.

In [None]:
y_score, y_true = [], []
best_checkpoints = early_stop.best_checkpoints
for original_batch in test_dataset.iterbatches(batch_size=batch_size):
    batch = collate_fn(original_batch)
    logits, _, targets = evaluate(*best_checkpoints, batch)
    y_score.extend(logits), y_true.extend(targets)

score, scores = multi_task_roc_auc_score(np.array(y_true), np.array(y_score))
print(f'Test mean roc_auc score: {score:.4f}')

Test mean roc_auc score: 0.7799


In [None]:
print(f'Test all roc_auc score: {str(scores)}')

Test all roc_auc score: [0.8177938392703493, 0.9044473800088066, 0.8703401337160338, 0.7781960784313726, 0.6970420913325198, 0.764922568034882, 0.7323701410388718, 0.6976487819919827, 0.7766335892155026, 0.6964186061779865, 0.8184771680247303, 0.8043416858330309]


In [None]:
# save best checkpoints
with open('./best_checkpoints.pkl', 'wb') as f:
    pickle.dump(best_checkpoints, f)