In [1]:
import math
from pathlib import Path
from zipfile import ZipFile
from tempfile import TemporaryDirectory

import numpy as np
import pandas as pd
import scipy
from rdkit.Chem import PandasTools
from sklearn import metrics
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import matplotlib as mpl
import seaborn as sns
from tdc.single_pred import HTS
from tdc.single_pred import ADME
from rdkit import DataStructs
from rdkit.Chem import AllChem, MolFromSmiles, Draw
import xgboost as xgb
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_curve, auc, precision_recall_curve, average_precision_score
from xgboost import XGBClassifier
import pickle

In [2]:

# list of cytochrome p-450 enzymes
cyps = ['CYP2C19', 'CYP2D6', 'CYP3A4', 'CYP1A2', 'CYP2C9', 'CYP23A4']
# fit ML Models for all the above enzymes and evaluate their results


In [3]:
def data_load_and_prep(cyp_name):
    """Load the data for given cyp enzyme and splits that data into train, validation and test

     Args:
         cyp_name: CPY 450 Enzyme Name.

     Returns:
         list of DataFrames: Train, Validation and Test Dataframes.
    """
    data = ADME(name = cyp_name+'_Veith')
    splits = data.get_split('scaffold')
    train_df, val_df, test_df = splits['train'], splits['valid'], splits['test']
    print(train_df.shape)
    return train_df, val_df, test_df


def compute_fingerprint(mol, r, nBits) -> np.ndarray:
    """Featurize Rdkit Mol into Morgan Fingerprint bit vectors.

     Args:
         mol: an Rdkit Mol object (representing a molecule).
         r: radius
         nBits: Bit Vector Length

     Returns:
         numpy.ndarray: A 1-D numpy array of the featurized molecule.
    """
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, r, nBits=nBits)
    arr = np.zeros((1,), dtype=np.int8)
    DataStructs.ConvertToNumpyArray(fp, arr)
    return arr

def process_df(df):
    """Process the Dataframe.

     Args:
         df: Dataframe to be processed

     Returns:
         df: Processed df.
         fingerprints: fingerprints array for all the molecules in the dataframe 
    """
    df = df.rename(columns={'Y':'Inhibition'})
    PandasTools.AddMoleculeColumnToFrame(df, "Drug", "Molecule")
    df = df.dropna()
    fingerprints = np.stack([compute_fingerprint(mol, 2, 2048) for mol in df.Molecule])
    df['fingerprints'] = df['Molecule'].apply(compute_fingerprint, r=2, nBits=2048)

    return df, fingerprints
    
    
def fit_xgb_model(fingerprints,
                  target,
                  max_depth = 50,
                  n_estimators = 500
                 ):
    """Process the Dataframe.

     Args:
         fingerprints: Morgan Fingerprints ndarray for all the molecules in a dataframe
         target: A binary label vector
         max_depth: MAxmimum depth of a tree
         n_estimators: number of rounds for boosting 

     Returns:
         An XGBoost Binary Classification Model that provides the probability of a moelcules ability to inhibit CYP450 
    """
    # declare parameters
    params = {
                'objective':'binary:logistic',
                'max_depth': max_depth,
                'colsample_bytree' : 0.7,
                'sub_sample' : 0.5,
                'learning_rate': 0.05,
                'n_estimators':n_estimators,
                'eval_metric' : 'logloss'
            }
            
    # instantiate the classifier 
    xgb_clf = XGBClassifier(**params)
    # fit the classifier to the training data
    xgb_clf.fit(fingerprints, target)
    return xgb_clf


def validate_model(xgb_model,
                   fingerprints,
                   target,
                   name = 'test'):
    
    # make predictions on test data
    predictions = xgb_model.predict(fingerprints)
    accuracy = accuracy_score(target, predictions)
    precision = precision_score(target, predictions)
    recall = recall_score(target, predictions)
    prevalence = sum(target)/len(target)
    print(name+' dataset\n')
    print('Accuracy : {0:0.4f}'. format(accuracy))
    print(f"Precision: {precision:.2f}")               
    print(f"Recall: {recall:.2f}")
    print(f"Prevalence: {prevalence:.2f}")
    metrics_dict = {'accuracy' : accuracy,
                   'precision' : precision,
                   'recall' : recall,
                   'prevalence' : prevalence}
                   
    return predictions, metrics_dict

def roc_pr_curve(xgb_model,
              test_data_fp,
              target,
                 cyp_name,
              curve = "roc",
              label = 'ROC curve (area = %0.2f)',
              x_label = 'False Positive Rate',
              y_label = 'True Positive Rate',
              title = 'Receiver Operating Characteristic (ROC)'
                
             ):
    probs = xgb_model.predict_proba(test_data_fp)[:, 1]
    plt.figure(figsize=(5, 3))
    if curve == "roc":
        x, y, thresholds = roc_curve(target, probs)
        plot_data = auc(x, y)
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    else:
        y, x, _ = precision_recall_curve(target, probs)
        plot_data = average_precision_score(target, probs)
    
    plt.plot(x, y, color='darkorange', lw=2, label= label % plot_data)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title+" "+cyp_name)
    plt.legend(loc="lower right")
    # plt.show()
    plt.savefig('./ML/'+curve+'_curve_'+cyp_name+'.png', bbox_inches='tight')
    plt.close()


