In [1]:
import pandas as pd
import numpy as np
import os
from os.path import join, exists
import json
import matplotlib.pyplot as plt
import seaborn as sns


from sklearn.metrics import roc_curve, auc, precision_recall_curve


from sklearn.metrics import (
    matthews_corrcoef,
    balanced_accuracy_score,
    f1_score,
    average_precision_score,
)
from sklearn.metrics import precision_score, recall_score

In [4]:
# template = "embdim512_DRIAMS-{}_{}_sp{}_results"
template = "rand_DRIAMS-B_{}_metrics"

dataset = "B"
splits = ["random", "partitioned"]


In [7]:
metrics_df = []

# for dset in datasets:
for sp in splits:
    for i in range(10):
        try:
            with open(join("outputs/RandomClassifier", template.format(sp), f"test_metrics_{i}.json"), "r") as f:
                met = json.load(f)
            met["dataset"] = dataset
            met["split"] = sp
            met["seed"]=i
            # met["species_embedding_dim"]=dim
            metrics_df.append(met)
        except:
            continue
    
metrics_df = pd.DataFrame(metrics_df)
metrics_df = metrics_df.drop(["seed"], axis=1)
metrics_df

Unnamed: 0,mcc,balanced_accuracy,f1,AUPRC,precision,recall,dataset,split
0,0.00075,0.500504,0.249185,-1,0.166149,0.498138,B,random
1,0.007405,0.504977,0.253247,-1,0.168623,0.50838,B,random
2,0.003111,0.502091,0.250755,-1,0.167028,0.502793,B,random
3,-0.00657,0.495584,0.245541,-1,0.163429,0.493482,B,random
4,-0.020199,0.486429,0.23599,-1,0.158162,0.464618,B,random
5,-0.011289,0.492412,0.242495,-1,0.161651,0.485102,B,random
6,-0.003895,0.497383,0.248405,-1,0.164454,0.507449,B,random
7,0.001203,0.500808,0.250517,-1,0.166311,0.507449,B,random
8,-0.012813,0.491389,0.24162,-1,0.16108,0.48324,B,random
9,0.00587,0.503945,0.252038,-1,0.168065,0.503724,B,random


In [8]:
mean_df = metrics_df.groupby(["split", "dataset"]).mean()
mean_df.columns = [c+"_average" for c in mean_df.columns]
std_df = metrics_df.groupby(["split", "dataset"]).std()
std_df.columns = [c+"_std" for c in std_df.columns]
std_df
joined_df = pd.merge(mean_df, std_df, left_index=True, right_index=True)
joined_df

Unnamed: 0_level_0,Unnamed: 1_level_0,mcc_average,balanced_accuracy_average,f1_average,AUPRC_average,precision_average,recall_average,mcc_std,balanced_accuracy_std,f1_std,AUPRC_std,precision_std,recall_std
split,dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
partitioned,B,0.002015,0.501239,0.288501,-1.0,0.202281,0.503195,0.01042,0.006498,0.009259,0.0,0.008179,0.00913
random,B,-0.003643,0.497552,0.246979,-1.0,0.164495,0.495438,0.008983,0.006036,0.005479,0.0,0.003387,0.014124


In [10]:
metrics_order = ["mcc", "f1", "precision", "recall", "AUPRC", "balanced_accuracy"]
cols = []
for m in metrics_order:
    cols.append(f"{m}_average")
    cols.append(f"{m}_std")
    
joined_df = joined_df[cols]
joined_df

Unnamed: 0_level_0,Unnamed: 1_level_0,mcc_average,mcc_std,f1_average,f1_std,precision_average,precision_std,recall_average,recall_std,AUPRC_average,AUPRC_std,balanced_accuracy_average,balanced_accuracy_std
split,dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
partitioned,B,0.002015,0.01042,0.288501,0.009259,0.202281,0.008179,0.503195,0.00913,-1.0,0.0,0.501239,0.006498
random,B,-0.003643,0.008983,0.246979,0.005479,0.164495,0.003387,0.495438,0.014124,-1.0,0.0,0.497552,0.006036


In [11]:
drugs_df = pd.read_csv("../processed_data/drug_fingerprints.csv", index_col=0)
long_table = pd.read_csv("../processed_data/DRIAMS_combined_long_table.csv")
long_table = long_table[long_table["drug"].isin(drugs_df.index)]

long_table

Unnamed: 0,species,sample_id,drug,response,dataset
0,Staphylococcus epidermidis,e9adf43d-679b-497c-9849-1fa214838dd3,Meropenem,1,A
1,Staphylococcus epidermidis,e9adf43d-679b-497c-9849-1fa214838dd3,Ciprofloxacin,1,A
2,Staphylococcus epidermidis,e9adf43d-679b-497c-9849-1fa214838dd3,Cefepime,1,A
3,Staphylococcus epidermidis,e9adf43d-679b-497c-9849-1fa214838dd3,Cotrimoxazole,0,A
4,Staphylococcus epidermidis,e9adf43d-679b-497c-9849-1fa214838dd3,Imipenem,1,A
...,...,...,...,...,...
652766,Staphylococcus aureus,08bc8410-51ec-46d7-ac7b-afba9e6ba2cd_3313,Linezolid,0,D
652767,Staphylococcus aureus,08bc8410-51ec-46d7-ac7b-afba9e6ba2cd_3313,Rifampicin,0,D
652768,Staphylococcus aureus,08bc8410-51ec-46d7-ac7b-afba9e6ba2cd_3313,Tetracycline,0,D
652769,Staphylococcus aureus,08bc8410-51ec-46d7-ac7b-afba9e6ba2cd_3313,Tigecycline,0,D


