In [1]:
import warnings
warnings.filterwarnings("ignore")

# Visualization of model performance - Multimodal AMR

In [2]:
import os
import json
import pandas as pd
import numpy as np

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import glob

from notebook_utils import *

In [None]:
results_path = "../results/"
os.makedirs(results_path, exist_ok=True)

In [4]:
# Load all results
results_dict = {}
for dataset in ["A"]:
    for method in ["spectra"]:
        for drug_representation in ["Morgan_1024", "Mole_BERT"]:
            results = pd.read_csv(
                f"{results_path}/Stratified_by_species/{drug_representation}/ConformalAMR/Conformal_DRIAMS-{dataset}_{method}_results/test_set_seed0.csv"
            )
            results_dict[f"DRIAMS-{dataset}_{method}_{drug_representation}"] = results

In [5]:
finetuned_results_dict = {}
for dataset in ["A"]:
    for drug_representation in ["Morgan_1024", "Mole_BERT"]:
        finetuned_results = []
        finetuned_results_train = []

        finetuned_path = f"{results_path}/{split}/{drug_representation}/ConformalAMR/Conformal_DRIAMS-{dataset}_spectra_results/finetuned_models/ConformalAMR/FinetuningResMLP/finetuning"
        for species_drug_combination in os.listdir(finetuned_path):
            try:
                finetuned_results.append(
                    pd.read_csv(
                        f"{finetuned_path}/{species_drug_combination}/predictions/split_None.csv"
                    )
                )
                if drug_representation == "Mole_BERT":
                    finetuned_results_train.append(
                        pd.read_csv(
                            f"{finetuned_path}/{species_drug_combination}/predictions/split_None_train.csv"
                        )
                    )
            except FileNotFoundError:
                continue

        finetuned_results = pd.concat(finetuned_results)
        finetuned_results.rename(
            columns={"predicted_proba": "Predictions"}, inplace=True
        )
        finetuned_results_dict[f"DRIAMS-{dataset}_spectra_{drug_representation}"] = (
            finetuned_results
        )
        finetuned_results.to_csv(
            f"{results_path}/{split}/{drug_representation}/ConformalAMR/Conformal_DRIAMS-{dataset}_spectra_results/finetuned_models/finetuned_results_test.csv",
            index=False,
        )

        if drug_representation == "Mole_BERT":  # Update training set predictions
            finetuned_results_train = pd.concat(finetuned_results_train)
            finetuned_results_train.rename(
                columns={"predicted_proba": "Predictions"}, inplace=True
            )
            finetuned_results_dict[
                f"DRIAMS-{dataset}_spectra_train_{drug_representation}"
            ] = finetuned_results_train
            finetuned_results_train.to_csv(
                f"{results_path}/{split}/{drug_representation}/ConformalAMR/Conformal_DRIAMS-{dataset}_spectra_results/finetuned_models/finetuned_results_train.csv",
                index=False,
            )

In [6]:
# Define relevant drugs
coli_drugs = [
    "Ertapenem",
    "Amoxicillin-Clavulanic acid",
    "Piperacillin-Tazobactam",
    "Ceftriaxone",
    "Ceftazidime",
    "Cefepime",
    "Ciprofloxacin",
    "Levofloxacin",
]
pneumoniae_drugs = [
    "Ertapenem",
    "Imipenem",
    "Meropenem",
    "Amoxicillin-Clavulanic acid",
    "Piperacillin-Tazobactam",
    "Ceftriaxone",
    "Ceftazidime",
    "Cefepime",
    "Ciprofloxacin",
    "Levofloxacin",
]

### Write base and fine-tuned table comparing performance per species and drug with baselines

In [8]:
from collections import defaultdict


