<a href="https://colab.research.google.com/github/castudil/bacteria-multi-label/blob/main/multilabel_bac.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Libraries used

In [1]:
import os
import itertools

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.preprocessing import LabelEncoder
from sklearn.multioutput import ClassifierChain
from sklearn.metrics import (f1_score, multilabel_confusion_matrix,
                             accuracy_score, hamming_loss, jaccard_score, make_scorer)

from skopt import BayesSearchCV
from skopt.space import Real, Categorical, Integer
from sklearn.calibration import CalibratedClassifierCV

from sklearn.svm import SVC

import shap
import torch

from joblib import dump, load

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [3]:
os.chdir("..")
os.chdir("..")

In [4]:
train_file = "data/processed/binned/standard/train_s_aureus_driams_bin20.csv"
train_bac = pd.read_csv(train_file)
train_bac

Unnamed: 0,2000,2020,2040,2060,2080,2100,2120,2140,2160,2180,...,9860,9880,9900,9920,9940,9960,9980,Oxacillin,Clindamycin,Fusidic acid
0,-0.286217,-0.263798,-0.185108,-0.460499,-0.340950,-0.140792,-0.224080,-0.401850,-0.324227,0.033580,...,-0.523652,-0.499112,-0.496468,-0.558968,-0.476363,-0.511347,-0.534530,0.0,0.0,0.0
1,-0.551211,-0.571190,-0.567694,-0.682123,-0.626766,-0.625056,-0.681753,-0.698166,-0.694874,-0.655262,...,-0.843329,-0.786038,-0.779041,-0.730581,-0.736446,-0.665822,-0.654775,0.0,0.0,0.0
2,-0.286692,-0.298551,-0.273101,-0.149969,-0.227278,-0.219336,-0.295638,-0.411640,-0.279910,-0.341791,...,-0.728511,-0.710503,-0.635119,-0.617323,-0.605996,-0.633157,-0.635633,0.0,0.0,0.0
3,-0.317437,-0.007199,-0.186219,-0.218208,-0.295089,-0.320160,-0.353085,-0.465764,-0.178849,-0.214500,...,-0.487951,-0.449209,-0.286415,-0.273146,-0.305423,-0.280417,-0.299386,0.0,0.0,0.0
4,-0.510705,-0.506801,-0.561659,-0.685221,-0.596548,-0.506514,-0.510323,0.121535,-0.525663,-0.524299,...,1.018196,0.824586,1.432266,1.421755,1.191401,1.256895,1.746369,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2819,0.249296,0.248307,0.264162,0.193118,0.250260,0.321630,0.206245,0.214599,0.265692,0.190006,...,-0.772152,-0.467501,-0.589754,-0.633639,-0.698323,-0.675144,-0.704980,0.0,1.0,0.0
2820,1.785533,1.823061,2.178447,1.698854,1.434265,1.664576,1.464162,1.581346,1.956586,1.645609,...,-0.159264,0.090428,0.101984,0.266643,0.136130,0.155044,0.115017,0.0,1.0,0.0
2821,0.089459,0.270119,0.243205,0.641778,0.398895,0.409032,0.425259,0.431668,0.485349,0.216998,...,-0.945420,-0.887857,-0.932228,-0.931200,-0.937917,-0.937877,-0.928717,0.0,0.0,0.0
2822,-0.604482,-0.637415,-0.639497,-0.736734,-0.680152,-0.666202,-0.695422,-0.827097,-0.728076,-0.665091,...,-0.514248,-0.564228,-0.552775,-0.479549,-0.461669,-0.460116,-0.485656,1.0,1.0,0.0


In [5]:
test_file = "data/processed/binned/standard/test_s_aureus_driams_bin20.csv"
test_bac = pd.read_csv(test_file)
test_bac

