In [70]:
import pandas as pd
from rdkit import Chem
# discussion of circular fingerprints: https://pubs.acs.org/doi/10.1021/ci100050t
from rdkit.Chem import AllChem
import numpy as np
from tqdm import tqdm
tqdm.pandas()
import os
import pickle
import pandas as pd

from sklearn.model_selection import train_test_split

#Load the Models:
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.svm import SVC
from sklearn.dummy import DummyClassifier


#Load the Metrics:
from sklearn.model_selection import cross_validate
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import log_loss

#Other fingerprint types to explore? 
#useful example: https://medium.com/@gurkamaldeol/predicting-environmental-carcinogens-with-logistic-regression-knn-gradient-boosting-and-7973f88eb8b3

In [15]:
# pull in the datasets:
training_datasets=[i for i in os.listdir('data_split_cleaned') if 'train' in i]
validation_datasets=[i for i in os.listdir('data_split_cleaned') if 'validate' in i]

In [21]:
data_map={
    'HIV': {'target':'HIV_active','structure':'smiles'},
    'bace':{'target':'active','structure':'mol'},
    'tox21':{'target':'NR-AhR','structure':'smiles'},
    'clintox':{'target':'CT_TOX','structure':'smiles'},
    'sol_del':{'target':'binned_sol','structure':'smiles'},
    'deepchem_Lipophilicity':{'target':'drug_like','structure':'smiles'}   
}
model_save_path='Simple_Models'

In [9]:
def generate_fingerprint(smiles,radius,bits):
    try:
        mol=Chem.MolFromSmiles(smiles)
        fp=AllChem.GetMorganFingerprintAsBitVect(mol,radius,bits)
        return(np.array(fp))
    except:
        print(f'{smiles} failed in RDkit')
        return (np.nan)

In [87]:
# Load the scoring and model functions:
scoring = ['accuracy', 'f1','roc_auc','neg_log_loss']

models={'Logistic_Regression':LogisticRegression(random_state=0,solver='lbfgs',max_iter=1000,verbose=False),
        'Random_Forest':RandomForestClassifier(random_state=0,n_jobs=-1),
        'KNN': KNeighborsClassifier(n_neighbors=5, n_jobs=-1),
        'Gradient_Boosted_Tree': GradientBoostingClassifier(n_estimators=100, learning_rate=0.1, max_depth=1, random_state=0),
        'SVM':SVC(C=1.0, kernel='linear', degree=3, gamma='scale',probability=True),
        'Dummy_Most_Frequent':DummyClassifier(strategy="most_frequent")
       }

In [88]:
# load the data set
# create finger prints on molecules
# extract X and Y for training
# Train the model on this data (cross val)
# save the scores to csv in folder
# save the model as a pickle 
# predict with the model using one trained on all. 
#save those scores to csv for datafile (train and val) 

In [89]:
def save_model(model,model_name):
    model_out=os.path.join('Simple_Models',model_name)

    if not os.path.exists(model_out):
        os.makedirs(model_out)
    with open(f'{model_name}_model.pkl','wb') as f:
        pickle.dump(clf,f)
    print(f'saved_{model_name}_to_{model_out}')

In [90]:
def featurize_dataset(dataset,radius,bits):
    df=pd.read_csv(os.path.join('data_split_cleaned',dataset))
    data_lookup=dataset.split('-')[0]
    target=data_map[data_lookup]['target']
    smiles=data_map[data_lookup]['structure']
    df['fp']=df[smiles].apply(lambda x: generate_fingerprint(x,radius,bits))
    df.dropna(subset=['fp',target],inplace=True) # Remove any failed fingerprints or missing targets
    X=df['fp'].to_list()
    y=df[target].to_list()
    return(X,y)

