# Compound representation learning and property prediction

In this tuorial, we will go through how to run a Graph Neural Network (GNN) model for compound property prediction. In particular, we will demonstrate how to pretrain and finetune the model in the downstream tasks. If you are intersted in more details, please refer to the README for "[info graph](https://github.com/PaddlePaddle/PaddleHelix/apps/pretrained_compound/info_graph)" and "[pretrained GNN](https://github.com/PaddlePaddle/PaddleHelix/apps/pretrained_compound/pretrain_gnns)".

#  Part I: Pretraining

In this part, we will show how to pretrain a compound GNN model. The pretraining skills here are adapted from the work of pretrain gnns, including attribute masking, context prediction and supervised pretraining.

Visit `pretrain_attrmask.py` and `pretrain_supervised.py` for more details.

In [1]:
import os
import numpy as np
import sys
sys.path.insert(0, os.getcwd() + "/..")
os.chdir("../apps/pretrained_compound/pretrain_gnns")

The Pahelix framework is build upon PaddlePaddle, which is a deep learning framework.


In [2]:
import paddle
import paddle.nn as nn
import pgl

from pahelix.model_zoo.pretrain_gnns_model import PretrainGNNModel, AttrmaskModel
from pahelix.datasets.zinc_dataset import load_zinc_dataset
from pahelix.utils.splitters import RandomSplitter
from pahelix.featurizers.pretrain_gnn_featurizer import AttrmaskTransformFn, AttrmaskCollateFn
from pahelix.utils import load_json_config

2021-05-08 15:04:44,729 - INFO - ujson not install, fail back to use json instead
2021-05-08 15:04:44,809 - INFO - Enabling RDKit 2020.09.1 jupyter extensions


## Load configurations
Here, we use `compound_encoder_config`,`model_config` to hold the compound encoder and model configurations. `PretrainGNNModel` is the basic GNN Model used in pretrain gnns,`AttrmaskModel` is an unsupervised pretraining model which randomly masks the atom type of a node and then tries to predict the masked atom type. In the meanwhile, we use the Adam optimizer and set the learning rate to be 0.001.

In [3]:
compound_encoder_config = load_json_config("model_configs/pregnn_paper.json")
model_config = load_json_config("model_configs/pre_Attrmask.json")

compound_encoder = PretrainGNNModel(compound_encoder_config)
model = AttrmaskModel(model_config, compound_encoder)
opt = paddle.optimizer.Adam(0.001, parameters=model.parameters())

  and should_run_async(code)


[PretrainGNNModel] embed_dim:300
[PretrainGNNModel] dropout_rate:0.5
[PretrainGNNModel] norm_type:batch_norm
[PretrainGNNModel] graph_norm:False
[PretrainGNNModel] residual:False
[PretrainGNNModel] layer_num:5
[PretrainGNNModel] gnn_type:gin
[PretrainGNNModel] JK:last
[PretrainGNNModel] readout:mean
[PretrainGNNModel] atom_names:['atomic_num', 'chiral_tag']
[PretrainGNNModel] bond_names:['bond_dir', 'bond_type']


## Dataset loading and feature extraction
### Download the dataset using wget
First, we need to download a small dataset for this demo. If you do not have `wget` on your machine, you could also
copy the url below into your web browser to download the data. But remember to copy the data manually to the
path "../apps/pretrained_compound/pretrain_gnns/".

In [4]:
### Download a toy dataset for demonstration:
!wget "https://baidu-nlp.bj.bcebos.com/PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz" --no-check-certificate
!tar -zxf "PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz"
!ls "./chem_dataset_small"
### Download the full dataset as you want:
# !wget "http://snap.stanford.edu/gnn-pretrain/data/chem_dataset.zip" --no-check-certificate
# !unzip "chem_dataset.zip"
# !ls "./chem_dataset"

--2021-05-08 15:05:21--  https://baidu-nlp.bj.bcebos.com/PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz
Resolving baidu-nlp.bj.bcebos.com (baidu-nlp.bj.bcebos.com)... 10.70.0.165
Connecting to baidu-nlp.bj.bcebos.com (baidu-nlp.bj.bcebos.com)|10.70.0.165|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 609563 (595K) [application/gzip]
Saving to: ‘PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz.1’


2021-05-08 15:05:21 (13.9 MB/s) - ‘PaddleHelix%2Fdatasets%2Fcompound_datasets%2Fchem_dataset_small.tgz.1’ saved [609563/609563]

