In [15]:
import pandas as pd
import numpy as np
import os
from os.path import join, exists
import json
import matplotlib.pyplot as plt
import seaborn as sns


from sklearn.metrics import roc_curve, auc, precision_recall_curve


from sklearn.metrics import (
    matthews_corrcoef,
    balanced_accuracy_score,
    f1_score,
    average_precision_score,
)
from sklearn.metrics import precision_score, recall_score

In [16]:
# template = "embdim512_DRIAMS-{}_{}_sp{}_results"
template = "OneHot_noCNN_emb512_DRIAMS-{}_{}_sp0_results"

datasets = ["A", "B", "C", "D"]
splits = ["random", "partitioned"]


In [17]:
metrics_df = []

for dset in datasets:
    for sp in splits:
        for i in range(10):
            try:
                with open(join("outputs/Species1hot_ResAMR", template.format(dset, sp), f"test_metrics_{i}.json"), "r") as f:
                    met = json.load(f)
                met["dataset"] = dset
                met["split"] = sp
                met["seed"]=i
                # met["species_embedding_dim"]=dim
                metrics_df.append(met)
            except:
                continue
    
metrics_df = pd.DataFrame(metrics_df)
# metrics_df = metrics_df.drop(["test_loss", "seed"], axis=1)
metrics_df

Unnamed: 0,test_loss,test_mcc,test_balanced_accuracy,test_f1,test_AUPRC,test_precision,test_recall,dataset,split,seed
0,0.961193,0.078934,0.535002,0.232325,0.282550,0.281924,0.203708,A,random,0
1,1.031153,0.072346,0.529879,0.210987,0.280285,0.282727,0.173016,A,random,1
2,1.102658,0.024877,0.510615,0.178318,0.256765,0.231093,0.150080,A,random,2
3,0.906183,0.105122,0.539052,0.213301,0.332007,0.334992,0.160700,A,random,3
4,0.866487,0.119473,0.550584,0.253626,0.322033,0.329099,0.212201,A,random,5
...,...,...,...,...,...,...,...,...,...,...
65,0.389593,0.573866,0.730324,0.614031,0.716674,0.855022,0.486808,D,partitioned,2
66,0.473085,0.522854,0.707773,0.564792,0.662313,0.815090,0.441014,D,partitioned,3
67,0.597669,0.459294,0.687716,0.517882,0.608500,0.706408,0.423186,D,partitioned,4
68,0.400771,0.512843,0.709790,0.565758,0.704230,0.786510,0.452089,D,partitioned,8


In [18]:
mean_df = metrics_df.groupby(["split", "dataset"]).mean()
mean_df.columns = [c+"_average" for c in mean_df.columns]
std_df = metrics_df.groupby(["split", "dataset"]).std()
std_df.columns = [c+"_std" for c in std_df.columns]
std_df
joined_df = pd.merge(mean_df, std_df, left_index=True, right_index=True)
joined_df

