In [1]:
import os
import glob
import pandas as pd
from tqdm import tqdm

In [2]:
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

### Read Files

In [3]:
app_df = pd.read_csv("../mimic-iv-3.1/app_df_2025_03_22.csv.gz", compression="gzip")

In [10]:
import pickle 

# Load the pickle file
with open("../mimic-iv-3.1/rf_models_25_03_22.pkl", "rb") as f:
    trained_models = pickle.load(f)

In [20]:
# Drop some columns that aren't modeled
model_df = app_df.drop(columns=['Unnamed: 0','charttime'])

In [7]:
# set final bacteria-antibiotic pairs 
pairs = [('ESCHERICHIA COLI', 'AMPICILLIN'),
 ('ESCHERICHIA COLI', 'AMPICILLIN/SULBACTAM'),
 ('ESCHERICHIA COLI', 'CEFAZOLIN'),
 ('ESCHERICHIA COLI', 'CEFEPIME'),
 ('ESCHERICHIA COLI', 'CEFTAZIDIME'),
 ('ESCHERICHIA COLI', 'CEFTRIAXONE'),
 ('ESCHERICHIA COLI', 'CIPROFLOXACIN'),
 ('ESCHERICHIA COLI', 'GENTAMICIN'),
 ('ESCHERICHIA COLI', 'MEROPENEM'),
 ('ESCHERICHIA COLI', 'NITROFURANTOIN'),
 ('ESCHERICHIA COLI', 'TOBRAMYCIN'),
 ('ESCHERICHIA COLI', 'TRIMETHOPRIM/SULFA'),
 ('ESCHERICHIA COLI', 'PIPERACILLIN/TAZO'),
 ('PSEUDOMONAS AERUGINOSA', 'CEFEPIME'),
 ('PSEUDOMONAS AERUGINOSA', 'CEFTAZIDIME'),
 ('PSEUDOMONAS AERUGINOSA', 'CIPROFLOXACIN'),
 ('PSEUDOMONAS AERUGINOSA', 'GENTAMICIN'),
 ('PSEUDOMONAS AERUGINOSA', 'MEROPENEM'),
 ('PSEUDOMONAS AERUGINOSA', 'PIPERACILLIN/TAZO'),
 ('PSEUDOMONAS AERUGINOSA', 'TOBRAMYCIN'),
 ('KLEBSIELLA PNEUMONIAE', 'AMPICILLIN/SULBACTAM'),
 ('KLEBSIELLA PNEUMONIAE', 'CEFAZOLIN'),
 ('KLEBSIELLA PNEUMONIAE', 'CEFEPIME'),
 ('KLEBSIELLA PNEUMONIAE', 'CEFTAZIDIME'),
 ('KLEBSIELLA PNEUMONIAE', 'CEFTRIAXONE'),
 ('KLEBSIELLA PNEUMONIAE', 'CIPROFLOXACIN'),
 ('KLEBSIELLA PNEUMONIAE', 'GENTAMICIN'),
 ('KLEBSIELLA PNEUMONIAE', 'MEROPENEM'),
 ('KLEBSIELLA PNEUMONIAE', 'NITROFURANTOIN'),
 ('KLEBSIELLA PNEUMONIAE', 'PIPERACILLIN/TAZO'),
 ('KLEBSIELLA PNEUMONIAE', 'TOBRAMYCIN'),
 ('KLEBSIELLA PNEUMONIAE', 'TRIMETHOPRIM/SULFA')]

### Run Models and get Probabilities

In [24]:
# Select a few subject_ids
model_df['subject_id'].head()

0    10002013.0
1    10002557.0
2    10002557.0
3    10003400.0
4    10003400.0
Name: subject_id, dtype: float64

In [25]:
subj1 = 10002013.0
subj2 = 10002557.0
subj3 = 10003400.0

In [26]:
def predict_antibiotic_probs(subject_id, bacteria_name, app_df, trained_models, pairs):
    """
    Predicts antibiotic resistance probabilities for a given subject and bacteria.
    
    Parameters:
    - subject_id (int): ID of the subject in app_df
    - bacteria_name (str): Name of the bacteria (e.g., "ESCHERICHIA COLI")
    - app_df (pd.DataFrame): DataFrame containing subject data
    - trained_models (dict): Dictionary of trained models with keys like "rf_BACTERIA_ANTIBIOTIC"
    - pairs (list): List of (bacteria, antibiotic) pairs
    
    Returns:
    - dict: {antibiotic: probability_of_resistance}
    """
    # Filter the relevant antibiotics for this bacteria
    relevant_pairs = [antibiotic for bact, antibiotic in pairs if bact == bacteria_name]
    
    # Get the row for the subject
    subject_row = app_df[app_df['subject_id'] == subject_id]
    
    if subject_row.empty:
        raise ValueError(f"No data found for subject_id {subject_id}")
    
    # Drop the subject_id column to pass only feature columns to the model
    subject_features = subject_row.drop(columns=['subject_id'])
    
    results = {}
    
    for antibiotic in relevant_pairs:
        model_key = f"rf_{bacteria_name}_{antibiotic}"
        model = trained_models.get(model_key)
        
        if model:
            proba = model.predict_proba(subject_features)[0][1]  # probability of class 1 (resistant)
            results[antibiotic] = proba
        else:
            results[antibiotic] = None  # Or you can skip this key or log a warning
            
    return results


