### Training tabular models on MIMIC-IV dataset for Sepsis classification

In [9]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, balanced_accuracy_score
from sklearn.metrics import roc_auc_score

from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier

import xgboost as xgb
from catboost import CatBoostClassifier

import warnings
warnings.filterwarnings('ignore')

pd.set_option("display.max_columns", None)

In [10]:
train_df = pd.read_csv("sepsis_data/train.csv") 
dsel_df  = pd.read_csv("sepsis_data/dsel.csv") 
test_df  = pd.read_csv("sepsis_data/test.csv") 

In [11]:
train_df.head()

Unnamed: 0,hadm_id,HR_mean,HR_max,HR_min,HR_std,HR_var,HR_skewness,HR_kurtosis,HR_q1,HR_q2,HR_q3,RR_mean,RR_max,RR_min,RR_std,RR_var,RR_skewness,RR_kurtosis,RR_q1,RR_q2,RR_q3,SysBP_mean,SysBP_max,SysBP_min,SysBP_std,SysBP_var,SysBP_skewness,SysBP_kurtosis,SysBP_q1,SysBP_q2,SysBP_q3,DiasBP_mean,DiasBP_max,DiasBP_min,DiasBP_std,DiasBP_var,DiasBP_skewness,DiasBP_kurtosis,DiasBP_q1,DiasBP_q2,DiasBP_q3,TempC_mean,TempC_max,TempC_min,TempC_std,TempC_var,TempC_skewness,TempC_kurtosis,TempC_q1,TempC_q2,TempC_q3,SpO2_mean,SpO2_max,SpO2_min,SpO2_std,SpO2_var,SpO2_skewness,SpO2_kurtosis,SpO2_q1,SpO2_q2,SpO2_q3,Glucose_mean,Glucose_max,Glucose_min,Glucose_std,Glucose_var,Glucose_skewness,Glucose_kurtosis,Glucose_q1,Glucose_q2,Glucose_q3,subject_id,admission_type,gender,anchor_age,sepsis,text
0,21209744,82.833333,86.0,80.0,2.562551,6.566667,-0.045561,-2.129661,80.5,83.0,84.75,17.833333,24.0,15.0,3.311596,10.966667,1.655783,2.687059,16.0,16.5,18.5,132.5,152.0,119.0,11.623253,135.1,0.799528,0.897474,125.0,132.0,136.0,68.666667,81.0,58.0,9.709102,94.266667,0.33288,-2.027427,61.0,67.0,76.75,98.2,98.5,97.9,0.424264,0.18,0.0,0.0,98.05,98.2,98.35,97.166667,99.0,96.0,0.983192,0.966667,1.437962,3.602854,97.0,97.0,97.0,111.0,111.0,111.0,0.0,0.0,0.0,0.0,111.0,111.0,111.0,19099492,5,1,53,1.0,INDICATION: ___ with cough // eval infiltrat...
1,23571962,93.8,104.0,89.0,4.299327,18.484211,1.101513,0.802174,90.0,93.0,95.5,17.173913,23.0,9.0,4.14128,17.150198,-0.774524,-0.274389,15.5,18.0,19.5,93.166667,101.0,77.0,6.875708,47.275362,-0.622453,-0.621134,86.0,94.0,99.25,58.791667,76.0,51.0,6.724253,45.21558,1.591712,2.07646,55.0,57.0,60.0,99.65,100.8,98.9,0.720417,0.519,0.565665,-0.038647,99.05,99.7,99.9,93.65,98.0,90.0,2.277464,5.186842,0.867996,0.127984,92.0,93.0,94.25,87.666667,101.0,81.0,11.547005,133.333333,1.732051,0.0,81.0,81.0,91.0,10546701,7,1,62,0.0,INDICATION: ___ year old man s/p tracheoplast...
2,21203213,99.875,105.0,93.0,3.943802,15.553571,-0.584271,-0.165304,98.25,100.0,103.0,18.375,21.0,16.0,1.59799,2.553571,0.301953,-0.164859,17.75,18.0,19.25,102.571429,114.0,95.0,7.39047,54.619048,0.967008,-0.809823,98.5,100.0,106.0,55.142857,63.0,47.0,6.17599,38.142857,-0.470591,-1.210725,50.5,57.0,59.0,98.5,98.8,98.1,0.360555,0.13,-1.15207,0.0,98.35,98.6,98.7,94.125,96.0,92.0,1.246423,1.553571,-0.304319,0.146492,93.75,94.0,95.0,126.0,126.0,126.0,0.0,0.0,0.0,0.0,126.0,126.0,126.0,11040157,5,0,50,0.0,INDICATION: ___ with peds mvc// trauma\n\nCOM...
3,24078707,55.75,67.0,46.0,5.775564,33.357143,0.491623,2.833984,54.0,55.5,56.5,13.5,15.0,12.0,1.581139,2.5,0.0,-2.571429,12.0,13.5,15.0,152.888889,179.0,131.0,16.669166,277.861111,0.06211,-1.261471,137.0,153.0,165.0,71.666667,85.0,59.0,9.836158,96.75,0.201192,-1.406116,63.0,72.0,79.0,98.3,98.3,98.3,0.0,0.0,0.0,0.0,98.3,98.3,98.3,100.0,100.0,100.0,0.0,0.0,0.0,0.0,100.0,100.0,100.0,109.0,109.0,109.0,0.0,0.0,0.0,0.0,109.0,109.0,109.0,11731531,6,1,86,0.0,INDICATION: ___ year old man with hx of parki...
4,22812743,109.166667,115.0,103.0,5.076088,25.766667,-0.475812,-1.739715,104.75,110.5,112.5,27.666667,36.0,20.0,6.218253,38.666667,0.343815,-1.395571,23.75,26.5,32.25,116.666667,122.0,106.0,6.055301,36.666667,-1.238284,1.374446,114.75,118.0,121.25,82.0,105.0,74.0,11.401754,130.0,2.314624,5.54587,78.0,78.0,78.75,97.5,97.5,97.5,0.0,0.0,0.0,0.0,97.5,97.5,97.5,99.5,100.0,98.0,0.83666,0.7,-1.536722,1.428571,99.25,100.0,100.0,206.0,210.0,202.0,5.656854,32.0,0.0,0.0,204.0,206.0,208.0,11048504,5,1,68,0.0,INDICATION: ___ with SOB// PNA\n\nTECHNIQUE: ...


