NOTE NB-models don't give us feature_importances_
https://stackoverflow.com/questions/41592661/determining-the-most-contributing-features-for-svm-classifier-in-sklearn

NOTE we can include ELI5 for explanation of predictors
https://github.com/TeamHG-Memex/eli5

NOTE There are other explanation oriented libraries as well
https://github.com/DistrictDataLabs/yellowbrick


In [1]:
# Import the usual suspects.

from __future__ import print_function
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split

#Import scikit-learn metrics module for accuracy calculation
from sklearn import metrics

from sklearn.model_selection import cross_val_score
from sklearn.metrics import classification_report, confusion_matrix




import warnings
warnings.filterwarnings('ignore')
sns.set_style('whitegrid')
sns.set_context('paper')



def print_ln():
    print('-' * 80, '\n')


In [2]:

def model_performance_metrics(model, X, X_test, X_train, y, y_test, y_pred, detailed= False, show_feature_importances= True):

    print("Accuracy:",metrics.accuracy_score(y_test, y_pred))
    print('Mean Absolute Error:', metrics.mean_absolute_error(y_test, y_pred))
    print('Mean Squared Error:', metrics.mean_squared_error(y_test, y_pred))
    print('Root Mean Squared Error:', np.sqrt(metrics.mean_squared_error(y_test, y_pred)))
    print_ln()
    

    if show_feature_importances:

        feature_importances = pd.DataFrame(model.feature_importances_,
                                               index =  X_train.columns,
                                               columns=['importance']).sort_values('importance', ascending=False)



        print("=== Feature Importances ===")
        print(feature_importances)


    
    if detailed:
        model_score = cross_val_score(model, X, y, cv=10)


        print("=== Confusion Matrix ===")
        print(confusion_matrix(y_test, y_pred))
        print_ln()

        print("=== Classification Report ===")
        print(classification_report(y_test, y_pred))
        print_ln()

        print("=== All AUC Scores ===")
        print(model_score)

        print_ln()

        print("=== Mean AUC Score ===")
        print(model_score.mean())
        print_ln()





In [3]:
drugs_column_names = ['rifampicin',
                      'isoniazid',
                      'pyrazinamide',
                      'ethambutol',
                      'streptomycin',
                      'fluoroquinolones',
                      'moxifloxacin',
                      'ofloxacin',
                      'levofloxacin',
                      'ciprofloxacin',
                      'aminoglycosides',
                      'amikacin',
                      'kanamycin',
                      'capreomycin',
                      'ethionamide',
                      'para-aminosalicylic_acid',
                      'cycloserine',
                      'linezolid',
                      'bedaquiline',
                      'clofazimine',
                      'delamanid']


lineage_column_names = [ 'main_lin', 'sublin' ]

resistance_status_column_names = [ 'drtype', 'MDR', 'XDR', 'Resistance_Status' ]


renamed_drug_columns_names = [                     'rifampicin_resistance',
                                                   'isoniazid_resistance',
                                                   'pyrazinamide_resistance',
                                                   'ethambutol_resistance',
                                                   'streptomycin_resistance',
                                                   'fluoroquinolones_resistance',
                                                   'moxifloxacin_resistance',
                                                   'ofloxacin_resistance',
                                                   'levofloxacin_resistance',
                                                   'ciprofloxacin_resistance',
                                                   'aminoglycosides_resistance',
                                                   'amikacin_resistance',
                                                   'kanamycin_resistance',
                                                   'capreomycin_resistance',
                                                   'ethionamide_resistance',
                                                   'para-aminosalicylic_acid_resistance',
                                                   'cycloserine_resistance',
                                                   'linezolid_resistance',
                                                   'bedaquiline_resistance',
                                                   'clofazimine_resistance',
                                                   'delamanid_resistance']


