In [35]:
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, classification_report, roc_curve
from catboost import CatBoostClassifier

Bootstrapping method to find the confidence interval of the AUROC score

In [36]:
def bootstrap_auc_ci(y_true, y_scores, n_bootstraps=2000, ci=0.95):
    rng = np.random.default_rng(42)
    aucs = []

    y_true = np.array(y_true)
    y_scores = np.array(y_scores)

    for _ in range(n_bootstraps):
        idx = rng.integers(0, len(y_true), len(y_true))
        if len(np.unique(y_true[idx])) < 2:
            continue
        aucs.append(roc_auc_score(y_true[idx], y_scores[idx]))

    lower = np.percentile(aucs, (1 - ci) / 2 * 100)
    upper = np.percentile(aucs, (1 + ci) / 2 * 100)
    return np.mean(aucs), lower, upper

Loading the dataset, pre-processing, and analysing the data

In [37]:
cohort_data = pd.read_csv('../cohort_data_new.csv')
cohort_data

Unnamed: 0,icustay_id,anion_gap_mean,anion_gap_sd,anion_gap_min,anion_gap_max,bicarbonate_mean,bicarbonate_sd,bicarbonate_min,bicarbonate_max,calcium_total_mean,calcium_total_sd,calcium_total_min,calcium_total_max,chloride_mean,chloride_sd,chloride_min,chloride_max,creatinine_mean,creatinine_sd,creatinine_min,creatinine_max,glucose_mean,glucose_sd,glucose_min,glucose_max,hematocrit_mean,hematocrit_sd,hematocrit_min,hematocrit_max,hemoglobin_mean,hemoglobin_sd,hemoglobin_min,hemoglobin_max,mchc_mean,mchc_sd,mchc_min,mchc_max,mch_mean,mch_sd,mch_min,...,pt_mean,pt_sd,pt_min,pt_max,phosphate_mean,phosphate_sd,phosphate_min,phosphate_max,platelet_count_mean,platelet_count_sd,platelet_count_min,platelet_count_max,potassium_mean,potassium_sd,potassium_min,potassium_max,rdw_mean,rdw_sd,rdw_min,rdw_max,red_blood_cells_mean,red_blood_cells_sd,red_blood_cells_min,red_blood_cells_max,sodium_mean,sodium_sd,sodium_min,sodium_max,urea_nitrogen_mean,urea_nitrogen_sd,urea_nitrogen_min,urea_nitrogen_max,white_blood_cells_mean,white_blood_cells_sd,white_blood_cells_min,white_blood_cells_max,age,gender,icu_los_hours,target
0,200003,13.375000,3.583195,9.0,21.0,25.250000,3.105295,18.0,28.0,7.771429,0.292770,7.5,8.3,108.125000,2.356602,105.0,111.0,0.757143,0.113389,0.7,1.0,108.250000,26.596187,81.0,159.0,31.077778,1.943436,28.5,35.0,10.283333,0.421505,9.6,10.8,33.483333,0.711102,32.8,34.8,30.233333,0.524087,29.6,...,14.540000,2.440901,12.7,18.8,3.312500,0.820170,2.5,4.7,118.857143,6.568322,109.0,126.0,3.587500,0.356320,3.1,4.2,14.583333,0.278687,14.1,14.9,3.403333,0.141657,3.17,3.57,143.125000,1.246423,141.0,145.0,15.571429,4.577377,10.0,21.0,26.471429,13.176711,13.2,43.9,48,M,141,0
1,200007,15.500000,2.121320,14.0,17.0,23.000000,1.414214,22.0,24.0,8.900000,,8.9,8.9,102.000000,1.414214,101.0,103.0,0.800000,0.000000,0.8,0.8,225.000000,11.313709,217.0,233.0,37.750000,0.494975,37.4,38.1,13.050000,0.353553,12.8,13.3,34.600000,0.424264,34.3,34.9,26.400000,0.424264,26.1,...,13.700000,,13.7,13.7,2.400000,,2.4,2.4,236.000000,15.556349,225.0,247.0,3.850000,0.070711,3.8,3.9,13.200000,0.141421,13.1,13.3,4.945000,0.049497,4.91,4.98,136.500000,2.121320,135.0,138.0,9.000000,1.414214,8.0,10.0,10.300000,1.272792,9.4,11.2,44,M,30,0
2,200009,9.500000,2.121320,8.0,11.0,23.333333,2.081666,21.0,25.0,8.000000,,8.0,8.0,113.333333,1.527525,112.0,115.0,0.500000,0.000000,0.5,0.5,108.500000,24.748737,91.0,126.0,29.366667,1.888121,26.3,31.2,10.057143,0.704408,9.0,10.7,34.371429,0.309377,34.0,34.7,32.257143,0.723089,31.7,...,14.480000,1.269646,12.9,16.2,2.700000,,2.7,2.7,139.428571,59.642985,75.0,221.0,4.200000,0.294392,3.9,4.6,15.214286,0.445079,14.3,15.6,3.117143,0.194398,2.84,3.32,142.000000,1.414214,141.0,143.0,17.333333,3.214550,15.0,21.0,12.471429,1.471637,10.5,14.3,47,F,51,0
3,200012,,,,,,,,,,,,,,,,,,,,,,,,,31.000000,,31.0,31.0,10.400000,,10.4,10.4,33.500000,,33.5,33.5,29.200000,,29.2,...,,,,,,,,,129.000000,,129.0,129.0,,,,,12.700000,,12.7,12.7,3.550000,,3.55,3.55,,,,,,,,,4.900000,,4.9,4.9,33,F,10,0
4,200014,10.000000,1.732051,9.0,12.0,24.000000,1.000000,23.0,25.0,7.733333,0.057735,7.7,7.8,111.333333,3.055050,108.0,114.0,0.633333,0.057735,0.6,0.7,110.000000,7.810250,101.0,115.0,33.050000,2.661453,29.8,36.3,11.033333,0.702377,10.3,11.7,33.433333,1.150362,32.3,34.6,30.033333,0.945163,29.3,...,13.066667,0.115470,13.0,13.2,2.450000,0.070711,2.4,2.5,121.000000,8.544004,113.0,130.0,4.000000,0.200000,3.8,4.2,13.300000,0.100000,13.2,13.4,3.690000,0.347707,3.32,4.01,141.333333,3.055050,138.0,144.0,23.000000,1.732051,21.0,24.0,13.233333,2.203028,10.7,14.7,85,M,41,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
30484,299992,15.375000,2.856153,11.0,25.0,23.125000,2.609556,15.0,26.0,8.307143,0.255597,7.7,8.7,107.730769,3.539013,102.0,115.0,0.565217,0.098205,0.4,0.7,126.864865,26.268023,62.0,178.0,28.628000,2.299551,24.8,32.8,9.930435,0.803082,8.5,11.3,34.686957,0.779455,33.3,36.1,29.660870,0.575041,28.5,...,12.841667,0.553433,12.3,13.9,3.419048,0.910834,0.8,4.9,475.869565,272.759852,102.0,861.0,3.981250,0.314647,3.2,5.0,13.808696,0.432659,13.0,14.9,3.352609,0.290223,2.85,3.84,141.976744,3.180958,137.0,149.0,16.000000,4.662524,8.0,23.0,14.134783,3.781727,8.1,22.1,41,M,499,0
30485,299993,9.400000,1.341641,8.0,11.0,29.600000,2.073644,26.0,31.0,8.000000,0.216025,7.7,8.2,99.800000,3.492850,98.0,106.0,0.520000,0.044721,0.5,0.6,119.400000,23.776038,97.0,158.0,28.680000,1.042593,27.3,30.2,10.175000,0.403113,9.6,10.5,36.025000,0.928709,35.2,37.1,32.875000,0.531507,32.3,...,11.700000,0.565685,11.3,12.1,1.920000,0.486826,1.5,2.5,319.000000,49.598387,268.0,370.0,3.660000,0.409878,3.3,4.2,13.575000,0.150000,13.4,13.7,3.095000,0.084261,2.97,3.15,135.400000,1.516575,133.0,137.0,13.000000,1.224745,12.0,15.0,12.600000,0.605530,12.0,13.3,26,M,67,0
30486,299994,16.157895,2.477973,13.0,24.0,21.631579,3.451417,17.0,31.0,8.100000,0.316228,7.6,8.7,109.315789,4.203869,101.0,117.0,4.336842,0.791069,3.0,5.3,146.526316,80.865024,63.0,430.0,29.848276,2.714988,24.9,38.1,10.923810,1.191178,9.2,14.0,36.114286,1.165026,33.4,37.7,30.647619,0.499619,29.9,...,15.328571,2.356086,13.1,24.1,5.914286,0.868117,4.6,7.9,156.636364,46.494798,109.0,271.0,4.771429,0.882124,3.4,6.7,15.300000,0.626897,13.8,16.4,3.566667,0.405553,2.99,4.68,142.315789,4.397634,130.0,148.0,44.578947,12.102873,28.0,63.0,10.076190,2.642329,5.3,14.5,74,F,152,1
30487,299998,11.500000,1.732051,10.0,14.0,23.500000,1.290994,22.0,25.0,8.800000,0.416333,8.3,9.3,108.500000,1.290994,107.0,110.0,1.050000,0.057735,1.0,1.1,171.000000,32.269697,130.0,206.0,29.480000,2.426314,26.9,32.5,9.680000,0.690652,8.9,10.4,32.820000,0.544977,32.0,33.5,29.520000,0.389872,28.9,...,12.700000,0.282843,12.5,12.9,3.333333,0.513160,2.9,3.9,190.800000,10.848963,173.0,199.0,4.150000,0.208167,3.9,4.4,14.980000,0.164317,14.7,15.1,3.278000,0.268179,2.98,3.60,139.500000,1.732051,137.0,141.0,20.750000,0.957427,20.0,22.0,9.900000,1.210372,7.9,11.0,87,M,46,1


