# Molecule Property Prediction with Tox21 Dataset


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepchem/jaxchem//blob/master/notebooks/tox21_exmaple.ipynb)


## Install packages

First, we need to install deepchem for using some useful functions about the Tox21 dataset. (Maybe it will take almost 3 minutes)


In [1]:
%tensorflow_version 1.x
!curl -Lo deepchem_installer.py https://raw.githubusercontent.com/deepchem/deepchem/master/scripts/colab_install.py
import deepchem_installer
%time deepchem_installer.install(version='2.3.0')

TensorFlow 1.x selected.
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  3477  100  3477    0     0  25014      0 --:--:-- --:--:-- --:--:-- 25014


add /root/miniconda/lib/python3.6/site-packages to PYTHONPATH
deepchem is already installed


CPU times: user 1.65 ms, sys: 920 µs, total: 2.57 ms
Wall time: 2.18 ms


And then, we install jaxchem with some dependencies   
**Caution** : After running the following commands, you need to restart the session. If you don't restart the session, maybe you will face an error.

In [2]:
!pip install -q dm-haiku==0.0.1 typing-extensions==3.7.4.2  git+https://github.com/deepchem/jaxchem

  Building wheel for jaxchem (setup.py) ... [?25l[?25hdone


## Import modules

If we face the error `ImportError: cannot import name 'Literal'
`, we should restart the session of this notebook.

In [3]:
import os
import time
import random
import pickle
import argparse
from typing import Any, Tuple

import jax
import numpy as np
import haiku as hk
import jax.numpy as jnp
from jax.experimental import optix
from sklearn.metrics import roc_auc_score


from deepchem.molnet import load_tox21
from jaxchem.models import PadGCNPredicator, clipped_sigmoid
from jaxchem.utils import EarlyStopping


# type definition
Batch = Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray]
State, OptState = Any, Any

# tox21 tasks
task_names = ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD',
              'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']



The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



## Download Tox21 dateset

we download the Tox21 dataset which were preprocessed.  In this example, we should use the `AdjacencyConv` featurizer because `PadGCNPredicator` depends on the pad pattern GCN which use the adjacency matrix to represent node connections



In [4]:
# load tox21 dataset
tox21_tasks, tox21_datasets, _ = load_tox21(featurizer='AdjacencyConv', reload=True)
train_dataset, valid_dataset, test_dataset = tox21_datasets

Loading dataset from disk.
Loading dataset from disk.
Loading dataset from disk.


## Define some utilities

In [5]:
def seed_everything(seed: int = 42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)


def collate_fn(original_batch: Any, task_index: int) -> Batch:
    """Make batch data for PadGCN model inputs."""
    inputs, targets, _, _ = original_batch
    node_feats = np.array([inputs[i][1] for i in range(len(inputs))])
    adj = np.array([inputs[i][0] for i in range(len(inputs))])
    targets = targets[:, task_index]
    return ((node_feats, adj), targets)

