In [1]:
import deepchem as dc
import pandas as pd
import numpy as np
from rdkit import Chem

from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV,KFold

from sklearn.metrics import make_scorer
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score
from sklearn.metrics import f1_score
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import matthews_corrcoef

Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (D:\anaconda\envs\PI3K\lib\site-packages\deepchem\models\torch_models\__init__.py)
Skipped loading modules with pytorch-lightning dependency, missing a dependency. No module named 'pytorch_lightning'
Skipped loading some Jax models, missing a dependency. No module named 'jax'


In [2]:
data = pd.read_csv('../../data/refined_gabaa.csv')

# Feature extraction & Data splitting

In [3]:
featurizer = dc.feat.MACCSKeysFingerprint()
features = featurizer.featurize(data['smiles'])
dataset = dc.data.NumpyDataset(features,data['label'])

In [4]:
splitter = dc.splits.RandomSplitter()
train_dataset, test_dataset = splitter.train_test_split(dataset=dataset,frac_train=0.7,seed=100)

# GridSearchCV

In [5]:
scoring = {
          'F1':make_scorer(f1_score),
          'AUC':make_scorer(roc_auc_score),
           'BA':make_scorer(balanced_accuracy_score),
           'MCC':make_scorer(matthews_corrcoef)
}

param_grid = {'C': [0.1, 1, 10,100], 
              'kernel': ['linear', 'rbf','poly'],
             'gamma':[0.1,1,10]}
    
svm_classifier = SVC(probability=True)

cv = KFold(n_splits=5, shuffle=True, random_state=100)


gs = GridSearchCV(
                    svm_classifier,
                   param_grid,
                 scoring = scoring,
                  cv = cv,
                 n_jobs = -1,
               refit = 'F1',
               return_train_score = True)

gs_fit = gs.fit(train_dataset.X, train_dataset.y.ravel())

In [6]:
val_F1 = gs.best_score_
val_AUC = gs.cv_results_['mean_test_AUC'][gs.best_index_]
val_BA = gs.cv_results_['mean_test_BA'][gs.best_index_]
val_MCC = gs.cv_results_['mean_test_MCC'][gs.best_index_]

# GridSearchCV_Result

In [7]:
print('Best parameters: ', gs.best_params_)
print('Best score (F1): ', gs.best_score_)
print('AUC: ',gs.cv_results_['mean_test_AUC'][gs.best_index_])
print('BA: ', gs.cv_results_['mean_test_BA'][gs.best_index_])
print('MCC: ',gs.cv_results_['mean_test_MCC'][gs.best_index_])

Best parameters:  {'C': 10, 'gamma': 0.1, 'kernel': 'rbf'}
Best score (F1):  0.8438847220510326
AUC:  0.8643529790516947
BA:  0.8643529790516947
MCC:  0.729185091694476


In [8]:
svm_model = gs_fit.best_estimator_
svm_model

SVC(C=10, gamma=0.1, probability=True)

# Evaluate model

In [9]:
y_test_pred_proba = svm_model.predict_proba(test_dataset.X)
y_test_pred = svm_model.predict(test_dataset.X)

In [10]:
tn, fp, fn, tp = confusion_matrix(test_dataset.y, y_test_pred).ravel()
print('TN:', tn)
print('FP:', fp)
print('FN:', fn)
print('TP:', tp)

TN: 284
FP: 29
FN: 39
TP: 221


In [11]:
test_pred_list = []
for test_score in y_test_pred_proba:
    test_score = test_score[1]
    test_pred_list.append(test_score)

In [12]:
test_pred_array = np.array(test_pred_list)

In [13]:
test_F1 = f1_score(test_dataset.y,np.round(test_pred_array))
test_AUC = roc_auc_score(test_dataset.y,test_pred_list)
test_BA =  balanced_accuracy_score(test_dataset.y,np.round(test_pred_array))
test_MCC = matthews_corrcoef(test_dataset.y,np.round(test_pred_array))

# Finall result

In [14]:
performance_dataset = {
    'F1':[val_F1,test_F1],
    'AUC':[val_AUC,test_AUC],
    'BA':[val_BA,test_BA],
    'MCC':[val_MCC,test_MCC],
}

In [16]:
performance = pd.DataFrame(performance_dataset,index=['val','test'])
performance

Unnamed: 0,F1,AUC,BA,MCC
val,0.843885,0.864353,0.864353,0.729185
test,0.868369,0.915858,0.880272,0.763854
