In [1]:
#running python 3.9.6
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["PYTHONHASHSEED"] = "42"
os.environ["CUDA_VISIBLE_DEVICES"] = ""

import sys
import json
import re
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from tableone import TableOne
from functools import reduce, partial

import statistics
import scipy.stats as stats
from statsmodels.stats.multitest import multipletests
from sklearn.metrics import confusion_matrix

#xgboost
from xgboost import XGBClassifier
from xgboost import plot_importance

# Evaluation of models
from sklearn.model_selection import train_test_split, StratifiedKFold, GroupKFold, cross_val_score
from sklearn.metrics import (
    roc_auc_score, roc_curve, accuracy_score, precision_score, recall_score, f1_score,
    precision_recall_curve, auc, confusion_matrix, average_precision_score, classification_report, ConfusionMatrixDisplay
)
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import calibration_curve
from itertools import combinations
from collections import defaultdict

import pyroc
import optuna
import random
import statannotations.Annotator

import shap

sns.set_context("paper")
mpl.rcParams['pdf.fonttype'] = 42  # edit-able in illustrator
mpl.rcParams['font.sans-serif'] = 'Arial'
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)
        
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"



pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)

In [2]:
def set_all_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    
set_all_seeds(42)

## Functions

In [3]:
def prettify_feature_names(columns):
    # handle specific names
    special_tokens = ['gcs', 'fio2', 'po2', 'rass', 'pco2', 'wbc', 'rdw', 'ldh', 'sofa', 'bmi']
    
    pretty_names = []

    for col in columns:
        if col == 'sofa_points_htn':
            pretty_names.append('Cardiovascular SOFA Score')
        elif col.startswith('po2_fio2_ratio_'):
            suffix = col.split('_')[-1].capitalize()
            pretty_names.append(r'PaO$_2$/FiO$_2$ ' + suffix)
        elif col.startswith('po2_art_'):
            suffix = col.split('_')[-1].capitalize()
            pretty_names.append(r'PaO$_2$ ' + suffix)
        elif col.startswith('abs_lymphocytes_'):
            suffix = col.split('_')[-1].capitalize()
            pretty_names.append('Absolute Lymphocytes ' + suffix)
        elif col == 'edw_adm_age':
            pretty_names.append('Age')
        elif col.startswith('pco2_art_'):
            suffix = col.split('_')[-1].capitalize()
            pretty_names.append(r'PaCO$_2$ ' + suffix)
        elif col.startswith('fio2_'):
            suffix = col.split('_')[-1].capitalize()
            pretty_names.append(r'FiO$_2$ ' + suffix)
        else:
            parts = col.split('_')
            pretty_parts = []
            for part in parts:
                if part.lower() in special_tokens:
                    pretty_parts.append(part.upper())
                else:
                    pretty_parts.append(part.capitalize())
            pretty_names.append(' '.join(pretty_parts))
    return pretty_names

# SHAP plot
# Initialize SHAP explainer
def plot_shap(model,data, plot_title, plot=True, save_path=None, max_display=10):
    explainer = shap.Explainer(model)
    shap_values = explainer(data)

    # get cleaner column names
    pretty_names = prettify_feature_names(data.columns)

    # Plot SHAP values
    if plot:
        new_title =  plot_title.replace('_', ' ')
        plt.figure(figsize=(10, 6))
        shap.summary_plot(shap_values, data, feature_names=pretty_names, max_display=max_display, show=False)
        plt.title(new_title, fontsize=14)
        plt.tight_layout()
        if save_path is not None:
            plt.savefig(save_path, bbox_inches='tight', dpi=300)

        # plt.show()
        plt.close()
    return shap_values

# ROC plot
def plot_auroc(y_train, y_test, y_train_pred_proba, y_test_pred_proba, title,
               fontsize=14, figsize=(10, 10)):

    y_train_preds = y_train_pred_proba
    y_test_preds = y_test_pred_proba

    y_train = y_train.astype(int)
    y_test = y_test.astype(int)

    # ROC curve data
    fpr1, tpr1, _ = roc_curve(y_train, y_train_preds)
    fpr2, tpr2, _ = roc_curve(y_test, y_test_preds)

    auc1 = roc_auc_score(y_train, y_train_preds)
    auc2 = roc_auc_score(y_test, y_test_preds)


    fig, ax = plt.subplots(1,1, figsize=figsize)
    ax.plot(fpr1, tpr1, color='C1', label="Train set (AUROC = "+str(round(auc1, 3))+")")
    ax.plot(fpr2, tpr2, color='C2', label="Test set (AUROC = "+str(round(auc2, 3))+")")

    # Add line for random chance
    ax.plot([0, 1], [0, 1], color='black', linestyle='--', label='Random Chance')

    ax.set_ylabel('True Positive Rate', fontsize=fontsize)
    ax.set_xlabel('False Positive Rate', fontsize=fontsize)
    ax.tick_params(axis='x', labelsize=fontsize - 2)
    ax.tick_params(axis='y', labelsize=fontsize - 2)
    ax.grid(linestyle=':')
    ax.legend(loc='best', fontsize=fontsize - 2)
    ax.set_title(title, fontsize=fontsize)
    plt.tight_layout()
    plt.show()

def plot_shap_customized_axis(model,data, plot_title, 
                              left_label=None, right_label=None,
                              plot=True, save_path=None, 
                              max_display=10):
    explainer = shap.Explainer(model)
    shap_values = explainer(data)

    # get cleaner column names
    pretty_names = prettify_feature_names(data.columns)

    # Plot SHAP values
    if plot:
        plt.figure(figsize=(10, 6))
        shap.summary_plot(
            shap_values, data, feature_names=pretty_names,
            max_display=max_display, show=False
        )

        # Axis and colorbar tweaks
        ax = plt.gca()
        fig = plt.gcf()
        ax.tick_params(axis='y', labelsize=18)
        ax.tick_params(axis='x', labelsize=12)
        
        # Colorbar tweak
        if fig.axes and (left_label is not None or right_label is not None):
            colorbar_ax = fig.axes[-1]
            colorbar_ax.tick_params(labelsize=15)
            colorbar_ax.set_yticklabels(['Low/No', 'High/Yes'])
            colorbar_ax.yaxis.label.set_size(10)

            # Custom left/right labels under the colorbar
            if left_label is not None:
                ax.text(-0.1, -0.18, left_label,
                    transform=ax.transAxes,
                    fontsize=15,
                    horizontalalignment='left',
                    fontweight='bold')

            if right_label is not None:
                ax.text(1.1, -0.18, right_label,
                    transform=ax.transAxes,
                    fontsize=15,
                    horizontalalignment='right',
                    fontweight='bold')

        plt.title(plot_title.replace('_', ' '), fontsize=20, pad=20)
        # plt.xlabel('SHAP value (impact on model output)', fontsize=14)
        ax.xaxis.label.set_visible(False)
        plt.tight_layout()
        # plt.subplots_adjust(bottom=0.22) 

        if save_path is not None:
            plt.savefig(save_path, bbox_inches='tight', dpi=300)

        # plt.show()
        plt.close()

def cross_validation_metrics(best_model, X, y, episode_ids, n_splits=5):
    # Set up n-fold cross-validation
    gkf = GroupKFold(n_splits=n_splits)

    # Arrays to store metrics
    tprs = []
    aucs = []
    mean_fpr = np.linspace(0, 1, 100)
    precisions = []
    aps = []
    mean_recall = np.linspace(0, 1, 100)


    for train_index, val_index in gkf.split(X, y, groups=episode_ids):
        X_tr, X_val = X.iloc[train_index], X.iloc[val_index]
        y_tr, y_val = y.iloc[train_index], y.iloc[val_index]
        
        # Fit model
        best_model.fit(X_tr, y_tr)
        
        # Predict probabilities
        y_val_preds = best_model.predict_proba(X_val)[:, 1]
        
        # Compute ROC curve and AUC
        fpr, tpr, thresholds = roc_curve(y_val.astype(int), y_val_preds)
        auc_score = auc(fpr, tpr)
        aucs.append(auc_score)
        
        # Interpolate TPR
        interp_tpr = np.interp(mean_fpr, fpr, tpr)
        interp_tpr[0] = 0.0
        tprs.append(interp_tpr)
        

        # PRC
        precision, recall, _ = precision_recall_curve(y_val.astype(int), y_val_preds)
        ap_score = average_precision_score(y_val.astype(int), y_val_preds)
        aps.append(ap_score)
        # Interpolate precision to mean_recall grid
        interp_precision = np.interp(mean_recall, recall[::-1], precision[::-1])  # flip for increasing recall
        precisions.append(interp_precision)


    # Compute mean TPR and AUC
    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0 
    mean_auc = auc(mean_fpr, mean_tpr)
    std_auc = np.std(aucs)
    std_tpr = np.std(tprs, axis=0)

    # Compute means/std for PRC
    mean_precision = np.mean(precisions, axis=0)
    std_precision = np.std(precisions, axis=0)
    mean_ap = np.mean(aps)
    std_ap = np.std(aps)

    return {
        "aurocs": aucs,
        "mean_fpr": mean_fpr,
        "mean_tpr": mean_tpr,
        "std_tpr": std_tpr,
        "aps": aps,
        "mean_recall": mean_recall,
        "mean_precision": mean_precision,
        "std_precision": std_precision
    }

# get metric results
def get_metrics_dict(results):
    metrics = results['metrics']
    
    accuracy = metrics.get('accuracy', None)
    precision = metrics.get('ppv', None) or metrics.get('precision', None)
    recall = metrics.get('sensitivity', None) or metrics.get('recall', None)
    specificity = metrics.get('specificity', None)
    f1 = metrics.get('f1', None) or metrics.get('f1-score', None)
    auroc = metrics.get('auroc', None)
    aupr = metrics.get('aupr', None)
    npv = metrics.get('npv', None)

    return {
        "Accuracy": accuracy,
        "Precision": precision,
        "Recall": recall,
        "Specificity": specificity,
        "F1-score": f1,
        "AUROC": auroc,
        "AUPR": aupr,
        "NPV": npv,
    }