renamed_drug_columns_names_dict = {
                         'rifampicin': 'rifampicin_resistance',
                         'isoniazid': 'isoniazid_resistance',
                         'pyrazinamide': 'pyrazinamide_resistance',
                         'ethambutol': 'ethambutol_resistance',
                         'streptomycin': 'streptomycin_resistance',
                         'fluoroquinolones': 'fluoroquinolones_resistance',
                         'moxifloxacin': 'moxifloxacin_resistance',
                         'ofloxacin': 'ofloxacin_resistance',
                         'levofloxacin': 'levofloxacin_resistance',
                         'ciprofloxacin': 'ciprofloxacin_resistance',
                         'aminoglycosides': 'aminoglycosides_resistance',
                         'amikacin': 'amikacin_resistance',
                         'kanamycin': 'kanamycin_resistance',
                         'capreomycin': 'capreomycin_resistance',
                         'ethionamide': 'ethionamide_resistance',
                         'para-aminosalicylic_acid': 'para-aminosalicylic_acid_resistance',
                         'cycloserine': 'cycloserine_resistance',
                         'linezolid': 'linezolid_resistance',
                         'bedaquiline': 'bedaquiline_resistance',
                         'clofazimine': 'clofazimine_resistance',
                         'delamanid': 'delamanid_resistance'
}

In [4]:
# mono_resistance_df_filledna = pd.read_csv("../data/processed/mono_resistance_df_filledna.csv").set_index('SampleID')

binarized_final_df = pd.read_csv("../data/processed/binarized_final_df.csv").set_index('SampleID')


binarized_final_df.head()

Unnamed: 0_level_0,rifampicin_resistance,isoniazid_resistance,pyrazinamide_resistance,ethambutol_resistance,streptomycin_resistance,fluoroquinolones_resistance,moxifloxacin_resistance,ofloxacin_resistance,levofloxacin_resistance,ciprofloxacin_resistance,...,NC000962_3.3890339,NC000962_3.3890344,NC000962_3.3890347,NC000962_3.3890356,NC000962_3.3890358,NC000962_3.3890363,NC000962_3.3890368,NC000962_3.3890774,NC000962_3.3890776,NC000962_3.3890781
SampleID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ERR760783,1,1,1,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
ERR776661,1,1,1,1,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
SRR11098556,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
ERR760911,1,1,1,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
SRR9224969,1,1,1,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [5]:
binarized_final_df.columns

Index(['rifampicin_resistance', 'isoniazid_resistance',
       'pyrazinamide_resistance', 'ethambutol_resistance',
       'streptomycin_resistance', 'fluoroquinolones_resistance',
       'moxifloxacin_resistance', 'ofloxacin_resistance',
       'levofloxacin_resistance', 'ciprofloxacin_resistance',
       ...
       'NC000962_3.3890339', 'NC000962_3.3890344', 'NC000962_3.3890347',
       'NC000962_3.3890356', 'NC000962_3.3890358', 'NC000962_3.3890363',
       'NC000962_3.3890368', 'NC000962_3.3890774', 'NC000962_3.3890776',
       'NC000962_3.3890781'],
      dtype='object', length=109705)

In [6]:
binarized_final_df= binarized_final_df.drop(columns=[*renamed_drug_columns_names, *lineage_column_names, 'drtype', 'MDR', 'XDR'], axis= 1)

binarized_final_df.head()

Unnamed: 0_level_0,Resistance_Status,NC000962_3.42,NC000962_3.78,NC000962_3.80,NC000962_3.102,NC000962_3.104,NC000962_3.117,NC000962_3.120,NC000962_3.135,NC000962_3.138,...,NC000962_3.3890339,NC000962_3.3890344,NC000962_3.3890347,NC000962_3.3890356,NC000962_3.3890358,NC000962_3.3890363,NC000962_3.3890368,NC000962_3.3890774,NC000962_3.3890776,NC000962_3.3890781
SampleID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ERR760783,Resistant,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
ERR776661,Resistant,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
SRR11098556,Sensitive,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
ERR760911,Resistant,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
SRR9224969,Resistant,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [17]:
binarized_final_df['Resistance_Status']= binarized_final_df['Resistance_Status'].apply(lambda resistance: 0.0 if resistance == 'Sensitive' else 1.0)
binarized_final_df.head()