In [38]:
print(f"Dataset shape: {cohort_data.shape}")
print(f"Readmission rate: {cohort_data['target'].mean() * 100:.2f}%")

Dataset shape: (30489, 93)
Readmission rate: 10.74%


In [39]:
lab_cols = [
    'anion_gap_mean', 'anion_gap_min', 'anion_gap_max', 'anion_gap_sd',
    'bicarbonate_mean', 'bicarbonate_min', 'bicarbonate_max', 'bicarbonate_sd',
    'calcium_total_mean', 'calcium_total_min', 'calcium_total_max', 'calcium_total_sd',
    'chloride_mean', 'chloride_min', 'chloride_max', 'chloride_sd',
    'creatinine_mean', 'creatinine_min', 'creatinine_max', 'creatinine_sd',
    'glucose_mean', 'glucose_min', 'glucose_max', 'glucose_sd',
    'hematocrit_mean', 'hematocrit_min', 'hematocrit_max', 'hematocrit_sd',
    'hemoglobin_mean', 'hemoglobin_min', 'hemoglobin_max', 'hemoglobin_sd',
    'mchc_mean', 'mchc_min', 'mchc_max', 'mchc_sd',
    'mcv_mean', 'mcv_min', 'mcv_max', 'mcv_sd',
    'magnesium_mean', 'magnesium_min', 'magnesium_max', 'magnesium_sd',
    'pt_mean', 'pt_min', 'pt_max', 'pt_sd',
    'phosphate_mean', 'phosphate_min', 'phosphate_max', 'phosphate_sd',
    'platelet_count_mean', 'platelet_count_min', 'platelet_count_max', 'platelet_count_sd',
    'potassium_mean', 'potassium_min', 'potassium_max', 'potassium_sd',
    'rdw_mean', 'rdw_min', 'rdw_max', 'rdw_sd',
    'red_blood_cells_mean', 'red_blood_cells_min', 'red_blood_cells_max', 'red_blood_cells_sd',
    'sodium_mean', 'sodium_min', 'sodium_max', 'sodium_sd',
    'urea_nitrogen_mean', 'urea_nitrogen_min', 'urea_nitrogen_max', 'urea_nitrogen_sd',
    'white_blood_cells_mean', 'white_blood_cells_min', 'white_blood_cells_max', 'white_blood_cells_sd',
    'age', 'icu_los_hours'
]