# File Path and Read Data

In [4]:
# Input file (Please replace the path to your file)
input_path = "./data/input.csv"

# Patient to exclude (A file include patient id to exclude, default column name is id, please chnage to the name meet your file)
exclude_path = "./data/exclude_list.csv"

# Define id column
id_col = 'pt_study_id'

# Output path
output_path = './output/plots'

In [5]:
data = pd.read_csv(input_path)
exclude_ids = pd.read_csv(exclude_path)['id'].tolist()

# Exlcude patient not meet the criteria
data = data.loc[~data[id_col].isin(exclude_ids), :]
print("Number of Unique Ids:", data[id_col].nunique())

Number of Unique Ids: 702


  data = pd.read_csv(input_path)


# 1. Data preparation

## 1.1 Episode cleaning

In [6]:
# clean episode resolution
def get_episode_resolution_l2(x):
    if pd.isnull(x):
        return x
    elif 'indeterminate' in x.lower():
        return 'indeterminate'        
    elif 'cure' in x.lower():
        return 'cured'
    elif 'super' in x.lower() or 'persistence' in x.lower(): 
        return 'not cured'  

data['episode_resolution_d7_l2'] = data['clin_impression_d78'].apply(get_episode_resolution_l2)
data['episode_resolution_d10_l2'] = data['clin_impression_d10'].apply(get_episode_resolution_l2)
data['episode_resolution_d14_l2'] = data['clin_impression_d14'].apply(get_episode_resolution_l2)

### 1.1.1 Exclude viral cases

In [7]:
idx = data.episode_etiology.eq('Viral')
viral_episodes_id = data.loc[idx, 'episode_id'].unique()
data = data[~data['episode_id'].isin(viral_episodes_id)].copy()

In [8]:
# Get patient without day 7 episode resolution
idx = (
    data.episode_etiology.notna() 
    & data.episode_etiology.ne('Viral') 
    & data.episode_resolution_d7_l2.isna()
)
episodes_without_d7 = data.episode_id[idx]

### 1.1.2 Create episode_day column

In [9]:
data['episode_day'] = data.groupby(['pt_study_id','episode_id']).cumcount() + 1

### 1.1.3 Update episode resolution d7 with cured column

* for those patients have < 7 day episode, update their d7 using cured column

In [10]:
for ep in episodes_without_d7:
    idx = data.episode_id.eq(ep)
    if idx.sum() < 7:
        data.loc[idx, 'episode_resolution_d7_l2'] = data.cured[idx].str.lower()

* for cases with all missing values in clinical adjudication columns, updated as global clinical cure

In [11]:
idx = (
    (data['cured'].notna()) &
    (data[['episode_resolution_d7_l2', 'episode_resolution_d10_l2', 'episode_resolution_d14_l2']].isna().all(axis=1)) &
    (data['episode_duration']>=7) & (data['episode_duration'].notna())
)
    
missing_d7_id = data.loc[idx,'episode_id'].unique()
data.loc[idx, 'episode_resolution_d7_l2'] = data.loc[idx, 'cured'].str.lower()

print("Patients with all missing clinical adjudication:", len(missing_d7_id))

Patients with all missing clinical adjudication: 27


### 1.1.4 keep patients have at least 3 consecutive episode days

In [12]:
def is_consecutive_episode_day(group, day):
    consecutive_day_list = list(range(1, day+1))
    if len(group) < day:
        return False
    return sorted(group['episode_day'].head(day).tolist()) == consecutive_day_list


n_day = 3
data = data.sort_values(by=['pt_study_id', 'admission_datetime', 'episode_id', 'episode_day'])
mask = (
    data.groupby(['pt_study_id', 'episode_id'])
        .apply(lambda group: is_consecutive_episode_day(group, day=n_day))
        .reset_index(name='has_consecutive_day')
)
idx = mask[mask['has_consecutive_day'] == True][['pt_study_id', 'episode_id']]
data = data.merge(idx, on=['pt_study_id', 'episode_id'])

  data.groupby(['pt_study_id', 'episode_id'])


In [13]:
#ffill episode resolution column
mask = data['episode_id'].notna()
data.loc[mask, 'episode_resolution_d7_l2'] = (
    data.loc[mask]
    .groupby('episode_id')['episode_resolution_d7_l2']
    .transform(lambda x: x.fillna(x.dropna().iloc[0]) if x.notna().any() else x)
)

### 1.1.5 Convert categorical column

In [14]:
mask = data['episode_id'].notna()
data.loc[mask, 'episode_category'] = (
    data.loc[mask]
    .groupby('episode_id')['episode_category']
    .transform(lambda x: x.fillna(x.dropna().iloc[0]) if x.notna().any() else x)
)

data['episode_category'] = data['episode_category'].astype('category')

### 1.1.6 Binarize outcome variables

In [15]:
data['d7_episode_resolution_binary'] = np.where(data.episode_resolution_d7_l2=='cured',0 ,1)

## 1.2 Get CAP only data

In [16]:
#keeping CAP episodes only
ep_cap = data.episode_id[data.episode_category == 'CAP']
cap = data.loc[data.episode_id.isin(ep_cap)].copy()

#longest duration of CAP episode
print("Longest CAP episode duration day:", cap.episode_duration.max())

Longest CAP episode duration day: 30.0


## 1.3 Get day1 to day3  data

In [17]:
day_list = list(range(1, 4))
cap_1to3 = cap[cap.episode_day.isin(day_list)].copy()

In [18]:
static_cols = [
                'pt_study_id',
                'episode_day',
                'episode_id',
                'd7_episode_resolution_binary'
]

demo_feature = [
    'edw_adm_age',
    'bmi'
]


lab_feature = [
            # 'wbc_avg',
            'wbc_max',
            # 'abs_lymphocytes_avg',
            'abs_lymphocytes_max',
            # 'abs_neutrophils_avg',
            'abs_neutrophils_max',
            # 'abs_eosinophils_avg',
            'abs_eosinophils_max',
            # 'hemoglobin_avg',
            'hemoglobin_min',
            # 'platelet_avg',
            'platelet_min',
            # 'rdw_avg',
            'rdw_max',
            # 'sodium_avg',
            'sodium_min',
            # 'bicarbonate_avg',
            'bicarbonate_min',
            # 'creatinine_avg',
            'creatinine_max',
            # 'bun_avg',
            'bun_max',
            # 'glucose_avg',
            'glucose_min',
            # 'albumin_avg',
            'albumin_min',
            # 'bilirubin_avg',
            'bilirubin_max',
            # 'pt_avg',
            'pt_max',
            # 'ptt_avg',
            'ptt_max',
            # 'crp_avg',
            'crp_max',
            # 'lactic_acid_avg',
            'lactic_acid_max',
            # 'ldh_avg',
            'ldh_max',
            'sofa_points_p_f_ratio',
            'sofa_points_platelet',
            'sofa_points_bilirubin',
            'sofa_points_htn',
            'sofa_points_gcs',
            'sofa_points_renal',
            'sofa_score'
]

ventilator_feature = [
    'peep_max',
    'plateau_pressure_max',
    'peak_airway_pressure_max',
    'minute_ventilation_max',
    'driving_pressure',
    'ph_art_min',
    'pco2_art_max'
]

med_feature = [
                'hydrocort_equivalent_steroid_dose_day', 
                'cumulative_steroid_dose_until_today',
                'cumulative_nat_score',
                'received_remdesivir', 
                'remdesivir_study_drug'
]

halms_feature = [
                    'temperature_avg', 
                    'temperature_max', 
                    'heart_rate_avg', 
                    'heart_rate_max', 
                    'systolic_blood_pressure_avg', 
                    'systolic_blood_pressure_min', 
                    'diastolic_blood_pressure_avg', 
                    'diastolic_blood_pressure_min', 
                    'respiratory_rate_avg', 
                    'respiratory_rate_max', 
                    'oxygen_saturation_avg', 
                    'oxygen_saturation_min', 
                    'norepinephrine_avg', 
                    'norepinephrine_max', 
                    'rass_avg', 
                    'rass_min', 
                    'gcs_eye_opening_avg', 
                    'gcs_motor_response_avg', 
                    'gcs_verbal_response_avg', 
                    'gcs_eye_opening_min', 
                    'gcs_motor_response_min', 
                    'gcs_verbal_response_min', 
                    'fio2_avg', 
                    'fio2_max', 
                    'po2_art_avg', 
                    'po2_art_min', 
                    'po2_fio2_ratio_avg', 
                    'po2_fio2_ratio_min', 
]



# all variables to keep
feature_columns_all = static_cols + halms_feature + lab_feature + ventilator_feature + med_feature + demo_feature
cap_1to3_sub = cap_1to3[feature_columns_all].copy()

## 1.4 Create binning columns for Halm's criteria

In [19]:
# create gcs_total first
cap_1to3_sub['gcs_total_avg'] = cap_1to3_sub['gcs_eye_opening_avg'] + cap_1to3_sub['gcs_motor_response_avg'] + cap_1to3_sub['gcs_verbal_response_avg']
cap_1to3_sub['gcs_total_min'] = cap_1to3_sub['gcs_eye_opening_min'] + cap_1to3_sub['gcs_motor_response_min'] + cap_1to3_sub['gcs_verbal_response_min']

In [20]:
# Set HR > 100 as 1 (Halms)
cap_1to3_sub['heart_rate_bin'] = np.where(cap_1to3_sub.heart_rate_avg > 100, 1, 0)

# Set RR > 24 as 1 (Halms)
cap_1to3_sub['respiratory_rate_bin'] = np.where(cap_1to3_sub.respiratory_rate_avg > 24, 1, 0)

