In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time, sys, os, re, pickle, itertools, multiprocessing
from pathlib import Path
import networkx as nx
import pandas as pd
import numpy as np
import scipy as sp
import seaborn as sns
import matplotlib.pyplot as plt
from os import path

from pysmiles import read_smiles
from copy import deepcopy
from collections import defaultdict
from multiprocessing import Pool
#from pandarallel import pandarallel
#pandarallel.initialize(nb_workers= 20 )

from scipy import stats
from math import log
from scipy.special import softmax
from numpy.random import choice
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier as RFC
from sklearn.ensemble import GradientBoostingClassifier, BaggingClassifier, AdaBoostClassifier
from sklearn.svm import SVC
from sklearn.naive_bayes import GaussianNB
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import cross_val_score, train_test_split, cross_validate
from sklearn.metrics import accuracy_score, balanced_accuracy_score, roc_auc_score, matthews_corrcoef, precision_score, recall_score, f1_score, confusion_matrix
from sklearn.metrics.cluster import contingency_matrix
from sklearn.metrics.pairwise import cosine_similarity, pairwise_distances

import rdkit
from rdkit import Chem
from rdkit.Chem import rdMolDescriptors as rdm
from rdkit.Chem import AllChem, Draw, rdDepictor, rdchem
from rdkit.Chem.SaltRemover import SaltRemover
from rdkit.Chem.Draw import IPythonConsole, MolDrawing, rdMolDraw2D
from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions #Only needed if modifying defaults
from IPython.display import SVG
remover = SaltRemover()
rdDepictor.SetPreferCoordGen(True)

# avoid print warnings in 'pysmiles'
import logging
logging.getLogger('pysmiles').setLevel(logging.CRITICAL)  # Anything higher than warning

In [3]:
from srw_class.set_device_seed import *   # seeding
from srw_class.mychem import *            # RW functions
from srw_class.mydata import *            # data
from srw_class.DILI import *              # DILI class object, classification 
from srw_class import graph_cook as gk    # graph cooking functions

# Load TDC Data

In [4]:
from tdc.benchmark_group import admet_group
group = admet_group(path = 'data/')
benchmark = group.get('dili')

predictions = {}
name = benchmark['name']
train_val, test = benchmark['train_val'], benchmark['test']

Found local copy...


In [5]:
train, valid = group.get_train_valid_split(benchmark = name, split_type = 'scaffold', seed = 3078)

generating training, validation splits...
100%|██████████| 379/379 [00:00<00:00, 1939.18it/s]


In [6]:
# split['train'].head()
print(train['Y'].value_counts(sort=False))
print(valid['Y'].value_counts(sort=False))
print(test['Y'].value_counts(sort=False))

0.0    150
1.0    175
Name: Y, dtype: int64
0.0    43
1.0    11
Name: Y, dtype: int64
0.0    46
1.0    50
Name: Y, dtype: int64


In [7]:
tdc_dili_train = train.copy()
tdc_dili_valid = valid.copy()
tdc_dili_test = test.copy()

def transform_dili(df):    
    removed = []
    df['molobj'] = [Chem.MolFromSmiles(x) for x in df.Drug]
    df['molgraph'] = [read_smiles(x, reinterpret_aromatic=True) for x in df.Drug]
    df.rename(columns = {'Drug':'smiles', 'Y':'class', 'Drug_ID':'ID'}, inplace=True)
    df.index = df['ID']
    for ind in df.index:
        if ('.' in df.smiles[ind]):
            removed.append(ind)
            df.drop([ind],inplace=True)
        else:
            if df.molgraph[ind].number_of_nodes() <2:
                df.drop([ind],inplace=True)
                removed.append(ind)
    print('Dropped: ', removed)
    return df

train_molinfo_df = transform_dili(tdc_dili_train)
valid_molinfo_df = transform_dili(tdc_dili_valid)
test_molinfo_df = transform_dili(tdc_dili_test)