Unnamed: 0,2000,2020,2040,2060,2080,2100,2120,2140,2160,2180,...,9860,9880,9900,9920,9940,9960,9980,Oxacillin,Clindamycin,Fusidic acid
0,0.104186,0.110324,0.045273,0.093256,0.033230,-0.017189,0.020650,0.021929,0.172607,0.072579,...,2.509083,2.029358,2.289490,2.349229,2.329039,2.450476,2.464367,0.0,0.0,0.0
1,-0.683329,-0.682614,-0.700279,-0.772893,-0.734363,-0.680305,-0.712202,-0.504519,-0.739549,-0.717521,...,-0.758693,-0.690376,-0.661395,-0.617496,-0.522159,-0.508605,-0.429725,0.0,0.0,0.0
2,-0.126579,-0.080111,-0.088407,-0.004778,-0.107682,-0.076610,-0.099526,-0.182381,-0.138457,-0.201429,...,-0.064662,-0.165242,0.094636,0.190332,0.271641,0.152851,0.115044,0.0,1.0,0.0
3,-0.393258,-0.391113,-0.381737,-0.364848,-0.337233,-0.363708,-0.367530,-0.309146,-0.398401,-0.417081,...,-0.897556,-0.906563,-0.933985,-0.920237,-0.925019,-0.922054,-0.932670,0.0,0.0,1.0
4,0.084063,-0.057955,-0.069780,-0.221597,-0.068780,-0.093904,-0.131015,-0.165511,-0.178512,-0.102924,...,-0.479673,-0.348515,-0.551896,-0.588232,-0.607792,-0.637152,-0.585738,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
702,-0.485712,-0.436828,-0.444883,-0.438499,-0.477204,-0.447962,-0.500632,-0.649616,-0.488410,-0.359514,...,-0.434061,-0.449801,-0.258621,-0.228516,-0.194811,-0.291070,-0.191354,0.0,0.0,0.0
703,-0.363197,-0.367166,-0.352612,-0.385939,-0.376585,-0.375591,-0.336656,-0.224120,-0.396143,-0.400442,...,-0.123281,-0.449327,-0.307500,-0.303360,-0.293913,-0.293342,-0.312114,1.0,0.0,0.0
704,0.238776,0.234833,0.245708,0.542942,0.288722,0.223289,0.272952,0.391740,0.246997,0.011889,...,0.380082,0.359241,0.284242,0.318265,0.313519,0.319468,0.226947,0.0,0.0,1.0
705,-0.480531,-0.494409,-0.504593,-0.593396,-0.539916,-0.530114,-0.595922,-0.720296,-0.577149,-0.533768,...,-0.424826,-0.123627,-0.056155,-0.176721,-0.259546,-0.239056,-0.207350,1.0,0.0,0.0


In [6]:
train_x = train_bac[train_bac.columns.drop(list(train_bac.filter(regex='[^0-9]')))]
test_x = test_bac[test_bac.columns.drop(list(test_bac.filter(regex='[^0-9]')))]

In [7]:
antibiotics = train_bac.columns.drop(train_x.columns)

In [8]:
train_y = train_bac[antibiotics]
test_y = test_bac[antibiotics]

In [9]:
train_y_lps = pd.DataFrame()
train_y_lps["Class"] = train_bac[antibiotics].astype(int).astype(str).agg(''.join, axis=1)
train_y_lps["Class"] = train_y_lps["Class"].astype(str)

test_y_lps = pd.DataFrame()
test_y_lps["Class"] = test_bac[antibiotics].astype(int).astype(str).agg(''.join, axis=1)
test_y_lps["Class"] = test_y_lps["Class"].astype(str)

In [10]:
lc = LabelEncoder()
lc.fit(train_y_lps.values.ravel())
train_y_lps = lc.transform(train_y_lps.values.ravel())
test_y_lps = lc.transform(test_y_lps.values.ravel())

In [11]:
# Transforms a one-label instance into a multi-label one.
def lps_to_multilabel_instance(lps_num):
  inverse = lc.inverse_transform([lps_num])
  multilabel_instance = []
  for result in inverse[0]:
      multilabel_instance.append(int(result))
  return multilabel_instance

