# Toxicity Prediction using Deepchem - 
#### Models - GCN,GAT and GGCN

In [50]:
# Installing conda
!curl -Lo conda_installer.py https://raw.githubusercontent.com/deepchem/deepchem/master/scripts/colab_install.py
import conda_installer
conda_installer.install()
!/root/miniconda/bin/conda info -e

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  3457  100  3457    0     0  34919      0 --:--:-- --:--:-- --:--:-- 35275


all packages are already installed
INFO:conda_installer:all packages are already installed


# conda environments:
#
base                     /root/miniconda



In [51]:
# Installing Deepchem
!pip install --pre deepchem
# Import deepchem just to check the version
import deepchem
deepchem.__version__

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


'2.7.2.dev'

In [52]:
# Importing required libraries and its utilities
import numpy as np

np.random.seed(123)
import tensorflow as tf

tf.random.set_seed(123)
import deepchem as dc
from deepchem.molnet import load_tox21
from deepchem.models.graph_models import GraphConvModel
# from deepchem.models import GATModel

In [53]:
# Tox21 is a part of Deepchem library
# so we can convieniently download it using load_tox21 function
tox21_tasks, tox21_datasets, transformers = load_tox21(featurizer='GraphConv')
train_dataset, valid_dataset, test_dataset = tox21_datasets

In [54]:
# Define metric for the model
metric = dc.metrics.Metric(dc.metrics.roc_auc_score, 
                           np.mean, 
                           mode="classification")


GCN model

In [64]:
# Define and fit the model
model = GraphConvModel(len(tox21_tasks), 
                       batch_size=32,
                       activation_fn=tf.nn.relu,
                       mode='classification')
print("Fitting the model")
model.fit(train_dataset, nb_epoch=20)

Fitting the model


0.545823049545288

In [65]:
print("Evaluating model with ROC AUC")
train_scores = model.evaluate(train_dataset, [metric], transformers)
valid_scores = model.evaluate(valid_dataset, [metric], transformers)

Evaluating model with ROC AUC


In [66]:
print("Train scores")
print(train_scores)

print("Validation scores")
print(valid_scores)

Train scores
{'mean-roc_auc_score': 0.934546761353627}
Validation scores
{'mean-roc_auc_score': 0.7449231964974411}


GAT model

In [67]:
model = GraphConvModel(len(tox21_tasks), 
                 batch_size=32, 
                 mode='classification',
                 dropout=0.2,
                 activation_fn=tf.nn.relu,
                 model_dir = 'gat_model')
print("Fitting the model")
model.fit(train_dataset, nb_epoch=20)

Fitting the model


0.6991569519042968

In [68]:
print("Evaluating model with ROC AUC")
train_scores = model.evaluate(train_dataset, [metric], transformers)
valid_scores = model.evaluate(valid_dataset, [metric], transformers)

Evaluating model with ROC AUC


In [69]:
print("Train scores")
print(train_scores)

print("Validation scores")
print(valid_scores)

Train scores
{'mean-roc_auc_score': 0.8821193898922894}
Validation scores
{'mean-roc_auc_score': 0.7452863334290091}


Gated GCN model

In [70]:
# from deepchem.models import GatedGraphConvModel

model = GraphConvModel(
    len(tox21_tasks),
    batch_size=32,
    mode = "classification",
    dropout = 0.2,
    activation_fn=tf.nn.relu,
    model_dir='ggcn_model'
)
print("Fitting the model")
model.fit(train_dataset, nb_epoch=20)

Fitting the model


0.7176002502441406

In [71]:
print("Evaluating model with ROC AUC")
train_scores = model.evaluate(train_dataset, [metric], transformers)
valid_scores = model.evaluate(valid_dataset, [metric], transformers)

Evaluating model with ROC AUC


In [72]:
print("Train scores")
print(train_scores)

print("Validation scores")
print(valid_scores)

Train scores
{'mean-roc_auc_score': 0.8827387471016305}
Validation scores
{'mean-roc_auc_score': 0.747146528696915}