In [95]:
# Train the models:
for dataset in training_datasets:
    #extract the dataset name
    data_set_name=dataset.replace('.csv','').split('-')
    data_set_name=data_set_name[0]+'-'+data_set_name[1]
    
    X,y=featurize_dataset(dataset,2,1024) # featurize the dataset
    
    model_out=os.path.join('Simple_Models',data_set_name)
    if not os.path.exists(model_out):# make an output path
        os.makedirs(model_out)
    
    for model in tqdm(models):
        clf=models[model]
        # run cross val to estimate error:
        cv_result=cross_validate(clf , X, y, scoring=scoring, cv=5, return_estimator=False)
        pd.DataFrame(cv_result).describe().to_csv(os.path.join(model_out,model+'.csv'))
        
        clf.fit(X,y) # fit the model on the whole dataset:
        
        with open(os.path.join(model_out,f'{model}_model.pkl'),'wb') as f:
            pickle.dump(clf,f)
        print(f'saved_{model}_to_{model_out}')

 17%|███████▌                                     | 1/6 [00:00<00:00,  6.42it/s]

saved_Logistic_Regression_to_Simple_Models/sol_del-cluster


 33%|███████████████                              | 2/6 [00:02<00:05,  1.48s/it]

saved_Random_Forest_to_Simple_Models/sol_del-cluster


  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
 50%|██████████████████████▌                      | 3/6 [00:02<00:02,  1.01it/s]

saved_KNN_to_Simple_Models/sol_del-cluster


 67%|██████████████████████████████               | 4/6 [00:04<00:02,  1.28s/it]

saved_Gradient_Boosted_Tree_to_Simple_Models/sol_del-cluster


100%|█████████████████████████████████████████████| 6/6 [00:06<00:00,  1.06s/it]

saved_SVM_to_Simple_Models/sol_del-cluster
saved_Dummy_Most_Frequent_to_Simple_Models/sol_del-cluster



 17%|███████▌                                     | 1/6 [00:00<00:01,  3.46it/s]

saved_Logistic_Regression_to_Simple_Models/clintox-random


 33%|███████████████                              | 2/6 [00:01<00:03,  1.03it/s]

saved_Random_Forest_to_Simple_Models/clintox-random


  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
 50%|██████████████████████▌                      | 3/6 [00:02<00:02,  1.24it/s]

saved_KNN_to_Simple_Models/clintox-random


 67%|██████████████████████████████               | 4/6 [00:04<00:02,  1.40s/it]

saved_Gradient_Boosted_Tree_to_Simple_Models/clintox-random


100%|█████████████████████████████████████████████| 6/6 [00:07<00:00,  1.21s/it]

saved_SVM_to_Simple_Models/clintox-random
saved_Dummy_Most_Frequent_to_Simple_Models/clintox-random



 17%|███████▌                                     | 1/6 [00:00<00:03,  1.28it/s]

saved_Logistic_Regression_to_Simple_Models/bace-cluster


 33%|███████████████                              | 2/6 [00:02<00:04,  1.17s/it]

saved_Random_Forest_to_Simple_Models/bace-cluster


  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
 50%|██████████████████████▌                      | 3/6 [00:02<00:02,  1.10it/s]

saved_KNN_to_Simple_Models/bace-cluster


 67%|██████████████████████████████               | 4/6 [00:05<00:03,  1.51s/it]

saved_Gradient_Boosted_Tree_to_Simple_Models/bace-cluster


100%|█████████████████████████████████████████████| 6/6 [00:08<00:00,  1.47s/it]

saved_SVM_to_Simple_Models/bace-cluster
saved_Dummy_Most_Frequent_to_Simple_Models/bace-cluster



 17%|███████▌                                     | 1/6 [00:00<00:04,  1.13it/s]

saved_Logistic_Regression_to_Simple_Models/deepchem_Lipophilicity-random


 33%|███████████████                              | 2/6 [00:03<00:06,  1.64s/it]

saved_Random_Forest_to_Simple_Models/deepchem_Lipophilicity-random


  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
 50%|██████████████████████▌                      | 3/6 [00:05<00:06,  2.18s/it]