Unnamed: 0_level_0,Resistance_Status,NC000962_3.42,NC000962_3.78,NC000962_3.80,NC000962_3.102,NC000962_3.104,NC000962_3.117,NC000962_3.120,NC000962_3.135,NC000962_3.138,...,NC000962_3.3890339,NC000962_3.3890344,NC000962_3.3890347,NC000962_3.3890356,NC000962_3.3890358,NC000962_3.3890363,NC000962_3.3890368,NC000962_3.3890774,NC000962_3.3890776,NC000962_3.3890781
SampleID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ERR760783,1.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
ERR776661,1.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
SRR11098556,0.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
ERR760911,1.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
SRR9224969,1.0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [18]:
X = binarized_final_df.loc[:, binarized_final_df.columns != 'Resistance_Status']
y = binarized_final_df.loc[:, 'Resistance_Status']


In [19]:
X

Unnamed: 0_level_0,NC000962_3.42,NC000962_3.78,NC000962_3.80,NC000962_3.102,NC000962_3.104,NC000962_3.117,NC000962_3.120,NC000962_3.135,NC000962_3.138,NC000962_3.150,...,NC000962_3.3890339,NC000962_3.3890344,NC000962_3.3890347,NC000962_3.3890356,NC000962_3.3890358,NC000962_3.3890363,NC000962_3.3890368,NC000962_3.3890774,NC000962_3.3890776,NC000962_3.3890781
SampleID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ERR760783,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
ERR776661,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
SRR11098556,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
ERR760911,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
SRR9224969,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SRR10851740,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
ERR751560,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
SRR10851694,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
ERR757166,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [20]:
# NOTE: The dtype might be problematic but for now this works!

y.describe().T

count    1321.000000
mean        0.744890
std         0.436088
min         0.000000
25%         0.000000
50%         1.000000
75%         1.000000
max         1.000000
Name: Resistance_Status, dtype: float64

In [21]:
# NOTE: The dtype might be problematic but for now this works!

y

SampleID
ERR760783      1.0
ERR776661      1.0
SRR11098556    0.0
ERR760911      1.0
SRR9224969     1.0
              ... 
SRR10851740    1.0
ERR751560      1.0
SRR10851694    0.0
ERR757166      1.0
ERR751427      1.0
Name: Resistance_Status, Length: 1321, dtype: float64

In [22]:

X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                    stratify=y,
                                                    train_size=0.7,
                                                    test_size=0.3,
                                                    random_state=100)


In [23]:
y_train

SampleID
ERR751400     1.0
ERR758379     1.0
SRR9224957    1.0
ERR779858     1.0
ERR760925     1.0
             ... 
ERR760763     1.0
ERR779909     1.0
ERR751399     1.0
SRR9224929    1.0
SRR3732651    1.0
Name: Resistance_Status, Length: 924, dtype: float64

In [24]:
from sklearn.svm import LinearSVC

model_svm= LinearSVC()

model_svm.fit(X_train, y_train)

y_pred= model_svm.predict(X_test)


model_performance_metrics(model_svm, X, X_test, X_train, y, y_test, y_pred, show_feature_importances= False)


Accuracy: 0.836272040302267
Mean Absolute Error: 0.163727959697733
Mean Squared Error: 0.163727959697733
Root Mean Squared Error: 0.40463311740110075
-------------------------------------------------------------------------------- 



In [25]:
from sklearn.ensemble import RandomForestClassifier



model_rf= RandomForestClassifier(n_estimators= 100,
                                  random_state = 100,
                                  max_depth=5,
                                  min_samples_leaf=50,
                                  min_samples_split=50)

model_rf.fit(X_train, y_train)

y_pred= model_rf.predict(X_test)

model_performance_metrics(model_rf, X, X_test, X_train, y, y_test, y_pred)

Accuracy: 0.7455919395465995
Mean Absolute Error: 0.25440806045340053
Mean Squared Error: 0.25440806045340053
Root Mean Squared Error: 0.5043887988976367
-------------------------------------------------------------------------------- 

=== Feature Importances ===
                    importance
NC000962_3.2631932    0.034049
NC000962_3.1637197    0.028205
NC000962_3.1634609    0.025967
NC000962_3.2867532    0.023453
NC000962_3.2866580    0.019393
...                        ...
NC000962_3.1223311    0.000000
NC000962_3.1223295    0.000000
NC000962_3.1223290    0.000000
NC000962_3.1223287    0.000000
NC000962_3.3890781    0.000000

[109678 rows x 1 columns]


In [26]:
from sklearn.ensemble import GradientBoostingClassifier

