<a href="https://colab.research.google.com/github/castudil/bacteria-multi-label/blob/main/multilabel_bac.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Libraries used

In [1]:
import os

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.multioutput import ClassifierChain
from sklearn.metrics import (f1_score, multilabel_confusion_matrix,
                             accuracy_score, hamming_loss, jaccard_score, make_scorer)

from skopt import BayesSearchCV
from skopt.space import Real, Categorical, Integer

import xgboost as xgb

from joblib import dump, load
import joblib

In [2]:
os.chdir("..")

In [3]:
train_file = "data/processed/binned/train_s_aureus_driams_bin5.csv"
train_bac = pd.read_csv(train_file)
train_bac

Unnamed: 0,2000,2020,2040,2060,2080,2100,2120,2140,2160,2180,...,9860,9880,9900,9920,9940,9960,9980,Oxacillin,Clindamycin,Fusidic acid
0,0.021271,0.021774,0.025388,0.021310,0.022024,0.029203,0.028496,0.030084,0.024553,0.040691,...,0.037734,0.034542,0.031264,0.027489,0.035240,0.035080,0.034169,0.0,0.0,0.0
1,0.008401,0.007471,0.007531,0.007448,0.007080,0.006454,0.006307,0.014959,0.006481,0.005659,...,0.015979,0.015110,0.014384,0.016972,0.018075,0.024027,0.025398,0.0,0.0,0.0
2,0.021248,0.020157,0.021281,0.040733,0.027967,0.025514,0.025027,0.029584,0.026714,0.021601,...,0.023793,0.020226,0.022981,0.023913,0.026685,0.026365,0.026794,0.0,0.0,0.0
3,0.019755,0.033714,0.025336,0.036465,0.024422,0.020777,0.022242,0.026822,0.031641,0.028074,...,0.040163,0.037922,0.043811,0.045006,0.046522,0.051603,0.051320,0.0,0.0,0.0
4,0.010368,0.010467,0.007812,0.007254,0.008660,0.012023,0.014619,0.056800,0.014732,0.012319,...,0.142658,0.124189,0.146476,0.148880,0.145308,0.161595,0.200539,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2819,0.047280,0.045602,0.046357,0.062193,0.052936,0.050927,0.049359,0.061550,0.053316,0.048646,...,0.020823,0.036683,0.025691,0.022913,0.020591,0.023360,0.021736,0.0,1.0,0.0
2820,0.121892,0.118875,0.135706,0.156375,0.114842,0.114014,0.110345,0.131316,0.135759,0.122672,...,0.062531,0.074468,0.067012,0.078088,0.075663,0.082759,0.081547,0.0,1.0,0.0
2821,0.039517,0.046617,0.045379,0.090256,0.060707,0.055033,0.059977,0.072631,0.064026,0.050019,...,0.009032,0.008215,0.005234,0.004677,0.004779,0.004562,0.005416,0.0,0.0,0.0
2822,0.005814,0.004390,0.004179,0.004032,0.004289,0.004521,0.005645,0.008377,0.004863,0.005159,...,0.038374,0.030132,0.027900,0.032357,0.036210,0.038745,0.037734,1.0,1.0,0.0


In [4]:
test_file = "data/processed/binned/test_s_aureus_driams_bin5.csv"
test_bac = pd.read_csv(test_file)
test_bac

