In [1]:
import deepchem as dc
import pandas as pd
import numpy as np

from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import GridSearchCV,KFold

from sklearn import metrics
from sklearn.metrics import confusion_matrix
from sklearn.metrics import make_scorer
from sklearn.metrics import roc_auc_score
from sklearn.metrics import f1_score
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import matthews_corrcoef

import warnings
warnings.filterwarnings('ignore')

Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'
Skipped loading modules with pytorch-lightning dependency, missing a dependency. No module named 'pytorch_lightning'
Skipped loading some Jax models, missing a dependency. jax requires jaxlib to be installed. See https://github.com/google/jax#installation for installation instructions.


In [2]:
data = pd.read_csv('../../data/refined_gabaa.csv')

# Feature extraction & Data splitting

In [3]:
featurizer = dc.feat.CircularFingerprint(size=1024,radius=4)
features = featurizer.featurize(data['smiles'])
dataset = dc.data.NumpyDataset(features,data['label'])

In [4]:
splitter = dc.splits.RandomSplitter()
train_dataset, test_dataset = splitter.train_test_split(dataset=dataset,frac_train=0.7,seed=300)

# GridSearchCV

In [5]:
scoring = {
        'F1':make_scorer(f1_score),
        'AUC':make_scorer(roc_auc_score),
        'BA':make_scorer(balanced_accuracy_score),
        'MCC':make_scorer(matthews_corrcoef)
}

param_grid = {'learning_rate': [0.1,0.3,0.5,0.7,0.9],
              'n_estimators':[50,100,150,200],
              'max_depth': [ 4, 5, 6, 7, 8]}


gbdt_classifier = GradientBoostingClassifier(random_state=42)

cv = KFold(n_splits=5, shuffle=True, random_state=300)

gs = GridSearchCV(
                    gbdt_classifier,
                   param_grid,
                 scoring = scoring,
                  cv = cv,
                 n_jobs = -1,
               refit = 'F1',
               return_train_score = True)

gs_fit = gs.fit(train_dataset.X, train_dataset.y.ravel())

In [6]:
val_F1 = gs.best_score_
val_AUC = gs.cv_results_['mean_test_AUC'][gs.best_index_]
val_BA = gs.cv_results_['mean_test_BA'][gs.best_index_]
val_MCC = gs.cv_results_['mean_test_MCC'][gs.best_index_]

# GridSearchCV_Result

In [7]:
print('Best parameters: ', gs.best_params_)
print('Best score (F1): ', gs.best_score_)
print('AUC: ',gs.cv_results_['mean_test_AUC'][gs.best_index_])
print('BA: ', gs.cv_results_['mean_test_BA'][gs.best_index_])
print('MCC: ',gs.cv_results_['mean_test_MCC'][gs.best_index_])

Best parameters:  {'learning_rate': 0.1, 'max_depth': 8, 'n_estimators': 200}
Best score (F1):  0.8155921124583294
AUC:  0.8386811862810617
BA:  0.8386811862810617
MCC:  0.6785838804983614


In [8]:
gbdt_model = gs_fit.best_estimator_
gbdt_model

GradientBoostingClassifier(max_depth=8, n_estimators=200, random_state=42)

# Evaluate model

In [9]:
y_test_pred_proba = gbdt_model.predict_proba(test_dataset.X)
y_test_pred = gbdt_model.predict(test_dataset.X)

In [10]:
tn, fp, fn, tp = confusion_matrix(test_dataset.y, y_test_pred).ravel()
print('TN:', tn)
print('FP:', fp)
print('FN:', fn)
print('TP:', tp)

TN: 288
FP: 34
FN: 36
TP: 215


In [11]:
test_pred_list = []
for test_score in y_test_pred_proba:
    test_score = test_score[1]
    test_pred_list.append(test_score)

In [12]:
test_pred_array = np.array(test_pred_list)

In [13]:
test_F1 = f1_score(test_dataset.y,np.round(test_pred_array))
test_AUC = roc_auc_score(test_dataset.y,test_pred_list)
test_BA = balanced_accuracy_score(test_dataset.y,np.round(test_pred_array))
test_MCC = matthews_corrcoef(test_dataset.y,np.round(test_pred_array))

# Finall result

In [14]:
performance_dataset = {
        'F1':[val_F1,test_F1],
        'AUC':[val_AUC,test_AUC],
        'BA':[val_BA,test_BA],
        'MCC':[val_MCC,test_MCC],
}

In [15]:
performance = pd.DataFrame(performance_dataset,index=['val','test'])
performance

Unnamed: 0,F1,AUC,BA,MCC
val,0.815592,0.838681,0.838681,0.678584
test,0.86,0.925472,0.875492,0.751663