model_gb= GradientBoostingClassifier(
                                     n_estimators= 100,
                                     random_state = 100,
                                     max_depth=5
                                    )

model_gb.fit(X_train, y_train)

y_pred= model_gb.predict(X_test)


model_performance_metrics(model_gb, X, X_test, X_train, y, y_test, y_pred)

Accuracy: 0.853904282115869
Mean Absolute Error: 0.14609571788413098
Mean Squared Error: 0.14609571788413098
Root Mean Squared Error: 0.3822246955445592
-------------------------------------------------------------------------------- 

=== Feature Importances ===
                    importance
NC000962_3.1637204    0.067002
NC000962_3.2866938    0.043465
NC000962_3.839534     0.036705
NC000962_3.338020     0.033225
NC000962_3.3248075    0.030296
...                        ...
NC000962_3.1223851    0.000000
NC000962_3.1223848    0.000000
NC000962_3.1223845    0.000000
NC000962_3.1223842    0.000000
NC000962_3.1796945    0.000000

[109678 rows x 1 columns]


In [27]:
from sklearn.naive_bayes import GaussianNB
# from sklearn.naive_bayes import BernoulliNB

model_nb= GaussianNB()
# model_nb= BernoulliNB()

model_nb.fit(X_train, y_train)

y_pred= model_nb.predict(X_test)


model_performance_metrics(model_nb, X, X_test, X_train, y, y_test, y_pred, show_feature_importances= False)

Accuracy: 0.7884130982367759
Mean Absolute Error: 0.21158690176322417
Mean Squared Error: 0.21158690176322417
Root Mean Squared Error: 0.4599857625657822
-------------------------------------------------------------------------------- 



In [28]:
from sklearn.neural_network import MLPClassifier

model_mlp = MLPClassifier(
                          solver='lbfgs',
                          alpha=1e-5,
                          hidden_layer_sizes=(5, 2),
                          random_state=1
)



model_mlp.fit(X_train, y_train)

y_pred= model_mlp.predict(X_test)

model_performance_metrics(model_mlp, X, X_test, X_train, y, y_test, y_pred, show_feature_importances= False)

Accuracy: 0.7455919395465995
Mean Absolute Error: 0.25440806045340053
Mean Squared Error: 0.25440806045340053
Root Mean Squared Error: 0.5043887988976367
-------------------------------------------------------------------------------- 



In [29]:
from xgboost import XGBClassifier

model_xgb = XGBClassifier(
                          learning_rate= 0.01,
                          random_state= 1
)



model_xgb.fit(X_train, y_train)

y_pred= model_xgb.predict(X_test)

model_performance_metrics(model_xgb, X, X_test, X_train, y, y_test, y_pred)

Accuracy: 0.8488664987405542
Mean Absolute Error: 0.15113350125944586
Mean Squared Error: 0.15113350125944586
Root Mean Squared Error: 0.3887589243470121
-------------------------------------------------------------------------------- 

=== Feature Importances ===
                    importance
NC000962_3.1637197    0.044943
NC000962_3.3248075    0.039473
NC000962_3.2301075    0.028397
NC000962_3.1637214    0.026396
NC000962_3.332759     0.022807
...                        ...
NC000962_3.1223443    0.000000
NC000962_3.1223434    0.000000
NC000962_3.1223410    0.000000
NC000962_3.1223392    0.000000
NC000962_3.3890781    0.000000

[109678 rows x 1 columns]


In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import LinearSVC
from sklearn.ensemble import StackingClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.naive_bayes import GaussianNB
from xgboost import XGBClassifier
from sklearn.ensemble import GradientBoostingClassifier


estimators = [
    ('rf', RandomForestClassifier()),
    ('gb', GradientBoostingClassifier()),
    ('svm', LinearSVC()),
    ('mlp', MLPClassifier()),
    ('nb', GaussianNB()),
    ('xgb', XGBClassifier())
]

model_se = StackingClassifier(
    estimators=estimators,
    final_estimator=RandomForestClassifier()
)

model_se.fit(X_train, y_train)

y_pred = model_se.predict(X_test)

model_performance_metrics(model_se, X, X_test, X_train, y, y_test, y_pred, show_feature_importances= False)

## Grid search for each model