In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from tqdm.notebook import tqdm
from os.path import join


from joblib import Parallel, delayed

In [2]:
datasets = ["A", "B", "C", "D"]
cutoffs = list(range(1, 6))
folder_template = "outputs/ResAMR_Recommender/REC_noCNN_emb512_DRIAMS-{}_rec_sp0_recommendations"

In [3]:
def precision_at_K(df, K=1, target_response=0):
    if K>len(df):
        return np.nan
    
    n_test = 100# (df["response"]==target_response).sum()
    if target_response==0:
        responses = df["response"].iloc[:K]
    elif target_response==1:
        responses = df["response"].iloc[-K:]
        responses = responses[::-1]
    # print(responses)
    return np.sum(responses==target_response)/np.min([n_test, K]) 

In [4]:
def parallel_process_sample(sid, df):
    sample_sensitive_results = {"sample_id": sid, "seed": seed, "dataset": dset, "target": "sensitive"}
    sample_resistant_results = {"sample_id": sid, "seed": seed, "dataset": dset, "target": "resistant"}
    sample_df = df[df["sample_id"]==sid]
    sample_df = sample_df[sample_df["response"]!=-1]

    for co in cutoffs:
        sample_sensitive_results["cutoff_at_{}".format(co)] = precision_at_K(sample_df, K=co, target_response=0)
        sample_resistant_results["cutoff_at_{}".format(co)] = precision_at_K(sample_df, K=co, target_response=1)
    return sample_sensitive_results, sample_resistant_results

In [5]:
sensitive_drugs_pak = []

resistant_drugs_pak = []

for dset in datasets:
    folder = folder_template.format(dset)
    files = os.listdir(folder)
    for fname in tqdm(files):
        if not "prediction_set" in fname:
            continue
        seed = int(fname.split("seed")[-1].split(".")[0])
        df = pd.read_csv(join(folder, fname))
        df = df[df["response"]!=-1]
        df = df.sort_values(by="Predictions", ascending=True)
        samples_set = df["sample_id"].unique()
        
        # results = Parallel(n_jobs=3)(delayed(parallel_process_sample)(s, df) for s in tqdm(samples_set))
        for sid in tqdm(samples_set):
            sample_sensitive_results = {"sample_id": sid, "seed": seed, "dataset": dset, "target": "sensitive"}
            sample_resistant_results = {"sample_id": sid, "seed": seed, "dataset": dset, "target": "resistant"}
            sample_df = df[df["sample_id"]==sid]
            # sample_df = sample_df[sample_df["response"]!=-1]

            for co in cutoffs:
                sample_sensitive_results["cutoff_at_{}".format(co)] = precision_at_K(sample_df, K=co, target_response=0)
                sample_resistant_results["cutoff_at_{}".format(co)] = precision_at_K(sample_df, K=co, target_response=1)  
            sensitive_drugs_pak.append(sample_sensitive_results)
            resistant_drugs_pak.append(sample_resistant_results)
    
sensitive_drugs_pak = pd.DataFrame(sensitive_drugs_pak)
resistant_drugs_pak = pd.DataFrame(resistant_drugs_pak)


  0%|          | 0/22 [00:00<?, ?it/s]

  0%|          | 0/7662 [00:00<?, ?it/s]


KeyboardInterrupt



In [54]:
sensitive_drugs_pak.to_csv("outputs/RecommendationMetrics/ResAMR_recommendations_sensitive_precision.csv")

In [55]:
resistant_drugs_pak.to_csv("outputs/RecommendationMetrics/ResAMR_recommendations_resistant_precision.csv")

In [14]:
samples_set = df["sample_id"].unique()
for sid in samples_set:
    sample_df = df[df["sample_id"]==sid]
    break
sample_df

Unnamed: 0,species,sample_id,drug,response,dataset,Predictions
83623,Enterococcus casseliflavus,bcbd9c5b-12ba-4d01-b597-02d53c4dbe75_MALDI1,Isavuconazole,-1,A,2.28751e-25
83618,Enterococcus casseliflavus,bcbd9c5b-12ba-4d01-b597-02d53c4dbe75_MALDI1,Fluconazole,-1,A,3.378028e-24
83643,Enterococcus casseliflavus,bcbd9c5b-12ba-4d01-b597-02d53c4dbe75_MALDI1,Tobramycin,-1,A,7.691784e-24
83593,Enterococcus casseliflavus,bcbd9c5b-12ba-4d01-b597-02d53c4dbe75_MALDI1,Amikacin,-1,A,1.056013e-23
83633,Enterococcus casseliflavus,bcbd9c5b-12ba-4d01-b597-02d53c4dbe75_MALDI1,Nitrofurantoin,0,A,3.5576080000000005e-23
83627,Enterococcus casseliflavus,bcbd9c5b-12ba-4d01-b597-02d53c4dbe75_MALDI1,Meropenem,-1,A,1.953474e-22
83635,Enterococcus casseliflavus,bcbd9c5b-12ba-4d01-b597-02d53c4dbe75_MALDI1,Oxacillin,-1,A,4.281568e-22
83592,Enterococcus casseliflavus,bcbd9c5b-12ba-4d01-b597-02d53c4dbe75_MALDI1,5-Fluorocytosine,-1,A,1.739678e-21
83639,Enterococcus casseliflavus,bcbd9c5b-12ba-4d01-b597-02d53c4dbe75_MALDI1,Rifampicin,-1,A,1.911078e-21
83606,Enterococcus casseliflavus,bcbd9c5b-12ba-4d01-b597-02d53c4dbe75_MALDI1,Ceftriaxone,-1,A,8.579476e-21


In [None]:
pos_ranks = []
neg_ranks = []


In [7]:
sensitive_drugs_pak = pd.read_csv("outputs/RecommendationMetrics/ResAMR_recommendations_resistant_precision.csv", index_col=0)
sensitive_drugs_pak

Unnamed: 0,sample_id,seed,dataset,target,cutoff_at_1,cutoff_at_2,cutoff_at_3,cutoff_at_4,cutoff_at_5
0,bcbd9c5b-12ba-4d01-b597-02d53c4dbe75_MALDI1,17,A,resistant,1.0,0.5,0.333333,0.25,0.2
1,cce1168a-242f-4ff6-9b84-0fc9189c99f6_MALDI2,17,A,resistant,1.0,0.5,0.333333,0.25,0.4
2,1fbd05cf-c813-46c4-9af4-a08fe91a9fb9_MALDI2,17,A,resistant,1.0,0.5,0.333333,0.25,0.2
3,1e919d81-f6b4-49ab-b332-d71a7b7a8081_MALDI2,17,A,resistant,1.0,1.0,0.666667,0.50,0.4
4,74e6c209-9f16-4473-95a7-9002c4b08181_MALDI1,17,A,resistant,0.0,0.0,0.000000,0.00,0.2
...,...,...,...,...,...,...,...,...,...
222975,3fc51645-abbf-4bc8-b328-8d828ae68fad_3313,11,D,resistant,1.0,0.5,0.666667,0.50,0.4
222976,9ac55e54-1734-48f9-9431-15edf9200d1f_3312,11,D,resistant,0.0,0.5,0.333333,0.25,0.2
222977,60e65a02-3733-4ce4-ae8f-6161d4344acf_3312,11,D,resistant,1.0,1.0,1.000000,0.75,0.6
222978,108023e5-6499-44fd-a507-8331b54d51da_3313,11,D,resistant,1.0,1.0,0.666667,0.50,0.4