Dropped:  [441281.0, 60714.0, 9887054.0, 444013.0, 25517.0, 28486.0, 5234.0, 6433328.0, 6433516.0, 11963622.0, 23663956.0, 23673837.0, 3000502.0]
Dropped:  []
Dropped:  [114965.0, 11806.0, 10429215.0, 27350.0, 9915926.0, 60168.0]


In [8]:
train_molinfo_df = pd.concat([train_molinfo_df, valid_molinfo_df])
print(train_molinfo_df.shape)

(366, 5)


# Set Parameters for Classification

In [10]:
# alpha = 0.1, rule = "random", "walkers" = 10

lChemistry = ['graph', 'atom']
lAlphas    = [0.1, 0.5]
lwalkers   = [1, 3, 5, 10, 20]
lrules     = ['random']
lRW        = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
lPruning   = [False, 'pure']
# lPruning   = [False]

lTrainData = ["tdc_dili_train"] 
lvalidnames = ['tdc_dili_test']

iterations = 20
directory = './results/'
os.chdir('../')  # bin
print(os.getcwd())
dUpdateMethodDict = {1 : 'Method1RatioUpdate', 2 : 'Method2LRTrainEntropyUpdate', 3 : 'Method3EntropyUpdate'}
nMethod = 2
project_name = dUpdateMethodDict[nMethod]
os.makedirs(os.path.dirname(directory), exist_ok=True)
paramnames = ['rw', 'alpha', 'rule', 'walkers', 'pruning', 'chemistry']
paramlist = list(itertools.product(lRW, lAlphas, lrules, lwalkers, lPruning, lChemistry))

/data/project/sslim/drug/DILI/bin
168


In [12]:
def run_ml(train_mat, train_label, test_mat, test_label,
           classifier='RF', count=False, message=False):
    # count: "count" / "binary"
    # message: "original" / "pruned"
    clf = RFC(random_state = 3078).fit(train_mat, 
                                           train_label)
    preds = clf.predict(test_mat)
    probs = clf.predict_proba(test_mat).T[1]
    #
    nAccuracy = accuracy_score(test_label, preds)
    nBalAcc = balanced_accuracy_score(test_label, preds)
    nAUC = roc_auc_score(test_label, probs, average="weighted")
    nMCC = matthews_corrcoef(test_label, preds)
    n_precision = precision_score(test_label, preds, average="weighted")
    n_recall    = recall_score(test_label, preds, average="weighted")
    n_f1score   = f1_score(test_label, preds, average="weighted")
    # contingency table
    returned = confusion_matrix( test_label, preds ).ravel()
    if len(returned) == 4:
        tn, fp, fn, tp = returned
    else:
        tn, fp, fn, tp = 0, 0, 0, 0
    
    return nAccuracy, nBalAcc, nAUC, nMCC, n_precision, n_recall, n_f1score, tn, fp, fn, tp

In [None]:
seed = 3078
lperform_params = list(itertools.product(['count'], ['original']))