Unnamed: 0_level_0,Unnamed: 1_level_0,test_loss_average,test_mcc_average,test_balanced_accuracy_average,test_f1_average,test_AUPRC_average,test_precision_average,test_recall_average,seed_average,test_loss_std,test_mcc_std,test_balanced_accuracy_std,test_f1_std,test_AUPRC_std,test_precision_std,test_recall_std,seed_std
split,dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
partitioned,A,0.828167,0.093317,0.536869,0.178718,0.310555,0.290284,0.170126,4.5,0.16321,0.032177,0.015734,0.03837,0.042423,0.051823,0.057684,3.02765
partitioned,B,0.670067,0.29947,0.620924,0.359048,0.519858,0.551868,0.309405,4.5,0.09818,0.043974,0.019189,0.042401,0.041317,0.093476,0.066055,3.02765
partitioned,C,1.065853,0.28091,0.615753,0.373486,0.479417,0.548977,0.316651,4.0,0.118947,0.079728,0.03904,0.082898,0.050999,0.071345,0.084106,4.082483
partitioned,D,0.453226,0.512144,0.718123,0.573078,0.67261,0.742672,0.483558,3.857143,0.069874,0.03754,0.019158,0.03422,0.035134,0.079351,0.050033,3.436499
random,A,0.925082,0.083447,0.534876,0.221787,0.304711,0.293768,0.184588,4.555556,0.107448,0.052472,0.021627,0.045031,0.035773,0.05775,0.041073,3.205897
random,B,1.001896,0.209033,0.591855,0.306663,0.345148,0.382658,0.275245,4.5,0.094637,0.050628,0.030528,0.053398,0.038383,0.058028,0.085227,3.02765
random,C,1.786254,0.172251,0.588042,0.3901,0.406806,0.374568,0.418455,4.5,0.526832,0.129731,0.063774,0.090441,0.10845,0.102153,0.086517,3.02765
random,D,0.407821,0.465479,0.713993,0.525074,0.567484,0.607575,0.48311,4.5,0.06467,0.078395,0.031051,0.059193,0.070702,0.122506,0.062904,3.02765


In [19]:
metrics_order = ["mcc", "f1", "precision", "recall", "AUPRC", "balanced_accuracy"]
cols = []
for m in metrics_order:
    cols.append(f"test_{m}_average")
    cols.append(f"test_{m}_std")
    
joined_df = joined_df[cols]
joined_df

Unnamed: 0_level_0,Unnamed: 1_level_0,test_mcc_average,test_mcc_std,test_f1_average,test_f1_std,test_precision_average,test_precision_std,test_recall_average,test_recall_std,test_AUPRC_average,test_AUPRC_std,test_balanced_accuracy_average,test_balanced_accuracy_std
split,dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
partitioned,A,0.093317,0.032177,0.178718,0.03837,0.290284,0.051823,0.170126,0.057684,0.310555,0.042423,0.536869,0.015734
partitioned,B,0.29947,0.043974,0.359048,0.042401,0.551868,0.093476,0.309405,0.066055,0.519858,0.041317,0.620924,0.019189
partitioned,C,0.28091,0.079728,0.373486,0.082898,0.548977,0.071345,0.316651,0.084106,0.479417,0.050999,0.615753,0.03904
partitioned,D,0.512144,0.03754,0.573078,0.03422,0.742672,0.079351,0.483558,0.050033,0.67261,0.035134,0.718123,0.019158
random,A,0.083447,0.052472,0.221787,0.045031,0.293768,0.05775,0.184588,0.041073,0.304711,0.035773,0.534876,0.021627
random,B,0.209033,0.050628,0.306663,0.053398,0.382658,0.058028,0.275245,0.085227,0.345148,0.038383,0.591855,0.030528
random,C,0.172251,0.129731,0.3901,0.090441,0.374568,0.102153,0.418455,0.086517,0.406806,0.10845,0.588042,0.063774
random,D,0.465479,0.078395,0.525074,0.059193,0.607575,0.122506,0.48311,0.062904,0.567484,0.070702,0.713993,0.031051


In [20]:
## Zero shot

In [21]:
drugs_df = pd.read_csv("../processed_data/drug_fingerprints.csv", index_col=0)
long_table = pd.read_csv("../processed_data/DRIAMS_combined_long_table.csv")
long_table = long_table[long_table["drug"].isin(drugs_df.index)]

long_table

