<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [None]:
!pip install rdkit-pypi

In [2]:
import os

import pandas as pd
import numpy as np
from data_handler import *


from imblearn.over_sampling import RandomOverSampler


from sklearn.svm import SVC
from sklearn.model_selection import cross_validate
from sklearn.model_selection import GridSearchCV

import pickle

np.random.seed(7)

In [4]:
# Data in BindingDB format
VDR = pd.read_csv("BindingDB_data/Vitamin-D3-Receptor.zip", sep='\t', compression = 'zip')
GABA = pd.read_csv("BindingDB_data/Gamma-aminobutyric acid receptor subunit alpha-1.zip", sep='\t', compression = 'zip')
mTOR = pd.read_csv("BindingDB_data/Serine-threonine-protein kinase mTOR.zip", sep='\t', compression = 'zip')

VDR.head(1)

Unnamed: 0,BindingDB Reactant_set_id,Ligand SMILES,Ligand InChI,Ligand InChI Key,BindingDB MonomerID,BindingDB Ligand Name,Target Name Assigned by Curator or DataSource,Target Source Organism According to Curator or DataSource,Ki (nM),IC50 (nM),...,UniProt (SwissProt) Recommended Name of Target Chain,UniProt (SwissProt) Entry Name of Target Chain,UniProt (SwissProt) Primary ID of Target Chain,UniProt (SwissProt) Secondary ID(s) of Target Chain,UniProt (SwissProt) Alternative ID(s) of Target Chain,UniProt (TrEMBL) Submitted Name of Target Chain,UniProt (TrEMBL) Entry Name of Target Chain,UniProt (TrEMBL) Primary ID of Target Chain,UniProt (TrEMBL) Secondary ID(s) of Target Chain,UniProt (TrEMBL) Alternative ID(s) of Target Chain
0,499761,O=C(N[C@H]1CC[C@H](CCN2CCN(CC2)c2nsc3ccccc23)C...,InChI=1S/C28H33N5OS/c34-28(25-19-21-5-1-3-7-24...,KXAOPEMSBPZNPQ-AQYVVDRMSA-N,50207116,"CHEMBL3905247::US9550741, I-4",Vitamin D3 receptor,Homo sapiens,0.029,,...,Vitamin D3 receptor,VDR_HUMAN,P11473,"B2R5Q1,G3V1V9,Q5PSV3",,,,F1D8P8,,


In [5]:
rus = RandomOverSampler(random_state=7)
load = LoadDBForSklearn(1, rus)

X_vdr, y_vdr = load.prepare(VDR.copy())
X_gaba, y_gaba = load.prepare(GABA.copy())
X_mtor, y_mtor = load.prepare(mTOR.copy())

Num of active:  0.453 %
Num of active after oversampling:  0.5 %
Num of active:  0.743 %
Num of active after oversampling:  0.5 %
Num of active:  0.723 %
Num of active after oversampling:  0.5 %


In [6]:
base_svm = SVC(probability=True, class_weight='balanced')

In [7]:
cross_validate(base_svm, X_vdr, y_vdr, scoring=['accuracy', 'f1'], cv=3)

{'fit_time': array([1.16954684, 1.3992703 , 1.15316153]),
 'score_time': array([0.16447926, 0.18742514, 0.22457147]),
 'test_accuracy': array([0.94117647, 0.92016807, 0.91176471]),
 'test_f1': array([0.94166667, 0.92244898, 0.91139241])}

In [8]:
cross_validate(base_svm, X_gaba, y_gaba, scoring=['accuracy', 'f1'], cv=3)

{'fit_time': array([0.97588015, 0.91832495, 0.88276505]),
 'score_time': array([0.1413281 , 0.15434074, 0.16423345]),
 'test_accuracy': array([0.9537037 , 0.91162791, 0.89302326]),
 'test_f1': array([0.9537037 , 0.90731707, 0.88442211])}

In [9]:
cross_validate(base_svm, X_mtor, y_mtor, scoring=['accuracy', 'f1'], cv=3)

{'fit_time': array([103.57470155,  98.3940475 ,  97.89733291]),
 'score_time': array([10.17812681, 11.87894893, 10.39885306]),
 'test_accuracy': array([0.92756037, 0.93213988, 0.93838468]),
 'test_f1': array([0.92645816, 0.93148382, 0.93776283])}

In [10]:
tuned_parameters = [{
    "kernel": ["rbf"],
    "gamma": ["scale"],
    "C": [0.25, 0.5, 1, 2, 2**2, 2**4, 2**6, 2**8],
    "class_weight": ['balanced'],
    "probability": [True],
    "shrinking": [False],
    "cache_size": [3000]
}]

clf_vdr = GridSearchCV(base_svm, tuned_parameters, scoring='f1', cv=3)
clf_vdr.fit(X_vdr, y_vdr)
print(clf_vdr.best_params_)
svc_vdr = clf_vdr.best_estimator_
cross_validate(svc_vdr, X_vdr, y_vdr, scoring=['accuracy', 'f1'], cv=3)

{'C': 1, 'cache_size': 3000, 'class_weight': 'balanced', 'gamma': 'scale', 'kernel': 'rbf', 'probability': True, 'shrinking': False}


{'fit_time': array([0.65798783, 0.63946009, 0.62035394]),
 'score_time': array([0.09915972, 0.09668398, 0.0955348 ]),
 'test_accuracy': array([0.94117647, 0.92016807, 0.91176471]),
 'test_f1': array([0.94166667, 0.92244898, 0.91139241])}

In [11]:
clf_gaba = GridSearchCV(base_svm, tuned_parameters, scoring='f1', cv=3)
clf_gaba.fit(X_gaba, y_gaba)
print(clf_gaba.best_params_)

svc_gaba = clf_gaba.best_estimator_
cross_validate(svc_gaba, X_gaba, y_gaba, scoring=['accuracy', 'f1'], cv=3)

{'C': 4, 'cache_size': 3000, 'class_weight': 'balanced', 'gamma': 'scale', 'kernel': 'rbf', 'probability': True, 'shrinking': False}


{'fit_time': array([0.48609138, 0.48164344, 0.46433401]),
 'score_time': array([0.07081509, 0.07026768, 0.06919861]),
 'test_accuracy': array([0.96759259, 0.9255814 , 0.9255814 ]),
 'test_f1': array([0.96744186, 0.92307692, 0.92156863])}

In [12]:
clf_mtor = GridSearchCV(base_svm, tuned_parameters, scoring='f1', cv=3)
clf_mtor.fit(X_mtor, y_mtor)
print(clf_mtor.best_params_)

svc_mtor = clf_mtor.best_estimator_
cross_validate(svc_mtor, X_mtor, y_mtor, scoring=['accuracy', 'f1'], cv=3)

{'C': 16, 'cache_size': 3000, 'class_weight': 'balanced', 'gamma': 'scale', 'kernel': 'rbf', 'probability': True, 'shrinking': False}


{'fit_time': array([79.73646307, 79.02252507, 85.77738929]),
 'score_time': array([8.71906757, 9.30167079, 9.14257884]),
 'test_accuracy': array([0.95212323, 0.96044963, 0.96586178]),
 'test_f1': array([0.95133305, 0.95986481, 0.96563286])}

In [13]:
svc_vdr = svc_vdr.fit(X_vdr, y_vdr)
svc_gaba = svc_gaba.fit(X_gaba, y_gaba)
svc_mtor = svc_mtor.fit(X_mtor, y_mtor)

In [14]:
calcitriol = load.make_rdkit_canonical("C=C1C(=CC=C2CCCC3(C)C2CCC3C(C)CCCC(C)(C)O)CC(O)CC1O")
svc_vdr.predict_proba(load.get_fp(load.get_mol(calcitriol)).reshape(1, -1))

array([[0.01455057, 0.98544943]])

In [15]:
diazepam = load.make_rdkit_canonical("CN1C(=O)CN=C(C2=C1C=CC(=C2)Cl)C3=CC=CC=C3")
svc_gaba.predict_proba(load.get_fp(load.get_mol(diazepam)).reshape(1, -1))

array([[0.00287045, 0.99712955]])

In [16]:
torin1 = load.make_rdkit_canonical("CCC(=O)N1CCN(CC1)C2=C(C=C(C=C2)N3C(=O)C=CC4=CN=C5C=CC(=CC5=C43)C6=CC7=CC=CC=C7N=C6)C(F)(F)F")
svc_mtor.predict_proba(load.get_fp(load.get_mol(torin1)).reshape(1, -1))

array([[6.30844426e-07, 9.99999369e-01]])

In [17]:
if not os.path.exists('trained_svms'):
    os.makedirs('trained_svms')

In [18]:
pickle.dump(svc_vdr, open('trained_svms/svc_vdr.pkl', 'wb'))
pickle.dump(svc_gaba, open('trained_svms/svc_gaba.pkl', 'wb'))
pickle.dump(svc_mtor, open('trained_svms/svc_mtor.pkl', 'wb'))