def run_srw(param):
    nRW, nAlpha, sRule, nWalker, sPruning, sChemistry = param
    def classification(train_obj, valid_obj, train_df, valid_df, valid_name, project_name, _message, seed=3078, num_cores = 10): # SRWtrain, SRWvalid
        fname = directory + str(valid_name) + '_' + project_name + '_' + str(sChemistry) + '_rw' + str(nRW) + '_alpha' + str(int(nAlpha*10)) + "_Path" + str(sRule) + "_Prune" + str(sPruning) + "_walkers" + str(nWalker) + '_iteration' + str(iterations) + '_RF_' + '_'.join(_message) + '_performance.tsv'
        if not os.path.isfile(fname):
            _datatype, _pruned = _message
            fmatname = directory + str(valid_name) + '_' + project_name + '_' + str(sChemistry) + '_rw' + str(nRW) + '_alpha' + str(int(nAlpha*10)) + "_Path" + str(sRule) + "_Prune" + str(sPruning) + "_walkers" + str(nWalker) + '_iteration' + str(iterations) + '_' + '_'.join(_message) + '_matrices.pickle'
            try:
                mat_archive = open(fmatname, 'rb')
                dMatDict = pickle.load(mat_archive)
            except:
                dMatDict = {}
                #
                train_agg_X, valid_agg_X = pd.DataFrame(), pd.DataFrame()
                #
                for nIter in range(iterations):
                    train_X, valid_X, train_y, valid_y, n_union, n_train, n_valid = GetSubgraphMatrix(train_obj, valid_obj, train_df, valid_df, nIter)
                    if nIter == 0:
                        train_agg_X, valid_agg_X = train_X.copy(), valid_X.copy()
                    else:
                        train_X = GetUpdateUnion(train_agg_X, train_X, useupdate=False)
                        valid_X = GetUpdateUnion(valid_agg_X, valid_X, useupdate=False)
                        train_agg_X, valid_agg_X = train_X.copy(), valid_X.copy()
                    #
                    #
                    dMatDict[nIter] = train_X, valid_X, n_union, n_train, n_valid
                    print(f'Classification starts with Training/validation data shape: {train_X.shape} / {valid_X.shape}')
                # Write to pickle
                mat_archive = open(fmatname, 'wb')
                pickle.dump(dMatDict, mat_archive, pickle.HIGHEST_PROTOCOL)
            # END OF try/except
            # Run classificaiton
            # RF    
            fname = directory + str(valid_name) + '_' + project_name + '_' + str(sChemistry) + '_rw' + str(nRW) + '_alpha' + str(int(nAlpha*10)) + "_Path" + str(sRule) + "_Prune" + str(sPruning) + "_walkers" + str(nWalker) + '_iteration' + str(iterations) + '_RF_' + '_'.join(_message) + '_performance.tsv'
            if not path.exists(fname):
                pd_result =  pd.DataFrame( 0, index = range(iterations),  columns = ['n_union_subgraphs', 'n_train_subgraphs', 'n_valid_subgraphs', 'n_updated_subgraphs', 'Accuracy', 'BAcc', 'Precision', 'Recall', 'F1_score', 'AUC', 'MCC'], dtype=np.float64)
                pd_confusion = pd.DataFrame(0, index = range(iterations), columns = ["tn", "fp", "fn", "tp"])
                for nIter in range(iterations):
                    train_X, valid_X, n_union, n_train, n_valid = dMatDict[nIter]
                    nAccuracy, nBalAcc, nAUC, nMCC, n_precision, n_recall, n_f1score, tn, fp, fn, tp = run_ml(train_X, train_y, valid_X, valid_y, 
                           classifier = "RF", count = _datatype, message = _pruned)
                    n_final = train_X.shape[1]
                    pd_result.iloc[nIter] = [n_union, n_train, n_valid, n_final, nAccuracy, nBalAcc, n_precision, n_recall, n_f1score, nAUC, nMCC]
                    pd_confusion.iloc[nIter] = [tn, fp, fn, tp]
                fname = directory + str(valid_name) + '_' + project_name + '_' + str(sChemistry) + '_rw' + str(nRW) + '_alpha' + str(int(nAlpha*10)) + "_Path" + str(sRule) + "_Prune" + str(sPruning) + "_walkers" + str(nWalker) + '_iteration' + str(iterations) + '_RF_' + '_'.join(_message) + '_performance.tsv'
                pd_result.to_csv(fname, sep="\t")
                fname = directory + str(valid_name) + '_' + project_name + '_' + str(sChemistry) + '_rw' + str(nRW) + '_alpha' + str(int(nAlpha*10)) + "_Path" + str(sRule) + "_Prune" + str(sPruning) + "_walkers" + str(nWalker) + '_iteration' + str(iterations) + '_RF_' + '_'.join(_message) + '_confusion_matrix.tsv'
                pd_confusion.to_csv(fname, sep="\t")
    # END of definition: classification
    # train: tdc_dili_train
    train_filename = directory + 'tdc_dili_train' + '_' + project_name + '_' + str(sChemistry) + '_rw' + str(nRW) + '_alpha' + str(int(nAlpha*10)) + "_Path" + str(sRule) + "_Prune" + str(sPruning) + "_walkers" + str(nWalker) + '_iteration' + str(iterations) + '.pickle'
    try:
        srw_archive = open(train_filename, 'rb')
        SRWtrain = pickle.load(srw_archive)
    except:
        print('New run starts')
        SRWtrain = DILInew(chemistry = sChemistry, n_rw = nRW, n_alpha = nAlpha, iteration = iterations, pruning = sPruning, n_walker = nWalker , rw_mode = sRule, update_method = nMethod, ml_mode = 'train')
        _ = SRWtrain.train(train_molinfo_df)
        train_archive = open(train_filename, 'wb')
        pickle.dump(SRWtrain, train_archive, pickle.HIGHEST_PROTOCOL)
        print('train run done')
    # test: tdc_dili_test
    fname = directory + 'tdc_dili_test_' + project_name + '_' + str(sChemistry) + '_rw' + str(nRW) + '_alpha' + str(int(nAlpha*10)) + "_Path" + str(sRule) + "_Prune" + str(sPruning) + "_walkers" + str(nWalker) + '_iteration' + str(iterations) + '_XGB_binary_pruned_performance.tsv'
    if not path.exists(fname):
        test_filename = directory + 'tdc_dili_test' + '_' + project_name + '_' + str(sChemistry) + '_rw' + str(nRW) + '_alpha' + str(int(nAlpha*10)) + "_Path" + str(sRule) + "_Prune" + str(sPruning) + "_walkers" + str(nWalker) + '_iteration' + str(iterations) + '.pickle'
        try: 
            srw_archive = open(test_filename, 'rb')
            SRWtest = pickle.load(srw_archive)
        except:
            SRWtest = DILInew(chemistry = sChemistry, n_rw = nRW, n_alpha = nAlpha, iteration = iterations, pruning = sPruning, n_walker = nWalker , rw_mode = sRule, update_method = nMethod, ml_mode = 'test')
            SRWtest.valid(test_molinfo_df, train_molinfo_df, SRWtrain)
            test_archive = open(test_filename, 'wb')
            pickle.dump(SRWtest, test_archive, pickle.HIGHEST_PROTOCOL)
        for message in lperform_params:
            # message = ['valid', 'count', 'original']
            classification(SRWtrain, SRWtest, train_molinfo_df, test_molinfo_df, 'tdc_dili_test', project_name, message, seed=seed)
        print('test run done')

pool = multiprocessing.Pool(processes = 1)
pool.map(run_srw, paramlist)
pool.close()
pool.join()

New run startsNew run startsNew run startsNew run starts


New run startsNew run starts
New run startsNew run starts



New run startsNew run startsNew run startsNew run startsNew run starts
New run starts
New run starts

New run starts



The Number of allowed walks: 5
The Number of allowed walks: 15The Number of allowed walks: 20The Number of allowed walks: 1The Number of allowed walks: 40


 
The Number of allowed walks: 20The Number of allowed walks: 300The Number of allowed walks: 4loop starts
00
  
  00loop startsloop starts0loop startsloop starts  The Number of allowed walks: 8 The Number of allowed walks: 1The Number of allowed walks: 15loop starts
loop startsThe Number of allowed walks: 15loop starts
The Number of allowed walks: 1The Number of allowed walks: 5
0The Number of allowed walks: 2

The Number of allowed walks: 70 
0
00
 0loop starts 0  0 loop startsloop starts loop startsloop starts loop startsloop startsloop startstest run done
New run starts
New run starts
The Num