REmove the ICUstay_id and the gender

In [40]:
drop_cols = [c for c in cohort_data.columns if 'icustay_id' in c.lower() or 'gender' in c.lower()]
df = cohort_data.drop(columns=['icustay_id', 'gender'], errors='ignore')

X = df.drop(columns=['target'])
y = df['target']

X

Unnamed: 0,anion_gap_mean,anion_gap_sd,anion_gap_min,anion_gap_max,bicarbonate_mean,bicarbonate_sd,bicarbonate_min,bicarbonate_max,calcium_total_mean,calcium_total_sd,calcium_total_min,calcium_total_max,chloride_mean,chloride_sd,chloride_min,chloride_max,creatinine_mean,creatinine_sd,creatinine_min,creatinine_max,glucose_mean,glucose_sd,glucose_min,glucose_max,hematocrit_mean,hematocrit_sd,hematocrit_min,hematocrit_max,hemoglobin_mean,hemoglobin_sd,hemoglobin_min,hemoglobin_max,mchc_mean,mchc_sd,mchc_min,mchc_max,mch_mean,mch_sd,mch_min,mch_max,...,ptt_min,ptt_max,pt_mean,pt_sd,pt_min,pt_max,phosphate_mean,phosphate_sd,phosphate_min,phosphate_max,platelet_count_mean,platelet_count_sd,platelet_count_min,platelet_count_max,potassium_mean,potassium_sd,potassium_min,potassium_max,rdw_mean,rdw_sd,rdw_min,rdw_max,red_blood_cells_mean,red_blood_cells_sd,red_blood_cells_min,red_blood_cells_max,sodium_mean,sodium_sd,sodium_min,sodium_max,urea_nitrogen_mean,urea_nitrogen_sd,urea_nitrogen_min,urea_nitrogen_max,white_blood_cells_mean,white_blood_cells_sd,white_blood_cells_min,white_blood_cells_max,age,icu_los_hours
0,13.375000,3.583195,9.0,21.0,25.250000,3.105295,18.0,28.0,7.771429,0.292770,7.5,8.3,108.125000,2.356602,105.0,111.0,0.757143,0.113389,0.7,1.0,108.250000,26.596187,81.0,159.0,31.077778,1.943436,28.5,35.0,10.283333,0.421505,9.6,10.8,33.483333,0.711102,32.8,34.8,30.233333,0.524087,29.6,30.8,...,25.9,32.1,14.540000,2.440901,12.7,18.8,3.312500,0.820170,2.5,4.7,118.857143,6.568322,109.0,126.0,3.587500,0.356320,3.1,4.2,14.583333,0.278687,14.1,14.9,3.403333,0.141657,3.17,3.57,143.125000,1.246423,141.0,145.0,15.571429,4.577377,10.0,21.0,26.471429,13.176711,13.2,43.9,48,141
1,15.500000,2.121320,14.0,17.0,23.000000,1.414214,22.0,24.0,8.900000,,8.9,8.9,102.000000,1.414214,101.0,103.0,0.800000,0.000000,0.8,0.8,225.000000,11.313709,217.0,233.0,37.750000,0.494975,37.4,38.1,13.050000,0.353553,12.8,13.3,34.600000,0.424264,34.3,34.9,26.400000,0.424264,26.1,26.7,...,49.6,51.8,13.700000,,13.7,13.7,2.400000,,2.4,2.4,236.000000,15.556349,225.0,247.0,3.850000,0.070711,3.8,3.9,13.200000,0.141421,13.1,13.3,4.945000,0.049497,4.91,4.98,136.500000,2.121320,135.0,138.0,9.000000,1.414214,8.0,10.0,10.300000,1.272792,9.4,11.2,44,30
2,9.500000,2.121320,8.0,11.0,23.333333,2.081666,21.0,25.0,8.000000,,8.0,8.0,113.333333,1.527525,112.0,115.0,0.500000,0.000000,0.5,0.5,108.500000,24.748737,91.0,126.0,29.366667,1.888121,26.3,31.2,10.057143,0.704408,9.0,10.7,34.371429,0.309377,34.0,34.7,32.257143,0.723089,31.7,33.6,...,29.9,39.3,14.480000,1.269646,12.9,16.2,2.700000,,2.7,2.7,139.428571,59.642985,75.0,221.0,4.200000,0.294392,3.9,4.6,15.214286,0.445079,14.3,15.6,3.117143,0.194398,2.84,3.32,142.000000,1.414214,141.0,143.0,17.333333,3.214550,15.0,21.0,12.471429,1.471637,10.5,14.3,47,51
3,,,,,,,,,,,,,,,,,,,,,,,,,31.000000,,31.0,31.0,10.400000,,10.4,10.4,33.500000,,33.5,33.5,29.200000,,29.2,29.2,...,,,,,,,,,,,129.000000,,129.0,129.0,,,,,12.700000,,12.7,12.7,3.550000,,3.55,3.55,,,,,,,,,4.900000,,4.9,4.9,33,10
4,10.000000,1.732051,9.0,12.0,24.000000,1.000000,23.0,25.0,7.733333,0.057735,7.7,7.8,111.333333,3.055050,108.0,114.0,0.633333,0.057735,0.6,0.7,110.000000,7.810250,101.0,115.0,33.050000,2.661453,29.8,36.3,11.033333,0.702377,10.3,11.7,33.433333,1.150362,32.3,34.6,30.033333,0.945163,29.3,31.1,...,28.1,30.1,13.066667,0.115470,13.0,13.2,2.450000,0.070711,2.4,2.5,121.000000,8.544004,113.0,130.0,4.000000,0.200000,3.8,4.2,13.300000,0.100000,13.2,13.4,3.690000,0.347707,3.32,4.01,141.333333,3.055050,138.0,144.0,23.000000,1.732051,21.0,24.0,13.233333,2.203028,10.7,14.7,85,41
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
30484,15.375000,2.856153,11.0,25.0,23.125000,2.609556,15.0,26.0,8.307143,0.255597,7.7,8.7,107.730769,3.539013,102.0,115.0,0.565217,0.098205,0.4,0.7,126.864865,26.268023,62.0,178.0,28.628000,2.299551,24.8,32.8,9.930435,0.803082,8.5,11.3,34.686957,0.779455,33.3,36.1,29.660870,0.575041,28.5,30.7,...,21.4,29.3,12.841667,0.553433,12.3,13.9,3.419048,0.910834,0.8,4.9,475.869565,272.759852,102.0,861.0,3.981250,0.314647,3.2,5.0,13.808696,0.432659,13.0,14.9,3.352609,0.290223,2.85,3.84,141.976744,3.180958,137.0,149.0,16.000000,4.662524,8.0,23.0,14.134783,3.781727,8.1,22.1,41,499
30485,9.400000,1.341641,8.0,11.0,29.600000,2.073644,26.0,31.0,8.000000,0.216025,7.7,8.2,99.800000,3.492850,98.0,106.0,0.520000,0.044721,0.5,0.6,119.400000,23.776038,97.0,158.0,28.680000,1.042593,27.3,30.2,10.175000,0.403113,9.6,10.5,36.025000,0.928709,35.2,37.1,32.875000,0.531507,32.3,33.5,...,26.2,27.8,11.700000,0.565685,11.3,12.1,1.920000,0.486826,1.5,2.5,319.000000,49.598387,268.0,370.0,3.660000,0.409878,3.3,4.2,13.575000,0.150000,13.4,13.7,3.095000,0.084261,2.97,3.15,135.400000,1.516575,133.0,137.0,13.000000,1.224745,12.0,15.0,12.600000,0.605530,12.0,13.3,26,67
30486,16.157895,2.477973,13.0,24.0,21.631579,3.451417,17.0,31.0,8.100000,0.316228,7.6,8.7,109.315789,4.203869,101.0,117.0,4.336842,0.791069,3.0,5.3,146.526316,80.865024,63.0,430.0,29.848276,2.714988,24.9,38.1,10.923810,1.191178,9.2,14.0,36.114286,1.165026,33.4,37.7,30.647619,0.499619,29.9,32.5,...,30.8,150.0,15.328571,2.356086,13.1,24.1,5.914286,0.868117,4.6,7.9,156.636364,46.494798,109.0,271.0,4.771429,0.882124,3.4,6.7,15.300000,0.626897,13.8,16.4,3.566667,0.405553,2.99,4.68,142.315789,4.397634,130.0,148.0,44.578947,12.102873,28.0,63.0,10.076190,2.642329,5.3,14.5,74,152
30487,11.500000,1.732051,10.0,14.0,23.500000,1.290994,22.0,25.0,8.800000,0.416333,8.3,9.3,108.500000,1.290994,107.0,110.0,1.050000,0.057735,1.0,1.1,171.000000,32.269697,130.0,206.0,29.480000,2.426314,26.9,32.5,9.680000,0.690652,8.9,10.4,32.820000,0.544977,32.0,33.5,29.520000,0.389872,28.9,29.9,...,28.9,30.2,12.700000,0.282843,12.5,12.9,3.333333,0.513160,2.9,3.9,190.800000,10.848963,173.0,199.0,4.150000,0.208167,3.9,4.4,14.980000,0.164317,14.7,15.1,3.278000,0.268179,2.98,3.60,139.500000,1.732051,137.0,141.0,20.750000,0.957427,20.0,22.0,9.900000,1.210372,7.9,11.0,87,46