saved_KNN_to_Simple_Models/deepchem_Lipophilicity-random


 67%|██████████████████████████████               | 4/6 [00:14<00:09,  4.90s/it]

saved_Gradient_Boosted_Tree_to_Simple_Models/deepchem_Lipophilicity-random


100%|█████████████████████████████████████████████| 6/6 [00:29<00:00,  4.90s/it]

saved_SVM_to_Simple_Models/deepchem_Lipophilicity-random
saved_Dummy_Most_Frequent_to_Simple_Models/deepchem_Lipophilicity-random



 17%|███████▌                                     | 1/6 [00:01<00:05,  1.10s/it]

saved_Logistic_Regression_to_Simple_Models/deepchem_Lipophilicity-cluster


 33%|███████████████                              | 2/6 [00:03<00:07,  1.83s/it]

saved_Random_Forest_to_Simple_Models/deepchem_Lipophilicity-cluster


  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
 50%|██████████████████████▌                      | 3/6 [00:06<00:06,  2.33s/it]

saved_KNN_to_Simple_Models/deepchem_Lipophilicity-cluster


 67%|██████████████████████████████               | 4/6 [00:14<00:09,  4.80s/it]

saved_Gradient_Boosted_Tree_to_Simple_Models/deepchem_Lipophilicity-cluster


100%|█████████████████████████████████████████████| 6/6 [00:29<00:00,  4.95s/it]

saved_SVM_to_Simple_Models/deepchem_Lipophilicity-cluster
saved_Dummy_Most_Frequent_to_Simple_Models/deepchem_Lipophilicity-cluster



 17%|███████▌                                     | 1/6 [00:00<00:00,  6.39it/s]

saved_Logistic_Regression_to_Simple_Models/sol_del-random


 33%|███████████████                              | 2/6 [00:01<00:03,  1.14it/s]

saved_Random_Forest_to_Simple_Models/sol_del-random


  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
 50%|██████████████████████▌                      | 3/6 [00:01<00:01,  1.52it/s]

saved_KNN_to_Simple_Models/sol_del-random


 67%|██████████████████████████████               | 4/6 [00:03<00:02,  1.06s/it]

saved_Gradient_Boosted_Tree_to_Simple_Models/sol_del-random


100%|█████████████████████████████████████████████| 6/6 [00:05<00:00,  1.11it/s]

saved_SVM_to_Simple_Models/sol_del-random
saved_Dummy_Most_Frequent_to_Simple_Models/sol_del-random



 17%|███████▌                                     | 1/6 [00:01<00:08,  1.72s/it]

saved_Logistic_Regression_to_Simple_Models/tox21-random


 33%|███████████████                              | 2/6 [00:05<00:11,  2.76s/it]

saved_Random_Forest_to_Simple_Models/tox21-random


  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
 50%|██████████████████████▌                      | 3/6 [00:09<00:10,  3.57s/it]

saved_KNN_to_Simple_Models/tox21-random


 67%|██████████████████████████████               | 4/6 [00:25<00:17,  8.53s/it]

saved_Gradient_Boosted_Tree_to_Simple_Models/tox21-random


100%|█████████████████████████████████████████████| 6/6 [01:48<00:00, 18.13s/it]

saved_SVM_to_Simple_Models/tox21-random
saved_Dummy_Most_Frequent_to_Simple_Models/tox21-random



 17%|███████▌                                     | 1/6 [00:13<01:07, 13.54s/it]

saved_Logistic_Regression_to_Simple_Models/HIV-cluster


 33%|███████████████                              | 2/6 [00:53<01:56, 29.09s/it]

saved_Random_Forest_to_Simple_Models/HIV-cluster


  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
 50%|██████████████████████▌                      | 3/6 [02:39<03:12, 64.01s/it]

saved_KNN_to_Simple_Models/HIV-cluster


 67%|██████████████████████████████               | 4/6 [04:53<03:03, 91.62s/it]