In [27]:
bacteria_name = "ESCHERICHIA COLI"

predicted_probs = predict_antibiotic_probs(subj1, bacteria_name, model_df, trained_models, pairs)
predicted_probs

{'AMPICILLIN': 0.25733237547892723,
 'AMPICILLIN/SULBACTAM': 0.4661553817254408,
 'CEFAZOLIN': 0.7313697318763108,
 'CEFEPIME': 0.8025651292413909,
 'CEFTAZIDIME': 0.9255983751373836,
 'CEFTRIAXONE': 0.8745250247215905,
 'CIPROFLOXACIN': 0.7336086462988456,
 'GENTAMICIN': 0.42087285319929213,
 'MEROPENEM': 0.9990333770533246,
 'NITROFURANTOIN': 0.9525936786340333,
 'TOBRAMYCIN': 0.9081741712263035,
 'TRIMETHOPRIM/SULFA': 0.3651066070298138,
 'PIPERACILLIN/TAZO': 0.9755698624578442}

### Get SHAP values

In [43]:
def get_shap_values_for_model(subject_id, model_name, model_df, trained_models):
    import shap
    import numpy as np

    model = trained_models.get(model_name)
    if model is None:
        raise ValueError(f"Model '{model_name}' not found in trained_models.")
    
    subject_row = model_df[model_df['subject_id'] == subject_id]
    if subject_row.empty:
        raise ValueError(f"No data found for subject_id {subject_id}")
    
    X_subject = subject_row.drop(columns=['subject_id'])

    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(X_subject)

    # shap_values shape: (1, num_features, 2) -> [sample, feature, class]
    if isinstance(shap_values, np.ndarray) and shap_values.ndim == 3:
        shap_row = shap_values[0, :, 1]  # first sample, class 1
    else:
        raise ValueError(f"Unexpected SHAP output shape: {shap_values.shape}")

    feature_names = X_subject.columns.tolist()

    if len(feature_names) != len(shap_row):
        raise ValueError(f"Feature count ({len(feature_names)}) and SHAP value count ({len(shap_row)}) do not match.")

    # Sort by importance (absolute value of SHAP)
    sorted_shap = sorted(zip(feature_names, shap_row), key=lambda x: abs(x[1]), reverse=True)
    return dict(sorted_shap)


In [45]:
shap_vals = get_shap_values_for_model(subj1, "rf_ESCHERICHIA COLI_AMPICILLIN", model_df, trained_models)
shap_vals