def binary_cross_entropy(logits: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray:
    """Compute binary cross entropy loss."""
    return -jnp.mean(targets * jnp.log(logits) + (1.0 - targets) * jnp.log(1.0 - logits))

## Setup model and optimizer

We define the forward function using `GCNPredicator` which JAXChem provides. In this case, our task is a classification, so we modify the output of  `GCNPredicator` using a sigmoid function. After defining the forward function, we create the model instance by using `haiku.transform_with_state`.

In [6]:
rng_seq = hk.PRNGSequence(1234)

# model params
in_feats = train_dataset.X[0][1].shape[1]
hidden_feats = [64, 64, 64]
activation, batch_norm, dropout = None, None, None  # use default
predicator_hidden_feats = 32
pooling_method = 'mean'
predicator_dropout = None  # use default
n_out = 1  # binary classification

def forward(node_feats: jnp.ndarray, adj: jnp.ndarray, is_training: bool) -> jnp.ndarray:
    """Forward application of the GCN."""
    model = PadGCNPredicator(in_feats=in_feats, hidden_feats=hidden_feats, activation=activation,
                                              batch_norm=batch_norm, dropout=dropout, pooling_method=pooling_method,
                                              predicator_hidden_feats=predicator_hidden_feats,
                                              predicator_dropout=predicator_dropout, n_out=n_out)
    preds = model(node_feats, adj, is_training)
    logits = clipped_sigmoid(preds)
    return logits

# we use haiku
model = hk.transform_with_state(forward)

And then, we also create the optimizer instance. 

In [7]:
# optimizer params
lr = 0.001
optimizer = optix.adam(learning_rate=lr)

## Define loss, update and evaluate function

Using the model and optimizer instance, we define the following functions. These functions are used in a training loop, so we add `@jax.jit` of the decorator to improve the performance.

- The function which calculates a loss value
- The function which updates parameters
- The function which calculates metric values for the validation data

In [8]:
# define training loss
def train_loss(params: hk.Params, state: State, batch: Batch) -> Tuple[jnp.ndarray, State]:
    """Compute the loss."""
    inputs, targets = batch
    logits, new_state = model.apply(params, state, next(rng_seq), *inputs, True)
    loss = binary_cross_entropy(logits, targets)
    return loss, new_state

# define training update
@jax.jit
def update(params: hk.Params, state: State, opt_state: OptState,
           batch: Batch) -> Tuple[hk.Params, State, OptState]:
    """Update the params."""
    (_, new_state), grads = jax.value_and_grad(train_loss, has_aux=True)(params, state, batch)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optix.apply_updates(params, updates)
    return new_params, new_state, new_opt_state

# define evaluate metrics
@jax.jit
def evaluate(params: hk.Params, state: State, batch: Batch) -> jnp.ndarray:
    """Compute evaluate metrics."""
    inputs, targets = batch
    logits, _ = model.apply(params, state, next(rng_seq), *inputs, False)
    loss = binary_cross_entropy(logits, targets)
    return logits, loss, targets

## Training

We set up hyperparamter. 

In [9]:
# training params
seed = 42
batch_size = 32
early_stop_patience = 15
num_epochs = 50
task = 'NR-AR'

# fix seed
seed_everything(seed)

And then, we train our model!

In [10]:
# initialize some values 
task_index = tox21_tasks.index(task)
early_stop = EarlyStopping(patience=early_stop_patience)
batch_init_data = (
    np.zeros((batch_size, *train_dataset.X[0][1].shape)),
    np.zeros((batch_size, *train_dataset.X[0][0].shape)),
    True
)
params, state = model.init(next(rng_seq), *batch_init_data)
opt_state = optimizer.init(params)

In [11]:
print("Starting training...")
for epoch in range(num_epochs):
    # train
    start_time = time.time()
    for original_batch in train_dataset.iterbatches(batch_size=batch_size):
        batch = collate_fn(original_batch, task_index)
        params, state, opt_state = update(params, state, opt_state, batch)
    epoch_time = time.time() - start_time

    # valid
    y_score, y_true, valid_loss = [], [], []
    for original_batch in valid_dataset.iterbatches(batch_size=batch_size):
        batch = collate_fn(original_batch, task_index)
        logits, loss, targets = evaluate(params, state, batch)
        y_score.extend(logits), valid_loss.append(loss), y_true.extend(targets)
    score = roc_auc_score(y_true, y_score)
    # log
    print(f"Iter {epoch}/{num_epochs} ({epoch_time:.4f} s) valid loss: {np.mean(valid_loss):.4f} \
        valid roc_auc score: {score:.4f}")

    # check early stopping
    early_stop.update(score, (params, state))
    if early_stop.is_train_stop:
        print("Early stopping...")
        break

Starting training...
Iter 0/50 (4.5257 s) valid loss: 0.1609         valid roc_auc score: 0.6886
Iter 1/50 (1.8490 s) valid loss: 0.1611         valid roc_auc score: 0.6843
Iter 2/50 (1.6758 s) valid loss: 0.1568         valid roc_auc score: 0.6875
Iter 3/50 (1.6675 s) valid loss: 0.1616         valid roc_auc score: 0.6901
Iter 4/50 (1.6652 s) valid loss: 0.1560         valid roc_auc score: 0.6476
Iter 5/50 (1.6780 s) valid loss: 0.1607         valid roc_auc score: 0.7002
Iter 6/50 (1.6586 s) valid loss: 0.1590         valid roc_auc score: 0.6935
Iter 7/50 (1.6643 s) valid loss: 0.1713         valid roc_auc score: 0.6992
Iter 8/50 (1.6635 s) valid loss: 0.1559         valid roc_auc score: 0.7040
Iter 9/50 (1.6957 s) valid loss: 0.1562         valid roc_auc score: 0.7161
Iter 10/50 (1.6514 s) valid loss: 0.1612         valid roc_auc score: 0.6659
Iter 11/50 (1.6497 s) valid loss: 0.1569         valid roc_auc score: 0.7249
Iter 12/50 (1.6658 s) valid loss: 0.1563         valid roc_auc sc

## Testing

Finally, we evaluate the result of test dataset and save best model parametars and states.

In [12]:
y_score, y_true = [], []
best_checkpoints = early_stop.best_checkpoints
for original_batch in test_dataset.iterbatches(batch_size=batch_size):
    batch = collate_fn(original_batch, task_index)
    logits, _, targets = evaluate(*best_checkpoints, batch)
    y_score.extend(logits), y_true.extend(targets)

score = roc_auc_score(y_true, y_score)
print(f'Test roc_auc score: {score:.4f}')

Test roc_auc score: 0.7349


In [13]:
# save best checkpoints
with open('./best_checkpoints.pkl', 'wb') as f:
    pickle.dump(best_checkpoints, f)