In [1]:
import tensorflow as tf
import deepchem as dc

import numpy as np

np.random.seed(123)

from deepchem.feat import Featurizer

In [2]:
import sys
sys.path.insert(0, "/home/zhengxu/github/drug/seq2seq-fingerprint/")

from unsupervised.seq2seq_model import FingerprintFetcher

In [3]:
# Define our seq2seq featurizer.

from rdkit import Chem

class Seq2seqFeaturizer(Featurizer):
    """Seq2seq Featurizer."""

    def __init__(self, model_dir, vocab_dir):
        """Define the seq2seq feature."""
        self.fetcher = FingerprintFetcher(model_dir, vocab_dir)
        
    def _featurize(self, mol):
        """
        Calculate features for a single molecule.
        Parameters
        ----------
        mol : RDKit Mol
            Molecule.
        """
        # This is a bit hacky. I have no idea why we have to start from mol instead of original smile.
        smile = Chem.MolToSmiles(mol)
        fp, _ = self.fetcher.decode(smile)
        return fp

In [4]:
# Initailize the featurizer and cache it.
sess = tf.InteractiveSession()
featurizer = Seq2seqFeaturizer("/home/zhengxu/expr/test/gru-4-256", "/home/zhengxu/expr/test/pretrain/pm2.vocab")

Loading seq2seq model definition from /home/zhengxu/expr/test/gru-4-256/model.json...
Loading model weights from checkpoint_dir: /home/zhengxu/expr/test/gru-4-256/weights/


In [13]:
# Build up specific model builder.

from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC, NuSVC
# Building scikit random forest model

# Use this class to select different models for different task/dataset/split.
class SKLearnModelSelector(object):
    
    DATASET_MAPPING = {
        "sider": {
            "index": (GradientBoostingClassifier, {}),
            "random": (GradientBoostingClassifier, {}),
            "scaffold": (GradientBoostingClassifier, {})
        }
    }
    
    def __init__(self, dataset, split):
        """Input dataset and split."""
        self.dataset = dataset
        self.split = split
        
    def __call__(self, task):
        model_class, model_hparam = self.DATASET_MAPPING[self.dataset][self.split]
        sklearn_model = model_class(**model_hparam)
        return dc.models.sklearn_models.SklearnModel(sklearn_model, task)


In [14]:
from deepchem.molnet.run_benchmark import load_dataset, benchmark_model
from itertools import product

metric = [dc.metrics.Metric(dc.metrics.roc_auc_score, np.mean)]

datasets = [
    'sider',
    # 'clintox' # discard due to lower performance.
]
splits = [
    "index",
    "random", # discard to save time.
    "scaffold"
]

for dataset, split in product(datasets, splits):
    print("="*80)
    print("Dataset: %s, split: %s" % (dataset, split))
    tasks, all_datasets, transformers = load_dataset(dataset, featurizer, split)
    reg_model = dc.models.multitask.SingletaskToMultitask(tasks, SKLearnModelSelector(dataset, split))
    train, val, test, t = benchmark_model(reg_model, all_datasets, transformers, metric, test=True)
    print("RESULT: " + "="*80)
    print("Dataset: %s, split: %s" % (dataset, split))
    print(train, val, test)
    print("t = %.10f" % t)
    print("="*80)

Dataset: sider, split: index
-------------------------------------
Loading dataset: sider
-------------------------------------
Splitting function: index
About to load MUV dataset.
Columns of dataset: ['smiles' 'Hepatobiliary disorders' 'Metabolism and nutrition disorders'
 'Product issues' 'Eye disorders' 'Investigations'
 'Musculoskeletal and connective tissue disorders'
 'Gastrointestinal disorders' 'Social circumstances'
 'Immune system disorders' 'Reproductive system and breast disorders'
 'Neoplasms benign, malignant and unspecified (incl cysts and polyps)'
 'General disorders and administration site conditions'
 'Endocrine disorders' 'Surgical and medical procedures'
 'Vascular disorders' 'Blood and lymphatic system disorders'
 'Skin and subcutaneous tissue disorders'
 'Congenital, familial and genetic disorders' 'Infections and infestations'
 'Respiratory, thoracic and mediastinal disorders' 'Psychiatric disorders'
 'Renal and urinary disorders'
 'Pregnancy, puerperium and peri

In [None]:
sess.close()