# Set SBP >= 90 as 1 (Halms)
cap_1to3_sub['systolic_blood_pressure_bin'] = np.where(cap_1to3_sub.systolic_blood_pressure_avg >= 90, 1, 0)

# Set stable temperature (<= 99F and > 95.9F) as 0 (modified Halms)
cap_1to3_sub['temperature_bin'] = np.where((cap_1to3_sub.temperature_avg <= 95.9) | (cap_1to3_sub.temperature_avg > 99), 1, 0)

# Set oxygen saturation < 90 as 1
cap_1to3_sub['oxygen_saturation_bin'] = np.where(cap_1to3_sub.oxygen_saturation_avg < 90, 1, 0)

# Set RASS >= -2 as 1
cap_1to3_sub['rass_score_bin'] = np.where(cap_1to3_sub.rass_avg >= -2, 1, 0)

# Set GCS >= 11 as 1 (this is a little random, but setting cutoff at 10 gives better GCS contribution than setting cutoff at 12)
cap_1to3_sub['gcs_total_bin'] = np.where(cap_1to3_sub.gcs_total_avg >= 11, 1, 0)

# Set pao2fio2_ratio >= 400 as 1
cap_1to3_sub['po2_fio2_ratio_bin'] = np.where(cap_1to3_sub.po2_fio2_ratio_avg >= 400, 1, 0)

# Set norepinephrine_flag if not 0
cap_1to3_sub['norepinephrine_flag'] = np.where(cap_1to3_sub.norepinephrine_avg > 0, 1, 0)

# Clinical stability is defined as halms_vs = 0, otherwise 1
cap_1to3_sub['halms_complex'] = np.where((cap_1to3_sub.heart_rate_bin == 0) &
                                        (cap_1to3_sub.respiratory_rate_bin == 0) &
                                        (cap_1to3_sub.systolic_blood_pressure_bin == 0) &
                                        (cap_1to3_sub.temperature_bin == 0) &
                                        (cap_1to3_sub.oxygen_saturation_bin == 0) &
                                        (cap_1to3_sub.rass_score_bin == 0) &
                                        (cap_1to3_sub.gcs_total_bin == 0) &
                                        (cap_1to3_sub.po2_fio2_ratio_bin == 0) &
                                        (cap_1to3_sub.norepinephrine_flag == 0), 0, 1)


# Clinical stability is defined as halms_vs = 0, otherwise 1
cap_1to3_sub['halms_vs'] = np.where((cap_1to3_sub.heart_rate_bin == 0) &
                                (cap_1to3_sub.respiratory_rate_bin == 0) &
                                (cap_1to3_sub.systolic_blood_pressure_bin == 0) &
                                (cap_1to3_sub.temperature_bin == 0) &
                                (cap_1to3_sub.oxygen_saturation_bin == 0), 0, 1)

# Clinical stability is defined as halms_pred = 0, otherwise 1
cap_1to3_sub['halms_pred'] = cap_1to3_sub['halms_complex']

## 1.5 Define Halm's related variables group

### Variables 1: Halms binning

In [21]:
halms_bin_col = [
    'temperature_bin',
    'heart_rate_bin',
    'systolic_blood_pressure_bin',
    'norepinephrine_flag',
    'respiratory_rate_bin',
    'oxygen_saturation_bin',
    'po2_fio2_ratio_bin',
    'rass_score_bin',
    'gcs_total_bin'
] + static_cols

### Variables 2: No binning, avg and worst Halm's variables

In [22]:
halms_feature_sub = [
                   'temperature_avg', 
                   'temperature_max', 
                   'heart_rate_avg', 
                   'heart_rate_max', 
                   'systolic_blood_pressure_avg', 
                   'systolic_blood_pressure_min', 
                   'respiratory_rate_avg', 
                   'respiratory_rate_max', 
                   'oxygen_saturation_avg', 
                   'oxygen_saturation_min', 
                   'norepinephrine_avg', 
                   'norepinephrine_max', 
                   'rass_avg', 
                   'rass_min', 
                   'gcs_eye_opening_avg', 
                   'gcs_motor_response_avg', 
                   'gcs_verbal_response_avg',
                   'gcs_eye_opening_min', 
                   'gcs_motor_response_min', 
                   'gcs_verbal_response_min', 
                   'fio2_avg', 
                   'fio2_max', 
                   'po2_art_avg', 
                   'po2_art_min', 
                   'po2_fio2_ratio_avg', 
                   'po2_fio2_ratio_min',
] + static_cols

### Variables 3: No binning, avg Halm's variables

In [23]:
halms_feature_avg = [c for c in halms_feature_sub if c.endswith('_avg')] + static_cols

### Variables 4: No binning, Worst Halm's variables

In [24]:
halms_feature_worst = [c for c in halms_feature_sub if not c.endswith('_avg')]

# 2. Model training

In [25]:
# initialize result dictionary to store all model results
results = {}

### 2.0 Create baseline (Halm's criteria)

In [None]:
y_true = cap_1to3_sub['d7_episode_resolution_binary']
y_pred = cap_1to3_sub['halms_pred']

# For rule-based methods, predicted "probabilities" are just the predicted class
y_pred_prob = y_pred

tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
sensitivity = recall_score(y_true, y_pred)
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
ppv = precision_score(y_true, y_pred)
npv = tn / (tn + fn) if (tn + fn) > 0 else 0
f1 = f1_score(y_true, y_pred)

metrics = {
    'accuracy': accuracy_score(y_true, y_pred),
    'auroc': roc_auc_score(y_true, y_pred) if len(np.unique(y_true)) > 1 else None,
    'confusion_matrix': confusion_matrix(y_true, y_pred),
    'report': classification_report(y_true, y_pred, output_dict=True),
    'sensitivity': sensitivity,
    'specificity': specificity,
    'ppv': ppv,
    'npv': npv,
    'f1': f1,
}

baseline_results = {
    'y_test': y_true,
    'test_pred': y_pred,
    'test_pred_prob': y_pred_prob,
    'metrics': metrics,
    'model': None,
    'cv_results': None,
}

results['halms_baseline'] = baseline_results

## 2.1 Define fine-tuning and model pipeline

In [27]:
def fine_tune_and_train_xgb(X_train, y_train, pt_ids, n_trials=50, random_state=42):
    def objective(trial, X, y, pt_ids):
        params = {
            "n_estimators": trial.suggest_int("n_estimators", 50, 300),
            "max_depth": trial.suggest_int("max_depth", 2, 10),
            "learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3),
            "subsample": trial.suggest_float("subsample", 0.5, 1.0),
            "colsample_bytree": trial.suggest_float("colsample_bytree", 0.5, 1.0),
            "eval_metric": "logloss",
            "random_state": random_state,
            "verbosity": 0,
            "enable_categorical": True,
            "tree_method": "hist"
        }
        model = XGBClassifier(**params)
        gkf = GroupKFold(n_splits=5)
        train_scores = []
        val_scores = []
        
        for train_idx, val_idx in gkf.split(X, y, groups=pt_ids):
            X_train_fold = X.iloc[train_idx]
            X_val_fold = X.iloc[val_idx]
            y_train_fold = y.iloc[train_idx]
            y_val_fold = y.iloc[val_idx]
            
            model.fit(X_train_fold, y_train_fold)
            train_pred = model.predict_proba(X_train_fold)[:, 1]
            val_pred = model.predict_proba(X_val_fold)[:, 1]
        
            train_auc = roc_auc_score(y_train_fold, train_pred)
            val_auc = roc_auc_score(y_val_fold, val_pred)
            
            train_scores.append(train_auc)
            val_scores.append(val_auc)
    
        avg_train = np.mean(train_scores)
        avg_val = np.mean(val_scores)

        overfitting_penalty = max(0, avg_train - avg_val - 0.1)
        return avg_val - overfitting_penalty

    sampler = optuna.samplers.TPESampler(seed=random_state)
    study = optuna.create_study(direction="maximize", sampler=sampler)
    study.optimize(partial(objective, X=X_train, y=y_train, pt_ids=pt_ids), n_trials=n_trials)

    best_params = study.best_params
    best_params.update({
        "eval_metric": "logloss",
        "random_state": random_state,
        "verbosity": 0,
        "enable_categorical": True,
        "tree_method": "hist"
    })

    model = XGBClassifier(**best_params)
    model.fit(X_train, y_train)

    return model, best_params


In [28]:
def evaluate_model(model, X_train, X_test, y_train, y_test, threshold=0.5):
    # Predict probabilities
    y_train_pred_proba = model.predict_proba(X_train)[:, 1]
    y_test_pred_proba = model.predict_proba(X_test)[:, 1]

    # Predict class using threshold
    y_test_pred = (y_test_pred_proba >= threshold).astype(int)

    # Confusion matrix for specificity/NPV/etc.
    tn, fp, fn, tp = confusion_matrix(y_test, y_test_pred).ravel()

    # Metrics
    accuracy = accuracy_score(y_test, y_test_pred)
    sensitivity = recall_score(y_test, y_test_pred)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    auroc = roc_auc_score(y_test, y_test_pred_proba)
    f1 = f1_score(y_test, y_test_pred)
    ppv = precision_score(y_test, y_test_pred)
    npv = tn / (tn + fn) if (tn + fn) > 0 else 0

    precision, recall, _ = precision_recall_curve(y_test, y_test_pred_proba)
    aupr = auc(recall, precision)

    print(f"Test Set Evaluation")
    print(f"Accuracy: {accuracy:.3f}")
    print(f"Sensitivity: {sensitivity:.3f}")
    print(f"Specificity: {specificity:.3f}")
    print(f"AUROC: {auroc:.3f}")
    print(f"F1 Score: {f1:.3f}")
    print(f"AUPR: {aupr:.3f}")
    print(f"PPV: {ppv:.3f}")
    print(f"NPV: {npv:.3f}")

    return {
        "train_pred": y_train_pred_proba,
        "test_pred": y_test_pred_proba,
        "y_test_pred": y_test_pred,
        "metrics": {
            "accuracy": accuracy,
            "sensitivity": sensitivity,
            "specificity": specificity,
            "auroc": auroc,
            "f1": f1,
            "aupr": aupr,
            "ppv": ppv,
            "npv": npv,
            "confusion_matrix": confusion_matrix(y_test, y_test_pred)
        }
    }