Unnamed: 0,species,sample_id,drug,response,dataset
0,Staphylococcus epidermidis,e9adf43d-679b-497c-9849-1fa214838dd3,Meropenem,1,A
1,Staphylococcus epidermidis,e9adf43d-679b-497c-9849-1fa214838dd3,Ciprofloxacin,1,A
2,Staphylococcus epidermidis,e9adf43d-679b-497c-9849-1fa214838dd3,Cefepime,1,A
3,Staphylococcus epidermidis,e9adf43d-679b-497c-9849-1fa214838dd3,Cotrimoxazole,0,A
4,Staphylococcus epidermidis,e9adf43d-679b-497c-9849-1fa214838dd3,Imipenem,1,A
...,...,...,...,...,...
652766,Staphylococcus aureus,08bc8410-51ec-46d7-ac7b-afba9e6ba2cd_3313,Linezolid,0,D
652767,Staphylococcus aureus,08bc8410-51ec-46d7-ac7b-afba9e6ba2cd_3313,Rifampicin,0,D
652768,Staphylococcus aureus,08bc8410-51ec-46d7-ac7b-afba9e6ba2cd_3313,Tetracycline,0,D
652769,Staphylococcus aureus,08bc8410-51ec-46d7-ac7b-afba9e6ba2cd_3313,Tigecycline,0,D


In [22]:
# folder_template = "outputs/ResAMR_DrugZeroShot/ZS_noCNN_emb512_DRIAMS-{}_drugs_zero_shot_sp0"
drug_zs_results = []
threshold = 0.5
folder_template = "outputs/Species1hot_ResAMR/OneHot_noCNN_emb512_DRIAMS-{}_drugs_zero_shot_sp0_results"

for dset in datasets:
    folder = folder_template.format(dset)
    
    


    for fname in os.listdir(folder):
        if not fname.startswith("test_set"):
            continue
            
        test_set = pd.read_csv(join(folder, fname))
        drug_name = test_set["drug"].iloc[0]
    
        response_classes = test_set["response"].values
        predictions = test_set["Predictions"].values

        predicted_classes = (predictions>=threshold).astype(int)

        metrics =  {
                "mcc": matthews_corrcoef(response_classes, predicted_classes),
                "balanced_accuracy": balanced_accuracy_score(
                    response_classes, predicted_classes
                ),
                "f1": f1_score(response_classes, predicted_classes, zero_division=0),
                "AUPRC": average_precision_score(
                    response_classes, predictions
                ),
                "precision": precision_score(
                    response_classes, predicted_classes, zero_division=0
                ),
                "recall": recall_score(
                    response_classes, predicted_classes, zero_division=0
                ),
            }


    
        metrics["dataset"] = dset
        metrics["drug"]=drug_name
        drug_zs_results.append(metrics)







drug_zs_results = pd.DataFrame(drug_zs_results)
# drug_zs_results.columns = ["dataset", "drug", "test_AUPRC", "n_test_samples", "n_resistant_test_samples"]
# drug_zs_results = drug_zs_results.drop("drug", axis=1)

drug_zs_results



Unnamed: 0,mcc,balanced_accuracy,f1,AUPRC,precision,recall,dataset,drug
0,-0.008985,0.499185,0.000000,0.047321,0.000000,0.000000,A,Rifampicin
1,-0.203020,0.396819,0.011034,0.136166,0.010809,0.011269,A,Meropenem
2,0.100393,0.557136,0.292264,0.227182,0.247927,0.355911,A,Cotrimoxazole
3,0.101753,0.513158,0.051282,0.841224,1.000000,0.026316,A,Ceftobiprole
4,0.131786,0.568011,0.449844,0.434846,0.314044,0.792574,A,Ceftriaxone
...,...,...,...,...,...,...,...,...
107,0.049299,0.509132,0.234875,0.133065,0.133065,1.000000,D,Piperacillin
108,0.000000,0.500000,0.000000,0.015711,0.000000,0.000000,D,Ertapenem
109,0.000000,0.500000,0.074074,0.038462,0.038462,1.000000,D,Chloramphenicol
110,-0.001591,0.499868,0.003534,0.254750,0.071429,0.001812,D,Imipenem