Creating the final datasets

In [41]:
# train-test split
X_train_full, X_test, y_train_full, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=7, shuffle=True)
# train-validation split on the previous train set
X_train, X_eval, y_train, y_eval = train_test_split(X_train_full, y_train_full, test_size=0.3, stratify=y_train_full, random_state=7, shuffle=True)

print(np.shape(X_train))
print(np.shape(X_test))
print(np.shape(X_eval))
print(np.shape(y_train))
print(np.shape(y_test))
print(np.shape(y_eval))

print(f'# Readmissions in Train: {np.sum(y_train)}')
print(f'# Readmissions in Test: {np.sum(y_test) }')
print(f'# Readmissions in eval: {np.sum(y_eval) }')
print(f'% Readmissions in Train: {np.mean(y_train) * 100}')
print(f'% Readmissions in Test: {np.mean(y_test) * 100}')
print(f'% Readmissions in eval: {np.mean(y_eval) * 100}')
print(f'Total: {np.sum(y), np.mean(y)*100}')

(14939, 90)
(9147, 90)
(6403, 90)
(14939,)
(9147,)
(6403,)
# Readmissions in Train: 1605
# Readmissions in Test: 983
# Readmissions in eval: 688
% Readmissions in Train: 10.74369101010777
% Readmissions in Test: 10.746692904777523
% Readmissions in eval: 10.74496329845385
Total: (np.int64(3276), np.float64(10.74485880153498))


