In [42]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os

import json

from tqdm.notebook import tqdm
from sklearn.metrics import roc_curve, auc, precision_recall_curve


from sklearn.metrics import (
    matthews_corrcoef,
    balanced_accuracy_score,
    f1_score,
    average_precision_score,
)
from sklearn.metrics import precision_score, recall_score

In [32]:
drugs_df = pd.read_csv("../processed_data/drug_fingerprints.csv", index_col=0)
drugs_df

Unnamed: 0_level_0,MACCS_fp,morgan_512_fp,morgan_1024_fp,pubchem_fp
drug,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
5-Fluorocytosine,0000000000000000000000000000000000000110001100...,0000000000000000000000000000000000000000000000...,0000000000000000000000000000000000000000000000...,1101110001100000000110001100100001000000000000...
Amikacin,0000000000000000000000000000000000000000000000...,0100000000000100000001000000001001010000000000...,0100000000000100000000000000001000010000000000...,1101110001111100000111101110111100000000000000...
Amoxicillin,0000000010010000000100000000000000001000000000...,0100010000000000001000000000000001001000000000...,0100010000000000001000000000000001001000000000...,1101110001111000000111101100111000000000000100...
Amphotericin B,0000000000000000000000000000000000000000000000...,0100000000000000000001000010000001011011100000...,0100000000000000000000000010000001011000100000...,1101110001111100000111111000111110000000000000...
Ampicillin,0000000010010000000100000000000000001000000000...,0100010000000000001000000000000001001000000001...,0100010000000000001000000000000001001000000000...,1101110001111000000111101100111000000000000100...
...,...,...,...,...
Ticarcillin,0000000010010000000100000000000000001000000000...,0100010000000010001000000000000001001000100000...,0100010000000010001000000000000001001000000000...,1101110001111000000111001100111000000000000110...
Tigecycline,0000000000000000000000000010000000000000000000...,0000000100000100000000000000010001001100011001...,0000000000000000000000000000000001001100010000...,1101110001111100000111101110111100000000000000...
Tobramycin,0000000000000000000000000000000000000000000000...,0000000000010000000001000000000000010000000000...,0000000000010000000000000000000000010000000000...,1101110001111100000111101110111100000000000000...
Vancomycin,0000000000000000000000000000000000000000000000...,0101000000111000000010000000000001001001000010...,0101000000110000000000000000000001001000000010...,1101110001111100000111111111111110000000000000...


In [33]:
long_table = pd.read_csv("../processed_data/DRIAMS_combined_long_table.csv")
long_table = long_table[long_table["drug"].isin(drugs_df.index)]

long_table

Unnamed: 0,species,sample_id,drug,response,dataset
0,Staphylococcus epidermidis,e9adf43d-679b-497c-9849-1fa214838dd3,Meropenem,1,A
1,Staphylococcus epidermidis,e9adf43d-679b-497c-9849-1fa214838dd3,Ciprofloxacin,1,A
2,Staphylococcus epidermidis,e9adf43d-679b-497c-9849-1fa214838dd3,Cefepime,1,A
3,Staphylococcus epidermidis,e9adf43d-679b-497c-9849-1fa214838dd3,Cotrimoxazole,0,A
4,Staphylococcus epidermidis,e9adf43d-679b-497c-9849-1fa214838dd3,Imipenem,1,A
...,...,...,...,...,...
652766,Staphylococcus aureus,08bc8410-51ec-46d7-ac7b-afba9e6ba2cd_3313,Linezolid,0,D
652767,Staphylococcus aureus,08bc8410-51ec-46d7-ac7b-afba9e6ba2cd_3313,Rifampicin,0,D
652768,Staphylococcus aureus,08bc8410-51ec-46d7-ac7b-afba9e6ba2cd_3313,Tetracycline,0,D
652769,Staphylococcus aureus,08bc8410-51ec-46d7-ac7b-afba9e6ba2cd_3313,Tigecycline,0,D