Unnamed: 0,2000,2020,2040,2060,2080,2100,2120,2140,2160,2180,...,9860,9880,9900,9920,9940,9960,9980,Oxacillin,Clindamycin,Fusidic acid
0,0.040233,0.039182,0.036141,0.055947,0.041588,0.035010,0.040361,0.051716,0.048777,0.042674,...,0.244114,0.205781,0.197681,0.205721,0.220389,0.246994,0.252911,0.0,0.0,0.0
1,0.001984,0.002287,0.001342,0.001770,0.001454,0.003859,0.004831,0.024843,0.004303,0.002493,...,0.021739,0.021589,0.021412,0.023902,0.032218,0.035276,0.041813,0.0,0.0,0.0
2,0.029025,0.030321,0.029901,0.049815,0.034221,0.032219,0.034535,0.041287,0.033611,0.028739,...,0.068968,0.057153,0.066573,0.073411,0.084606,0.082602,0.081549,0.0,1.0,0.0
3,0.016073,0.015850,0.016210,0.027293,0.022218,0.018732,0.021542,0.034816,0.020937,0.017772,...,0.012289,0.006948,0.005129,0.005349,0.005630,0.005694,0.005128,0.0,0.0,1.0
4,0.039255,0.031352,0.030771,0.036253,0.036255,0.031406,0.033008,0.042148,0.031658,0.033749,...,0.040727,0.044741,0.027953,0.025696,0.026566,0.026079,0.030433,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
702,0.011582,0.013723,0.013263,0.022686,0.014900,0.014774,0.015089,0.017437,0.016548,0.020700,...,0.043830,0.037882,0.045471,0.047741,0.053822,0.050840,0.059200,0.0,0.0,0.0
703,0.017533,0.016964,0.017570,0.025974,0.020161,0.018173,0.023038,0.039156,0.021047,0.018618,...,0.064979,0.037914,0.042551,0.043155,0.047281,0.050678,0.050392,1.0,0.0,0.0
704,0.046769,0.044975,0.045496,0.084074,0.054947,0.046307,0.052593,0.070593,0.052405,0.039588,...,0.099234,0.092673,0.077899,0.081252,0.087370,0.094524,0.089711,0.0,0.0,1.0
705,0.011834,0.011044,0.010476,0.012997,0.011621,0.010914,0.010469,0.013829,0.012221,0.011838,...,0.044459,0.059972,0.057565,0.050916,0.049549,0.054562,0.058033,1.0,0.0,0.0


In [5]:
train_x = train_bac[train_bac.columns.drop(list(train_bac.filter(regex='[^0-9]')))]
test_x = test_bac[test_bac.columns.drop(list(test_bac.filter(regex='[^0-9]')))]

In [6]:
antibiotics = train_bac.columns.drop(train_x.columns)

In [7]:
train_y = train_bac[antibiotics]
test_y = test_bac[antibiotics]

In [8]:
def multilabel_f1_wrapper(true, pred, average="weighted"):
    if isinstance(true, list):
        true = np.array(true)
    elif isinstance(true, pd.DataFrame):
        true = true.to_numpy()
    if isinstance(pred, list):
        pred = np.array(pred)
    elif isinstance(true, pd.DataFrame):
        pred = pred.to_numpy()
    column = 0
    total = 0
    while column < true[0].size:
        total+=f1_score(true[:, column], pred[:, column], average=average)
        column+=1
    return total/(column)

In [9]:
def report(true, pred):
        
    hl = hamming_loss(true, pred)
    f1w = multilabel_f1_wrapper(true, pred, "weighted")
    acc = accuracy_score(true, pred)
    
    f1u = multilabel_f1_wrapper(true, pred, "macro")
    f1su = f1_score(true, pred, average="macro")
    f1sw = f1_score(true, pred, average="weighted")

    
    print("Main metrics:")
    print(" Hamming Loss:", hl)
    print(" Accuracy:", acc)
    print(" F1 Score (Weighted):", f1w)
    print("================================================")
    print("Other metrics:")
    print(" F1 Score (Unweighted):", f1u)
    print(" F1 Score (sklearn Unweighted):", f1su)
    print(" F1 Score (sklearn Weighted):", f1sw)
    return hl, acc, f1w

___


In [10]:
bayesopt = BayesSearchCV(
    ClassifierChain(xgb.XGBClassifier(), random_state=0),
    {
        "base_estimator__objective": Categorical(["binary:logistic"]),
        "base_estimator__max_depth": Integer(1, 10),
        "base_estimator__min_child_weight": Real(1e-6, 10, prior="log-uniform"),
        "base_estimator__max_delta_step": Real(1e-6, 10, prior="log-uniform"),
        "base_estimator__subsample": Real(1e-6, 1, prior="log-uniform"),
        "base_estimator__tree_method": Categorical(["exact", "approx", "hist"]),
        "base_estimator__scale_pos_weight": Real(1e-6, 10, prior="log-uniform"),
        "base_estimator__gamma": Real(1e-6, 10, prior="log-uniform"),
        "base_estimator__eta": Real(1e-6, 1, prior="log-uniform")
    },
    n_iter=250,
    cv=5,
    random_state=0,
    n_jobs=10,
    n_points=2,
    scoring=make_scorer(multilabel_f1_wrapper),
    verbose=1,
)

bayesopt.fit(train_x, train_y)