Training the model

In [42]:
model = CatBoostClassifier()
model.fit(X_train, y_train, eval_set=(X_eval, y_eval), use_best_model=True)

Learning rate set to 0.061797
0:	learn: 0.6416053	test: 0.6416407	best: 0.6416407 (0)	total: 12.3ms	remaining: 12.3s
1:	learn: 0.5965894	test: 0.5969061	best: 0.5969061 (1)	total: 22.4ms	remaining: 11.2s
2:	learn: 0.5598351	test: 0.5602592	best: 0.5602592 (2)	total: 31.6ms	remaining: 10.5s
3:	learn: 0.5279621	test: 0.5285811	best: 0.5285811 (3)	total: 42ms	remaining: 10.4s
4:	learn: 0.4983333	test: 0.4993170	best: 0.4993170 (4)	total: 51.8ms	remaining: 10.3s
5:	learn: 0.4745638	test: 0.4758232	best: 0.4758232 (5)	total: 61.8ms	remaining: 10.2s
6:	learn: 0.4542014	test: 0.4558303	best: 0.4558303 (6)	total: 72.6ms	remaining: 10.3s
7:	learn: 0.4367481	test: 0.4387206	best: 0.4387206 (7)	total: 82.1ms	remaining: 10.2s
8:	learn: 0.4221267	test: 0.4243150	best: 0.4243150 (8)	total: 93.3ms	remaining: 10.3s
9:	learn: 0.4080044	test: 0.4105979	best: 0.4105979 (9)	total: 105ms	remaining: 10.4s
10:	learn: 0.3967518	test: 0.3996229	best: 0.3996229 (10)	total: 114ms	remaining: 10.2s
11:	learn: 0.38