{'SKN004': -0.032163525621358005,
 'M03BX': -0.030162918888260143,
 'N02BF': -0.025820242835705772,
 'EXT027': -0.021657510197842685,
 'CIR028': -0.02083403128775645,
 'J01CR': -0.019577963137232757,
 'C07AG': -0.01762848869781685,
 'MUS002': 0.013910462276088038,
 'CIR031': 0.012937463375419964,
 'Blood - Alkaline Phosphatase': 0.01136688273307464,
 'B01AB': 0.011299777087352128,
 'Blood - Hematocrit': -0.010924142349540793,
 'FAC015': -0.010472793891433903,
 'A07AA': 0.01027594558469156,
 'B05BA': 0.00990029187818088,
 'Blood - Hemoglobin': -0.009285144555957137,
 'B05AA': 0.007455627893106689,
 'J01EC': 0.007451888577747106,
 'N02AA': -0.006774586367784406,
 'Urine - Protein': 0.0064496731253173055,
 'END005': -0.0064078644924012084,
 'FAC009': -0.006219956461814838,
 'Blood - Asparate Aminotransferase (AST)': -0.005896144282983302,
 'SKN003': -0.005720062397761761,
 'N02BA': -0.005671039100712286,
 'N02BE': 0.005303270427787009,
 'NVS017': 0.005217460409210828,
 'FAC022': 0.0051561

### Get SHAP Values for all models for a selected bacteria

In [46]:
def get_all_shap_values_for_subject(subject_id, bacteria_name, model_df, trained_models, pairs):
    """
    Returns sorted SHAP values for all antibiotic models for a given subject and bacteria.

    Parameters:
    - subject_id (int): ID of the subject in model_df
    - bacteria_name (str): Name of the bacteria
    - model_df (pd.DataFrame): DataFrame containing model input features
    - trained_models (dict): Dictionary of trained models
    - pairs (list): List of (bacteria, antibiotic) pairs

    Returns:
    - dict: {antibiotic: {feature_name: shap_value}}, sorted by importance for each antibiotic
    """
    relevant_antibiotics = [antibiotic for bact, antibiotic in pairs if bact == bacteria_name]
    all_shap_results = {}

    for antibiotic in relevant_antibiotics:
        model_key = f"rf_{bacteria_name}_{antibiotic}"
        model = trained_models.get(model_key)
        if model:
            try:
                shap_dict = get_shap_values_for_model(subject_id, model_key, model_df, trained_models)
                all_shap_results[antibiotic] = shap_dict
            except Exception as e:
                all_shap_results[antibiotic] = {"error": str(e)}
        else:
            all_shap_results[antibiotic] = None  # or {"error": "Model not found"}

    return all_shap_results


In [52]:
# Print Top 5 SHAP Values for every antibiotic
def print_top_shap_features(shap_explanations, top_n=5):
    """
    Prints the top N SHAP features per antibiotic from a SHAP explanation dictionary.

    Parameters:
    - shap_explanations (dict): Output from get_all_shap_values_for_subject
    - top_n (int): Number of top features to display per antibiotic
    """
    for antibiotic, shap_dict in shap_explanations.items():
        print(f"\nTop {top_n} features for {antibiotic}:")
        
        if shap_dict is None:
            print("  Model not found.")
            continue
        if "error" in shap_dict:
            print(f"  Error: {shap_dict['error']}")
            continue

        for feature, value in list(shap_dict.items())[:top_n]:
            print(f"  {feature}: {value:.4f}")



In [53]:
print_top_shap_features(shap_explanations_subj1, top_n=5)


Top 5 features for AMPICILLIN:
  SKN004: -0.0322
  M03BX: -0.0302
  N02BF: -0.0258
  EXT027: -0.0217
  CIR028: -0.0208

Top 5 features for AMPICILLIN/SULBACTAM:
  B01AB: -0.0072
  EXT025: -0.0054
  procedure_ct: -0.0052
  C03CA: -0.0043
  J01CR: -0.0042

Top 5 features for CEFAZOLIN:
  INF003: -0.0163
  days_since_last_proc: 0.0158
  N02BF: -0.0142
  Blood - Hematocrit: 0.0121
  Blood - RDW: 0.0115

Top 5 features for CEFEPIME:
  B05CX: -0.0340
  J01DE: -0.0250
  N01BB: -0.0214
  days_since_last_proc: 0.0162
  D01AA: -0.0157

Top 5 features for CEFTAZIDIME:
  days_since_last_proc: 0.0118
  FAC015: -0.0109
  P01AB: -0.0085
  Blood - Hematocrit: 0.0085
  N02BF: -0.0062

Top 5 features for CEFTRIAXONE:
  days_since_last_proc: 0.0163
  Blood - Red Blood Cells: 0.0104
  FAC015: -0.0099
  Blood - Hemoglobin: 0.0095
  Blood - Hematocrit: 0.0088

Top 5 features for CIPROFLOXACIN:
  Blood - RDW: 0.0228
  days_since_last_proc: 0.0219
  Blood - Hematocrit: 0.0165
  Blood - Hemoglobin: 0.0135
  p

In [54]:
# For a single antibiotic:
shap_explanations_subj1['AMPICILLIN']

{'SKN004': -0.032163525621358005,
 'M03BX': -0.030162918888260143,
 'N02BF': -0.025820242835705772,
 'EXT027': -0.021657510197842685,
 'CIR028': -0.02083403128775645,
 'J01CR': -0.019577963137232757,
 'C07AG': -0.01762848869781685,
 'MUS002': 0.013910462276088038,
 'CIR031': 0.012937463375419964,
 'Blood - Alkaline Phosphatase': 0.01136688273307464,
 'B01AB': 0.011299777087352128,
 'Blood - Hematocrit': -0.010924142349540793,
 'FAC015': -0.010472793891433903,
 'A07AA': 0.01027594558469156,
 'B05BA': 0.00990029187818088,
 'Blood - Hemoglobin': -0.009285144555957137,
 'B05AA': 0.007455627893106689,
 'J01EC': 0.007451888577747106,
 'N02AA': -0.006774586367784406,
 'Urine - Protein': 0.0064496731253173055,
 'END005': -0.0064078644924012084,
 'FAC009': -0.006219956461814838,
 'Blood - Asparate Aminotransferase (AST)': -0.005896144282983302,
 'SKN003': -0.005720062397761761,
 'N02BA': -0.005671039100712286,
 'N02BE': 0.005303270427787009,
 'NVS017': 0.005217460409210828,
 'FAC022': 0.0051561