In [31]:
# folder_template = "outputs/ResAMR_DrugZeroShot/ZS_noCNN_emb512_DRIAMS-{}_drugs_zero_shot_sp0"
drug_zs_results = []
threshold = 0.5
folder_template = "outputs/Species1hot_ResAMR/OneHot_noCNN_emb512_DRIAMS-{}_drugs_zero_shot_sp0_results"

for dset in datasets:
    folder = folder_template.format(dset)
    
    


    for fname in os.listdir(folder):
        if not fname.startswith("test_metrics"):
            continue
        
#         test_set = pd.read_csv(join(folder, fname))
#         drug_name = test_set["drug"].iloc[0]
    
#         response_classes = test_set["response"].values
#         predictions = test_set["Predictions"].values

#         predicted_classes = (predictions>=threshold).astype(int)

#         metrics =  {
#                 "mcc": matthews_corrcoef(response_classes, predicted_classes),
#                 "balanced_accuracy": balanced_accuracy_score(
#                     response_classes, predicted_classes
#                 ),
#                 "f1": f1_score(response_classes, predicted_classes, zero_division=0),
#                 "AUPRC": average_precision_score(
#                     response_classes, predictions
#                 ),
#                 "precision": precision_score(
#                     response_classes, predicted_classes, zero_division=0
#                 ),
#                 "recall": recall_score(
#                     response_classes, predicted_classes, zero_division=0
#                 ),
#             }

        with open(join(folder, fname), "r") as f:
            metrics = json.load(f)
    
        metrics["dataset"] = dset
        metrics["drug"]=drug_name
        drug_zs_results.append(metrics)







drug_zs_results = pd.DataFrame(drug_zs_results)
# drug_zs_results.columns = ["dataset", "drug", "test_AUPRC", "n_test_samples", "n_resistant_test_samples"]
drug_zs_results = drug_zs_results.drop("test_loss", axis=1)

drug_zs_results

Unnamed: 0,test_mcc,test_balanced_accuracy,test_f1,test_AUPRC,test_precision,test_recall,dataset,drug
0,0.000000,0.500000,0.000000,0.221448,0.000000,0.000000,A,Erythromycin
1,-0.142324,0.470088,0.015035,0.364748,0.066748,0.008613,A,Erythromycin
2,0.125244,0.532975,0.186047,0.226295,0.125000,0.363636,A,Erythromycin
3,-0.027135,0.495146,0.017221,0.324175,0.107054,0.009698,A,Erythromycin
4,-0.093137,0.453431,0.135135,0.385308,0.073529,0.833333,A,Erythromycin
...,...,...,...,...,...,...,...,...
155,0.000000,0.500000,0.824882,0.703801,0.702003,1.000000,D,Erythromycin
156,0.000000,0.000000,0.000000,1.000000,0.000000,0.000000,D,Erythromycin
157,0.208734,0.625891,0.829276,0.870760,0.882767,0.784177,D,Erythromycin
158,0.000000,0.500000,0.074074,0.038462,0.038462,1.000000,D,Erythromycin


In [32]:
metrics

{'test_loss': 0.5384594202041626,
 'test_mcc': -0.03365667195815865,
 'test_balanced_accuracy': 0.49492063492063487,
 'test_f1': 0.05555555555555556,
 'test_AUPRC': 0.26177048344231,
 'test_precision': 0.125,
 'test_recall': 0.03571428571428571,
 'dataset': 'D',
 'drug': 'Erythromycin'}

In [33]:
test_set