In [12]:
train_df.shape

(5709, 77)

In [13]:
X_train = train_df.drop(["hadm_id", "sepsis", "text"], axis=1)
y_train = train_df["sepsis"] 

X_test = test_df.drop(["hadm_id", "sepsis", "text"], axis=1)
y_test = test_df["sepsis"] 

In [14]:
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

In [15]:
models = {
    "RandomForest": RandomForestClassifier(n_estimators=200, random_state=42),
    "CatBoost": CatBoostClassifier(iterations=200, verbose=0, random_state=42),
    "MLP": MLPClassifier(hidden_layer_sizes=(128, 64), max_iter=500, random_state=42),
    "SVC": SVC(probability=True, random_state=42), 
    "XGBoost": xgb.XGBClassifier(n_estimators=200, use_label_encoder=False, eval_metric="mlogloss", random_state=42)
}

In [16]:
results = {}

for name, model in models.items():
    print(f"\nTraining {name}...")
    
    # Use scaled data for MLP and SVC
    model.fit(X_train_scaled, y_train)
    y_pred = model.predict(X_test_scaled)
    y_proba = model.predict_proba(X_test_scaled) if hasattr(model, "predict_proba") else None

    
    acc = accuracy_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred, average="macro")
    precision = precision_score(y_test, y_pred, average="macro")
    recall = recall_score(y_test, y_pred, average="macro")
    balanced_acc = balanced_accuracy_score(y_test, y_pred)
    if y_proba is not None:
        auroc = roc_auc_score(y_test, y_proba[:, 1])
    else:
        auroc = None 
    
    print(f"{name} | Accuracy: {acc:.4f} | F1: {f1:.4f} | Precision: {precision:.4f} | Recall: {recall:.4f} | Balanced Acc: {balanced_acc:.4f} | AUROC: {auroc:.4f}")
    
    results[name] = {
        "model": model,
        "accuracy": acc,
        "f1": f1,
        "precision": precision,
        "recall": recall,
        "balanced_accuracy": balanced_acc,
        "auroc": auroc,
        "y_pred": y_pred,
        "y_proba": y_proba
    }



Training RandomForest...
RandomForest | Accuracy: 0.6904 | F1: 0.6893 | Precision: 0.6931 | Recall: 0.6904 | Balanced Acc: 0.6904 | AUROC: 0.7587

Training CatBoost...
CatBoost | Accuracy: 0.7165 | F1: 0.7156 | Precision: 0.7193 | Recall: 0.7165 | Balanced Acc: 0.7165 | AUROC: 0.7772

Training MLP...
MLP | Accuracy: 0.6201 | F1: 0.6201 | Precision: 0.6201 | Recall: 0.6201 | Balanced Acc: 0.6201 | AUROC: 0.6635

Training SVC...
SVC | Accuracy: 0.6879 | F1: 0.6866 | Precision: 0.6912 | Recall: 0.6879 | Balanced Acc: 0.6879 | AUROC: 0.7436

Training XGBoost...
XGBoost | Accuracy: 0.7141 | F1: 0.7137 | Precision: 0.7150 | Recall: 0.7141 | Balanced Acc: 0.7141 | AUROC: 0.7683