tox21  zinc_standard_agent


### Load the dataset  and generate features
The Zinc dataset is used as the pretraining dataset.Here we use a toy dataset for demonstration,you can load the full dataset as you want.
`AttrmaskTransformFn` is used along with `AttrmaskModel`.It is used to generate features. The original features are processed into features available on the network, such as smiles strings into node and edge features.

In [5]:
### Load the first 1000 of the toy dataset for speed up
dataset = load_zinc_dataset("./chem_dataset_small/zinc_standard_agent/")
dataset = dataset[:1000]
print("dataset num: %s" % (len(dataset)))

transform_fn = AttrmaskTransformFn()
dataset.transform(transform_fn, num_workers=2)

dataset num: 1000


## Start train
Now we train the attrmask model for 2 epochs for demostration purposes. Here we use `AttrmaskTransformFn` to aggregate multiple samples into a mini-batch.And the data loading process is accelerated with 4 processors.Then the pretrained model is saved to "./model/pretrain_attrmask", which will serve as the initial model of the downstream tasks.

In [6]:
def train(model, dataset, collate_fn, opt):
    data_gen = dataset.get_data_loader(
            batch_size=128, 
            num_workers=4, 
            shuffle=True,
            collate_fn=collate_fn)
    list_loss = []
    model.train()
    for graphs, masked_node_indice, masked_node_label in data_gen:
        graphs = graphs.tensor()
        masked_node_indice = paddle.to_tensor(masked_node_indice, 'int64')
        masked_node_label = paddle.to_tensor(masked_node_label, 'int64')
        loss = model(graphs, masked_node_indice, masked_node_label)
        loss.backward()
        opt.step()
        opt.clear_grad()
        list_loss.append(loss.numpy())
    return np.mean(list_loss)

collate_fn = AttrmaskCollateFn(
        atom_names=compound_encoder_config['atom_names'], 
        bond_names=compound_encoder_config['bond_names'],
        mask_ratio=0.15)

for epoch_id in range(2):
    train_loss = train(model, dataset, collate_fn, opt)
    print("epoch:%d train/loss:%s" % (epoch_id, train_loss))
paddle.save(compound_encoder.state_dict(), 
        './model/pretrain_attrmask/compound_encoder.pdparams')

  "When training, we now always track global mean and variance.")


epoch:0 train/loss:2.7467122
epoch:1 train/loss:0.93576944


The above is about the pretraining steps,you can adjust as needed.

# Part II: Downstream finetuning
Below we will introduce how to use the pretrained model for the finetuning of downstream tasks.

Visit `finetune.py` for more details.

In [7]:
from pahelix.utils.splitters import \
    RandomSplitter, IndexSplitter, ScaffoldSplitter
from pahelix.datasets import *

from src.model import DownstreamModel
from src.featurizer import DownstreamTransformFn, DownstreamCollateFn
from src.utils import calc_rocauc_score, exempt_parameters

The downstream datasets are usually small and have different tasks. For example, the BBBP dataset is used for the predictions of the Blood-brain barrier permeability. The Tox21 dataset is used for the predictions of toxicity of compounds. Here we use the Tox21 dataset for demonstrations.

In [8]:
task_names = get_default_tox21_task_names()
print(task_names)

['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']


## Load configurations
Here, we use `compound_encoder_config` and `model_config` to hold the model configurations. Note that the configurations of the model architecture should align with that of the pretraining model, otherwise the loading will fail. 
`DownstreamModel` is an supervised GNN model which predicts the tasks shown in `task_names`. Meanwhile, we use BCEloss as the criterion,Adam optimizer and set the lr to be 0.001.

In [9]:
compound_encoder_config = load_json_config("model_configs/pregnn_paper.json")
model_config = load_json_config("model_configs/down_linear.json")
model_config['num_tasks'] = len(task_names)

compound_encoder = PretrainGNNModel(compound_encoder_config)
model = DownstreamModel(model_config, compound_encoder)
criterion = nn.BCELoss(reduction='none')
opt = paddle.optimizer.Adam(0.001, parameters=model.parameters())