# Transforms a list of one-label instances into a multi-label one.
def lps_to_multilabel_list(lps_list):
  multilabel_list = []
  for lps_instance in lps_list:
    multilabel_list.append(lps_to_multilabel_instance(lps_instance))
  return multilabel_list

In [12]:
def multilabel_f1_wrapper(true, pred, average="weighted"):
    if isinstance(true, list):
        true = np.array(true)
    elif isinstance(true, pd.DataFrame):
        true = true.to_numpy()
    if isinstance(pred, list):
        pred = np.array(pred)
    elif isinstance(true, pd.DataFrame):
        pred = pred.to_numpy()
    column = 0
    total = 0
    while column < true[0].size:
        total+=f1_score(true[:, column], pred[:, column], average=average)
        column+=1
    return total/(column)

In [13]:
def lps_f1_wrapper(true, pred, average="weighted"):
    non_lps_true = lps_to_multilabel_list(true)
    non_lps_pred = lps_to_multilabel_list(pred)
    return multilabel_f1_wrapper(non_lps_true, non_lps_pred, average=average)

In [14]:
def report(true, pred):
    if not len(pred.shape) > 1:
        true = lps_to_multilabel_list(true)
        pred = lps_to_multilabel_list(pred)
        
    hl = hamming_loss(true, pred)
    f1w = multilabel_f1_wrapper(true, pred, "weighted")
    acc = accuracy_score(true, pred)
    
    f1u = multilabel_f1_wrapper(true, pred, "macro")
    f1su = f1_score(true, pred, average="macro")
    f1sw = f1_score(true, pred, average="weighted")

    
    print("Main metrics:")
    print(" Hamming Loss:", hl)
    print(" Accuracy:", acc)
    print(" F1 Score (Weighted):", f1w)
    print("================================================")
    print("Other metrics:")
    print(" F1 Score (Unweighted):", f1u)
    print(" F1 Score (sklearn Unweighted):", f1su)
    print(" F1 Score (sklearn Weighted):", f1sw)
    return hl, acc, f1w

___


In [15]:
# opt_lps = BayesSearchCV(
#     SVC(),
#     {
#         "C": Real(1e-6, 1000, prior="log-uniform"),
#         "kernel": Categorical(["rbf"]),
#         "gamma": Real(1e-6, 1000, prior="log-uniform"),
#     },
#     n_iter=250,
#     cv=5,
#     random_state=0,
#     n_jobs=5,
#     n_points=2,
#     scoring=make_scorer(lps_f1_wrapper),
#     verbose=1
# )
# np.int = int
# opt_lps.fit(train_x, train_y_lps)

# print("Best score:", opt_lps.best_score_)
# print("Best parameter combination found:", opt_lps.best_params_)

In [16]:
# opt_lps.best_params_

In [17]:
model_file = "modeling/models/s_aureus_driams_bin20_svc_standard_lps.joblib"

In [18]:
# dump(opt_lps.best_estimator_, model_file) 

In [19]:
model = load(model_file)

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [20]:
pred = model.predict(test_x)
model_hl, model_acc, model_f1 = report(test_y_lps, pred)

Main metrics:
 Hamming Loss: 0.10372465818010372
 Accuracy: 0.7454031117397454
 F1 Score (Weighted): 0.8840744114874974
Other metrics:
 F1 Score (Unweighted): 0.6892160547247682
 F1 Score (sklearn Unweighted): 0.43745935189614227
 F1 Score (sklearn Weighted): 0.5185201027817817


In [21]:
model.probability=True
model.fit(train_x, train_y_lps)

In [22]:
explainer = shap.KernelExplainer(model.predict_proba, shap.sample(train_x, 100))
shap_values = explainer.shap_values(shap.sample(test_x, 100))

 31%|███       | 31/100 [1:35:57<3:19:09, 173.18s/it]

In [None]:
# shap.initjs()
# shap.force_plot(explainer.expected_value[0], shap_values[..., 0], test_x)

: 

: 