saved_Gradient_Boosted_Tree_to_Simple_Models/HIV-cluster


100%|████████████████████████████████████████| 6/6 [10:04:35<00:00, 6046.00s/it]

saved_SVM_to_Simple_Models/HIV-cluster
saved_Dummy_Most_Frequent_to_Simple_Models/HIV-cluster



 17%|███████▌                                     | 1/6 [00:00<00:01,  3.09it/s]

saved_Logistic_Regression_to_Simple_Models/clintox-cluster


 33%|██████████████▋                             | 2/6 [15:33<36:35, 548.82s/it]

saved_Random_Forest_to_Simple_Models/clintox-cluster


  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
 50%|██████████████████████                      | 3/6 [15:34<14:56, 298.70s/it]

saved_KNN_to_Simple_Models/clintox-cluster


 67%|█████████████████████████████▎              | 4/6 [15:36<06:03, 181.66s/it]

saved_Gradient_Boosted_Tree_to_Simple_Models/clintox-cluster


100%|████████████████████████████████████████████| 6/6 [33:24<00:00, 334.12s/it]

saved_SVM_to_Simple_Models/clintox-cluster
saved_Dummy_Most_Frequent_to_Simple_Models/clintox-cluster



 17%|███████▌                                     | 1/6 [00:02<00:10,  2.11s/it]

saved_Logistic_Regression_to_Simple_Models/tox21-cluster


 33%|██████████████▋                             | 2/6 [15:24<36:13, 543.35s/it]

saved_Random_Forest_to_Simple_Models/tox21-cluster


  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
 50%|██████████████████████                      | 3/6 [15:29<14:53, 297.72s/it]

saved_KNN_to_Simple_Models/tox21-cluster


 67%|█████████████████████████████▎              | 4/6 [49:41<33:00, 990.26s/it]

saved_Gradient_Boosted_Tree_to_Simple_Models/tox21-cluster


100%|█████████████████████████████████████████| 6/6 [3:46:35<00:00, 2265.99s/it]

saved_SVM_to_Simple_Models/tox21-cluster
saved_Dummy_Most_Frequent_to_Simple_Models/tox21-cluster



 17%|██████▊                                  | 1/6 [48:07<4:00:36, 2887.39s/it]

saved_Logistic_Regression_to_Simple_Models/HIV-random


 33%|█████████████                          | 2/6 [2:46:03<5:56:46, 5351.63s/it]

saved_Random_Forest_to_Simple_Models/HIV-random


  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
 50%|███████████████████▌                   | 3/6 [5:48:55<6:35:53, 7917.83s/it]

saved_KNN_to_Simple_Models/HIV-random


 67%|██████████████████████████             | 4/6 [5:51:09<2:41:29, 4844.86s/it]

saved_Gradient_Boosted_Tree_to_Simple_Models/HIV-random


100%|█████████████████████████████████████████| 6/6 [6:48:31<00:00, 4085.29s/it]

saved_SVM_to_Simple_Models/HIV-random
saved_Dummy_Most_Frequent_to_Simple_Models/HIV-random



 17%|███████▌                                     | 1/6 [00:00<00:03,  1.56it/s]

saved_Logistic_Regression_to_Simple_Models/bace-random


 33%|███████████████                              | 2/6 [00:03<00:07,  1.88s/it]

saved_Random_Forest_to_Simple_Models/bace-random


  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
  mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
 50%|██████████████████████▌                      | 3/6 [00:04<00:03,  1.30s/it]

saved_KNN_to_Simple_Models/bace-random


 67%|██████████████████████████████               | 4/6 [00:06<00:03,  1.76s/it]

saved_Gradient_Boosted_Tree_to_Simple_Models/bace-random


100%|█████████████████████████████████████████████| 6/6 [00:10<00:00,  1.71s/it]

saved_SVM_to_Simple_Models/bace-random
saved_Dummy_Most_Frequent_to_Simple_Models/bace-random



