In [94]:
import numpy as np
import os
import pandas as pd
import pickle
import matplotlib.pyplot as plt
import seaborn as sns

from collections import defaultdict
from ipywidgets import interact

## Load baseline and new model results

In [95]:
data_path = "./lm_experiments/DRIAMS_A/"
model_metrics = defaultdict(dict)
model_roc = defaultdict(dict)
model_pr = defaultdict(dict)

for model_type in ["GBM", "MLP", "LogisticRegression"]:
    
    metrics = [i for i in os.listdir(os.path.join(data_path, model_type)) if "metrics" in i]
    rocs = [i for i in os.listdir(os.path.join(data_path, model_type)) if "_roc" in i]
    prs = [i for i in os.listdir(os.path.join(data_path, model_type)) if "_pr" in i]

    # Retrieve metrics tables
    for met in metrics:
        model_metrics[model_type][met[:-12]] = pd.read_csv(os.path.join(data_path, model_type, met), index_col=0)
    
    # Retrieve ROC curves
    for roc in rocs:        
        with open(os.path.join(data_path, model_type, roc), 'rb') as handle:
            roc_ = pickle.load(handle)
        model_roc[model_type][roc[:-8]] = roc_
    
    # Retrieve PR curves
    for pr in prs:
        with open(os.path.join(data_path, model_type, pr), 'rb') as handle:
            pr_ = pickle.load(handle)
        model_pr[model_type][pr[:-7]] = pr_

# Plot performance across models

In [117]:
metrics = [
    "precision",
    "recall",
    "specificity",
    "accuracy",
    "balanced_accuracy",
    "f1",
    "mcc",
    "roc_auc",
    "auprc",
]

In [118]:
@interact()
def plot_performance(
    species_drug_setting=model_metrics["GBM"].keys()
):
    
    metrics_table = pd.DataFrame()
    for model_type in ["GBM", "MLP", "LogisticRegression"]:
        
        cur_metrics = model_metrics[model_type][species_drug_setting]
        cur_metrics["model"] = model_type
        
        metrics_table = pd.concat([metrics_table, cur_metrics])
                
    metrics_table = metrics_table.loc[:, metrics + ["model"]]
    
    plt.figure(figsize=(8, 5))
    sns.set_context("talk")
                
    sns.barplot(
        data=metrics_table.melt(id_vars="model"),
        x="variable",
        y="value",
        hue="model",
        palette="Set2"
    )
    
    plt.xlabel("")
    plt.ylabel("metric value")
    
    plt.xticks(rotation=45)
    plt.show()


interactive(children=(Dropdown(description='species_drug_setting', options=('Staphylococcus_aureus_Oxacillin',…

<Figure size 576x360 with 0 Axes>