[PretrainGNNModel] embed_dim:300
[PretrainGNNModel] dropout_rate:0.5
[PretrainGNNModel] norm_type:batch_norm
[PretrainGNNModel] graph_norm:False
[PretrainGNNModel] residual:False
[PretrainGNNModel] layer_num:5
[PretrainGNNModel] gnn_type:gin
[PretrainGNNModel] JK:last
[PretrainGNNModel] readout:mean
[PretrainGNNModel] atom_names:['atomic_num', 'chiral_tag']
[PretrainGNNModel] bond_names:['bond_dir', 'bond_type']


## Load pretrained models
Load the pretrained model in the pretraining phase. Here we load the model "pretrain_attrmask" as an example.

In [10]:
compound_encoder.set_state_dict(paddle.load('./model/pretrain_attrmask/compound_encoder.pdparams'))

## Dataset loading and feature extraction
`DownstreamTransformFn` is used along with `DownstreamModel`.It is used to generate features. The original features are processed into features available on the network, such as smiles strings into node and edge features. 

The Tox21 dataset is used as the downstream dataset and we use `ScaffoldSplitter` to split the dataset into train/valid/test set. `ScaffoldSplitter` will firstly order the compounds according to Bemis-Murcko scaffold, then take the first `frac_train` proportion as the train set, the next `frac_valid` proportion as the valid set and the rest as the test set. `ScaffoldSplitter` can better evaluate the generalization ability of the model on out-of-distribution samples. Note that other splitters like `RandomSplitter`, `RandomScaffoldSplitter` and `IndexSplitter` is also available.

In [11]:
### Load the toy dataset:
dataset = load_tox21_dataset("./chem_dataset_small/tox21", task_names)
### Load the full dataset:
# dataset = load_tox21_dataset("./chem_dataset/tox21", task_names)
dataset.transform(DownstreamTransformFn(), num_workers=4)

# splitter = RandomSplitter()
splitter = ScaffoldSplitter()
train_dataset, valid_dataset, test_dataset = splitter.split(
        dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1)
print("Train/Valid/Test num: %s/%s/%s" % (
        len(train_dataset), len(valid_dataset), len(test_dataset)))




Train/Valid/Test num: 6264/783/784


## Start train
Now we train the attrmask model for 4 epochs for demostration purposes.Here we use `DownstreamCollateFn` to aggregate multiple samples data into a mini-batch. Since each downstream task will contain more than one sub-task, the performance of the model is evaluated by the average roc-auc on all sub-tasks.

In [12]:
def train(model, train_dataset, collate_fn, criterion, opt):
    data_gen = train_dataset.get_data_loader(
            batch_size=128, 
            num_workers=4, 
            shuffle=True,
            collate_fn=collate_fn)
    list_loss = []
    model.train()
    for graphs, valids, labels in data_gen:
        graphs = graphs.tensor()
        labels = paddle.to_tensor(labels, 'float32')
        valids = paddle.to_tensor(valids, 'float32')
        preds = model(graphs)
        loss = criterion(preds, labels)
        loss = paddle.sum(loss * valids) / paddle.sum(valids)
        loss.backward()
        opt.step()
        opt.clear_grad()
        list_loss.append(loss.numpy())
    return np.mean(list_loss)

def evaluate(model, test_dataset, collate_fn):
    data_gen = test_dataset.get_data_loader(
            batch_size=128, 
            num_workers=4, 
            shuffle=False,
            collate_fn=collate_fn)
    total_pred = []
    total_label = []
    total_valid = []
    model.eval()
    for graphs, valids, labels in data_gen:
        graphs = graphs.tensor()
        labels = paddle.to_tensor(labels, 'float32')
        valids = paddle.to_tensor(valids, 'float32')
        preds = model(graphs)
        total_pred.append(preds.numpy())
        total_valid.append(valids.numpy())
        total_label.append(labels.numpy())
    total_pred = np.concatenate(total_pred, 0)
    total_label = np.concatenate(total_label, 0)
    total_valid = np.concatenate(total_valid, 0)
    return calc_rocauc_score(total_label, total_pred, total_valid)

collate_fn = DownstreamCollateFn(
        atom_names=compound_encoder_config['atom_names'], 
        bond_names=compound_encoder_config['bond_names'])
for epoch_id in range(4):
    train_loss = train(model, train_dataset, collate_fn, criterion, opt)
    val_auc = evaluate(model, valid_dataset, collate_fn)
    test_auc = evaluate(model, test_dataset, collate_fn)
    print("epoch:%s train/loss:%s" % (epoch_id, train_loss))
    print("epoch:%s val/auc:%s" % (epoch_id, val_auc))
    print("epoch:%s test/auc:%s" % (epoch_id, test_auc))