Unnamed: 0,species,sample_id,drug,response,dataset,Predictions
0,Staphylococcus epidermidis,1fe16795-6257-470e-bc6d-58e952e72f00_3312,Erythromycin,0,D,0.019228
1,Staphylococcus epidermidis,e1e3ba11-532b-40fd-90dc-edec6590f5c2_3312,Erythromycin,1,D,0.019228
2,Staphylococcus aureus,7e0c8c2c-02d3-4f29-8485-e767dacfa506_3313,Erythromycin,0,D,0.111055
3,Staphylococcus aureus,3dc1659b-1586-412f-952b-d64982cff915_3312,Erythromycin,0,D,0.111055
4,Staphylococcus lugdunensis,c552a40d-2ffe-47ca-a989-63b035008fdc_3313,Erythromycin,0,D,0.020281
...,...,...,...,...,...,...
2426,Staphylococcus aureus,a28bd5c4-a668-4228-85f6-aa7a0efa037d_3313,Erythromycin,0,D,0.111055
2427,Staphylococcus aureus,4267daeb-4f1a-4597-a801-f3c402f86c27_3312,Erythromycin,0,D,0.111055
2428,Staphylococcus epidermidis,006d0547-154d-4f80-a51f-52333bc18e2f_3313,Erythromycin,0,D,0.019228
2429,Staphylococcus epidermidis,9708ff16-8025-4308-8cee-1b9b07c0e689_3312,Erythromycin,1,D,0.019228


In [41]:

mean_df = drug_zs_results.groupby("dataset").mean()
mean_df.columns = [c+"_average" for c in mean_df.columns]
std_df = drug_zs_results.groupby("dataset").std()
std_df.columns = [c+"_std" for c in std_df.columns]
std_df
zs_joined_df = pd.merge(mean_df, std_df, left_index=True, right_index=True)
# zs_joined_df.columns = ["test_" + c for c in zs_joined_df.columns]
zs_joined_df
metrics_order = ["mcc", "f1", "precision", "recall", "AUPRC", "balanced_accuracy"]
cols = []
for m in metrics_order:
    cols.append(f"test_{m}_average")
    cols.append(f"test_{m}_std")
    
zs_joined_df = zs_joined_df[cols]
zs_joined_df.insert(0, "split", "drugs_zero_shot")#reset_index()
zs_joined_df
# mean_df

  mean_df = drug_zs_results.groupby("dataset").mean()
  std_df = drug_zs_results.groupby("dataset").std()


Unnamed: 0_level_0,split,test_mcc_average,test_mcc_std,test_f1_average,test_f1_std,test_precision_average,test_precision_std,test_recall_average,test_recall_std,test_AUPRC_average,test_AUPRC_std,test_balanced_accuracy_average,test_balanced_accuracy_std
dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
A,drugs_zero_shot,0.026418,0.211242,0.151581,0.212151,0.217694,0.277029,0.214619,0.29094,0.309393,0.243217,0.544162,0.157115
B,drugs_zero_shot,-0.022596,0.159399,0.108499,0.156309,0.126306,0.176441,0.150947,0.206316,0.184429,0.170173,0.552898,0.15736
C,drugs_zero_shot,-0.048951,0.254314,0.116366,0.243599,0.137968,0.26711,0.190229,0.314043,0.226284,0.27162,0.486109,0.250772
D,drugs_zero_shot,0.009202,0.050046,0.092286,0.238551,0.086711,0.230604,0.153363,0.338558,0.19276,0.273279,0.572428,0.185996


In [40]:
std_df

Unnamed: 0_level_0,test_mcc_std,test_balanced_accuracy_std,test_f1_std,test_AUPRC_std,test_precision_std,test_recall_std
dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
A,0.211242,0.157115,0.212151,0.243217,0.277029,0.29094
B,0.159399,0.15736,0.156309,0.170173,0.176441,0.206316
C,0.254314,0.250772,0.243599,0.27162,0.26711,0.314043
D,0.050046,0.185996,0.238551,0.273279,0.230604,0.338558


In [42]:
results_df = pd.concat((joined_df.reset_index(), zs_joined_df.reset_index())).set_index(["split", "dataset"])
results_df = results_df[cols]
results_df