In [17]:
# folder_template = "outputs/ResAMR_DrugZeroShot/ZS_noCNN_emb512_DRIAMS-{}_drugs_zero_shot_sp0"
drug_zs_results = []
threshold = 0.5
# folder_template = "outputs/Species1hot_ResAMR/OneHot_noCNN_emb512_DRIAMS-{}_drugs_zero_shot_sp0_results"

folder = join("outputs/RandomClassifier", template.format("drugs_zero_shot"))




for fname in os.listdir(folder):
    if not fname.startswith("test"):
        continue

    drug_name = fname.split("_")[-1].split(".")[0]
    with open(join(folder, fname), "r") as f:
        metrics = json.load(f) 


    metrics["dataset"] = dataset
    metrics["drug"]= drug_name
    drug_zs_results.append(metrics)







drug_zs_results = pd.DataFrame(drug_zs_results)
# drug_zs_results.columns = ["dataset", "drug", "test_AUPRC", "n_test_samples", "n_resistant_test_samples"]
# drug_zs_results = drug_zs_results.drop("drug", axis=1)

drug_zs_results

Unnamed: 0,mcc,balanced_accuracy,f1,AUPRC,precision,recall,dataset,drug
0,-0.015829,0.491842,0.424754,-1,0.371007,0.496711,B,Cefuroxime
1,0.0,0.495622,0.0,-1,0.0,0.0,B,Linezolid
2,-0.005029,0.496565,0.237443,-1,0.157385,0.483271,B,Cefepime
3,0.170542,0.643678,0.25,-1,0.153846,0.666667,B,Metronidazole
4,0.037749,0.52463,0.285192,-1,0.193069,0.545455,B,Norfloxacin
5,0.049592,0.609426,0.035503,-1,0.018182,0.75,B,Vancomycin
6,-0.009301,0.478654,0.021459,-1,0.010989,0.454545,B,Meropenem
7,0.045487,0.52277,0.529583,-1,0.549296,0.511236,B,Nitrofurantoin
8,0.028467,0.51664,0.332689,-1,0.252199,0.488636,B,Fusidic acid
9,-0.023617,0.477685,0.121127,-1,0.069579,0.467391,B,Imipenem


In [19]:

mean_df = drug_zs_results.groupby("dataset").mean()
mean_df.columns = [c+"_average" for c in mean_df.columns]
std_df = drug_zs_results.groupby("dataset").std()
std_df.columns = [c+"_std" for c in std_df.columns]
std_df
zs_joined_df = pd.merge(mean_df, std_df, left_index=True, right_index=True)
zs_joined_df.columns = [ c for c in zs_joined_df.columns]
zs_joined_df
metrics_order = ["mcc", "f1", "precision", "recall", "AUPRC", "balanced_accuracy"]
cols = []
for m in metrics_order:
    cols.append(f"{m}_average")
    cols.append(f"{m}_std")
    
zs_joined_df = zs_joined_df[cols]
zs_joined_df.insert(0, "split", "drugs_zero_shot")#reset_index()
zs_joined_df
# mean_df

  mean_df = drug_zs_results.groupby("dataset").mean()
  std_df = drug_zs_results.groupby("dataset").std()


Unnamed: 0_level_0,split,mcc_average,mcc_std,f1_average,f1_std,precision_average,precision_std,recall_average,recall_std,AUPRC_average,AUPRC_std,balanced_accuracy_average,balanced_accuracy_std
dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
B,drugs_zero_shot,0.007346,0.040097,0.204489,0.161926,0.158532,0.159286,0.459992,0.188498,-1.0,0.0,0.481913,0.115272


In [20]:
results_df = pd.concat((joined_df.reset_index(), zs_joined_df.reset_index())).set_index(["split", "dataset"])
results_df = results_df[cols]
results_df

Unnamed: 0_level_0,Unnamed: 1_level_0,mcc_average,mcc_std,f1_average,f1_std,precision_average,precision_std,recall_average,recall_std,AUPRC_average,AUPRC_std,balanced_accuracy_average,balanced_accuracy_std
split,dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
partitioned,B,0.002015,0.01042,0.288501,0.009259,0.202281,0.008179,0.503195,0.00913,-1.0,0.0,0.501239,0.006498
random,B,-0.003643,0.008983,0.246979,0.005479,0.164495,0.003387,0.495438,0.014124,-1.0,0.0,0.497552,0.006036
drugs_zero_shot,B,0.007346,0.040097,0.204489,0.161926,0.158532,0.159286,0.459992,0.188498,-1.0,0.0,0.481913,0.115272


In [21]:
results_df.to_csv("outputs/aggregate_resuts/RandomClassifier_DRIAMS-B_metrics.csv")

In [23]:
metrics = ["AUPRC", "balanced_accuracy", "mcc"]

for i, row in results_df.iterrows():
    
    print(i)
    for m in metrics:
        print("{:.2f} ({:.2f}) ".format(row[f"{m}_average"], row[f"{m}_std"]), end="")
        if m!="mcc":
            print(" & ", end="")
    print(" \\\ ")

('partitioned', 'B')
-1.00 (0.00)  & 0.50 (0.01)  & 0.00 (0.01)  \\ 
('random', 'B')
-1.00 (0.00)  & 0.50 (0.01)  & -0.00 (0.01)  \\ 
('drugs_zero_shot', 'B')
-1.00 (0.00)  & 0.48 (0.12)  & 0.01 (0.04)  \\ 
