In [1]:
%pip install PyTDC

Collecting PyTDC
  Downloading PyTDC-0.4.0.tar.gz (107 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/107.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m107.3/107.3 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting rdkit-pypi (from PyTDC)
  Downloading rdkit_pypi-2022.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.4/29.4 MB[0m [31m46.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting fuzzywuzzy (from PyTDC)
  Downloading fuzzywuzzy-0.18.0-py2.py3-none-any.whl (18 kB)
Collecting huggingface_hub (from PyTDC)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m26.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dataclasses (from PyTDC)
  Downloading datac

In [2]:
%pip install rdkit

Collecting rdkit
  Downloading rdkit-2023.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.7/29.7 MB[0m [31m26.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: rdkit
Successfully installed rdkit-2023.3.2


In [3]:
%pip install pandas-flavor

Collecting pandas-flavor
  Downloading pandas_flavor-0.6.0-py3-none-any.whl (7.2 kB)
Installing collected packages: pandas-flavor
Successfully installed pandas-flavor-0.6.0


In [4]:
import pandas as pd
import numpy as np
from tqdm.auto import tqdm #progress bar
import rdkit
from rdkit import Chem #Chemistry
from rdkit.Chem import rdMolDescriptors #molecular descriptors
from rdkit.Chem import PandasTools
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
rdkit.__version__

import xgboost as xgb
import sklearn
from sklearn.metrics import r2_score
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import fbeta_score, make_scorer
from xgboost.sklearn import XGBRegressor
from sklearn.svm import SVR
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import AdaBoostRegressor
from sklearn.tree import DecisionTreeRegressor
from keras.models import Sequential
from keras.layers import Dense, Conv1D, Flatten
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split #ML training
from sklearn.model_selection import RandomizedSearchCV
from sklearn.metrics import r2_score, mean_squared_error #ML stats
from yellowbrick.regressor import prediction_error, ResidualsPlot
from tdc.single_pred import ADME
from tdc.benchmark_group import admet_group
from tdc import BenchmarkGroup
import warnings
warnings.filterwarnings("ignore")

In [5]:
group = admet_group(path = 'data/')
predictions_list_xgb = []
best_params_list_xgb = []
predictions_list_rf = []
predictions_list_svm = []
best_params_list_svm = []
predictions_list_adb = []
predictions_list_cnn = []

for seed in [1, 2, 3, 4, 5]:
    predictions_xgb = {}
    predictions_rf = {}
    predictions_svm = {}
    predictions_adb = {}
    predictions_cnn = {}

    benchmark = group.get('PPBR_AZ')
    name = benchmark['name']
    train_val, test = benchmark['train_val'], benchmark['test']
    train, valid = group.get_train_valid_split(benchmark = name, split_type = 'default', seed = seed)

    PandasTools.AddMoleculeColumnToFrame(train, smilesCol='Drug')
    radius=2
    nBits=1024
    ECFP6 = [AllChem.GetMorganFingerprintAsBitVect(x,radius=radius, nBits=nBits) for x in train['ROMol']]
    ecfp6_name = [f'Bit_{i}' for i in range(nBits)]
    ecfp6_bits = [list(l) for l in ECFP6]
    Y = train['Y']
    train = pd.DataFrame(ecfp6_bits, index = train.Drug, columns=ecfp6_name).reset_index(drop = False)
    train['Y'] = Y

    PandasTools.AddMoleculeColumnToFrame(benchmark['test'], smilesCol='Drug')
    radius=2
    nBits=1024
    ECFP6 = [AllChem.GetMorganFingerprintAsBitVect(x,radius=radius, nBits=nBits) for x in benchmark['test']['ROMol']]
    ecfp6_name = [f'Bit_{i}' for i in range(nBits)]
    ecfp6_bits = [list(l) for l in ECFP6]
    Y = benchmark['test']['Y']
    benchmark['test'] = pd.DataFrame(ecfp6_bits, index = benchmark['test'].Drug, columns=ecfp6_name).reset_index(drop = False)
    benchmark['test']['Y'] = Y

    train_X = train.drop(columns = ["Drug","Y"])
    train_y = train.Y
    test_X = benchmark['test'].drop(columns = ["Drug","Y"])
    test_y = benchmark['test'].Y


    #XGBoost + Morgan
    xgb_parameters = {'objective':['reg:squarederror'],
              'booster':['gbtree','gblinear'],
              'learning_rate': [0.1],
              'max_depth': [7,10,15,20],
              'min_child_weight': [10,15,20,25],
              'colsample_bytree': [0.8, 0.9, 1],
              'n_estimators': [300,400,500,600],
              "reg_alpha"   : [0.5,0.2,1],
              "reg_lambda"  : [2,3,5],
              "gamma"       : [1,2,3]}

    xgb_model = XGBRegressor()
    grid_obj_xgb = RandomizedSearchCV(xgb_model, xgb_parameters, cv=5, n_iter=15, scoring='neg_mean_absolute_error', verbose=5, n_jobs=1)
    grid_obj_xgb.fit(train_X, train_y, verbose = 1)
    y_pred_test_xgb = grid_obj_xgb.predict(test_X)
    bp_xgb = grid_obj_xgb.best_params_
    predictions_xgb[name] = y_pred_test_xgb
    predictions_list_xgb.append(predictions_xgb)
    best_params_list_xgb.append(bp_xgb)

    #Random Forest + Morgan
    rf_model = RandomForestRegressor()
    rf_model.fit(train_X, train_y)
    y_pred_test_rf = rf_model.predict(test_X)
    predictions_rf[name] = y_pred_test_rf
    predictions_list_rf.append(predictions_rf)

    #SVM + Morgan
    svm_parameters = {
        'C': [0.1, 1, 10, 100, 1000],
        'epsilon': [0.0001, 0.001, 0.01, 0.1, 0.5],
        'gamma': [0.0001, 0.001, 0.01, 0.1, 1]
    }
    svm_model = SVR(kernel="rbf")
    grid_obj_svm = RandomizedSearchCV(svm_model, svm_parameters, cv=5, n_iter=15, scoring='neg_mean_absolute_error', verbose=5, n_jobs=1)
    grid_obj_svm.fit(train_X, train_y)
    y_pred_test_svm = grid_obj_svm.predict(test_X)
    bp_svm = grid_obj_svm.best_params_
    predictions_svm[name] = y_pred_test_svm
    predictions_list_svm.append(predictions_svm)
    best_params_list_svm.append(bp_svm)

    # AdaBoost + Morgan
    DTR = DecisionTreeRegressor(max_depth=4)
    adb_model = AdaBoostRegressor(n_estimators=300, base_estimator=DTR, learning_rate=1)
    adb_model.fit(train_X, train_y)
    y_pred_test_adb = adb_model.predict(test_X)
    predictions_adb[name] = y_pred_test_adb
    predictions_list_adb.append(predictions_adb)

    #CNN + Morgan
    train_X = train_X.to_numpy()
    test_X = test_X.to_numpy()
    train_X = train_X.reshape(train_X.shape[0], train_X.shape[1], 1)
    test_X = test_X.reshape(test_X.shape[0], test_X.shape[1], 1)
    cnn_model = Sequential()
    cnn_model.add(Conv1D(32, 2, activation="relu", input_shape=(1024,1)))
    cnn_model.add(Flatten())
    cnn_model.add(Dense(64, activation="relu"))
    cnn_model.add(Dense(1))
    cnn_model.compile(loss="mse", optimizer="adam")
    cnn_model.fit(train_X, train_y, batch_size=12, epochs=10, verbose=0)
    y_pred_test_cnn = cnn_model.predict(test_X)
    predictions_cnn[name] = y_pred_test_cnn
    predictions_list_cnn.append(predictions_cnn)

xgb_results = group.evaluate_many(predictions_list_xgb)
rf_results = group.evaluate_many(predictions_list_rf)
svm_results = group.evaluate_many(predictions_list_svm)
adb_results = group.evaluate_many(predictions_list_adb)
cnn_results = group.evaluate_many(predictions_list_cnn)



Downloading Benchmark Group...
100%|██████████| 1.47M/1.47M [00:01<00:00, 1.33MiB/s]
Extracting zip file...
Done!
generating training, validation splits...
100%|██████████| 2231/2231 [00:01<00:00, 1719.47it/s]


Fitting 5 folds for each of 15 candidates, totalling 75 fits
Parameters: { "colsample_bytree", "gamma", "max_depth", "min_child_weight" } are not used.

[CV 1/5] END booster=gblinear, colsample_bytree=1, gamma=3, learning_rate=0.1, max_depth=10, min_child_weight=25, n_estimators=300, objective=reg:squarederror, reg_alpha=1, reg_lambda=5;, score=-14.711 total time=  14.2s
Parameters: { "colsample_bytree", "gamma", "max_depth", "min_child_weight" } are not used.

[CV 2/5] END booster=gblinear, colsample_bytree=1, gamma=3, learning_rate=0.1, max_depth=10, min_child_weight=25, n_estimators=300, objective=reg:squarederror, reg_alpha=1, reg_lambda=5;, score=-13.476 total time=   4.6s
Parameters: { "colsample_bytree", "gamma", "max_depth", "min_child_weight" } are not used.

[CV 3/5] END booster=gblinear, colsample_bytree=1, gamma=3, learning_rate=0.1, max_depth=10, min_child_weight=25, n_estimators=300, objective=reg:squarederror, reg_alpha=1, reg_lambda=5;, score=-13.624 total time=   5.6s


generating training, validation splits...
100%|██████████| 2231/2231 [00:01<00:00, 1690.71it/s]


Fitting 5 folds for each of 15 candidates, totalling 75 fits
Parameters: { "colsample_bytree", "gamma", "max_depth", "min_child_weight" } are not used.

[CV 1/5] END booster=gblinear, colsample_bytree=0.9, gamma=2, learning_rate=0.1, max_depth=20, min_child_weight=10, n_estimators=600, objective=reg:squarederror, reg_alpha=0.2, reg_lambda=3;, score=-12.457 total time=  10.0s
Parameters: { "colsample_bytree", "gamma", "max_depth", "min_child_weight" } are not used.

[CV 2/5] END booster=gblinear, colsample_bytree=0.9, gamma=2, learning_rate=0.1, max_depth=20, min_child_weight=10, n_estimators=600, objective=reg:squarederror, reg_alpha=0.2, reg_lambda=3;, score=-12.374 total time=   9.2s
Parameters: { "colsample_bytree", "gamma", "max_depth", "min_child_weight" } are not used.

[CV 3/5] END booster=gblinear, colsample_bytree=0.9, gamma=2, learning_rate=0.1, max_depth=20, min_child_weight=10, n_estimators=600, objective=reg:squarederror, reg_alpha=0.2, reg_lambda=3;, score=-16.307 total t

generating training, validation splits...
100%|██████████| 2231/2231 [00:01<00:00, 1720.14it/s]


Fitting 5 folds for each of 15 candidates, totalling 75 fits
Parameters: { "colsample_bytree", "gamma", "max_depth", "min_child_weight" } are not used.

[CV 1/5] END booster=gblinear, colsample_bytree=1, gamma=2, learning_rate=0.1, max_depth=15, min_child_weight=15, n_estimators=400, objective=reg:squarederror, reg_alpha=0.2, reg_lambda=2;, score=-13.136 total time=   7.4s
Parameters: { "colsample_bytree", "gamma", "max_depth", "min_child_weight" } are not used.

[CV 2/5] END booster=gblinear, colsample_bytree=1, gamma=2, learning_rate=0.1, max_depth=15, min_child_weight=15, n_estimators=400, objective=reg:squarederror, reg_alpha=0.2, reg_lambda=2;, score=-12.576 total time=   5.4s
Parameters: { "colsample_bytree", "gamma", "max_depth", "min_child_weight" } are not used.

[CV 3/5] END booster=gblinear, colsample_bytree=1, gamma=2, learning_rate=0.1, max_depth=15, min_child_weight=15, n_estimators=400, objective=reg:squarederror, reg_alpha=0.2, reg_lambda=2;, score=-12.727 total time=  

generating training, validation splits...
100%|██████████| 2231/2231 [00:02<00:00, 1028.51it/s]


Fitting 5 folds for each of 15 candidates, totalling 75 fits
[CV 1/5] END booster=gbtree, colsample_bytree=0.8, gamma=2, learning_rate=0.1, max_depth=20, min_child_weight=25, n_estimators=400, objective=reg:squarederror, reg_alpha=1, reg_lambda=2;, score=-13.261 total time= 1.1min
[CV 2/5] END booster=gbtree, colsample_bytree=0.8, gamma=2, learning_rate=0.1, max_depth=20, min_child_weight=25, n_estimators=400, objective=reg:squarederror, reg_alpha=1, reg_lambda=2;, score=-11.526 total time= 1.2min
[CV 3/5] END booster=gbtree, colsample_bytree=0.8, gamma=2, learning_rate=0.1, max_depth=20, min_child_weight=25, n_estimators=400, objective=reg:squarederror, reg_alpha=1, reg_lambda=2;, score=-10.984 total time= 1.2min
[CV 4/5] END booster=gbtree, colsample_bytree=0.8, gamma=2, learning_rate=0.1, max_depth=20, min_child_weight=25, n_estimators=400, objective=reg:squarederror, reg_alpha=1, reg_lambda=2;, score=-10.596 total time= 1.2min
[CV 5/5] END booster=gbtree, colsample_bytree=0.8, gamm

generating training, validation splits...
100%|██████████| 2231/2231 [00:01<00:00, 1701.70it/s]


Fitting 5 folds for each of 15 candidates, totalling 75 fits
Parameters: { "colsample_bytree", "gamma", "max_depth", "min_child_weight" } are not used.

[CV 1/5] END booster=gblinear, colsample_bytree=0.9, gamma=3, learning_rate=0.1, max_depth=7, min_child_weight=10, n_estimators=500, objective=reg:squarederror, reg_alpha=0.2, reg_lambda=3;, score=-12.647 total time=   6.6s
Parameters: { "colsample_bytree", "gamma", "max_depth", "min_child_weight" } are not used.

[CV 2/5] END booster=gblinear, colsample_bytree=0.9, gamma=3, learning_rate=0.1, max_depth=7, min_child_weight=10, n_estimators=500, objective=reg:squarederror, reg_alpha=0.2, reg_lambda=3;, score=-15.202 total time=   8.6s
Parameters: { "colsample_bytree", "gamma", "max_depth", "min_child_weight" } are not used.

[CV 3/5] END booster=gblinear, colsample_bytree=0.9, gamma=3, learning_rate=0.1, max_depth=7, min_child_weight=10, n_estimators=500, objective=reg:squarederror, reg_alpha=0.2, reg_lambda=3;, score=-13.412 total time

In [6]:
xgb_results

{'ppbr_az': [9.995, 0.148]}

In [7]:
rf_results

{'ppbr_az': [10.385, 0.184]}

In [8]:
svm_results

{'ppbr_az': [9.108, 0.328]}

In [9]:
adb_results

{'ppbr_az': [17.474, 3.328]}

In [10]:
cnn_results

{'ppbr_az': [12.45, 0.577]}

In [11]:
predictions_list_xgb

[{'ppbr_az': array([ 77.36366 ,  77.36366 ,  96.97196 ,  96.97196 ,  67.62856 ,
          68.18905 ,  67.62856 ,  68.18905 ,  95.286575,  92.00646 ,
          65.26235 ,  82.42039 ,  65.26235 ,  82.42039 ,  82.42039 ,
          55.287865,  65.26235 ,  87.22262 ,  96.35991 ,  68.72035 ,
          90.16987 ,  97.121216,  85.48101 ,  77.720406,  77.720406,
          77.720406, 104.780815, 104.66869 , 104.66869 , 104.780815,
         103.317665, 104.32786 ,  85.19259 ,  85.19259 ,  93.30288 ,
          93.30288 ,  67.84095 ,  67.84095 ,  67.84095 ,  67.84095 ,
          67.84095 ,  67.84095 ,  67.84095 ,  67.84095 ,  91.34934 ,
          93.06278 , 102.62187 ,  81.82919 ,  81.82919 ,  81.82919 ,
          81.82919 ,  81.82919 ,  81.82919 ,  96.200806,  86.711655,
          96.311264,  94.27203 ,  86.711655,  81.12942 ,  97.22133 ,
          81.12942 ,  91.83851 ,  97.22133 ,  85.7769  ,  83.52164 ,
          99.06116 , 101.92035 ,  82.10071 ,  82.10071 ,  97.79934 ,
          97.79934 ,  7

In [12]:
predictions_list_rf

[{'ppbr_az': array([75.3270525 , 75.3270525 , 98.62843325, 98.62843325, 79.90876119,
         77.21012452, 79.90876119, 77.21012452, 94.75714167, 95.11465012,
         64.21687857, 72.87745333, 64.21687857, 72.87745333, 72.87745333,
         64.48110833, 64.21687857, 88.49235   , 93.62555095, 78.91079667,
         89.07752595, 85.27908333, 81.06526262, 81.755845  , 81.755845  ,
         81.755845  , 96.3767244 , 84.63920667, 84.63920667, 96.3767244 ,
         91.75302083, 94.34849   , 88.67644167, 88.67644167, 92.43014619,
         92.43014619, 93.14427   , 93.14427   , 93.14427   , 93.14427   ,
         93.14427   , 93.14427   , 93.14427   , 93.14427   , 79.80967258,
         88.4268125 , 91.75051333, 87.53546833, 87.53546833, 87.53546833,
         87.53546833, 87.53546833, 87.53546833, 90.99953143, 80.58535167,
         79.01920667, 87.828125  , 80.58535167, 86.59113833, 89.90668333,
         86.59113833, 85.47722333, 89.90668333, 84.120015  , 88.51342167,
         75.92954571, 88.45

In [13]:
predictions_list_svm

[{'ppbr_az': array([ 73.38122745,  73.38122745,  95.66482396,  95.66482396,
          78.1848542 ,  84.13035066,  78.1848542 ,  84.13035066,
          93.69044525,  95.38183231,  82.84221378,  88.10177163,
          82.84221378,  88.10177163,  88.10177163,  78.79671866,
          82.84221378,  86.66729679,  90.76574508,  88.6318732 ,
          87.68756884,  94.20930801,  89.35634846,  86.90221778,
          86.90221778,  86.90221778, 102.09293784,  98.98355599,
          98.98355599, 102.09293784,  99.32330868,  98.87357589,
          90.56037818,  90.56037818,  92.33686569,  92.33686569,
          85.84056036,  85.84056036,  85.84056036,  85.84056036,
          85.84056036,  85.84056036,  85.84056036,  85.84056036,
          92.0808349 , 100.02990985, 100.95707765,  84.36724244,
          84.36724244,  84.36724244,  84.36724244,  84.36724244,
          84.36724244,  93.14520816,  89.06078249,  92.99712013,
          96.69378327,  89.06078249,  90.8534654 ,  96.80824448,
          90.8

In [14]:
predictions_list_adb

[{'ppbr_az': array([75.15213877, 75.15213877, 83.25211454, 83.25211454, 79.00320312,
         67.99108035, 79.00320312, 67.99108035, 75.15213877, 85.1001005 ,
         67.99108035, 72.10450704, 67.99108035, 72.10450704, 72.10450704,
         64.55      , 67.99108035, 78.4093844 , 92.8015942 , 72.10450704,
         64.00345625, 82.29651908, 79.98285714, 81.76021828, 81.76021828,
         81.76021828, 82.62773333, 82.62773333, 82.62773333, 82.62773333,
         75.15213877, 83.04715789, 78.4093844 , 78.4093844 , 81.76021828,
         81.76021828, 64.00345625, 64.00345625, 64.00345625, 64.00345625,
         64.00345625, 64.00345625, 64.00345625, 64.00345625, 85.1001005 ,
         81.76021828, 79.98285714, 71.74708178, 71.74708178, 71.74708178,
         71.74708178, 71.74708178, 71.74708178, 83.25211454, 78.57352542,
         78.57352542, 78.57352542, 78.57352542, 78.4093844 , 78.57352542,
         78.4093844 , 82.29651908, 78.57352542, 78.57352542, 78.4093844 ,
         81.76021828, 81.76

In [15]:
predictions_list_cnn

[{'ppbr_az': array([[ 72.70573 ],
         [ 72.70573 ],
         [ 92.53755 ],
         [ 92.53755 ],
         [ 58.826862],
         [ 72.686134],
         [ 58.826862],
         [ 72.686134],
         [ 92.76341 ],
         [ 96.62754 ],
         [ 69.26505 ],
         [ 92.87788 ],
         [ 69.26505 ],
         [ 92.87788 ],
         [ 92.87788 ],
         [ 57.15791 ],
         [ 69.26505 ],
         [ 71.58049 ],
         [ 79.74839 ],
         [111.84225 ],
         [ 65.71501 ],
         [107.95927 ],
         [ 90.44085 ],
         [ 70.30283 ],
         [ 70.30283 ],
         [ 70.30283 ],
         [104.00534 ],
         [100.73825 ],
         [100.73825 ],
         [104.00534 ],
         [104.38159 ],
         [105.63886 ],
         [ 76.00536 ],
         [ 76.00536 ],
         [ 95.79365 ],
         [ 95.79365 ],
         [ 69.476204],
         [ 69.476204],
         [ 69.476204],
         [ 69.476204],
         [ 69.476204],
         [ 69.476204],
         [ 69.476204],