In [43]:
# folder_template = "outputs/ResAMR_DrugZeroShot/ZS_noCNN_emb512_DRIAMS-{}_drugs_zero_shot_sp0"
drug_zs_results = []
threshold = 0.5
folder = "outputs/ResAMR_DrugZeroShot/ZS_noCNN_emb512_results"

for fname in tqdm(os.listdir(folder)):
    if not fname.startswith("DRIAMS"):
        continue
        
    drug_name = fname.split("_")[1]
    dset = fname.split("_")[0].split("-")[1]
    
    dataset_long_table = long_table[long_table["dataset"]==dset]
    
    test_set = pd.read_csv(os.path.join(folder, fname))
    
    trainval_data = dataset_long_table[~dataset_long_table[["sample_id", "drug"]].apply(tuple,1).isin(test_set_comb)]
    
    species_majority_response = {}
    for sp in trainval_data["species"].unique():
        avg_species_response = trainval_data[trainval_data["species"]==sp]["response"].mean()
        species_majority_response[sp] = 0 if avg_species_response<0.5 else 1
        
        
    # break
    response_classes = test_set["response"]
    predicted_classes = []
    for i, row in test_set.iterrows():
        predicted_classes.append(species_majority_response[row.species])
    
#     response_classes = test_set["response"].values
# #     predictions = test_set["Predictions"].values
    
#     predicted_classes = (predictions>=threshold).astype(int)
    
    metrics =  {
            "mcc": matthews_corrcoef(response_classes, predicted_classes),
            "balanced_accuracy": balanced_accuracy_score(
                response_classes, predicted_classes
            ),
            "f1": f1_score(response_classes, predicted_classes, zero_division=0),
            "AUPRC": -1,
            "precision": precision_score(
                response_classes, predicted_classes, zero_division=0
            ),
            "recall": recall_score(
                response_classes, predicted_classes, zero_division=0
            ),
        }

    
    metrics["dataset"] = dset
    metrics["drug"]=drug_name
    drug_zs_results.append(metrics)




drug_zs_results = pd.DataFrame(drug_zs_results)
# drug_zs_results.columns = ["dataset", "drug", "test_AUPRC", "n_test_samples", "n_resistant_test_samples"]
drug_zs_results = drug_zs_results.drop("drug", axis=1)

drug_zs_results

  0%|          | 0/163 [00:00<?, ?it/s]



Unnamed: 0,mcc,balanced_accuracy,f1,AUPRC,precision,recall,dataset
0,0.389526,0.576923,0.266667,-1,1.000000,0.153846,B
1,0.000000,0.500000,0.000000,-1,0.000000,0.000000,D
2,-0.018399,0.497774,0.000000,-1,0.000000,0.000000,B
3,0.057497,0.503501,0.013908,-1,1.000000,0.007003,B
4,0.000000,0.500000,0.000000,-1,0.000000,0.000000,B
...,...,...,...,...,...,...,...
158,0.000000,0.500000,0.000000,-1,0.000000,0.000000,D
159,0.026597,0.504817,0.055322,-1,0.808571,0.028641,A
160,0.000000,1.000000,0.000000,-1,0.000000,0.000000,D
161,0.048850,0.501988,0.007921,-1,1.000000,0.003976,C


In [47]:
species_majority_response