paddle.save(model.state_dict(), './model/tox21/model.pdparams')

Valid ratio: 0.7603235
Task evaluated: 12/12
Valid ratio: 0.7513818
Task evaluated: 12/12
epoch:0 train/loss:0.2782306
epoch:0 val/auc:0.6327719308701215
epoch:0 test/auc:0.6230107310357754


  "When training, we now always track global mean and variance.")


Valid ratio: 0.7603235
Task evaluated: 12/12
Valid ratio: 0.7513818
Task evaluated: 12/12
epoch:1 train/loss:0.22611414
epoch:1 val/auc:0.6998613183307891
epoch:1 test/auc:0.661867930217498


  "When training, we now always track global mean and variance.")


Valid ratio: 0.7603235
Task evaluated: 12/12
Valid ratio: 0.7513818
Task evaluated: 12/12
epoch:2 train/loss:0.22136745
epoch:2 val/auc:0.6257492122355784
epoch:2 test/auc:0.6332904016792812


  "When training, we now always track global mean and variance.")


Valid ratio: 0.7603235
Task evaluated: 12/12
Valid ratio: 0.7513818
Task evaluated: 12/12
epoch:3 train/loss:0.21559708
epoch:3 val/auc:0.6922950038299994
epoch:3 test/auc:0.6663004542741277


# Part III: Downstream Inference
In this part, we will briefly introduce how to use the trained downstream model to do inference on the given SMILES sequences.

## Load configurations
This part is the basically the same as the part II.

In [13]:
compound_encoder_config = load_json_config("model_configs/pregnn_paper.json")
model_config = load_json_config("model_configs/down_linear.json")
model_config['num_tasks'] = len(task_names)

compound_encoder = PretrainGNNModel(compound_encoder_config)
model = DownstreamModel(model_config, compound_encoder)

[PretrainGNNModel] embed_dim:300
[PretrainGNNModel] dropout_rate:0.5
[PretrainGNNModel] norm_type:batch_norm
[PretrainGNNModel] graph_norm:False
[PretrainGNNModel] residual:False
[PretrainGNNModel] layer_num:5
[PretrainGNNModel] gnn_type:gin
[PretrainGNNModel] JK:last
[PretrainGNNModel] readout:mean
[PretrainGNNModel] atom_names:['atomic_num', 'chiral_tag']
[PretrainGNNModel] bond_names:['bond_dir', 'bond_type']


  and should_run_async(code)


## Load finetuned models
Load the finetuned model from part II.

In [14]:
model.set_state_dict(paddle.load('./model/tox21/model.pdparams'))

## Start Inference
Do inference on the given SMILES sequence. We use directly call `DownstreamTransformFn` and `DownstreamCollateFn` to convert the raw SMILES sequence to the model input. 

Using Tox21 dataset as an example, our finetuned downstream model can make predictions over 12 sub-tasks on Tox21.

In [15]:
SMILES="O=C1c2ccccc2C(=O)C1c1ccc2cc(S(=O)(=O)[O-])cc(S(=O)(=O)[O-])c2n1"
transform_fn = DownstreamTransformFn(is_inference=True)
collate_fn = DownstreamCollateFn(
        atom_names=compound_encoder_config['atom_names'], 
        bond_names=compound_encoder_config['bond_names'],
        is_inference=True)
graph = collate_fn([transform_fn({'smiles': SMILES})])
preds = model(graph.tensor()).numpy()[0]
print('SMILES:%s' % SMILES)
print('Predictions:')
for name, prob in zip(task_names, preds):
    print("  %s:\t%s" % (name, prob))

SMILES:O=C1c2ccccc2C(=O)C1c1ccc2cc(S(=O)(=O)[O-])cc(S(=O)(=O)[O-])c2n1
Predictions:
  NR-AR:	0.40785
  NR-AR-LBD:	0.3486675
  NR-AhR:	0.36817572
  NR-Aromatase:	0.40480733
  NR-ER:	0.40827876
  NR-ER-LBD:	0.34133768
  NR-PPAR-gamma:	0.3827442
  SR-ARE:	0.37666
  SR-ATAD5:	0.3464901
  SR-HSE:	0.37402567
  SR-MMP:	0.47734728
  SR-p53:	0.39040926


  "When training, we now always track global mean and variance.")