In [29]:
def run_xgb_model_pipeline(model_name, 
                       X_train, X_test, y_train, y_test, pt_ids_train, n_splits=10,
                       n_trials=30, random_state=42, threshold=0.5, plot_sp=True, 
                       max_display=10, save_path=False):
    
    results = {}
    
    # Step 1: Make a copy
    results['X_train'] = X_train
    results['X_test'] = X_test
    results['y_train'] = y_train
    results['y_test'] = y_test

    # Step 2: Fine-tune and train
    model, best_params = fine_tune_and_train_xgb(X_train, y_train, pt_ids_train,
                                                 n_trials=n_trials, random_state=random_state)
    results['model'] = model
    results['best_parameters'] = best_params

    # Step 3: Evaluation
    eval_result = evaluate_model(model, X_train, X_test,
                                 y_train, y_test, threshold=threshold)
    results['metrics'] = eval_result['metrics']
    results['train_pred'] = eval_result['train_pred']
    results['test_pred'] = eval_result['test_pred']

    # Step 4: Cross-validation
    results['cv_results'] = cross_validation_metrics(model, X_train, y_train, pt_ids_train, n_splits=n_splits)

    # Step 5: SHAP
    results['shap_values'] = plot_shap(model, X_test, plot_title=model_name, plot=plot_sp, save_path=save_path, max_display=max_display)

    return model_name, results


## 2.2 run model

In [30]:
feature_sets = {
    'avg_and_worst_Halm\'s_features': halms_feature_sub + demo_feature,
    'avg_Halm\'s_features': halms_feature_avg + demo_feature,
    'worst_Halm\'s_features': halms_feature_worst + demo_feature,
    'Halm\'s_and_labs_features': halms_feature_worst + lab_feature + demo_feature,
    'Halm\'s_and_ventilator_features': halms_feature_worst + ventilator_feature + demo_feature,
    'Halm\'s_and_meds_features': halms_feature_worst + med_feature + demo_feature,
    'Halm\'s_and_labs_meds_and_ventilator_features': halms_feature_worst + lab_feature + ventilator_feature + med_feature + demo_feature,
}

In [31]:
# 1. Create a copy
df = cap_1to3_sub.copy()

# 2. train/test split
unique_pt_ids = df['pt_study_id'].unique()
train_ids, test_ids = train_test_split(unique_pt_ids, test_size=0.2, random_state=42)

train_mask = df['pt_study_id'].isin(train_ids)
test_mask = df['pt_study_id'].isin(test_ids)

train =df[train_mask].copy()
test = df[test_mask].copy()

print("Number of episodes:")
print(f"Train: {train['episode_id'].nunique()}, Test: {test['episode_id'].nunique()}")
print("Number of unique patients:")
print(f"Train: {train['pt_study_id'].nunique()}, Test: {test['pt_study_id'].nunique()}")

# 3. Run models
# get X_train, X_test, y_train, y_test
X_train = train.drop(columns=['pt_study_id', 'd7_episode_resolution_binary', 'episode_id'])
X_test = test.drop(columns=['pt_study_id', 'd7_episode_resolution_binary', 'episode_id'])
y_train = train['d7_episode_resolution_binary'].copy()
y_test = test['d7_episode_resolution_binary'].copy()

pt_ids_train = train['pt_study_id'].copy()

for model_name, feature_cols in feature_sets.items():
    available_features = [c for c in feature_cols if c not in ['pt_study_id', 'd7_episode_resolution_binary', 'episode_id']]
    full_model_name = f"xgb_{model_name}"
    X_train_sub = X_train[available_features].copy()
    X_test_sub = X_test[available_features].copy()

    os.makedirs(output_path, exist_ok=True)
    
    print(f"Running {full_model_name} (XGBoost)...")
    name, result = run_xgb_model_pipeline(
        model_name=full_model_name,
        X_train=X_train_sub,
        X_test=X_test_sub,
        y_train=y_train,
        y_test=y_test,
        pt_ids_train=pt_ids_train,
        n_trials=30,
        n_splits=10,
        random_state=42,
        threshold=0.5,
        plot_sp=False,
        max_display=10
    )

    results[full_model_name] = result

[I 2025-07-03 10:17:19,861] A new study created in memory with name: no-name-e18261c1-671c-4227-8d0b-aa30fce9db53
[I 2025-07-03 10:17:19,978] Trial 0 finished with value: 0.3389808515998992 and parameters: {'n_estimators': 144, 'max_depth': 10, 'learning_rate': 0.22227824312530747, 'subsample': 0.7993292420985183, 'colsample_bytree': 0.5780093202212182}. Best is trial 0 with value: 0.3389808515998992.


Number of episodes:
Train: 68, Test: 17
Number of unique patients:
Train: 68, Test: 17
Running xgb_avg_and_worst_Halm's_features (XGBoost)...