Unnamed: 0_level_0,Unnamed: 1_level_0,test_mcc_average,test_mcc_std,test_f1_average,test_f1_std,test_precision_average,test_precision_std,test_recall_average,test_recall_std,test_AUPRC_average,test_AUPRC_std,test_balanced_accuracy_average,test_balanced_accuracy_std
split,dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
partitioned,A,0.093317,0.032177,0.178718,0.03837,0.290284,0.051823,0.170126,0.057684,0.310555,0.042423,0.536869,0.015734
partitioned,B,0.29947,0.043974,0.359048,0.042401,0.551868,0.093476,0.309405,0.066055,0.519858,0.041317,0.620924,0.019189
partitioned,C,0.28091,0.079728,0.373486,0.082898,0.548977,0.071345,0.316651,0.084106,0.479417,0.050999,0.615753,0.03904
partitioned,D,0.512144,0.03754,0.573078,0.03422,0.742672,0.079351,0.483558,0.050033,0.67261,0.035134,0.718123,0.019158
random,A,0.083447,0.052472,0.221787,0.045031,0.293768,0.05775,0.184588,0.041073,0.304711,0.035773,0.534876,0.021627
random,B,0.209033,0.050628,0.306663,0.053398,0.382658,0.058028,0.275245,0.085227,0.345148,0.038383,0.591855,0.030528
random,C,0.172251,0.129731,0.3901,0.090441,0.374568,0.102153,0.418455,0.086517,0.406806,0.10845,0.588042,0.063774
random,D,0.465479,0.078395,0.525074,0.059193,0.607575,0.122506,0.48311,0.062904,0.567484,0.070702,0.713993,0.031051
drugs_zero_shot,A,0.026418,0.211242,0.151581,0.212151,0.217694,0.277029,0.214619,0.29094,0.309393,0.243217,0.544162,0.157115
drugs_zero_shot,B,-0.022596,0.159399,0.108499,0.156309,0.126306,0.176441,0.150947,0.206316,0.184429,0.170173,0.552898,0.15736


In [13]:
results_df.to_csv("outputs/aggregate_resuts/Species1Hot_ResMLP_metrics.csv")

In [14]:
metrics = ["AUPRC", "balanced_accuracy", "mcc"]

for i, row in results_df.iterrows():
    
    print(i)
    for m in metrics:
        print("{:.2f} ({:.2f}) ".format(row[f"test_{m}_average"], row[f"test_{m}_std"]), end="")
        if m!="mcc":
            print(" & ", end="")
    print(" \\\ ")

('partitioned', 'A')
0.31 (0.04)  & 0.54 (0.02)  & 0.09 (0.03)  \\ 
('partitioned', 'B')
0.52 (0.04)  & 0.62 (0.02)  & 0.30 (0.04)  \\ 
('partitioned', 'C')
0.48 (0.05)  & 0.62 (0.04)  & 0.28 (0.08)  \\ 
('partitioned', 'D')
0.67 (0.04)  & 0.72 (0.02)  & 0.51 (0.04)  \\ 
('random', 'A')
0.30 (0.04)  & 0.53 (0.02)  & 0.08 (0.05)  \\ 
('random', 'B')
0.35 (0.04)  & 0.59 (0.03)  & 0.21 (0.05)  \\ 
('random', 'C')
0.41 (0.11)  & 0.59 (0.06)  & 0.17 (0.13)  \\ 
('random', 'D')
0.57 (0.07)  & 0.71 (0.03)  & 0.47 (0.08)  \\ 
('drugs_zero_shot', 'A')
0.34 (0.32)  & 0.51 (0.07)  & 0.02 (0.14)  \\ 
('drugs_zero_shot', 'B')
0.17 (0.16)  & 0.50 (0.12)  & -0.01 (0.17)  \\ 
('drugs_zero_shot', 'C')
0.22 (0.27)  & 0.47 (0.23)  & -0.05 (0.26)  \\ 
('drugs_zero_shot', 'D')
0.18 (0.28)  & 0.52 (0.16)  & 0.01 (0.05)  \\ 