{'Staphylococcus epidermidis': 0,
 'Enterococcus faecalis': 0,
 'Enterococcus faecium': 0,
 'Klebsiella oxytoca': 0,
 'Pseudomonas aeruginosa': 0,
 'Streptococcus oralis': 0,
 'Streptococcus equinus': 0,
 'Escherichia coli': 0,
 'Proteus mirabilis': 0,
 'Enterobacter asburiae': 0,
 'Streptococcus agalactiae': 0,
 'Enterobacter ludwigii': 0,
 'Staphylococcus hominis': 0,
 'Lactobacillus rhamnosus': 0,
 'Enterobacter cloacae': 0,
 'Serratia marcescens': 0,
 'Klebsiella pneumoniae': 0,
 'Staphylococcus saprophyticus': 0,
 'Stenotrophomonas maltophilia': 0,
 'Staphylococcus aureus': 0,
 'Candida glabrata': 0,
 'Citrobacter koseri': 0,
 'Enterococcus dispar': 0,
 'Burkholderia ambifaria': 0,
 'Staphylococcus caprae': 0,
 'Campylobacter jejuni': 0,
 'Staphylococcus capitis': 0,
 'Neisseria gonorrhoeae': 0,
 'Haemophilus parainfluenzae': 0,
 'Propionibacterium acnes': 0,
 'Haemophilus influenzae': 0,
 'Aerococcus sanguinicola': 0,
 'Staphylococcus pettenkoferi': 0,
 'Clostridium difficile': 0

In [46]:
drug_zs_results.groupby("dataset").std()

Unnamed: 0_level_0,mcc,balanced_accuracy,f1,AUPRC,precision,recall
dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
A,0.126789,0.075894,0.111406,0.0,0.396952,0.080664
B,0.094867,0.120369,0.057827,0.0,0.489296,0.032855
C,0.277865,0.240346,0.306975,0.0,0.428847,0.326633
D,0.0,0.154731,0.0,0.0,0.0,0.0


In [22]:
test_set_comb = set(map(tuple, test_set[["sample_id", "drug"]].values.tolist()))

In [24]:
dataset_long_table[~dataset_long_table[["sample_id", "drug"]].apply(tuple,1).isin(test_set_comb)]

Unnamed: 0,species,sample_id,drug,response,dataset
474238,Klebsiella pneumoniae,ca568529-351a-43af-8cec-7175488f66ea,Amikacin,0,B
474239,Klebsiella pneumoniae,ca568529-351a-43af-8cec-7175488f66ea,Ampicillin,1,B
474240,Klebsiella pneumoniae,ca568529-351a-43af-8cec-7175488f66ea,Cefepime,0,B
474241,Klebsiella pneumoniae,ca568529-351a-43af-8cec-7175488f66ea,Cefoxitin,0,B
474242,Klebsiella pneumoniae,ca568529-351a-43af-8cec-7175488f66ea,Ceftazidime,0,B
...,...,...,...,...,...
506609,Staphylococcus aureus,57043b13-3ba8-4f30-83ac-2416c23cec3a,Oxacillin,0,B
506611,Staphylococcus aureus,57043b13-3ba8-4f30-83ac-2416c23cec3a,Teicoplanin,0,B
506612,Staphylococcus aureus,57043b13-3ba8-4f30-83ac-2416c23cec3a,Tetracycline,0,B
506613,Staphylococcus aureus,57043b13-3ba8-4f30-83ac-2416c23cec3a,Tigecycline,0,B


In [25]:
dataset_long_table

Unnamed: 0,species,sample_id,drug,response,dataset
474238,Klebsiella pneumoniae,ca568529-351a-43af-8cec-7175488f66ea,Amikacin,0,B
474239,Klebsiella pneumoniae,ca568529-351a-43af-8cec-7175488f66ea,Ampicillin,1,B
474240,Klebsiella pneumoniae,ca568529-351a-43af-8cec-7175488f66ea,Cefepime,0,B
474241,Klebsiella pneumoniae,ca568529-351a-43af-8cec-7175488f66ea,Cefoxitin,0,B
474242,Klebsiella pneumoniae,ca568529-351a-43af-8cec-7175488f66ea,Ceftazidime,0,B
...,...,...,...,...,...
506610,Staphylococcus aureus,57043b13-3ba8-4f30-83ac-2416c23cec3a,Rifampicin,0,B
506611,Staphylococcus aureus,57043b13-3ba8-4f30-83ac-2416c23cec3a,Teicoplanin,0,B
506612,Staphylococcus aureus,57043b13-3ba8-4f30-83ac-2416c23cec3a,Tetracycline,0,B
506613,Staphylococcus aureus,57043b13-3ba8-4f30-83ac-2416c23cec3a,Tigecycline,0,B


In [26]:
32377-31575

802