[I 2025-07-03 10:17:20,063] Trial 1 finished with value: 0.3008679768203578 and parameters: {'n_estimators': 89, 'max_depth': 2, 'learning_rate': 0.2611910822747312, 'subsample': 0.8005575058716043, 'colsample_bytree': 0.8540362888980227}. Best is trial 0 with value: 0.3389808515998992.
[I 2025-07-03 10:17:20,129] Trial 2 finished with value: 0.3719438145628623 and parameters: {'n_estimators': 55, 'max_depth': 10, 'learning_rate': 0.2514083658321223, 'subsample': 0.6061695553391381, 'colsample_bytree': 0.5909124836035503}. Best is trial 2 with value: 0.3719438145628623.
[I 2025-07-03 10:17:20,222] Trial 3 finished with value: 0.3121201814058957 and parameters: {'n_estimators': 96, 'max_depth': 4, 'learning_rate': 0.16217936517334897, 'subsample': 0.7159725093210578, 'colsample_bytree': 0.645614570099021}. Best is trial 2 with value: 0.3719438145628623.
[I 2025-07-03 10:17:20,374] Trial 4 finished with value: 0.3529793398841018 and parameters: {'n_estimators': 203, 'max_depth': 3, 'lear

Test Set Evaluation
Accuracy: 0.725
Sensitivity: 0.708
Specificity: 0.741
AUROC: 0.812
F1 Score: 0.708
AUPR: 0.851
PPV: 0.708
NPV: 0.741


[I 2025-07-03 10:17:25,108] A new study created in memory with name: no-name-09204874-ed3d-4b4f-b353-a198c7fd3e42
[I 2025-07-03 10:17:25,199] Trial 0 finished with value: 0.4116641471403376 and parameters: {'n_estimators': 144, 'max_depth': 10, 'learning_rate': 0.22227824312530747, 'subsample': 0.7993292420985183, 'colsample_bytree': 0.5780093202212182}. Best is trial 0 with value: 0.4116641471403376.
[I 2025-07-03 10:17:25,261] Trial 1 finished with value: 0.4028697404887883 and parameters: {'n_estimators': 89, 'max_depth': 2, 'learning_rate': 0.2611910822747312, 'subsample': 0.8005575058716043, 'colsample_bytree': 0.8540362888980227}. Best is trial 0 with value: 0.4116641471403376.


Running xgb_avg_Halm's_features (XGBoost)...


[I 2025-07-03 10:17:25,315] Trial 2 finished with value: 0.4530272108843537 and parameters: {'n_estimators': 55, 'max_depth': 10, 'learning_rate': 0.2514083658321223, 'subsample': 0.6061695553391381, 'colsample_bytree': 0.5909124836035503}. Best is trial 2 with value: 0.4530272108843537.
[I 2025-07-03 10:17:25,388] Trial 3 finished with value: 0.41399974804736706 and parameters: {'n_estimators': 96, 'max_depth': 4, 'learning_rate': 0.16217936517334897, 'subsample': 0.7159725093210578, 'colsample_bytree': 0.645614570099021}. Best is trial 2 with value: 0.4530272108843537.
[I 2025-07-03 10:17:25,508] Trial 4 finished with value: 0.41589443184681285 and parameters: {'n_estimators': 203, 'max_depth': 3, 'learning_rate': 0.09472194807521325, 'subsample': 0.6831809216468459, 'colsample_bytree': 0.728034992108518}. Best is trial 2 with value: 0.4530272108843537.
[I 2025-07-03 10:17:25,632] Trial 5 finished with value: 0.388219954648526 and parameters: {'n_estimators': 247, 'max_depth': 3, 'le

Test Set Evaluation
Accuracy: 0.706
Sensitivity: 0.625
Specificity: 0.778
AUROC: 0.765
F1 Score: 0.667
AUPR: 0.792
PPV: 0.714
NPV: 0.700


[I 2025-07-03 10:17:28,635] A new study created in memory with name: no-name-3e92b621-e2c2-4966-bf73-448a8f2fc50a
[I 2025-07-03 10:17:28,723] Trial 0 finished with value: 0.35385361552028216 and parameters: {'n_estimators': 144, 'max_depth': 10, 'learning_rate': 0.22227824312530747, 'subsample': 0.7993292420985183, 'colsample_bytree': 0.5780093202212182}. Best is trial 0 with value: 0.35385361552028216.
[I 2025-07-03 10:17:28,781] Trial 1 finished with value: 0.31631267321743517 and parameters: {'n_estimators': 89, 'max_depth': 2, 'learning_rate': 0.2611910822747312, 'subsample': 0.8005575058716043, 'colsample_bytree': 0.8540362888980227}. Best is trial 0 with value: 0.35385361552028216.


Running xgb_worst_Halm's_features (XGBoost)...


[I 2025-07-03 10:17:28,837] Trial 2 finished with value: 0.35321113630637446 and parameters: {'n_estimators': 55, 'max_depth': 10, 'learning_rate': 0.2514083658321223, 'subsample': 0.6061695553391381, 'colsample_bytree': 0.5909124836035503}. Best is trial 0 with value: 0.35385361552028216.
[I 2025-07-03 10:17:28,908] Trial 3 finished with value: 0.3366439909297053 and parameters: {'n_estimators': 96, 'max_depth': 4, 'learning_rate': 0.16217936517334897, 'subsample': 0.7159725093210578, 'colsample_bytree': 0.645614570099021}. Best is trial 0 with value: 0.35385361552028216.
[I 2025-07-03 10:17:29,017] Trial 4 finished with value: 0.32597883597883615 and parameters: {'n_estimators': 203, 'max_depth': 3, 'learning_rate': 0.09472194807521325, 'subsample': 0.6831809216468459, 'colsample_bytree': 0.728034992108518}. Best is trial 0 with value: 0.35385361552028216.
[I 2025-07-03 10:17:29,137] Trial 5 finished with value: 0.2987528344671201 and parameters: {'n_estimators': 247, 'max_depth': 3,

Test Set Evaluation
Accuracy: 0.706
Sensitivity: 0.750
Specificity: 0.667
AUROC: 0.770
F1 Score: 0.706
AUPR: 0.802
PPV: 0.667
NPV: 0.750


[I 2025-07-03 10:17:31,947] A new study created in memory with name: no-name-99668d68-1c87-464a-bedd-10bbe71bc938
[I 2025-07-03 10:17:32,081] Trial 0 finished with value: 0.26904761904761887 and parameters: {'n_estimators': 144, 'max_depth': 10, 'learning_rate': 0.22227824312530747, 'subsample': 0.7993292420985183, 'colsample_bytree': 0.5780093202212182}. Best is trial 0 with value: 0.26904761904761887.


Running xgb_Halm's_and_labs_features (XGBoost)...


[I 2025-07-03 10:17:32,177] Trial 1 finished with value: 0.25366717057193255 and parameters: {'n_estimators': 89, 'max_depth': 2, 'learning_rate': 0.2611910822747312, 'subsample': 0.8005575058716043, 'colsample_bytree': 0.8540362888980227}. Best is trial 0 with value: 0.26904761904761887.
[I 2025-07-03 10:17:32,257] Trial 2 finished with value: 0.3781368102796675 and parameters: {'n_estimators': 55, 'max_depth': 10, 'learning_rate': 0.2514083658321223, 'subsample': 0.6061695553391381, 'colsample_bytree': 0.5909124836035503}. Best is trial 2 with value: 0.3781368102796675.
[I 2025-07-03 10:17:32,365] Trial 3 finished with value: 0.3301612496850591 and parameters: {'n_estimators': 96, 'max_depth': 4, 'learning_rate': 0.16217936517334897, 'subsample': 0.7159725093210578, 'colsample_bytree': 0.645614570099021}. Best is trial 2 with value: 0.3781368102796675.
[I 2025-07-03 10:17:32,543] Trial 4 finished with value: 0.2644406651549508 and parameters: {'n_estimators': 203, 'max_depth': 3, 'le

Test Set Evaluation
Accuracy: 0.569
Sensitivity: 0.625
Specificity: 0.519
AUROC: 0.724
F1 Score: 0.577
AUPR: 0.778
PPV: 0.536
NPV: 0.609


[I 2025-07-03 10:17:37,361] A new study created in memory with name: no-name-8dabfa98-e2d3-4314-97d0-6c800fe06c67
[I 2025-07-03 10:17:37,454] Trial 0 finished with value: 0.23913580246913557 and parameters: {'n_estimators': 144, 'max_depth': 10, 'learning_rate': 0.22227824312530747, 'subsample': 0.7993292420985183, 'colsample_bytree': 0.5780093202212182}. Best is trial 0 with value: 0.23913580246913557.
[I 2025-07-03 10:17:37,517] Trial 1 finished with value: 0.3669211388259005 and parameters: {'n_estimators': 89, 'max_depth': 2, 'learning_rate': 0.2611910822747312, 'subsample': 0.8005575058716043, 'colsample_bytree': 0.8540362888980227}. Best is trial 1 with value: 0.3669211388259005.


Running xgb_Halm's_and_ventilator_features (XGBoost)...


[I 2025-07-03 10:17:37,574] Trial 2 finished with value: 0.3244721592340639 and parameters: {'n_estimators': 55, 'max_depth': 10, 'learning_rate': 0.2514083658321223, 'subsample': 0.6061695553391381, 'colsample_bytree': 0.5909124836035503}. Best is trial 1 with value: 0.3669211388259005.
[I 2025-07-03 10:17:37,649] Trial 3 finished with value: 0.35664399092970533 and parameters: {'n_estimators': 96, 'max_depth': 4, 'learning_rate': 0.16217936517334897, 'subsample': 0.7159725093210578, 'colsample_bytree': 0.645614570099021}. Best is trial 1 with value: 0.3669211388259005.
[I 2025-07-03 10:17:37,770] Trial 4 finished with value: 0.337867220962459 and parameters: {'n_estimators': 203, 'max_depth': 3, 'learning_rate': 0.09472194807521325, 'subsample': 0.6831809216468459, 'colsample_bytree': 0.728034992108518}. Best is trial 1 with value: 0.3669211388259005.
[I 2025-07-03 10:17:37,898] Trial 5 finished with value: 0.30284706475182677 and parameters: {'n_estimators': 247, 'max_depth': 3, 'le

Test Set Evaluation
Accuracy: 0.745
Sensitivity: 0.708
Specificity: 0.778
AUROC: 0.787
F1 Score: 0.723
AUPR: 0.792
PPV: 0.739
NPV: 0.750
Running xgb_Halm's_and_meds_features (XGBoost)...


[I 2025-07-03 10:17:40,854] Trial 0 finished with value: 0.14927941546989165 and parameters: {'n_estimators': 144, 'max_depth': 10, 'learning_rate': 0.22227824312530747, 'subsample': 0.7993292420985183, 'colsample_bytree': 0.5780093202212182}. Best is trial 0 with value: 0.14927941546989165.
[I 2025-07-03 10:17:40,914] Trial 1 finished with value: 0.2174565381708239 and parameters: {'n_estimators': 89, 'max_depth': 2, 'learning_rate': 0.2611910822747312, 'subsample': 0.8005575058716043, 'colsample_bytree': 0.8540362888980227}. Best is trial 1 with value: 0.2174565381708239.
[I 2025-07-03 10:17:40,966] Trial 2 finished with value: 0.2556273620559334 and parameters: {'n_estimators': 55, 'max_depth': 10, 'learning_rate': 0.2514083658321223, 'subsample': 0.6061695553391381, 'colsample_bytree': 0.5909124836035503}. Best is trial 2 with value: 0.2556273620559334.
[I 2025-07-03 10:17:41,036] Trial 3 finished with value: 0.251410934744268 and parameters: {'n_estimators': 96, 'max_depth': 4, 'l

Test Set Evaluation
Accuracy: 0.686
Sensitivity: 0.750
Specificity: 0.630
AUROC: 0.759
F1 Score: 0.692
AUPR: 0.792
PPV: 0.643
NPV: 0.739


[I 2025-07-03 10:17:44,396] A new study created in memory with name: no-name-390e721f-4535-4f1c-9bdc-c6b389b9d40a
[I 2025-07-03 10:17:44,544] Trial 0 finished with value: 0.23779037540942294 and parameters: {'n_estimators': 144, 'max_depth': 10, 'learning_rate': 0.22227824312530747, 'subsample': 0.7993292420985183, 'colsample_bytree': 0.5780093202212182}. Best is trial 0 with value: 0.23779037540942294.


Running xgb_Halm's_and_labs_meds_and_ventilator_features (XGBoost)...


[I 2025-07-03 10:17:44,656] Trial 1 finished with value: 0.25841773746535635 and parameters: {'n_estimators': 89, 'max_depth': 2, 'learning_rate': 0.2611910822747312, 'subsample': 0.8005575058716043, 'colsample_bytree': 0.8540362888980227}. Best is trial 1 with value: 0.25841773746535635.
[I 2025-07-03 10:17:44,747] Trial 2 finished with value: 0.2769110607205846 and parameters: {'n_estimators': 55, 'max_depth': 10, 'learning_rate': 0.2514083658321223, 'subsample': 0.6061695553391381, 'colsample_bytree': 0.5909124836035503}. Best is trial 2 with value: 0.2769110607205846.
[I 2025-07-03 10:17:44,871] Trial 3 finished with value: 0.2982224741748548 and parameters: {'n_estimators': 96, 'max_depth': 4, 'learning_rate': 0.16217936517334897, 'subsample': 0.7159725093210578, 'colsample_bytree': 0.645614570099021}. Best is trial 3 with value: 0.2982224741748548.
[I 2025-07-03 10:17:45,077] Trial 4 finished with value: 0.2513807004283194 and parameters: {'n_estimators': 203, 'max_depth': 3, 'le

Test Set Evaluation
Accuracy: 0.706
Sensitivity: 0.792
Specificity: 0.630
AUROC: 0.790
F1 Score: 0.717
AUPR: 0.799
PPV: 0.655
NPV: 0.773


# 3. Compare model results

### 3.1 Model Performance (Median)

In [32]:
# rename model names
model_rename = {
    'xgb_Halm\'s_and_labs_meds_and_ventilator_features': 'Halm\'s Features with \nLabs, Ventilator,\nand Medication Features',
    'xgb_Halm\'s_and_meds_features': 'Halm\'s Features with\nMedication Features',
    'xgb_Halm\'s_and_ventilator_features': 'Halm\'s Features with\nVentilator Features',
    'xgb_Halm\'s_and_labs_features': 'Halm\'s Features with\nLabs Features',
    'xgb_avg_and_worst_Halm\'s_features': 'Avg and\nWorst Halm\'s Features',
    'xgb_worst_Halm\'s_features': 'Worst Halm\'s Features',
    'xgb_avg_Halm\'s_features': 'Avg Halm\'s Features'
}

In [33]:
os.makedirs(output_path, exist_ok=True)

model_keys = [k for k in model_rename.keys() if k in results]
scores = [results[k]['cv_results']['aurocs'] for k in model_keys]

# get renamed x-labels
model_names = [model_rename[k] for k in model_keys]

x = np.arange(1, len(model_names) + 1)

#### --- BOX PLOT (median/IQR) + scatter --- ####
boxplot_file = os.path.join(output_path, 'model_performance_boxplot')

plt.figure(figsize=(max(10, len(model_names)*1.5), 6))
box = plt.boxplot(
        scores,
        patch_artist=True,
        boxprops=dict(facecolor="lightblue", color="black"),
        showfliers=False,
        medianprops=dict(color="red"),
        whiskerprops=dict(color="black"),
        capprops=dict(color="black"),
        flierprops=dict(markerfacecolor='gray', marker='o', markersize=5, linestyle='none')
    )
# Add scatter points (jittered)
for j, ss in enumerate(scores):
    x_jitter = np.random.normal(loc=x[j], scale=0.05, size=len(ss))
    plt.scatter(x_jitter, ss, color="black", alpha=0.7, zorder=3)

for i, median_line in enumerate(box['medians']):
    x_median, y_median = median_line.get_xdata()[1], median_line.get_ydata()[1]
    median_val = np.median(scores[i])
    plt.text(
        x_median + 0.03, y_median,
        f"{median_val:.3f}",
        va='center', ha='left',
        fontsize=10, color='red',
        bbox=dict(facecolor='white', edgecolor='none', alpha=0.7, pad=0.1)
    )

plt.title("Model Performance: AUROC (10-Fold CV ± STD)", fontsize=15)
plt.ylabel("AUROC")
plt.ylim(0, 1.0)
plt.grid(axis='y', linestyle='--', alpha=0.5)
plt.xticks(ticks=x, labels=model_names, rotation=45, fontsize=8)
plt.tight_layout()
plt.savefig(f"{boxplot_file}.pdf", bbox_inches='tight', dpi=300)
plt.savefig(f"{boxplot_file}.png", bbox_inches='tight', dpi=300)
# plt.show()
plt.close()

#### --- MEAN ± SD PLOT + scatter --- ####
meanplot_file = os.path.join(output_path, 'model_performance_mean_sd')

means = [np.mean(s) for s in scores]
stds = [np.std(s) for s in scores]

plt.figure(figsize=(8, 10))
plt.errorbar(
    means, x, xerr=stds, fmt='o', color='black',
    ecolor='black', elinewidth=2, capsize=6,
    markerfacecolor='grey', markeredgewidth=1.5, marker='o', markersize=10
)

# Add data points
for j, ss in enumerate(scores):
    y_jitter = np.random.normal(loc=x[j], scale=0.04, size=len(ss))
    plt.scatter(ss, y_jitter, color="darkgray", alpha=0.7, s=18, zorder=2)

for j, (mean, xj) in enumerate(zip(means, x)):
    plt.text(
        mean, xj - 0.15,
        f"{mean:.3f}",
        va='top', ha='center',
        fontsize=16, color='black', fontweight='bold',
        bbox=dict(facecolor='white', edgecolor='none', alpha=0.7, pad=0.1)
    )

plt.title("Model Performance: AUROC (10-Fold Cross-Validation, Mean ± SD)", fontsize=18)
plt.xlabel("AUROC", fontsize=15)
plt.xlim(0.3, 1.0)
plt.grid(axis='x', linestyle='--', alpha=0.5)
plt.xticks(fontsize=12)
plt.yticks(ticks=x, labels=model_names, fontsize=18)
plt.tight_layout()
plt.savefig(f"{meanplot_file}.pdf", bbox_inches='tight', dpi=300)
plt.savefig(f"{meanplot_file}.png", bbox_inches='tight', dpi=300)
plt.close()
# plt.show()

  plt.errorbar(


### 3.2 Model Performance (AUROC)

In [34]:
group1 = ['xgb_worst_Halm\'s_features']
group2 = ['xgb_Halm\'s_and_ventilator_features']

In [35]:
def plot_roc(models_to_plot, results, title, ax=None, ci_alpha=0.2, show_yaxis=True):
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 6))
    for model_name in models_to_plot:
        if model_name not in results:
            continue
        model_results = results[model_name]
        
        if model_name.startswith("logreg"):
            model_type = 'Logistic Regression'
            color_cv = 'blue'
            color_test = 'red'
        elif model_name.startswith("xgb"):
            model_type = 'XGBoost'
            color_cv = 'green'  
            color_test = 'orange' 
        elif 'baseline' in model_name.lower():
            model_type = "Halm's Criteria"
            color_cv = 'gray'
            color_test = 'black'
        else:
            model_type = model_name.replace('_', ' ').title()
            color_cv = 'black'
            color_test = 'black'


        # Plot ML models with CV
        if all(k in model_results for k in ['test_pred', 'y_test', 'cv_results']) and model_results['cv_results'] is not None:
            y_test = model_results['y_test']
            y_score_test = model_results.get('test_pred_prob', model_results['test_pred'])
            mean_fpr = model_results['cv_results']['mean_fpr']
            mean_tpr = model_results['cv_results']['mean_tpr']
            cv_aucs = model_results['cv_results']['aurocs']
            mean_auc = auc(mean_fpr, mean_tpr)
            std_auc = np.std(cv_aucs)

            # --- Confidence Interval ---
            if 'std_tpr' in model_results['cv_results']:
                std_tpr = model_results['cv_results']['std_tpr']
                upper = np.minimum(mean_tpr + std_tpr, 1)
                lower = np.maximum(mean_tpr - std_tpr, 0)
                ax.fill_between(mean_fpr, lower, upper, color=color_cv, alpha=ci_alpha)

            
            # Plot mean CV ROC
            ax.plot(mean_fpr, mean_tpr, lw=2, color=color_cv,
                    label=f"{model_type} CV (AUROC = {mean_auc:.3f} ± {std_auc:.3f})")
            # Plot test ROC
            fpr_test, tpr_test, _ = roc_curve(y_test.astype(int), y_score_test)
            auc_test = roc_auc_score(y_test, y_score_test)
            ax.plot(fpr_test, tpr_test, lw=2, color=color_test,
                    label=f"{model_type} Test (AUROC = {auc_test:.3f})")
            
            
    ax.plot([0, 1], [0, 1], linestyle='--', color='gray', lw=1)
    ax.set_xlabel("False Positive Rate", fontsize=16)
    ax.set_ylabel("True Positive Rate", fontsize=16)
    ax.tick_params(axis='x', labelsize=15)
    ax.tick_params(axis='y', labelsize=15)
    ax.set_title(title, fontsize=20)
    ax.legend(loc="lower right", fontsize=15)
    ax.grid(True)


In [36]:
file_path = os.path.join(output_path, 'roc_comparison')
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=False)

plot_roc(group1, results, "Worst Halm's Features", ax=ax1, show_yaxis=True)
plot_roc(group2, results, "Worst Halm's Features with Worst Ventilator Features", ax=ax2, show_yaxis=True)

plt.tight_layout()
plt.savefig(f'{file_path}.pdf', bbox_inches='tight', dpi=300)
plt.savefig(f'{file_path}.png', bbox_inches='tight', dpi=300)
plt.close()
# plt.show()

### 3.3 SHAP plot 

In [37]:
model1 = results['xgb_worst_Halm\'s_features']['model']
data1 = results['xgb_worst_Halm\'s_features']['X_test']
model2 = results['xgb_Halm\'s_and_ventilator_features']['model']
data2 = results['xgb_Halm\'s_and_ventilator_features']['X_test']

plot_shap_customized_axis(model1, data1, 'Worst Halm\'s Features', plot=True, 
                                  left_label='Cured Likely',right_label='Cured Unlikely', 
                                  save_path=os.path.join(output_path, 'xgb_worst_Halms_features_shap.png'))
plot_shap_customized_axis(model2, data2, 'Halm\'s and Ventilator Features', plot=True, 
                                  left_label='Cured Likely',right_label='Cured Unlikely',
                                  save_path=os.path.join(output_path, 'xgb_Halms_and_vents_features.png'))


  shap.summary_plot(
  shap.summary_plot(


# 4. Table one

In [38]:
# create summary tables for during CAP episode
crrt_flag_positive = cap.groupby(['pt_study_id', 'episode_id'])['crrt_flag'].max().reset_index().rename(columns={"crrt_flag": "crrt_flag_positive"})
ecmo_flag_positive = cap.groupby(['pt_study_id', 'episode_id'])['ecmo_flag'].max().reset_index().rename(columns={"ecmo_flag": "ecmo_flag_positive"})
intub_flag_positive = cap.groupby(['pt_study_id', 'episode_id'])['intub_flag'].max().reset_index().rename(columns={"intub_flag": "intub_flag_positive"})
tracheostomy_flag_positive = cap.groupby(['pt_study_id', 'episode_id'])['tracheostomy_flag'].max().reset_index().rename(columns={"tracheostomy_flag": "tracheostomy_flag_positive"})
hd_flag_positive = cap.groupby(['pt_study_id', 'episode_id'])['hd_flag'].max().reset_index().rename(columns={"hd_flag": "hd_flag_positive"})

summary_tables = [
    tracheostomy_flag_positive,
    intub_flag_positive,
    hd_flag_positive,
    crrt_flag_positive,
    ecmo_flag_positive
]

patient_summary = reduce(lambda left, right: pd.merge(left, right, on=['pt_study_id', 'episode_id'], how='outer'), summary_tables)
cols = patient_summary.columns
cap_only = cap.merge(patient_summary, on=['pt_study_id', 'episode_id'], how='left')

cols_with_nan = cap_only[cols].isna().any()
cols_with_nan = cols_with_nan[cols_with_nan].index.tolist()
print("Number of column with missing values:", len(cols_with_nan))

Number of column with missing values: 0


In [39]:
# Keep episode day 1 for each patient
cap_only = cap_only.groupby(['pt_study_id', 'episode_id']).first().reset_index()

# set immunocompromised flag to integer
cap_only['immunocompromised_flag'] = cap_only['immunocompromised_flag'].astype(int)

In [40]:
group_by = 'd7_episode_resolution_binary'
nonnormal_cols = ['edw_adm_age', 'bmi', 'hospital_los_days', 'total_icu_los_days', 'episode_duration', 'sofa_score', 'cumulative_nat_score', 
                  'cumulative_steroid_dose_until_today']

categorical_cols = ['external_transfer_flag','race', 'gender', 'ethnicity', 
                   'discharge_disposition_name', 'unfavorable_outcome', 'immunocompromised_flag', 'smoking_history',
                   'ecmo_flag_positive', 'intub_flag_positive', 'hd_flag_positive', 'crrt_flag_positive', 'tracheostomy_flag_positive']

all_cols = nonnormal_cols + categorical_cols

In [41]:
# Step 1: Define the desired order
discharge_order = ['Home', 'Rehab', 'LTACH', 'SNF', 'Hospice','Died']

race_order = [
    'White',
    'Black or African American',
    'Asian',
    'Unknown or Not Reported',
]

smoking_order = [
    'Never Smoked',
    'Former Smoker',
    'Current Smoker',
    'Unknown'
]

gender_order = ['Female', 'Male']


# Step 2: Set the column as a categorical with the specified order
cap_only['discharge_disposition_name'] = pd.Categorical(
    cap_only['discharge_disposition_name'],
    categories=discharge_order,
    ordered=True
)

# Convert cured column
cap_only['d7_episode_resolution_binary'] = cap_only['d7_episode_resolution_binary'].replace({
    1: 'Not Cured',
    0: 'Cured'
})

# Rename smoking status values
cap_only['smoking_history'] = cap_only['smoking_history'].replace({
    'No': 'Never Smoked',
    'Yes': 'Current Smoker',
    'Former': 'Former Smoker',
    'Unknown Smoking Status': 'Unknown'
})


cap_only['smoking_history'] = pd.Categorical(
    cap_only['smoking_history'],
    categories=smoking_order,
    ordered=True)


# Convert Race column to a categorical type with the specified order
cap_only['race'] = pd.Categorical(
    cap_only['race'],
    categories=race_order,
    ordered=True
)

# Convert gender column
cap_only['gender'] = pd.Categorical(
    cap_only['gender'],
    categories=gender_order,
    ordered=True
)

# Convert binary column 1/0 -> Yes/No
binary_cols = ['hd_flag_positive', 'immunocompromised_flag', 'external_transfer_flag', 
               'hd_flag_positive', 'crrt_flag_positive', 'ecmo_flag_positive', 'tracheostomy_flag_positive', 
               'intub_flag_positive', 'unfavorable_outcome']

for b in binary_cols:
    cap_only[b] = cap_only[b].replace({
        1: 'Yes',
        0: 'No'
    })

# Rename columns
col_renames = {
    'edw_adm_age': 'Age (years)',
    'ethnicity': 'Ethnicity',
    'gender': 'Gender',
    'race': 'Race',
    'smoking_history': 'Smoking Status',
    'bmi': 'Body Mass Index (kg/m²)',
    
    'immunocompromised_flag': 'Immunocompromised',
    'external_transfer_flag': 'Transferred from External Facility',

    'sofa_score': 'SOFA Score on Episode Day 1',

    'cumulative_nat_score': 'Cumulative NAT Score on Episode Day 1',
    'cumulative_steroid_dose_until_today': 'Cumulative Steroid Dose on Episode Day 1',

    
    'hd_flag_positive': 'Received hemodialysis during Episode',
    'crrt_flag_positive': 'Received CRRT during Episode',
    'ecmo_flag_positive': 'Received ECMO during Episode',
    'tracheostomy_flag_positive': 'Underwent Tracheostomy',
    'intub_flag_positive': 'Intubated During Episode',

    'total_icu_los_days': 'Total ICU Days',
    'hospital_los_days': 'Total Hospitalization Days',
    'episode_duration': 'Episode Duration',

    'discharge_disposition_name': 'Discharge Disposition',
    'unfavorable_outcome': 'Unfavorable Outcome',
}

In [42]:
tb1 = TableOne(cap_only, 
               columns=all_cols, categorical=categorical_cols, nonnormal=nonnormal_cols, 
               groupby=group_by, rename=col_renames, missing=False, pval=True)

tb1_df = tb1.tableone

# only show value=1 rows for binary column
binary_vars = ['Transferred from External Facility', 'Unfavorable Outcome', 'Immunocompromised',
               'Received ECMO during Episode', 'Intubated During Episode',
               'Received hemodialysis during Episode', 'Received CRRT during Episode', 'Underwent Tracheostomy']


rows_to_keep = []
for idx in tb1_df.index:
    # idx[0] is the variable name, idx[1] is the value
    matched = any(var in idx[0] for var in binary_vars)
    if matched:
        if idx[1] == 'Yes': 
            rows_to_keep.append(idx)
    else:
        # For non-binary variables
        rows_to_keep.append(idx)

tb1_filtered = tb1_df.loc[rows_to_keep]
tb1_filtered.to_csv("./output/cap_cohort_tb1.csv")

# 5. Performance table

### Fixed threshold

In [55]:
y_true = results['xgb_Halm\'s_and_ventilator_features']['y_test']
y_pred_prob = results['xgb_Halm\'s_and_ventilator_features']['test_pred']

fpr, tpr, thresholds = roc_curve(y_true, y_pred_prob)
specificity = 1 - fpr
sensitivity = tpr

fixed_threshold = 0.35
file_path = os.path.join(output_path, 'performance_metrics_thresholds_comparison_halsm_only_fixed035')

# Interpolate sensitivity and specificity at threshold=0.35
interp_sens = np.interp(fixed_threshold, thresholds[::-1], sensitivity[::-1])
interp_spec = np.interp(fixed_threshold, thresholds[::-1], specificity[::-1])

plt.figure(figsize=(8, 5))
plt.plot(thresholds, sensitivity, label='Sensitivity (ROC)', lw=2)
plt.plot(thresholds, specificity, label='Specificity (ROC)', lw=2)

plt.scatter(fixed_threshold, interp_sens, color='red', zorder=5, label=f'Sensitivity at {fixed_threshold:.2f}')
plt.scatter(fixed_threshold, interp_spec, color='blue', zorder=5, label=f'Specificity at {fixed_threshold:.2f}')

plt.axvline(x=fixed_threshold, color='green', linestyle='--', label=f'Threshold = {fixed_threshold:.2f}')

plt.annotate(
    f"Threshold = 0.35",
    xy=(fixed_threshold, (interp_sens + interp_spec) / 2),
    xytext=(fixed_threshold - 0.07, (interp_sens + interp_spec) / 2 + 0.05),
    arrowprops=dict(facecolor='black', shrink=0.05),
    fontsize=13,
    fontweight='bold',
    ha='right',
    bbox=dict(boxstyle='round,pad=0.2', fc='yellow', alpha=0.3)
)

plt.xlabel('Threshold', fontsize=15)
plt.ylabel('Value', fontsize=15)
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.title(f'Performance at Threshold = {fixed_threshold:.2f}', size=20)
plt.legend(loc='best', fontsize=11)
plt.grid(True)
plt.tight_layout()
plt.savefig(f"{file_path}.pdf", dpi=300, bbox_inches='tight')
plt.savefig(f"{file_path}.png", dpi=300, bbox_inches='tight')
plt.close()

#### Youden's J 

In [53]:
file_path = os.path.join(output_path, 'performance_metrics_best_threshold_youdenJ')

y_true = results['xgb_Halm\'s_and_ventilator_features']['y_test']
y_pred_prob = results['xgb_Halm\'s_and_ventilator_features']['test_pred']

fpr, tpr, thresholds = roc_curve(y_true, y_pred_prob)
specificity = 1 - fpr
sensitivity = tpr

# Calculate Youden's J statistic for all thresholds
youden_j = sensitivity + specificity - 1

# Find the index of the best Youden's J
best_idx = np.argmax(youden_j)
best_threshold = thresholds[best_idx]
best_sens = sensitivity[best_idx]
best_spec = specificity[best_idx]
best_j = youden_j[best_idx]

# --- Plotting ---
plt.figure(figsize=(8, 5))
plt.plot(thresholds, sensitivity, label='Sensitivity', lw=2)
plt.plot(thresholds, specificity, label='Specificity', lw=2)

# Highlight and annotate the best threshold
plt.scatter(best_threshold, best_sens, color='red', zorder=5, label='Best Sensitivity (Youden\'s J)')
plt.scatter(best_threshold, best_spec, color='blue', zorder=5, label='Best Specificity (Youden\'s J)')
plt.axvline(x=best_threshold, color='green', linewidth=3, linestyle=':', label='Best Threshold')

plt.annotate(
    f"Threshold = {best_threshold:.3f}",
    xy=(best_threshold, (best_sens + best_spec) / 2),
    xytext=(best_threshold - 0.1, best_j + 0.15),
    arrowprops=dict(facecolor='gray', shrink=0.05),
    fontsize=13,
    ha='right',
    fontweight='bold',
    bbox=dict(boxstyle='round,pad=0.2', fc='yellow', alpha=0.3)
)

plt.xlabel('Threshold', fontsize=15)
plt.ylabel('Value', fontsize=15)
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.title(f'Performance at Threshold = {best_threshold:.3f}', size=20)
plt.legend(loc='best', fontsize=11)
plt.grid(True)
plt.tight_layout()
plt.savefig(f"{file_path}.pdf", dpi=300, bbox_inches='tight')
plt.savefig(f"{file_path}.png", dpi=300, bbox_inches='tight')
plt.close()
# plt.show()


#### Best threshold < 0.5

In [None]:
file_path = os.path.join(output_path, 'performance_metrics_best_threshold_below_05')

y_true = results['xgb_Halm\'s_and_ventilator_features']['y_test']
y_pred_prob = results['xgb_Halm\'s_and_ventilator_features']['test_pred']

fpr, tpr, thresholds = roc_curve(y_true, y_pred_prob)
specificity = 1 - fpr
sensitivity = tpr

# Calculate Youden's J statistic and get the best threshold < 0.5
youden_j = sensitivity + specificity - 1
mask = thresholds < (0.5 - 1e-10)
filtered_thresholds = thresholds[mask]
filtered_sens = sensitivity[mask]
filtered_spec = specificity[mask]
filtered_j = youden_j[mask]

if len(filtered_j) > 0:
    best_idx = np.argmax(filtered_j)
    best_threshold = filtered_thresholds[best_idx]
    best_sens = filtered_sens[best_idx]
    best_spec = filtered_spec[best_idx]

    # set the best value to -inf temporarily and find again
    if len(filtered_j) > 1:
        temp_j = filtered_j.copy()
        temp_j[best_idx] = -np.inf
        second_best_idx = np.argmax(temp_j)
        second_best_threshold = filtered_thresholds[second_best_idx]
        # print("Second Best threshold=", second_best_threshold)
    else:
        second_best_threshold = None
else:
    raise ValueError("No threshold found < 0.5")

# --- Plotting ---
plt.figure(figsize=(8, 5))
plt.plot(thresholds, sensitivity, label='Sensitivity', lw=2)
plt.plot(thresholds, specificity, label='Specificity', lw=2)

# Only highlight and annotate the best threshold < 0.5
plt.scatter(best_threshold, best_sens, color='red', zorder=5, label=f'Best Sensitivity < 0.5 (Youden\'s J)')
plt.scatter(best_threshold, best_spec, color='blue', zorder=5, label=f'Best Specificity < 0.5 (Youden\'s J)')
plt.axvline(x=best_threshold, color='green', linewidth=3, linestyle=':', label=f'Best Threshold')

plt.annotate(
    f"Threshold = {best_threshold:.3f} ",
    xy=(best_threshold, (best_sens + best_spec) / 2),
    xytext=(best_threshold - 0.07, (best_sens + best_spec) / 2 + 0.12),
    arrowprops=dict(facecolor='gray', shrink=0.05),
    fontsize=13,
    ha='right',
    fontweight='bold',
    bbox=dict(boxstyle='round,pad=0.2', fc='yellow', alpha=0.3)
)

plt.xlabel('Threshold', fontsize=15)
plt.ylabel('Value', fontsize=15)
plt.xticks(fontsize=13)
plt.yticks(fontsize=13)
plt.title(f'Performance at Threshold = {best_threshold:.3f}', size=20)
plt.legend(loc='best', fontsize=11)
plt.grid(True)
plt.tight_layout()
plt.savefig(f"{file_path}.pdf", dpi=300, bbox_inches='tight')
plt.savefig(f"{file_path}.png", dpi=300, bbox_inches='tight')
plt.close()
# plt.show()

Second Best threshold= 0.42470446


In [46]:
metrics_model_1 = get_metrics_dict(results['xgb_worst_Halm\'s_features'])
metrics_model_2 = get_metrics_dict(results['xgb_Halm\'s_and_ventilator_features'])
metrics_baseline = get_metrics_dict(results['halms_baseline'])

# get scores
best_model = results['xgb_Halm\'s_and_ventilator_features']['model']
X_train_ = results['xgb_Halm\'s_and_ventilator_features']['X_train']
X_test_ = results['xgb_Halm\'s_and_ventilator_features']['X_test']
y_train_ = results['xgb_Halm\'s_and_ventilator_features']['y_train']
y_test_ = results['xgb_Halm\'s_and_ventilator_features']['y_test']

# set the threshold (best)
new_results = evaluate_model(best_model, X_train_, X_test_, y_train_, y_test_, threshold=best_threshold)
metrics_model_best = get_metrics_dict(new_results)

# set the threshold (fixed)
new_results_fixed = evaluate_model(best_model, X_train_, X_test_, y_train_, y_test_, threshold=fixed_threshold)
metrics_model_fixed = get_metrics_dict(new_results_fixed)

# save performance metrics
summary_df = pd.DataFrame([metrics_baseline, metrics_model_1, metrics_model_2, metrics_model_best, metrics_model_fixed],
                          index=['Baseline', 'XGBoost Model Halms Only', 'XGBoost Model with Worst Vents', 
                                 'XGBoost Model Halms with Vents with Best Threshold',
                                 'XGBoost Model Halms with Vents with Fixed Threshold(0.35)'])

summary_df.to_csv("./output/plots/preformance_metrics.csv")

Test Set Evaluation
Accuracy: 0.725
Sensitivity: 0.667
Specificity: 0.778
AUROC: 0.799
F1 Score: 0.696
AUPR: 0.815
PPV: 0.727
NPV: 0.724
Test Set Evaluation
Accuracy: 0.667
Sensitivity: 0.833
Specificity: 0.519
AUROC: 0.799
F1 Score: 0.702
AUPR: 0.815
PPV: 0.606
NPV: 0.778


# 6. Confusion matrix

### Comparison

In [47]:
file_path = os.path.join(output_path, 'cap_pred_confusion_matrices_comparison')

cm_model1 = results["xgb_worst_Halm\'s_features"]["metrics"]["confusion_matrix"]
cm_model2 = results['xgb_Halm\'s_and_ventilator_features']["metrics"]["confusion_matrix"]
cm_model3 = new_results['metrics']['confusion_matrix']
cm_model4 = new_results_fixed['metrics']['confusion_matrix']
labels = ['Cured', 'Not Cured']

fig, axes = plt.subplots(1, 3, figsize=(10, 4))

# Halm's only
disp1 = ConfusionMatrixDisplay(confusion_matrix=cm_model1, display_labels=labels)
disp1.plot(ax=axes[0], colorbar=False, cmap='Blues')
axes[0].set_title("Halm's Features", fontsize=10)

# Halm's only with lab
disp2 = ConfusionMatrixDisplay(confusion_matrix=cm_model2, display_labels=labels)
disp2.plot(ax=axes[1], colorbar=False, cmap='Blues')
axes[1].set_title("Halm's and Ventilator Features", fontsize=10)

# Halm's only with best threshold
disp3 = ConfusionMatrixDisplay(confusion_matrix=cm_model3, display_labels=labels)
disp3.plot(ax=axes[2], colorbar=False, cmap='Blues')
axes[2].set_title(f"Halm's and Ventilator Features\n(Threshold={best_threshold:.3f})", fontsize=10)

# change text size
desired_fontsize = 15

for ax in axes:
    for text in ax.texts:
        text.set_fontsize(desired_fontsize)

plt.tight_layout()
plt.savefig(f"{file_path}.pdf", dpi=300, bbox_inches='tight')
plt.savefig(f"{file_path}.png", dpi=300, bbox_inches='tight')
plt.close()

### Best threshold only (Best threshold < 0.5)

In [48]:
file_path = os.path.join(output_path, 'cap_pred_confusion_matrices_best_threshold')

# Halm's only with best threshold
# Swap labels for x and y axes
cm_model3_swapped = cm_model3.T
disp = ConfusionMatrixDisplay(confusion_matrix=cm_model3_swapped, display_labels=labels)
disp.plot(colorbar=False, cmap='Blues')
plt.title(f"Halm's and Ventilator Features\n(Threshold={best_threshold:.3f})", fontsize=18)
plt.xlabel('True Label', fontsize=15)        
plt.ylabel('Predicted Label', fontsize=15)   
plt.xticks(fontsize=13, fontweight='bold')
plt.yticks(fontsize=13, fontweight='bold')
ax = plt.gca()
for text in ax.texts:
    text.set_fontsize(16)
plt.tight_layout()
plt.savefig(f"{file_path}_swapped.pdf", dpi=300, bbox_inches='tight')
plt.savefig(f"{file_path}_swapped.png", dpi=300, bbox_inches='tight')
plt.close()
# plt.show()

### Fixed threshold only

In [49]:
file_path = os.path.join(output_path, 'cap_pred_confusion_matrices_best_threshold')

# Halm's only with best threshold
# Swap labels for x and y axes
cm_model4_swapped = cm_model4.T
disp = ConfusionMatrixDisplay(confusion_matrix=cm_model4_swapped, display_labels=labels)
disp.plot(colorbar=False, cmap='Blues')
plt.title(f"Halm's and Ventilator Features\n(Threshold={fixed_threshold:.2f})", fontsize=18)
plt.xlabel('True Label', fontsize=15)        
plt.ylabel('Predicted Label', fontsize=15)   
plt.xticks(fontsize=13, fontweight='bold')
plt.yticks(fontsize=13, fontweight='bold')
ax = plt.gca()
for text in ax.texts:
    text.set_fontsize(16)
plt.tight_layout()
plt.savefig(f"{file_path}_swapped_fixed_threshold.pdf", dpi=300, bbox_inches='tight')
plt.savefig(f"{file_path}_swapped_fixed_threshold.png", dpi=300, bbox_inches='tight')
plt.close()
# plt.show()