def get_performance_table(results_dicts, species, drugs, bootstrap=1000):
    performance_dict = defaultdict(dict)
    for key, results in results_dicts.items():
        data = results.loc[results["species"] == species]
        for drug in drugs:
            try:
                metrics = get_metrics(
                    data.loc[data.drug == drug], species=species, bootstraps=bootstrap
                )
                metrics = pd.DataFrame(
                    {
                        "Median": [metrics["roc_auc"], metrics["average_precision"]],
                        "IC95_low": [
                            metrics["roc_auc_low"],
                            metrics["average_precision_low"],
                        ],
                        "IC95_high": [
                            metrics["roc_auc_high"],
                            metrics["average_precision_high"],
                        ],
                    },
                    index=["AUROC", "AUPRC"],
                )
            except (IndexError, ValueError):
                metrics = pd.DataFrame(
                    {
                        "Median": [np.nan, np.nan],
                        "IC95_low": [np.nan, np.nan],
                        "IC95_high": [np.nan, np.nan],
                    },
                    index=["AUROC", "AUPRC"],
                )

            performance_dict[key][drug] = metrics

    performance_dfs = [
        pd.concat(performance_dict[key], axis=1).set_index(
            pd.MultiIndex.from_product([[key], ["AUROC", "AUPRC"]])
        )
        for key in results_dicts.keys()
    ]

    return pd.concat(performance_dfs, axis=0)

In [9]:
# E. coli performance comparison
coli_base_df = get_performance_table(
    {
        "Morgan-1024": results_dict["DRIAMS-A_spectra_Morgan_1024"],
        "Mole-BERT": results_dict["DRIAMS-A_spectra_Mole_BERT"],
    },
    species="Escherichia coli",
    drugs=coli_drugs,
    bootstrap=1000,
)
coli_finetuned_df = get_performance_table(
    {
        "Morgan-1024": finetuned_results_dict["DRIAMS-A_spectra_Morgan_1024"],
        "Mole-BERT": finetuned_results_dict["DRIAMS-A_spectra_Mole_BERT"],
    },
    species="Escherichia coli",
    drugs=coli_drugs,
    bootstrap=1000,
)

In [10]:
# K. pneumoniae performance comparison
pneumoniae_base_df = get_performance_table(
    {
        "Morgan-1024": results_dict["DRIAMS-A_spectra_Morgan_1024"],
        "Mole-BERT": results_dict["DRIAMS-A_spectra_Mole_BERT"],
    },
    species="Klebsiella pneumoniae",
    drugs=pneumoniae_drugs,
    bootstrap=1000,
)
pneumoniae_finetuned_df = get_performance_table(
    {
        "Morgan-1024": finetuned_results_dict["DRIAMS-A_spectra_Morgan_1024"],
        "Mole-BERT": finetuned_results_dict["DRIAMS-A_spectra_Mole_BERT"],
    },
    species="Klebsiella pneumoniae",
    drugs=pneumoniae_drugs,
    bootstrap=1000,
)

In [None]:
pd.concat([performance_dfs["Escherichia coli"], coli_finetuned_df]).xs(
    "AUROC", level=1
).round(3).T.rename(columns={"Mole-BERT": "ResMLP-GNN", "Morgan-1024": "ResMLP"})

In [None]:
pd.concat([performance_dfs["Klebsiella pneumoniae"], pneumoniae_finetuned_df]).xs(
    "AUROC", level=1
).round(3).T.rename(columns={"Mole-BERT": "ResMLP-GNN", "Morgan-1024": "ResMLP"})

### Plot base and fine-tuned performance on relevant species-drug combinations

In [None]:
dataset = "A"
method = "spectra"
drug_representation = "Mole_BERT"

plot_curves(
    results_dict={
        "DRIAMS-A_spectra": finetuned_results_dict[
            f"DRIAMS-{dataset}_{method}_{drug_representation}"
        ]
    },
    species="Escherichia coli",
    drugs=coli_drugs,
    save=f"{results_path}/finetuned_curves_ecoli.pdf",
)
plot_curves(
    results_dict={
        "DRIAMS-A_spectra": finetuned_results_dict[
            f"DRIAMS-{dataset}_{method}_{drug_representation}"
        ]
    },
    species="Klebsiella pneumoniae",
    drugs=pneumoniae_drugs,
    save=f"{results_path}/finetuned_curves_kpneu.pdf",
)