In [5]:
# CYP2C19
def from_load_to_model(cyp_name):
    train_data, val_data, test_data = data_load_and_prep(cyp_name)
    train_data, train_fingerprints = process_df(train_data)
    val_data, val_fingerprints = process_df(val_data)
    test_data, test_fingerprints = process_df(test_data)
    
    # fit XGBoost model
    xgb_model = fit_xgb_model(train_fingerprints,
                            train_data.Inhibition,
                            max_depth = 15,
                            n_estimators = 2000)
    
    # Save the model using pickle
    with open('./ML/model_'+cyp_name, 'wb') as file:
        pickle.dump(xgb_model, file)

    # Evaluate the model
    metrics_val = validate_model(xgb_model,
                       val_fingerprints,
                       val_data.Inhibition,
                       name = 'val'+"_"+cyp_name)
    metrics_test = validate_model(xgb_model,
                       test_fingerprints,
                       test_data.Inhibition,
                       name = 'test'+"_"+cyp_name)

    # plot ROC curve
    roc_pr_curve(xgb_model,
                  test_fingerprints,
                  test_data.Inhibition,
                  cyp_name = cyp_name,
                  curve = "roc",
                  label = 'ROC curve (area = %0.2f)',
                  x_label = 'False Positive Rate',
                  y_label = 'True Positive Rate',
                  title = 'Receiver Operating Characteristic (ROC)'
                 )
    # plot PR Curve
    roc_pr_curve(xgb_model,
                 test_fingerprints,
                 test_data.Inhibition,
                 cyp_name = cyp_name,
                 curve = "pr",
                 label = 'PRC curve (area = %0.2f)',
                 x_label = 'Recall',
                 y_label = 'Precision',
                 title = 'Precision-Recall Curve'
             )
    return test_data, metrics_val, metrics_test, xgb_model
    

In [6]:
cyps

['CYP2C19', 'CYP2D6', 'CYP3A4', 'CYP1A2', 'CYP2C9', 'CYP23A4']

In [7]:
CYP2C19 = from_load_to_model('CYP2C19')

Downloading...
100%|███████████████████████████████████████| 771k/771k [00:00<00:00, 9.92MiB/s]
Loading...
Done!
100%|███████████████████████████████████| 12665/12665 [00:01<00:00, 7763.40it/s]


(8865, 3)
val_CYP2C19 dataset

Accuracy : 0.7891
Precision: 0.78
Recall: 0.78
Prevalence: 0.48
test_CYP2C19 dataset

Accuracy : 0.7952
Precision: 0.78
Recall: 0.81
Prevalence: 0.49


In [8]:
CYP2D6 = from_load_to_model('CYP2D6')
CYP3A4 = from_load_to_model('CYP3A4')
CYP1A2 = from_load_to_model('CYP1A2')
CYP2C9 = from_load_to_model('CYP2C9')
CYP23A4 = from_load_to_model('CYP23A4')

Downloading...
100%|███████████████████████████████████████| 800k/800k [00:00<00:00, 5.54MiB/s]
Loading...
Done!
100%|███████████████████████████████████| 13130/13130 [00:01<00:00, 7567.51it/s]


(9191, 3)
val_CYP2D6 dataset

Accuracy : 0.8751
Precision: 0.71
Recall: 0.54
Prevalence: 0.18
test_CYP2D6 dataset

Accuracy : 0.8736
Precision: 0.67
Recall: 0.49
Prevalence: 0.17


Downloading...
100%|███████████████████████████████████████| 746k/746k [00:00<00:00, 6.18MiB/s]
Loading...
Done!
100%|███████████████████████████████████| 12328/12328 [00:01<00:00, 7763.33it/s]


(8629, 3)
val_CYP3A4 dataset

Accuracy : 0.7922
Precision: 0.78
Recall: 0.74
Prevalence: 0.45
test_CYP3A4 dataset

Accuracy : 0.7864
Precision: 0.76
Recall: 0.76
Prevalence: 0.44


Downloading...
100%|███████████████████████████████████████| 760k/760k [00:00<00:00, 3.10MiB/s]
Loading...
Done!
100%|███████████████████████████████████| 12579/12579 [00:01<00:00, 7735.94it/s]


(8805, 3)
val_CYP1A2 dataset

Accuracy : 0.8258
Precision: 0.83
Recall: 0.82
Prevalence: 0.50
test_CYP1A2 dataset

Accuracy : 0.8498
Precision: 0.87
Recall: 0.83
Prevalence: 0.50


Downloading...
100%|███████████████████████████████████████| 740k/740k [00:00<00:00, 4.82MiB/s]
Loading...
Done!
100%|███████████████████████████████████| 12092/12092 [00:01<00:00, 7686.22it/s]


(8464, 3)
val_CYP2C9 dataset

Accuracy : 0.7990
Precision: 0.71
Recall: 0.68
Prevalence: 0.33
test_CYP2C9 dataset

Accuracy : 0.8123
Precision: 0.73
Recall: 0.65
Prevalence: 0.32


Found local copy...
Loading...
Done!
100%|███████████████████████████████████| 12328/12328 [00:01<00:00, 7718.37it/s]


(8629, 3)
val_CYP23A4 dataset

Accuracy : 0.7922
Precision: 0.78
Recall: 0.74
Prevalence: 0.45
test_CYP23A4 dataset

Accuracy : 0.7864
Precision: 0.76
Recall: 0.76
Prevalence: 0.44