<catboost.core.CatBoostClassifier at 0x2630f6fb1f0>

In [43]:
# calculate training error
y_train_pred = model.predict(X_train_full)
train_error = np.mean(y_train_pred != y_train_full)
print(f"Training error (Vanilla LR): {train_error:.3f}")

# calculate test error
y_pred = model.predict(X_test)
y_pred_proba = model.predict_proba(X_test)[:, 1]
test_error = np.mean(y_pred != y_test)
print(f"Test error (Vanilla LR): {test_error:.3f}")

Training error (Vanilla LR): 0.097
Test error (Vanilla LR): 0.107


In [44]:
print("Classification Report:\n")
print(classification_report(y_test, y_pred, target_names=["No Readmission (0)", "Readmission (1)"]))

fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba, pos_label=1)
roc_auc = auc(fpr, tpr)
np.savez('../results/catboost_base_fpr_tpr_thresholds.npz', fpr, tpr, thresholds)
print(f"AUROC for class 1 (Readmission): {roc_auc:.3f}")

auc, lower, upper = bootstrap_auc_ci(y_test, y_pred_proba)
print(f"AUC = {auc:.4f}, 95% CI = [{lower:.4f}, {upper:.4f}]")

Classification Report:

                    precision    recall  f1-score   support

No Readmission (0)       0.89      1.00      0.94      8164
   Readmission (1)       0.60      0.02      0.03       983

          accuracy                           0.89      9147
         macro avg       0.75      0.51      0.49      9147
      weighted avg       0.86      0.89      0.85      9147

AUROC for class 1 (Readmission): 0.726
AUC = 0.7257, 95% CI = [0.7093, 0.7418]