Fitting 5 folds for each of 2 candidates, totalling 10 fits
Fitting 5 folds for each of 2 candidates, totalling 10 fits
Fitting 5 folds for each of 2 candidates, totalling 10 fits
Fitting 5 folds for each of 2 candidates, totalling 10 fits
Fitting 5 folds for each of 2 candidates, totalling 10 fits
Fitting 5 folds for each of 2 candidates, totalling 10 fits
Fitting 5 folds for each of 2 candidates, totalling 10 fits
Fitting 5 folds for each of 2 candidates, totalling 10 fits
Fitting 5 folds for each of 2 candidates, totalling 10 fits
Fitting 5 folds for each of 2 candidates, totalling 10 fits
Fitting 5 folds for each of 2 candidates, totalling 10 fits
Fitting 5 folds for each of 2 candidates, totalling 10 fits
Fitting 5 folds for each of 2 candidates, totalling 10 fits
Fitting 5 folds for each of 2 candidates, totalling 10 fits
Fitting 5 folds for each of 2 candidates, totalling 10 fits
Fitting 5 folds for each of 2 candidates, totalling 10 fits
Fitting 5 folds for each of 2 candidates

In [11]:
best_iteration = 0
for i in range(0, 250):
    if bayesopt.cv_results_["mean_test_score"][i] == bayesopt.best_score_:
        best_iteration = i
print("Best iteration:", best_iteration)
print("Split scores:")
for i in range(0, 5):
    print("", i, bayesopt.cv_results_["split"+str(i)+"_test_score"][best_iteration])
    
print("Mean score:", bayesopt.best_score_)
print("Best parameter combination found:", bayesopt.best_params_)

Best iteration: 242
Split scores:
 0 0.8485167786564469
 1 0.8616573057610184
 2 0.8509100794256234
 3 0.8552151198095618
 4 0.857133845686513
Mean score: 0.8546866258678326
Best parameter combination found: OrderedDict([('base_estimator__eta', 1.0), ('base_estimator__gamma', 1e-06), ('base_estimator__max_delta_step', 0.7910985489927498), ('base_estimator__max_depth', 10), ('base_estimator__min_child_weight', 0.07488040746686085), ('base_estimator__objective', 'binary:logistic'), ('base_estimator__scale_pos_weight', 1.1587577248521648), ('base_estimator__subsample', 1.0), ('base_estimator__tree_method', 'exact')])


In [15]:
model_file = "xgb_s_aureus_raw_bin5.joblib"

In [16]:
dump(bayesopt.best_estimator_, model_file) 

['xgb_s_aureus_raw_bin20.joblib']

In [18]:
model = load(model_file)

In [None]:
model.fit(train_x, train_y) 
pred = model.predict(test_x)
model_hl, model_acc, model_f1 = report(test_y, pred)

In [None]:
fig, axes = plt.subplots(1, len(antibiotics), figsize=(len(antibiotics)*5, 5))
fig.supxlabel("Predicted Label")
fig.supylabel("True Label")

cm_svm_c = multilabel_confusion_matrix(test_y, (pred > 0.5))

for i in range(len(antibiotics)):
  sns.heatmap(ax=axes[i], data=cm_svm_c[i], annot=True, fmt='d', cbar=None, cmap="Blues", xticklabels=["S", "R"], yticklabels=["S", "R"]).set(title=antibiotics[i])

In [None]:
proba = model.predict_proba(test_x)
proba

In [None]:
for antibiotic in range(len(antibiotics)):
    count_tp = 0
    count_tn = 0
    count_fp = 0
    count_fn = 0
    sum_tp = 0
    sum_tn = 0
    sum_fp = 0
    sum_fn = 0
    for i in range(len(proba[:, antibiotic])):
        disc_pred = int(proba[i, antibiotic] > 0.5)
        if disc_pred == 1:
            if disc_pred == test_y.iloc[i, antibiotic]:
                count_tp += 1
                sum_tp += proba[i, antibiotic]
            else:
                count_fp += 1
                sum_fp += proba[i, antibiotic]
        else:
            if disc_pred == test_y.iloc[i, antibiotic]:
                count_tn += 1
                sum_tn += proba[i, antibiotic]
            else:
                count_fn += 1
                sum_fn += proba[i, antibiotic]
    print("Results for antibiotic", antibiotics[antibiotic])
    if count_tp == 0:
        print(" Mean TP: None")
    else: 
        print(" Mean TP:", sum_tp/count_tp)
    if count_tn == 0:
        print(" Mean TN: None")
    else: 
        print(" Mean TN:", sum_tn/count_tn)
    if count_fp == 0:
        print(" Mean FP: None")
    else: 
        print(" Mean FP:", sum_fp/count_fp)
    if count_fn == 0:
        print(" Mean FN: None")
    else: 
        print(" Mean FN:", sum_fn/count_fn)
