# <u>Setup</u>

In [1]:
import pandas as pd
import numpy as np
import os

from sklearn.model_selection import train_test_split
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score, average_precision_score
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold
from scipy.stats import uniform, randint

from transformers import AutoTokenizer, AutoModel
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
mimic_iii_path = '/mnt/2TB-HDD-Ubuntu/GitHub-Repositories/UT Austin/2025 Spring/AI-Healthcare/MIMIC-III-Dataset/mimic-iii-clinical-database-1.4'

# <u>Import data</u>

Tables of interest:
- admissions
- caregivers (could include caregiver <u>label</u>)
- chartevents (huge and not much of interest actually)
- cptevents (not much here)
- dictionaries (not much additional info needed actually)
- datetimeevents (filter 24hrs after admission)
- diagnoses_icd (could add <u>primary diagnosis</u>, assuming it was given in first 24 hrs)
- icustays (specific admission location, # transfers, etc) -- taking info from transfers instead
- noteevents (NLP) -- notes only generated after discharge -- use discharge notes from <u>previous admissions</u>
- patients (gender and age)
- prescriptions (Rxs started within 24 hrs)
- procedures_icd (no timestamps)
- services (nothing helpful here)
- transfers (wards visited and # of transfers in 24 hrs)

In [6]:
def filter_df_24h(admissions_df, data_df, log_dt_col:str) -> pd.DataFrame:
    # Only keeping ADMITTIME from admissions and all columns from datetimeevents
    merged_df = pd.merge(
        data_df,
        admissions_df[['HADM_ID', 'ADMITTIME', 'DEATH']],
        on='HADM_ID',
        how='inner'
    )

    # Calculate time difference between CHARTTIME and ADMITTIME
    merged_df['time_diff_hours'] = (merged_df[log_dt_col] - merged_df['ADMITTIME']).dt.total_seconds() / 3600

    # Filter rows where CHARTTIME is within 24 hours after ADMITTIME
    merged_24h_df = merged_df[merged_df['time_diff_hours'] <= 24].copy()

    # Drop the temporary columns we created
    merged_24h_df = merged_24h_df.drop(['time_diff_hours'], axis=1)

    # Display the first few rows and shape of the filtered dataframe
    return merged_24h_df

## Admissions

In [7]:
admissions_df = pd.read_csv(os.path.join(mimic_iii_path, 'ADMISSIONS.csv'))
admissions_df

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,ADMITTIME,DISCHTIME,DEATHTIME,ADMISSION_TYPE,ADMISSION_LOCATION,DISCHARGE_LOCATION,INSURANCE,LANGUAGE,RELIGION,MARITAL_STATUS,ETHNICITY,EDREGTIME,EDOUTTIME,DIAGNOSIS,HOSPITAL_EXPIRE_FLAG,HAS_CHARTEVENTS_DATA
0,21,22,165315,2196-04-09 12:26:00,2196-04-10 15:54:00,,EMERGENCY,EMERGENCY ROOM ADMIT,DISC-TRAN CANCER/CHLDRN H,Private,,UNOBTAINABLE,MARRIED,WHITE,2196-04-09 10:06:00,2196-04-09 13:24:00,BENZODIAZEPINE OVERDOSE,0,1
1,22,23,152223,2153-09-03 07:15:00,2153-09-08 19:10:00,,ELECTIVE,PHYS REFERRAL/NORMAL DELI,HOME HEALTH CARE,Medicare,,CATHOLIC,MARRIED,WHITE,,,CORONARY ARTERY DISEASE\CORONARY ARTERY BYPASS...,0,1
2,23,23,124321,2157-10-18 19:34:00,2157-10-25 14:00:00,,EMERGENCY,TRANSFER FROM HOSP/EXTRAM,HOME HEALTH CARE,Medicare,ENGL,CATHOLIC,MARRIED,WHITE,,,BRAIN MASS,0,1
3,24,24,161859,2139-06-06 16:14:00,2139-06-09 12:48:00,,EMERGENCY,TRANSFER FROM HOSP/EXTRAM,HOME,Private,,PROTESTANT QUAKER,SINGLE,WHITE,,,INTERIOR MYOCARDIAL INFARCTION,0,1
4,25,25,129635,2160-11-02 02:06:00,2160-11-05 14:55:00,,EMERGENCY,EMERGENCY ROOM ADMIT,HOME,Private,,UNOBTAINABLE,MARRIED,WHITE,2160-11-02 01:01:00,2160-11-02 04:27:00,ACUTE CORONARY SYNDROME,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
58971,58594,98800,191113,2131-03-30 21:13:00,2131-04-02 15:02:00,,EMERGENCY,CLINIC REFERRAL/PREMATURE,HOME,Private,ENGL,NOT SPECIFIED,SINGLE,WHITE,2131-03-30 19:44:00,2131-03-30 22:41:00,TRAUMA,0,1
58972,58595,98802,101071,2151-03-05 20:00:00,2151-03-06 09:10:00,2151-03-06 09:10:00,EMERGENCY,CLINIC REFERRAL/PREMATURE,DEAD/EXPIRED,Medicare,ENGL,CATHOLIC,WIDOWED,WHITE,2151-03-05 17:23:00,2151-03-05 21:06:00,SAH,1,1
58973,58596,98805,122631,2200-09-12 07:15:00,2200-09-20 12:08:00,,ELECTIVE,PHYS REFERRAL/NORMAL DELI,HOME HEALTH CARE,Private,ENGL,NOT SPECIFIED,MARRIED,WHITE,,,RENAL CANCER/SDA,0,1
58974,58597,98813,170407,2128-11-11 02:29:00,2128-12-22 13:11:00,,EMERGENCY,EMERGENCY ROOM ADMIT,SNF,Private,ENGL,CATHOLIC,MARRIED,WHITE,2128-11-10 23:48:00,2128-11-11 03:16:00,S/P FALL,0,0


In [8]:
assert len(admissions_df[~admissions_df['DEATHTIME'].isna()]) == len(admissions_df[admissions_df['DISCHARGE_LOCATION'] == 'DEAD/EXPIRED'])

In [9]:
admissions_df['ADMISSION_TYPE'].value_counts()

ADMISSION_TYPE
EMERGENCY    42071
NEWBORN       7863
ELECTIVE      7706
URGENT        1336
Name: count, dtype: int64

In [10]:
admissions_df['ADMISSION_LOCATION'].value_counts()

ADMISSION_LOCATION
EMERGENCY ROOM ADMIT         22754
PHYS REFERRAL/NORMAL DELI    15079
CLINIC REFERRAL/PREMATURE    12032
TRANSFER FROM HOSP/EXTRAM     8456
TRANSFER FROM SKILLED NUR      273
** INFO NOT AVAILABLE **       204
HMO REFERRAL/SICK              102
TRANSFER FROM OTHER HEALT       71
TRSF WITHIN THIS FACILITY        5
Name: count, dtype: int64

In [11]:
admissions_df['INSURANCE'].value_counts()

INSURANCE
Medicare      28215
Private       22582
Medicaid       5785
Government     1783
Self Pay        611
Name: count, dtype: int64

In [8]:
admissions_cols = [
    'SUBJECT_ID',
    'HADM_ID',
    'ADMITTIME',
    'DEATHTIME',
    # 'ADMISSION_TYPE',
    'ADMISSION_LOCATION',
    'INSURANCE',
    'ETHNICITY',
    'DIAGNOSIS'
]

In [9]:
emergency_urgent = [
    'EMERGENCY',
    'URGENT'
]

admissions_df_filtered = admissions_df[admissions_df['ADMISSION_TYPE'].isin(emergency_urgent)][admissions_cols].reset_index(drop=True)
admissions_df_filtered['DEATH'] = ~admissions_df_filtered['DEATHTIME'].isna()

admissions_df_filtered['ADMITTIME'] = pd.to_datetime(admissions_df_filtered['ADMITTIME'])
admissions_df_filtered['DEATHTIME'] = pd.to_datetime(admissions_df_filtered['DEATHTIME'])

admissions_df_filtered

Unnamed: 0,SUBJECT_ID,HADM_ID,ADMITTIME,DEATHTIME,ADMISSION_LOCATION,INSURANCE,ETHNICITY,DIAGNOSIS,DEATH
0,22,165315,2196-04-09 12:26:00,NaT,EMERGENCY ROOM ADMIT,Private,WHITE,BENZODIAZEPINE OVERDOSE,False
1,23,124321,2157-10-18 19:34:00,NaT,TRANSFER FROM HOSP/EXTRAM,Medicare,WHITE,BRAIN MASS,False
2,24,161859,2139-06-06 16:14:00,NaT,TRANSFER FROM HOSP/EXTRAM,Private,WHITE,INTERIOR MYOCARDIAL INFARCTION,False
3,25,129635,2160-11-02 02:06:00,NaT,EMERGENCY ROOM ADMIT,Private,WHITE,ACUTE CORONARY SYNDROME,False
4,26,197661,2126-05-06 15:16:00,NaT,TRANSFER FROM HOSP/EXTRAM,Medicare,UNKNOWN/NOT SPECIFIED,V-TACH,False
...,...,...,...,...,...,...,...,...,...
43402,98797,105447,2132-12-24 20:06:00,2132-12-25 12:00:00,EMERGENCY ROOM ADMIT,Medicare,WHITE,ALTERED MENTAL STATUS,True
43403,98800,191113,2131-03-30 21:13:00,NaT,CLINIC REFERRAL/PREMATURE,Private,WHITE,TRAUMA,False
43404,98802,101071,2151-03-05 20:00:00,2151-03-06 09:10:00,CLINIC REFERRAL/PREMATURE,Medicare,WHITE,SAH,True
43405,98813,170407,2128-11-11 02:29:00,NaT,EMERGENCY ROOM ADMIT,Private,WHITE,S/P FALL,False


In [10]:
admissions_df_filtered['ADMISSION_LOCATION'].value_counts()

ADMISSION_LOCATION
EMERGENCY ROOM ADMIT         22754
CLINIC REFERRAL/PREMATURE    10020
TRANSFER FROM HOSP/EXTRAM     8414
PHYS REFERRAL/NORMAL DELI     1880
TRANSFER FROM SKILLED NUR      260
TRANSFER FROM OTHER HEALT       68
** INFO NOT AVAILABLE **         5
TRSF WITHIN THIS FACILITY        5
HMO REFERRAL/SICK                1
Name: count, dtype: int64

In [11]:
admissions_deathrate = admissions_df_filtered['DEATH'].mean()
admissions_deathrate

0.12889626097173268

## Caregivers 
Get caregiver label

In [12]:
caregivers_df = pd.read_csv(os.path.join(mimic_iii_path, 'CAREGIVERS.csv'))
caregivers_df

Unnamed: 0,ROW_ID,CGID,LABEL,DESCRIPTION
0,2228,16174,RO,Read Only
1,2229,16175,RO,Read Only
2,2230,16176,Res,Resident/Fellow/PA/NP
3,2231,16177,RO,Read Only
4,2232,16178,RT,Respiratory
...,...,...,...,...
7562,6300,20303,MD,
7563,6301,20304,RN,RN
7564,6302,20305,MDs,
7565,6303,20306,RPH,Pharmacist


## Chart Events

In [13]:
chartevents_df = pd.read_csv(os.path.join(mimic_iii_path, 'CHARTEVENTS.csv'), nrows=10000)
chartevents_df

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,ICUSTAY_ID,ITEMID,CHARTTIME,STORETIME,CGID,VALUE,VALUENUM,VALUEUOM,WARNING,ERROR,RESULTSTATUS,STOPPED
0,788,36,165660,241249,223834,2134-05-12 12:00:00,2134-05-12 13:56:00,17525,15.00,15.00,L/min,0,0,,
1,789,36,165660,241249,223835,2134-05-12 12:00:00,2134-05-12 13:56:00,17525,100.00,100.00,,0,0,,
2,790,36,165660,241249,224328,2134-05-12 12:00:00,2134-05-12 12:18:00,20823,0.37,0.37,,0,0,,
3,791,36,165660,241249,224329,2134-05-12 12:00:00,2134-05-12 12:19:00,20823,6.00,6.00,min,0,0,,
4,792,36,165660,241249,224330,2134-05-12 12:00:00,2134-05-12 12:19:00,20823,2.50,2.50,,0,0,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,10986,109,176760,256504,220277,2142-01-21 13:00:00,2142-01-21 13:03:00,16643,100.00,100.00,%,0,0,,
9996,10987,109,176760,256504,226531,2142-01-21 13:00:00,2142-01-21 13:02:00,16643,113.30,113.30,,0,0,,
9997,10988,109,176760,256504,220045,2142-01-21 14:00:00,2142-01-21 14:11:00,16643,92.00,92.00,bpm,0,0,,
9998,10989,109,176760,256504,220179,2142-01-21 14:00:00,2142-01-21 14:11:00,16643,173.00,173.00,mmHg,0,0,,


## CPT Events

In [14]:
cptevents_df = pd.read_csv(os.path.join(mimic_iii_path, 'CPTEVENTS.csv'))
cptevents_df

  cptevents_df = pd.read_csv(os.path.join(mimic_iii_path, 'CPTEVENTS.csv'))


Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,COSTCENTER,CHARTDATE,CPT_CD,CPT_NUMBER,CPT_SUFFIX,TICKET_ID_SEQ,SECTIONHEADER,SUBSECTIONHEADER,DESCRIPTION
0,317,11743,129545,ICU,,99232,99232.0,,6.0,Evaluation and management,Hospital inpatient services,
1,318,11743,129545,ICU,,99232,99232.0,,7.0,Evaluation and management,Hospital inpatient services,
2,319,11743,129545,ICU,,99232,99232.0,,8.0,Evaluation and management,Hospital inpatient services,
3,320,11743,129545,ICU,,99232,99232.0,,9.0,Evaluation and management,Hospital inpatient services,
4,321,6185,183725,ICU,,99223,99223.0,,1.0,Evaluation and management,Hospital inpatient services,
...,...,...,...,...,...,...,...,...,...,...,...,...
573141,573142,78876,163404,Resp,2105-09-01 00:00:00,94003,94003.0,,,Medicine,Pulmonary,VENT MGMT;SUBSQ DAYS(INVASIVE)
573142,573143,78879,136071,Resp,2150-08-29 00:00:00,94003,94003.0,,,Medicine,Pulmonary,VENT MGMT;SUBSQ DAYS(INVASIVE)
573143,573144,78879,136071,Resp,2150-08-28 00:00:00,94002,94002.0,,,Medicine,Pulmonary,"VENT MGMT, 1ST DAY (INVASIVE)"
573144,573145,78892,175171,Resp,2125-06-11 00:00:00,94003,94003.0,,,Medicine,Pulmonary,VENT MGMT;SUBSQ DAYS(INVASIVE)


## Dictionaries

In [15]:
d_cpt_df = pd.read_csv(os.path.join(mimic_iii_path, 'D_CPT.csv'))
d_cpt_df

Unnamed: 0,ROW_ID,CATEGORY,SECTIONRANGE,SECTIONHEADER,SUBSECTIONRANGE,SUBSECTIONHEADER,CODESUFFIX,MINCODEINSUBSECTION,MAXCODEINSUBSECTION
0,1,1,99201-99499,Evaluation and management,99201-99216,Office/other outpatient services,,99201,99216
1,2,1,99201-99499,Evaluation and management,99217-99220,Hospital observation services,,99217,99220
2,3,1,99201-99499,Evaluation and management,99221-99239,Hospital inpatient services,,99221,99239
3,4,1,99201-99499,Evaluation and management,99241-99255,Consultations,,99241,99255
4,5,1,99201-99499,Evaluation and management,99261-99263,Follow-up inpatient consultations (deleted codes),,99261,99263
...,...,...,...,...,...,...,...,...,...
129,130,2,0001F-7025F,Performance measurement,5005F-5100F,Follow-up or other outcomes,F,5005,5100
130,131,2,0001F-7025F,Performance measurement,6005F-6045F,Patient safety,F,6005,6045
131,132,2,0001F-7025F,Performance measurement,7010F-7025F,Structural Measures,F,7010,7025
132,133,3,0016T-0207T,Emerging technology,0016T-0207T,Temporary codes,T,16,207


In [16]:
d_items_df = pd.read_csv(os.path.join(mimic_iii_path, 'D_ITEMS.csv'))
d_items_df

Unnamed: 0,ROW_ID,ITEMID,LABEL,ABBREVIATION,DBSOURCE,LINKSTO,CATEGORY,UNITNAME,PARAM_TYPE,CONCEPTID
0,457,497,Patient controlled analgesia (PCA) [Inject],,carevue,chartevents,,,,
1,458,498,PCA Lockout (Min),,carevue,chartevents,,,,
2,459,499,PCA Medication,,carevue,chartevents,,,,
3,460,500,PCA Total Dose,,carevue,chartevents,,,,
4,461,501,PCV Exh Vt (Obser),,carevue,chartevents,,,,
...,...,...,...,...,...,...,...,...,...,...
12482,14518,226757,GCSMotorApacheIIValue,GCSMotorApacheIIValue,metavision,chartevents,Scores - APACHE II,,Text,
12483,14519,226758,GCSVerbalApacheIIValue,GCSVerbalApacheIIValue,metavision,chartevents,Scores - APACHE II,,Text,
12484,14520,226759,HCO3ApacheIIValue,HCO3ApacheIIValue,metavision,chartevents,Scores - APACHE II,,Numeric,
12485,14521,226760,HCO3Score,HCO3Score,metavision,chartevents,Scores - APACHE II,,Numeric,


In [17]:
d_icd_procedures_df = pd.read_csv(os.path.join(mimic_iii_path, 'D_ICD_PROCEDURES.csv'))
d_icd_procedures_df

Unnamed: 0,ROW_ID,ICD9_CODE,SHORT_TITLE,LONG_TITLE
0,264,851,Canthotomy,Canthotomy
1,265,852,Blepharorrhaphy,Blepharorrhaphy
2,266,859,Adjust lid position NEC,Other adjustment of lid position
3,267,861,Lid reconst w skin graft,Reconstruction of eyelid with skin flap or graft
4,268,862,Lid reconst w muc graft,Reconstruction of eyelid with mucous membrane ...
...,...,...,...,...
3877,3344,9959,Vaccination/innocula NEC,Other vaccination and inoculation
3878,3345,9960,Cardiopulm resuscita NOS,"Cardiopulmonary resuscitation, not otherwise s..."
3879,3346,9961,Atrial cardioversion,Atrial cardioversion
3880,3347,9962,Heart countershock NEC,Other electric countershock of heart


## Datetime Events

In [18]:
datetimeevents_df = pd.read_csv(os.path.join(mimic_iii_path, 'DATETIMEEVENTS.csv'))
datetimeevents_df

  datetimeevents_df = pd.read_csv(os.path.join(mimic_iii_path, 'DATETIMEEVENTS.csv'))


Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,ICUSTAY_ID,ITEMID,CHARTTIME,STORETIME,CGID,VALUE,VALUEUOM,WARNING,ERROR,RESULTSTATUS,STOPPED
0,711,7657,121183.0,297945.0,3411,2172-03-14 11:00:00,2172-03-14 11:52:00,16446,,Date,,,,NotStopd
1,712,7657,121183.0,297945.0,3411,2172-03-14 13:00:00,2172-03-14 12:36:00,16446,,Date,,,,NotStopd
2,713,7657,121183.0,297945.0,3411,2172-03-14 15:00:00,2172-03-14 15:10:00,14957,,Date,,,,NotStopd
3,714,7657,121183.0,297945.0,3411,2172-03-14 17:00:00,2172-03-14 17:01:00,16446,,Date,,,,NotStopd
4,715,7657,121183.0,297945.0,3411,2172-03-14 19:00:00,2172-03-14 19:29:00,14815,,Date,,,,NotStopd
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4485932,4456093,99366,136021.0,218447.0,224279,2197-03-24 16:32:00,2197-03-24 16:32:00,18234,2197-03-24 13:03:00,Date and Time,0.0,0.0,,
4485933,4456094,99366,136021.0,218447.0,224280,2197-03-24 16:32:00,2197-03-24 16:32:00,18234,2197-03-24 00:00:00,Date,0.0,0.0,,
4485934,4456095,99366,136021.0,218447.0,224282,2197-03-24 16:32:00,2197-03-24 16:32:00,18234,2197-03-24 00:00:00,Date,0.0,0.0,,
4485935,4456096,99366,136021.0,218447.0,224284,2197-03-24 16:32:00,2197-03-24 16:32:00,18234,2197-03-24 00:00:00,Date,0.0,0.0,,


In [19]:
datetimeevents_df['CHARTTIME'] = pd.to_datetime(datetimeevents_df['CHARTTIME'])
datetimeevents_df['STORETIME'] = pd.to_datetime(datetimeevents_df['STORETIME'])

In [20]:
datetimeevents_df['ERROR'].value_counts()

ERROR
0.0    2686325
1.0        595
Name: count, dtype: int64

In [21]:
datetimeevents_df['VALUEUOM'].value_counts()

VALUEUOM
Date             3766608
Date and Time     719329
Name: count, dtype: int64

In [22]:
datetimeevents_cols = [
    'HADM_ID',
    'ITEMID',
    'CHARTTIME',
    'CGID'
]

datetimeevents_df_filtered = datetimeevents_df[datetimeevents_df['ERROR'] != 1][datetimeevents_cols].copy()
datetimeevents_df_filtered

Unnamed: 0,HADM_ID,ITEMID,CHARTTIME,CGID
0,121183.0,3411,2172-03-14 11:00:00,16446
1,121183.0,3411,2172-03-14 13:00:00,16446
2,121183.0,3411,2172-03-14 15:00:00,14957
3,121183.0,3411,2172-03-14 17:00:00,16446
4,121183.0,3411,2172-03-14 19:00:00,14815
...,...,...,...,...
4485932,136021.0,224279,2197-03-24 16:32:00,18234
4485933,136021.0,224280,2197-03-24 16:32:00,18234
4485934,136021.0,224282,2197-03-24 16:32:00,18234
4485935,136021.0,224284,2197-03-24 16:32:00,18234


In [23]:
datetimeevents_df_filtered_24h = filter_df_24h(admissions_df_filtered, datetimeevents_df_filtered, 'CHARTTIME').reset_index(drop=True)
datetimeevents_df_filtered_24h

Unnamed: 0,HADM_ID,ITEMID,CHARTTIME,CGID,ADMITTIME,DEATH
0,112755.0,3411,2132-08-07 17:30:00,15484,2132-08-07 17:24:00,False
1,112755.0,3411,2132-08-07 18:00:00,15484,2132-08-07 17:24:00,False
2,112755.0,3411,2132-08-07 19:00:00,15713,2132-08-07 17:24:00,False
3,112755.0,3411,2132-08-07 20:00:00,17604,2132-08-07 17:24:00,False
4,112755.0,3411,2132-08-07 21:00:00,14437,2132-08-07 17:24:00,False
...,...,...,...,...,...,...
398318,193894.0,224288,2157-12-31 04:54:00,20622,2157-12-30 20:00:00,False
398319,193894.0,224290,2157-12-31 04:54:00,20622,2157-12-30 20:00:00,False
398320,193894.0,224284,2157-12-31 08:01:00,18576,2157-12-30 20:00:00,False
398321,193894.0,224287,2157-12-31 08:01:00,18576,2157-12-30 20:00:00,False


In [24]:
datetimeevents_deathrate = datetimeevents_df_filtered_24h['DEATH'].mean()
datetimeevents_deathrate

0.1642912912385175

## Diagnoses ICD
Get primary diagnosis

In [25]:
diagnoses_icd_df = pd.read_csv(os.path.join(mimic_iii_path, 'DIAGNOSES_ICD.csv'))
diagnoses_icd_df

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,SEQ_NUM,ICD9_CODE
0,1297,109,172335,1.0,40301
1,1298,109,172335,2.0,486
2,1299,109,172335,3.0,58281
3,1300,109,172335,4.0,5855
4,1301,109,172335,5.0,4254
...,...,...,...,...,...
651042,639798,97503,188195,2.0,20280
651043,639799,97503,188195,3.0,V5869
651044,639800,97503,188195,4.0,V1279
651045,639801,97503,188195,5.0,5275


## ICU Stays

In [26]:
icustays_df = pd.read_csv(os.path.join(mimic_iii_path, 'ICUSTAYS.csv'))
icustays_df

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,ICUSTAY_ID,DBSOURCE,FIRST_CAREUNIT,LAST_CAREUNIT,FIRST_WARDID,LAST_WARDID,INTIME,OUTTIME,LOS
0,365,268,110404,280836,carevue,MICU,MICU,52,52,2198-02-14 23:27:38,2198-02-18 05:26:11,3.2490
1,366,269,106296,206613,carevue,MICU,MICU,52,52,2170-11-05 11:05:29,2170-11-08 17:46:57,3.2788
2,367,270,188028,220345,carevue,CCU,CCU,57,57,2128-06-24 15:05:20,2128-06-27 12:32:29,2.8939
3,368,271,173727,249196,carevue,MICU,SICU,52,23,2120-08-07 23:12:42,2120-08-10 00:39:04,2.0600
4,369,272,164716,210407,carevue,CCU,CCU,57,57,2186-12-25 21:08:04,2186-12-27 12:01:13,1.6202
...,...,...,...,...,...,...,...,...,...,...,...,...
61527,59806,94944,143774,201233,metavision,CSRU,CSRU,15,15,2104-04-15 10:18:16,2104-04-17 14:51:00,2.1894
61528,59807,94950,123750,283653,metavision,CCU,CCU,7,7,2155-12-08 05:33:16,2155-12-10 17:24:58,2.4942
61529,59808,94953,196881,241585,metavision,SICU,SICU,57,57,2160-03-03 16:09:11,2160-03-04 14:22:33,0.9259
61530,59809,94954,118475,202802,metavision,CSRU,CSRU,15,15,2183-03-25 09:53:10,2183-03-27 17:55:03,2.3346


## Note Events
Get notes from previous admissions

In [27]:
noteevents_df = pd.read_csv(os.path.join(mimic_iii_path, 'NOTEEVENTS.csv'), engine='c') #, nrows=50000, skiprows=(lambda x: x % 2 != 0))
noteevents_df

  noteevents_df = pd.read_csv(os.path.join(mimic_iii_path, 'NOTEEVENTS.csv'), engine='c') #, nrows=50000, skiprows=(lambda x: x % 2 != 0))


Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,CHARTDATE,CHARTTIME,STORETIME,CATEGORY,DESCRIPTION,CGID,ISERROR,TEXT
0,174,22532,167853.0,2151-08-04,,,Discharge summary,Report,,,Admission Date: [**2151-7-16**] Dischar...
1,175,13702,107527.0,2118-06-14,,,Discharge summary,Report,,,Admission Date: [**2118-6-2**] Discharg...
2,176,13702,167118.0,2119-05-25,,,Discharge summary,Report,,,Admission Date: [**2119-5-4**] D...
3,177,13702,196489.0,2124-08-18,,,Discharge summary,Report,,,Admission Date: [**2124-7-21**] ...
4,178,26880,135453.0,2162-03-25,,,Discharge summary,Report,,,Admission Date: [**2162-3-3**] D...
...,...,...,...,...,...,...,...,...,...,...,...
2083175,2070657,31097,115637.0,2132-01-21,2132-01-21 03:27:00,2132-01-21 03:38:00,Nursing/other,Report,17581.0,,NPN\n\n\n#1 Infant remains in RA with O2 sats...
2083176,2070658,31097,115637.0,2132-01-21,2132-01-21 09:50:00,2132-01-21 09:53:00,Nursing/other,Report,19211.0,,"Neonatology\nDOL #5, CGA 36 weeks.\n\nCVR: Con..."
2083177,2070659,31097,115637.0,2132-01-21,2132-01-21 16:42:00,2132-01-21 16:44:00,Nursing/other,Report,20104.0,,Family Meeting Note\nFamily meeting held with ...
2083178,2070660,31097,115637.0,2132-01-21,2132-01-21 18:05:00,2132-01-21 18:16:00,Nursing/other,Report,16023.0,,NPN 1800\n\n\n#1 Resp: [**Known lastname 2243*...


In [28]:
noteevents_df['TEXT'][0]

'Admission Date:  [**2151-7-16**]       Discharge Date:  [**2151-8-4**]\n\n\nService:\nADDENDUM:\n\nRADIOLOGIC STUDIES:  Radiologic studies also included a chest\nCT, which confirmed cavitary lesions in the left lung apex\nconsistent with infectious process/tuberculosis.  This also\nmoderate-sized left pleural effusion.\n\nHEAD CT:  Head CT showed no intracranial hemorrhage or mass\neffect, but old infarction consistent with past medical\nhistory.\n\nABDOMINAL CT:  Abdominal CT showed lesions of\nT10 and sacrum most likely secondary to osteoporosis. These can\nbe followed by repeat imaging as an outpatient.\n\n\n\n                            [**First Name8 (NamePattern2) **] [**First Name4 (NamePattern1) 1775**] [**Last Name (NamePattern1) **], M.D.  [**MD Number(1) 1776**]\n\nDictated By:[**Hospital 1807**]\nMEDQUIST36\n\nD:  [**2151-8-5**]  12:11\nT:  [**2151-8-5**]  12:21\nJOB#:  [**Job Number 1808**]\n'

In [29]:
noteevents_df['CHARTDATE'] = pd.to_datetime(noteevents_df['CHARTDATE'])
noteevents_df['CHARTTIME'] = pd.to_datetime(noteevents_df['CHARTTIME'])
noteevents_df['STORETIME'] = pd.to_datetime(noteevents_df['STORETIME'])

In [30]:
noteevents_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2083180 entries, 0 to 2083179
Data columns (total 11 columns):
 #   Column       Dtype         
---  ------       -----         
 0   ROW_ID       int64         
 1   SUBJECT_ID   int64         
 2   HADM_ID      float64       
 3   CHARTDATE    datetime64[ns]
 4   CHARTTIME    datetime64[ns]
 5   STORETIME    datetime64[ns]
 6   CATEGORY     object        
 7   DESCRIPTION  object        
 8   CGID         float64       
 9   ISERROR      float64       
 10  TEXT         object        
dtypes: datetime64[ns](3), float64(3), int64(2), object(3)
memory usage: 174.8+ MB


In [31]:
noteevents_df['CATEGORY'].value_counts()

CATEGORY
Nursing/other        822497
Radiology            522279
Nursing              223556
ECG                  209051
Physician            141624
Discharge summary     59652
Echo                  45794
Respiratory           31739
Nutrition              9418
General                8301
Rehab Services         5431
Social Work            2670
Case Management         967
Pharmacy                103
Consult                  98
Name: count, dtype: int64

In [32]:
noteevents_df['DESCRIPTION'].value_counts()

DESCRIPTION
Report                                                 1132519
Nursing Progress Note                                   191836
CHEST (PORTABLE AP)                                     169270
Physician Resident Progress Note                         62698
CHEST (PA & LAT)                                         43158
                                                        ...   
RP FOOT 1 VIEW RIGHT PORT                                    1
Intensvist                                                   1
O IVP NO TOMO IN O.R.                                        1
OPL KNEE (2 VIEWS) IN O.R. PORT LEFT                         1
RO HIP NAILING IN OR W/FILMS & FLUORO RIGHT IN O.R.          1
Name: count, Length: 3848, dtype: int64

In [33]:
noteevents_df['ISERROR'].value_counts()

ISERROR
1.0    886
Name: count, dtype: int64

In [34]:
noteevents_df['ISERROR'].isna().sum()

2082294

In [35]:
noteevents_df[noteevents_df['ISERROR'] == 1]

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,CHARTDATE,CHARTTIME,STORETIME,CATEGORY,DESCRIPTION,CGID,ISERROR,TEXT
308706,316321,28742,185325.0,2144-02-02,2144-02-02 07:24:00,2144-02-02 07:32:05,Nursing,Nursing Transfer Note,15065.0,1.0,73 y/o female with breast cancer metastatic to...
308924,318078,27370,101266.0,2125-02-01,2125-02-01 16:30:00,2125-02-01 19:55:08,Physician,Physician Attending Admission Note,18917.0,1.0,"Chief Complaint: abd pain and hematuria , t..."
309218,316816,820,193603.0,2144-02-28,2144-02-27 05:43:00,2144-02-28 15:33:37,Nursing,Nursing Transfer Note,14607.0,1.0,Demographics\n Attending MD:\n Admit diagn...
309237,316679,29552,139512.0,2109-01-17,2109-01-17 06:12:00,2109-01-17 06:57:55,Respiratory,Respiratory Care Shift Note,17698.0,1.0,------ Protected Section------\n Blank note ...
309266,316432,31608,152365.0,2133-01-19,2133-01-16 17:54:00,2133-01-19 14:21:30,Nursing,Nursing Progress Note,18576.0,1.0,"Respiratory failure, acute (not ARDS/[**Doctor..."
...,...,...,...,...,...,...,...,...,...,...,...
2063307,371723,8492,133491.0,2117-04-04,2117-04-03 08:31:00,2117-04-04 08:15:49,Physician,Physician Attending Progress Note,19697.0,1.0,Chief Complaint:\n I saw and examined the pa...
2063558,572116,59225,123146.0,2139-05-28,2139-05-28 07:46:00,2139-05-28 08:04:29,Physician,Physician Resident Progress Note,19338.0,1.0,Chief Complaint:\n 24 Hour Events:\n PA CATH...
2065368,529203,66690,125078.0,2106-04-30,2106-04-30 07:43:00,2106-04-30 09:28:27,Physician,Physician Resident Progress Note,16478.0,1.0,Chief Complaint:\n 24 Hour Events:\n Pt w/...
2065469,572091,59225,123146.0,2139-05-28,2139-05-28 07:04:00,2139-05-28 07:10:16,Physician,Physician Resident Progress Note,19338.0,1.0,Chief Complaint:\n 24 Hour Events:\n -Cr 2...


In [36]:
noteevents_cols = [
    'SUBJECT_ID',
    'HADM_ID',
    'CHARTTIME',
    'TEXT'
]

noteevents_df_filtered = noteevents_df[
    (noteevents_df['ISERROR'].isna())\
    &\
    (~noteevents_df['CHARTTIME'].isna())
][noteevents_cols].copy()
noteevents_df_filtered

Unnamed: 0,SUBJECT_ID,HADM_ID,CHARTTIME,TEXT
308691,29075,179159.0,2116-02-07 14:08:00,"67M w/ h/o multiplemyeloma Dx [**2111**], neur..."
308692,18082,181163.0,2156-03-12 14:23:00,"[**Age over 90 52**]F with COPD on home O2, CA..."
308693,18082,181163.0,2156-03-12 14:28:00,"[**Age over 90 52**]F with COPD on home O2, CA..."
308694,16605,109285.0,2138-03-21 15:02:00,Chief Complaint:\n 24 Hour Events:\n Continu...
308695,29075,179159.0,2116-02-07 15:37:00,Chief Complaint:\n 24 Hour Events:\n EGD d...
...,...,...,...,...
2083175,31097,115637.0,2132-01-21 03:27:00,NPN\n\n\n#1 Infant remains in RA with O2 sats...
2083176,31097,115637.0,2132-01-21 09:50:00,"Neonatology\nDOL #5, CGA 36 weeks.\n\nCVR: Con..."
2083177,31097,115637.0,2132-01-21 16:42:00,Family Meeting Note\nFamily meeting held with ...
2083178,31097,115637.0,2132-01-21 18:05:00,NPN 1800\n\n\n#1 Resp: [**Known lastname 2243*...


In [37]:
noteevents_df_filtered_24h = filter_df_24h(admissions_df_filtered, noteevents_df_filtered, 'CHARTTIME').reset_index(drop=True)
noteevents_df_filtered_24h

Unnamed: 0,SUBJECT_ID,HADM_ID,CHARTTIME,TEXT,ADMITTIME,DEATH
0,18839,136227.0,2193-02-08 15:39:00,Attending Physician: [**Name10 (NameIs) 116**]...,2193-02-07 18:10:00,False
1,30015,157324.0,2154-03-25 07:02:00,"Chief Complaint: 75F with COPD, DM, [**Hospita...",2154-03-24 12:08:00,False
2,30015,157324.0,2154-03-25 07:02:00,"Chief Complaint: 75F with COPD, DM, [**Hospita...",2154-03-24 12:08:00,False
3,30015,157324.0,2154-03-25 07:02:00,"Chief Complaint: 75F with COPD, DM, [**Hospita...",2154-03-24 12:08:00,False
4,31975,165003.0,2138-02-11 05:09:00,"Pneumonia, other\n Assessment:\n Afebrile,...",2138-02-10 21:58:00,False
...,...,...,...,...,...,...
229481,29671,122817.0,2150-11-13 14:37:00,Clinical Nutrition:\nO:\n~Former 24 [**4-8**] ...,2150-11-12 15:48:00,False
229482,29671,122817.0,2150-11-12 18:08:00,Respiratory Care Note\nBaby Girl [**Known last...,2150-11-12 15:48:00,False
229483,29671,122817.0,2150-11-12 22:39:00,NICU NSG NOTE\n\n\n#1. Resp. O/ Conts on HFOV ...,2150-11-12 15:48:00,False
229484,29671,122817.0,2150-11-13 03:20:00,"1. Resp: O: Received infant on the HiFi vent, ...",2150-11-12 15:48:00,False


In [38]:
noteevents_df_filtered_24h['HADM_ID'].value_counts().mean()

5.981649941352796

In [39]:
noteevents_deathrate = noteevents_df_filtered_24h['DEATH'].mean()
noteevents_deathrate

0.1368100886328578

## Patients

- gender
- age (on admission)

In [40]:
patients_df = pd.read_csv(os.path.join(mimic_iii_path, 'PATIENTS.csv'))
patients_df

Unnamed: 0,ROW_ID,SUBJECT_ID,GENDER,DOB,DOD,DOD_HOSP,DOD_SSN,EXPIRE_FLAG
0,234,249,F,2075-03-13 00:00:00,,,,0
1,235,250,F,2164-12-27 00:00:00,2188-11-22 00:00:00,2188-11-22 00:00:00,,1
2,236,251,M,2090-03-15 00:00:00,,,,0
3,237,252,M,2078-03-06 00:00:00,,,,0
4,238,253,F,2089-11-26 00:00:00,,,,0
...,...,...,...,...,...,...,...,...
46515,31840,44089,M,2026-05-25 00:00:00,,,,0
46516,31841,44115,F,2124-07-27 00:00:00,,,,0
46517,31842,44123,F,2049-11-26 00:00:00,2135-01-12 00:00:00,2135-01-12 00:00:00,,1
46518,31843,44126,F,2076-07-25 00:00:00,,,,0


In [41]:
patients_df['DOB'] = pd.to_datetime(patients_df['DOB'])

In [42]:
patients_cols = [
    'SUBJECT_ID',
    'GENDER',
    'DOB'
]

patients_df_filtered = patients_df[patients_cols].copy()

# Convert gender to binary feature (0 for female, 1 for male)
patients_df_filtered['GENDER_BINARY'] = patients_df_filtered['GENDER'].apply(lambda x: 1 if x == 'M' else 0)

patients_df_filtered

Unnamed: 0,SUBJECT_ID,GENDER,DOB,GENDER_BINARY
0,249,F,2075-03-13,0
1,250,F,2164-12-27,0
2,251,M,2090-03-15,1
3,252,M,2078-03-06,1
4,253,F,2089-11-26,0
...,...,...,...,...
46515,44089,M,2026-05-25,1
46516,44115,F,2124-07-27,0
46517,44123,F,2049-11-26,0
46518,44126,F,2076-07-25,0


## Prescriptions

In [5]:
prescriptions_df = pd.read_csv(os.path.join(mimic_iii_path, 'PRESCRIPTIONS.csv'))
prescriptions_df

  prescriptions_df = pd.read_csv(os.path.join(mimic_iii_path, 'PRESCRIPTIONS.csv'))


Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,ICUSTAY_ID,STARTDATE,ENDDATE,DRUG_TYPE,DRUG,DRUG_NAME_POE,DRUG_NAME_GENERIC,FORMULARY_DRUG_CD,GSN,NDC,PROD_STRENGTH,DOSE_VAL_RX,DOSE_UNIT_RX,FORM_VAL_DISP,FORM_UNIT_DISP,ROUTE
0,2214776,6,107064,,2175-06-11 00:00:00,2175-06-12 00:00:00,MAIN,Tacrolimus,Tacrolimus,Tacrolimus,TACR1,021796,4.690617e+08,1mg Capsule,2,mg,2,CAP,PO
1,2214775,6,107064,,2175-06-11 00:00:00,2175-06-12 00:00:00,MAIN,Warfarin,Warfarin,Warfarin,WARF5,006562,5.601728e+07,5mg Tablet,5,mg,1,TAB,PO
2,2215524,6,107064,,2175-06-11 00:00:00,2175-06-12 00:00:00,MAIN,Heparin Sodium,,,HEPAPREMIX,006522,3.380550e+08,"25,000 unit Premix Bag",25000,UNIT,1,BAG,IV
3,2216265,6,107064,,2175-06-11 00:00:00,2175-06-12 00:00:00,BASE,D5W,,,HEPBASE,,0.000000e+00,HEPARIN BASE,250,ml,250,ml,IV
4,2214773,6,107064,,2175-06-11 00:00:00,2175-06-12 00:00:00,MAIN,Furosemide,Furosemide,Furosemide,FURO20,008208,5.482972e+07,20mg Tablet,20,mg,1,TAB,PO
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4156445,3996662,98887,121032,238144.0,2144-09-06 00:00:00,2144-09-06 00:00:00,MAIN,PredniSONE,PredniSONE,PredniSONE,PRED20,006751,5.400182e+07,20 mg Tablet,40,mg,2,TAB,PO/NG
4156446,3996070,98887,121032,238144.0,2144-09-06 00:00:00,2144-09-06 00:00:00,MAIN,Ipratropium Bromide Neb,Ipratropium Bromide Neb,Ipratropium Bromide Neb,IPRA2H,021700,4.879801e+08,2.5mL Vial,1,NEB,1,VIAL,IH
4156447,3996063,98887,121032,238144.0,2144-09-06 00:00:00,2144-09-06 00:00:00,MAIN,HYDROmorphone (Dilaudid),HYDROmorphone (Dilaudid),HYDROmorphone,HYDR20/100NS,048078,6.155302e+10,20 mg / 100 mL Premix Bag,0.12,mg,0.01,BAG,IVPCA
4156448,3996062,98887,121032,238144.0,2144-09-06 00:00:00,2144-09-06 00:00:00,MAIN,Docusate Sodium,Docusate Sodium,Docusate Sodium,DOCU100,003009,9.042245e+08,100mg Capsule,100,mg,1,CAP,PO


In [6]:
prescriptions_df['STARTDATE'] = pd.to_datetime(prescriptions_df['STARTDATE'])

In [7]:
prescriptions_df[['DRUG', 'DRUG_NAME_POE', 'DRUG_NAME_GENERIC', 'GSN', 'NDC']].isna().sum()

DRUG                       0
DRUG_NAME_POE        1664234
DRUG_NAME_GENERIC    1662989
GSN                   507164
NDC                     4463
dtype: int64

In [8]:
prescriptions_df['DRUG'].value_counts()

DRUG
Potassium Chloride      192993
Insulin                 143465
D5W                     142241
Furosemide              133122
0.9% Sodium Chloride    130147
                         ...  
Renaphro                     1
Morphine Sulfat              1
humulin R                    1
Meperidine PCA               1
rasagiline (Azilect)         1
Name: count, Length: 4525, dtype: int64

In [9]:
prescriptions_df[prescriptions_df['NDC'].isna()]['DRUG'].value_counts()

DRUG
D5W                              215
Potassium Chloride               162
Propofol                         135
Fentanyl Citrate                  99
Insulin Pump                      95
                                ... 
atomoxitene                        1
Enema Disposable                   1
LIDODERM PATCHES 3.5%              1
Menest (esterified estrogens)      1
rasagiline (Azilect)               1
Name: count, Length: 1146, dtype: int64

In [10]:
assert len(prescriptions_df[prescriptions_df['DRUG'].isna()]) == 0

In [14]:
drugs = prescriptions_df['DRUG'].unique()
min_length = min(len(str(drug)) for drug in drugs)
max_length = max(len(str(drug)) for drug in drugs)
min_length_drug = min(drugs, key=lambda x: len(str(x)))
max_length_drug = max(drugs, key=lambda x: len(str(x)))

print(f"Minimum drug length: {min_length}")
print(f"Minimum length drug: {min_length_drug}")
print(f"Maximum drug length: {max_length}")
print(f"Maximum length drug: {max_length_drug}")

dose_vals = prescriptions_df['DOSE_VAL_RX'].unique()
min_length = min(len(str(dose_val)) for dose_val in dose_vals)
max_length = max(len(str(dose_val)) for dose_val in dose_vals)
min_length_dose = min(dose_vals, key=lambda x: len(str(x)))
max_length_dose = max(dose_vals, key=lambda x: len(str(x)))

print(f"Minimum dose_val length: {min_length}")
print(f"Minimum length dose_val: {min_length_dose}")
print(f"Maximum dose_val length: {max_length}")
print(f"Maximum length dose_val: {max_length_dose}")

dose_units = prescriptions_df['DOSE_UNIT_RX'].unique()
min_length = min(len(str(dose_unit)) for dose_unit in dose_units)
max_length = max(len(str(dose_unit)) for dose_unit in dose_units)
min_length_unit = min((u for u in dose_units), key=lambda x: len(str(x)))
max_length_unit = max((u for u in dose_units), key=lambda x: len(str(x)))

print(f"Minimum dose_unit length: {min_length}")
print(f"Minimum length dose_unit: {min_length_unit}")
print(f"Maximum dose_unit length: {max_length}")
print(f"Maximum length dose_unit: {max_length_unit}")

Minimum drug length: 1
Minimum length drug: H
Maximum drug length: 58
Maximum length drug: Alteplase 1mg/Flush Volume ( Dialysis/Pheresis Catheters )
Minimum dose_val length: 1
Minimum length dose_val: 2
Maximum dose_val length: 26
Maximum length dose_val: 16; give 1/2 dose when NPO
Minimum dose_unit length: 1
Minimum length dose_unit: g
Maximum dose_unit length: 32
Maximum length dose_unit: PE (Phenytoin Sodium Equivalent)


In [49]:
prescriptions_cols = [
    'HADM_ID',
    'STARTDATE',
    'DRUG'
]

prescriptions_df_filtered = prescriptions_df[prescriptions_cols].copy()
prescriptions_df_filtered

Unnamed: 0,HADM_ID,STARTDATE,DRUG
0,107064,2175-06-11,Tacrolimus
1,107064,2175-06-11,Warfarin
2,107064,2175-06-11,Heparin Sodium
3,107064,2175-06-11,D5W
4,107064,2175-06-11,Furosemide
...,...,...,...
4156445,121032,2144-09-06,PredniSONE
4156446,121032,2144-09-06,Ipratropium Bromide Neb
4156447,121032,2144-09-06,HYDROmorphone (Dilaudid)
4156448,121032,2144-09-06,Docusate Sodium


In [50]:
prescriptions_df_filtered_24h = filter_df_24h(admissions_df_filtered, prescriptions_df_filtered, 'STARTDATE').reset_index(drop=True)
prescriptions_df_filtered_24h

Unnamed: 0,HADM_ID,STARTDATE,DRUG,ADMITTIME,DEATH
0,143045,2167-01-08,D5W,2167-01-08 18:43:00,False
1,143045,2167-01-08,Heparin Sodium,2167-01-08 18:43:00,False
2,143045,2167-01-08,Nitroglycerin,2167-01-08 18:43:00,False
3,143045,2167-01-08,Docusate Sodium,2167-01-08 18:43:00,False
4,143045,2167-01-08,Insulin,2167-01-08 18:43:00,False
...,...,...,...,...,...
1265959,121032,2144-09-06,PredniSONE,2144-09-06 10:03:00,False
1265960,121032,2144-09-06,Ipratropium Bromide Neb,2144-09-06 10:03:00,False
1265961,121032,2144-09-06,HYDROmorphone (Dilaudid),2144-09-06 10:03:00,False
1265962,121032,2144-09-06,Docusate Sodium,2144-09-06 10:03:00,False


In [51]:
prescriptions_deathrate = prescriptions_df_filtered_24h['DEATH'].mean()
prescriptions_deathrate

0.14087683378042345

## Procedures ICD

In [52]:
procedures_icd_df = pd.read_csv(os.path.join(mimic_iii_path, 'PROCEDURES_ICD.csv'))
procedures_icd_df

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,SEQ_NUM,ICD9_CODE
0,944,62641,154460,3,3404
1,945,2592,130856,1,9671
2,946,2592,130856,2,3893
3,947,55357,119355,1,9672
4,948,55357,119355,2,331
...,...,...,...,...,...
240090,228330,67415,150871,5,3736
240091,228331,67415,150871,6,3893
240092,228332,67415,150871,7,8872
240093,228333,67415,150871,8,3893


## Services

In [53]:
services_df = pd.read_csv(os.path.join(mimic_iii_path, 'SERVICES.csv'))
services_df

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,TRANSFERTIME,PREV_SERVICE,CURR_SERVICE
0,758,471,135879,2122-07-22 14:07:27,TSURG,MED
1,759,471,135879,2122-07-26 18:31:49,MED,TSURG
2,760,472,173064,2172-09-28 19:22:15,,CMED
3,761,473,129194,2201-01-09 20:16:45,,NB
4,762,474,194246,2181-03-23 08:24:41,,NB
...,...,...,...,...,...,...
73338,72914,98932,174244,2176-08-13 20:28:00,,CMED
73339,72915,98939,115549,2166-11-15 01:21:49,,NMED
73340,72916,98941,141129,2118-02-08 01:52:28,,CSURG
73341,72917,98943,193747,2164-11-14 20:04:12,,TRAUM


## Transfers

- wards visited
- number of transfers

In [54]:
transfers_df = pd.read_csv(os.path.join(mimic_iii_path, 'TRANSFERS.csv'))
transfers_df

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,ICUSTAY_ID,DBSOURCE,EVENTTYPE,PREV_CAREUNIT,CURR_CAREUNIT,PREV_WARDID,CURR_WARDID,INTIME,OUTTIME,LOS
0,657,111,192123,254245.0,carevue,transfer,CCU,MICU,7.0,23.0,2142-04-29 15:27:11,2142-05-04 20:38:33,125.19
1,658,111,192123,,carevue,transfer,MICU,,23.0,45.0,2142-05-04 20:38:33,2142-05-05 11:46:32,15.13
2,659,111,192123,,carevue,discharge,,,45.0,,2142-05-05 11:46:32,,
3,660,111,155897,249202.0,metavision,admit,,MICU,,52.0,2144-07-01 04:13:59,2144-07-01 05:19:39,1.09
4,661,111,155897,,metavision,transfer,MICU,,52.0,32.0,2144-07-01 05:19:39,2144-07-01 06:28:29,1.15
...,...,...,...,...,...,...,...,...,...,...,...,...,...
261892,259671,98385,195599,,metavision,transfer,,,36.0,49.0,2108-10-06 11:27:11,2108-10-06 13:05:57,1.65
261893,259672,98385,195599,292167.0,metavision,transfer,,SICU,49.0,33.0,2108-10-06 13:05:57,2108-10-11 17:00:31,123.91
261894,259673,98385,195599,,metavision,discharge,SICU,,33.0,,2108-10-11 17:00:31,,
261895,259674,98389,155368,,metavision,admit,,,,29.0,2153-10-14 22:12:58,2153-10-14 22:21:06,0.14


In [55]:
transfers_df['INTIME'] = pd.to_datetime(transfers_df['INTIME'])
transfers_df['OUTTIME'] = pd.to_datetime(transfers_df['OUTTIME'])

In [56]:
transfers_df[transfers_df['HADM_ID'] == 192123].sort_values('INTIME')

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,ICUSTAY_ID,DBSOURCE,EVENTTYPE,PREV_CAREUNIT,CURR_CAREUNIT,PREV_WARDID,CURR_WARDID,INTIME,OUTTIME,LOS
556,655,111,192123,254245.0,carevue,admit,,CCU,,7.0,2142-04-24 06:56:33,2142-04-27 16:27:37,81.52
557,656,111,192123,254245.0,carevue,transfer,CCU,CCU,7.0,7.0,2142-04-27 16:27:37,2142-04-29 15:27:11,46.99
0,657,111,192123,254245.0,carevue,transfer,CCU,MICU,7.0,23.0,2142-04-29 15:27:11,2142-05-04 20:38:33,125.19
1,658,111,192123,,carevue,transfer,MICU,,23.0,45.0,2142-05-04 20:38:33,2142-05-05 11:46:32,15.13
2,659,111,192123,,carevue,discharge,,,45.0,,2142-05-05 11:46:32,NaT,


In [57]:
transfers_df['EVENTTYPE'].value_counts()

EVENTTYPE
transfer     144045
discharge     58919
admit         58909
Name: count, dtype: int64

In [58]:
transfers_df[transfers_df['EVENTTYPE'] == 'transfer'][['PREV_WARDID', 'CURR_WARDID', 'INTIME', 'OUTTIME']].isna().sum()

PREV_WARDID     0
CURR_WARDID     0
INTIME          0
OUTTIME        30
dtype: int64

In [59]:
assert transfers_df[(transfers_df['EVENTTYPE'] == 'transfer') & (~transfers_df['OUTTIME'].isna())]\
    [['PREV_WARDID', 'CURR_WARDID', 'INTIME', 'OUTTIME']]\
        .isna().sum().sum() \
            == 0

In [60]:
transfers_cols = [
    'HADM_ID',
    'PREV_WARDID',
    'CURR_WARDID',
    'INTIME'
]

transfers_df_filtered = transfers_df[(transfers_df['EVENTTYPE'] == 'transfer') & (~transfers_df['OUTTIME'].isna())]\
    [transfers_cols].copy()
transfers_df_filtered

Unnamed: 0,HADM_ID,PREV_WARDID,CURR_WARDID,INTIME
0,192123,7.0,23.0,2142-04-29 15:27:11
1,192123,23.0,45.0,2142-05-04 20:38:33
4,155897,52.0,32.0,2144-07-01 05:19:39
5,155897,32.0,52.0,2144-07-01 06:28:29
6,155897,52.0,32.0,2144-07-01 08:07:16
...,...,...,...,...
261888,153384,14.0,36.0,2108-08-22 22:02:27
261891,195599,36.0,36.0,2108-10-05 20:45:56
261892,195599,36.0,49.0,2108-10-06 11:27:11
261893,195599,49.0,33.0,2108-10-06 13:05:57


In [61]:
transfers_df_filtered_24h = filter_df_24h(admissions_df_filtered, transfers_df_filtered, 'INTIME').reset_index(drop=True)
transfers_df_filtered_24h

Unnamed: 0,HADM_ID,PREV_WARDID,CURR_WARDID,INTIME,ADMITTIME,DEATH
0,155897,52.0,32.0,2144-07-01 05:19:39,2144-07-01 04:12:00,True
1,155897,32.0,52.0,2144-07-01 06:28:29,2144-07-01 04:12:00,True
2,155897,52.0,32.0,2144-07-01 08:07:16,2144-07-01 04:12:00,True
3,155897,32.0,23.0,2144-07-01 08:13:51,2144-07-01 04:12:00,True
4,174105,12.0,3.0,2194-06-14 14:51:17,2194-06-13 18:39:00,False
...,...,...,...,...,...,...
25908,126800,3.0,7.0,2140-07-14 22:59:38,2140-07-14 18:30:00,False
25909,195599,36.0,36.0,2108-10-05 20:45:56,2108-10-05 20:35:00,False
25910,195599,36.0,49.0,2108-10-06 11:27:11,2108-10-05 20:35:00,False
25911,195599,49.0,33.0,2108-10-06 13:05:57,2108-10-05 20:35:00,False


In [62]:
transfers_deathrate = transfers_df_filtered_24h['DEATH'].mean()
transfers_deathrate

0.09636089993439587

# <u>Preliminary Feature Selection</u>

## Methods

In [None]:
def train_and_eval_prelim_xgb(encoded_df, positive_freq, label_col='DEATH', id_cols=['HADM_ID'], test_size=0.2, random_state=42):
    # Separate features and label
    drop_cols = id_cols + [label_col]
    X = encoded_df.drop(drop_cols, axis=1)
    y = encoded_df[label_col]

    # Split the dataset into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)

    # Initialize and train the model
    model = XGBClassifier(
        enable_categorical = True,  # Enable categorical feature support
        eval_metric = 'logloss',
        scale_pos_weight = 1 / positive_freq  # Adjust class imbalance
    )
    model.fit(X_train, y_train)

    # Evaluate the model
    y_pred = model.predict(X_test)
    print('Accuracy:', accuracy_score(y_test, y_pred))
    print(f"ROC AUC: {roc_auc_score(y_test, model.predict_proba(X_test)[:,1]):.4f}")
    print(f"Average Precision: {average_precision_score(y_test, model.predict_proba(X_test)[:,1]):.4f}")
    print(classification_report(y_test, y_pred))

    # Print feature importances
    importances = model.feature_importances_
    feat_importances = pd.DataFrame({
        'feature': X.columns,
        'importance': importances
    }).sort_values(by='importance', ascending=False)
    print('Feature importances:')
    display(feat_importances)

    return feat_importances

In [64]:
def top_p_features(df, p=0.8, importance_col='importance'):
    """
    Returns a DataFrame with the top features that together account for at least 'p' of the total importance.

    Args:
        df (pd.DataFrame): DataFrame that includes an importance column.
        p (float): Cumulative importance threshold, e.g. 0.8.
        importance_col (str): Name of the column with importance scores.

    Returns:
        pd.DataFrame: Filtered DataFrame with the top-p features.
    """
    # Sort feature importances in descending order
    fi_sorted = df.sort_values(by=importance_col, ascending=False).reset_index(drop=True)

    # Compute cumulative normalized importance
    cum_norm = fi_sorted[importance_col].cumsum() / fi_sorted[importance_col].sum()

    # Determine the number of rows to include (include the row that pushes cumulative importance ≥ p)
    num_features = (cum_norm <= p).sum()
    if num_features < len(fi_sorted):
        num_features += 1

    return fi_sorted.iloc[:num_features]

In [65]:
def assert_unique_HADM_ID(df):
    assert df['HADM_ID'].nunique() == len(df), 'Some HADM_IDs have multiple rows'

## Admissions (not needed)

### Feature extraction

In [66]:
admissions_feature_cols = [
    'HADM_ID',
    # 'ADMISSION_TYPE', # removed non emergency/urgent admissions
    'ADMISSION_LOCATION',
    'INSURANCE',
    'ETHNICITY'
    # 'DIAGNOSIS' # skip, we'll get diagnostic information from NOTEEVENTS
]
for column in admissions_feature_cols:
    print(f'{column} unique values: {admissions_df_filtered[column].nunique()}')

HADM_ID unique values: 43407
ADMISSION_LOCATION unique values: 9
INSURANCE unique values: 5
ETHNICITY unique values: 41


In [67]:
admissions_df_filtered['ETHNICITY'].value_counts(normalize=True)

ETHNICITY
WHITE                                                       0.698044
BLACK/AFRICAN AMERICAN                                      0.097450
UNKNOWN/NOT SPECIFIED                                       0.081876
HISPANIC OR LATINO                                          0.026977
OTHER                                                       0.022231
UNABLE TO OBTAIN                                            0.016173
ASIAN                                                       0.016034
PATIENT DECLINED TO ANSWER                                  0.006819
HISPANIC/LATINO - PUERTO RICAN                              0.004700
ASIAN - CHINESE                                             0.004354
BLACK/CAPE VERDEAN                                          0.003410
WHITE - RUSSIAN                                             0.003340
BLACK/HAITIAN                                               0.002143
MULTI RACE ETHNICITY                                        0.002050
HISPANIC/LATINO - DOMINI

In [68]:
ethnicity_buckets = [
    'WHITE',
    'BLACK',
    'HISPANIC',
    'ASIAN'
]

def map_ethnicity(ethnicity_str):
    ethnicity_upper = ethnicity_str.upper()
    for bucket in ethnicity_buckets:
        if bucket in ethnicity_upper:
            return bucket
    return 'OTHER'


admissions_df_filtered['ETHNICITY_BUCKET'] = admissions_df_filtered['ETHNICITY'].apply(map_ethnicity)
admissions_df_filtered['ETHNICITY_BUCKET'].value_counts(normalize=True)


ETHNICITY_BUCKET
WHITE       0.703919
OTHER       0.132306
BLACK       0.103854
HISPANIC    0.035547
ASIAN       0.024374
Name: proportion, dtype: float64

In [69]:
admissions_feature_cols.append('ETHNICITY_BUCKET')
admissions_feature_cols.remove('ETHNICITY')

In [70]:
admissions_features_df = admissions_df_filtered[admissions_feature_cols].copy()

# Convert string columns to 'category' dtype
for col in admissions_feature_cols:
    admissions_features_df[col] = admissions_features_df[col].astype('category')

admissions_features_df['DEATH'] = admissions_df_filtered['DEATH']
admissions_features_df

Unnamed: 0,HADM_ID,ADMISSION_LOCATION,INSURANCE,ETHNICITY_BUCKET,DEATH
0,165315,EMERGENCY ROOM ADMIT,Private,WHITE,False
1,124321,TRANSFER FROM HOSP/EXTRAM,Medicare,WHITE,False
2,161859,TRANSFER FROM HOSP/EXTRAM,Private,WHITE,False
3,129635,EMERGENCY ROOM ADMIT,Private,WHITE,False
4,197661,TRANSFER FROM HOSP/EXTRAM,Medicare,OTHER,False
...,...,...,...,...,...
43402,105447,EMERGENCY ROOM ADMIT,Medicare,WHITE,True
43403,191113,CLINIC REFERRAL/PREMATURE,Private,WHITE,False
43404,101071,CLINIC REFERRAL/PREMATURE,Medicare,WHITE,True
43405,170407,EMERGENCY ROOM ADMIT,Private,WHITE,False


### Eval

In [71]:
admissions_feature_importances = train_and_eval_prelim_xgb(admissions_features_df, admissions_deathrate, label_col='DEATH')

Accuracy: 0.48963372494816865
ROC AUC: 0.6026
Average Precision: 0.1640
              precision    recall  f1-score   support

       False       0.91      0.46      0.61      7584
        True       0.16      0.70      0.26      1098

    accuracy                           0.49      8682
   macro avg       0.54      0.58      0.43      8682
weighted avg       0.82      0.49      0.57      8682

Feature importances:


Unnamed: 0,feature,importance
1,INSURANCE,0.480915
2,ETHNICITY_BUCKET,0.342122
0,ADMISSION_LOCATION,0.176963


In [72]:
admissions_keep_features = top_p_features(admissions_feature_importances)['feature'].tolist()
admissions_keep_features

['INSURANCE', 'ETHNICITY_BUCKET']

## Admissions-Patients (4 features)

### Join patients

In [73]:
admissions_patients_df = pd.merge(
    admissions_df_filtered,
    patients_df_filtered,
    on='SUBJECT_ID',
    how='inner'
)
admissions_patients_df

Unnamed: 0,SUBJECT_ID,HADM_ID,ADMITTIME,DEATHTIME,ADMISSION_LOCATION,INSURANCE,ETHNICITY,DIAGNOSIS,DEATH,ETHNICITY_BUCKET,GENDER,DOB,GENDER_BINARY
0,22,165315,2196-04-09 12:26:00,NaT,EMERGENCY ROOM ADMIT,Private,WHITE,BENZODIAZEPINE OVERDOSE,False,WHITE,F,2131-05-07,0
1,23,124321,2157-10-18 19:34:00,NaT,TRANSFER FROM HOSP/EXTRAM,Medicare,WHITE,BRAIN MASS,False,WHITE,M,2082-07-17,1
2,24,161859,2139-06-06 16:14:00,NaT,TRANSFER FROM HOSP/EXTRAM,Private,WHITE,INTERIOR MYOCARDIAL INFARCTION,False,WHITE,M,2100-05-31,1
3,25,129635,2160-11-02 02:06:00,NaT,EMERGENCY ROOM ADMIT,Private,WHITE,ACUTE CORONARY SYNDROME,False,WHITE,M,2101-11-21,1
4,26,197661,2126-05-06 15:16:00,NaT,TRANSFER FROM HOSP/EXTRAM,Medicare,UNKNOWN/NOT SPECIFIED,V-TACH,False,OTHER,M,2054-05-04,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...
43402,98797,105447,2132-12-24 20:06:00,2132-12-25 12:00:00,EMERGENCY ROOM ADMIT,Medicare,WHITE,ALTERED MENTAL STATUS,True,WHITE,M,2044-12-27,1
43403,98800,191113,2131-03-30 21:13:00,NaT,CLINIC REFERRAL/PREMATURE,Private,WHITE,TRAUMA,False,WHITE,F,2111-11-05,0
43404,98802,101071,2151-03-05 20:00:00,2151-03-06 09:10:00,CLINIC REFERRAL/PREMATURE,Medicare,WHITE,SAH,True,WHITE,F,2067-09-21,0
43405,98813,170407,2128-11-11 02:29:00,NaT,EMERGENCY ROOM ADMIT,Private,WHITE,S/P FALL,False,WHITE,F,2068-02-04,0


In [74]:
admissions_patients_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 43407 entries, 0 to 43406
Data columns (total 13 columns):
 #   Column              Non-Null Count  Dtype         
---  ------              --------------  -----         
 0   SUBJECT_ID          43407 non-null  int64         
 1   HADM_ID             43407 non-null  int64         
 2   ADMITTIME           43407 non-null  datetime64[ns]
 3   DEATHTIME           5595 non-null   datetime64[ns]
 4   ADMISSION_LOCATION  43407 non-null  object        
 5   INSURANCE           43407 non-null  object        
 6   ETHNICITY           43407 non-null  object        
 7   DIAGNOSIS           43395 non-null  object        
 8   DEATH               43407 non-null  bool          
 9   ETHNICITY_BUCKET    43407 non-null  object        
 10  GENDER              43407 non-null  object        
 11  DOB                 43407 non-null  datetime64[ns]
 12  GENDER_BINARY       43407 non-null  int64         
dtypes: bool(1), datetime64[ns](3), int64(3), objec

In [75]:
from pandas._libs.tslibs.np_datetime import OutOfBoundsDatetime

# Calculate age safely with error handling for extreme date ranges
def calculate_age(admit_date, birth_date):
    try:
        return (admit_date - birth_date).days / 365
    except (OverflowError, OutOfBoundsDatetime):
        # If dates are too extreme, check if year difference is reasonable
        year_diff = admit_date.year - birth_date.year
        if 0 <= year_diff <= 120:  # Reasonable age range
            return year_diff
        else:
            print(f'HADM_ID: {admit_date} - {birth_date} has unreasonable age difference: {year_diff}')
            return float('nan')  # Return NaN for unreasonable values

# Apply the safer calculation
admissions_patients_df['AGE_YRS'] = admissions_patients_df.apply(
    lambda row: calculate_age(row['ADMITTIME'], row['DOB']), axis=1
)
admissions_patients_df

HADM_ID: 2172-10-14 14:17:00 - 1872-10-14 00:00:00 has unreasonable age difference: 300
HADM_ID: 2186-07-18 16:46:00 - 1886-07-18 00:00:00 has unreasonable age difference: 300
HADM_ID: 2191-02-23 05:23:00 - 1886-07-18 00:00:00 has unreasonable age difference: 305
HADM_ID: 2137-07-11 17:56:00 - 1837-07-11 00:00:00 has unreasonable age difference: 300
HADM_ID: 2139-12-16 19:48:00 - 1837-07-11 00:00:00 has unreasonable age difference: 302
HADM_ID: 2104-01-02 02:01:00 - 1804-01-02 00:00:00 has unreasonable age difference: 300
HADM_ID: 2135-02-18 19:16:00 - 1835-02-18 00:00:00 has unreasonable age difference: 300
HADM_ID: 2145-05-06 20:00:00 - 1845-05-06 00:00:00 has unreasonable age difference: 300
HADM_ID: 2194-06-13 18:39:00 - 1894-06-13 00:00:00 has unreasonable age difference: 300
HADM_ID: 2196-09-27 18:21:00 - 1894-06-13 00:00:00 has unreasonable age difference: 302
HADM_ID: 2108-08-05 16:25:00 - 1808-08-05 00:00:00 has unreasonable age difference: 300
HADM_ID: 2141-05-18 17:21:00 - 1

Unnamed: 0,SUBJECT_ID,HADM_ID,ADMITTIME,DEATHTIME,ADMISSION_LOCATION,INSURANCE,ETHNICITY,DIAGNOSIS,DEATH,ETHNICITY_BUCKET,GENDER,DOB,GENDER_BINARY,AGE_YRS
0,22,165315,2196-04-09 12:26:00,NaT,EMERGENCY ROOM ADMIT,Private,WHITE,BENZODIAZEPINE OVERDOSE,False,WHITE,F,2131-05-07,0,64.969863
1,23,124321,2157-10-18 19:34:00,NaT,TRANSFER FROM HOSP/EXTRAM,Medicare,WHITE,BRAIN MASS,False,WHITE,M,2082-07-17,1,75.304110
2,24,161859,2139-06-06 16:14:00,NaT,TRANSFER FROM HOSP/EXTRAM,Private,WHITE,INTERIOR MYOCARDIAL INFARCTION,False,WHITE,M,2100-05-31,1,39.041096
3,25,129635,2160-11-02 02:06:00,NaT,EMERGENCY ROOM ADMIT,Private,WHITE,ACUTE CORONARY SYNDROME,False,WHITE,M,2101-11-21,1,58.989041
4,26,197661,2126-05-06 15:16:00,NaT,TRANSFER FROM HOSP/EXTRAM,Medicare,UNKNOWN/NOT SPECIFIED,V-TACH,False,OTHER,M,2054-05-04,1,72.052055
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
43402,98797,105447,2132-12-24 20:06:00,2132-12-25 12:00:00,EMERGENCY ROOM ADMIT,Medicare,WHITE,ALTERED MENTAL STATUS,True,WHITE,M,2044-12-27,1,88.049315
43403,98800,191113,2131-03-30 21:13:00,NaT,CLINIC REFERRAL/PREMATURE,Private,WHITE,TRAUMA,False,WHITE,F,2111-11-05,0,19.410959
43404,98802,101071,2151-03-05 20:00:00,2151-03-06 09:10:00,CLINIC REFERRAL/PREMATURE,Medicare,WHITE,SAH,True,WHITE,F,2067-09-21,0,83.506849
43405,98813,170407,2128-11-11 02:29:00,NaT,EMERGENCY ROOM ADMIT,Private,WHITE,S/P FALL,False,WHITE,F,2068-02-04,0,60.808219


In [76]:
admissions_patients_df[admissions_patients_df['AGE_YRS'].isna()]

Unnamed: 0,SUBJECT_ID,HADM_ID,ADMITTIME,DEATHTIME,ADMISSION_LOCATION,INSURANCE,ETHNICITY,DIAGNOSIS,DEATH,ETHNICITY_BUCKET,GENDER,DOB,GENDER_BINARY,AGE_YRS
5,30,104557,2172-10-14 14:17:00,NaT,TRANSFER FROM HOSP/EXTRAM,Medicare,UNKNOWN/NOT SPECIFIED,UNSTABLE ANGINA\CATH,False,OTHER,M,1872-10-14,1,
8,34,115799,2186-07-18 16:46:00,NaT,TRANSFER FROM HOSP/EXTRAM,Medicare,WHITE,CHEST PAIN\CATH,False,WHITE,M,1886-07-18,1,
9,34,144319,2191-02-23 05:23:00,NaT,CLINIC REFERRAL/PREMATURE,Medicare,WHITE,BRADYCARDIA,False,WHITE,M,1886-07-18,1,
29,368,105889,2137-07-11 17:56:00,NaT,EMERGENCY ROOM ADMIT,Medicare,WHITE,PNEUMONIA,False,WHITE,M,1837-07-11,1,
30,368,138061,2139-12-16 19:48:00,NaT,EMERGENCY ROOM ADMIT,Medicare,WHITE,PNEUMONIA,False,WHITE,M,1837-07-11,1,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
43334,95344,195056,2141-01-09 20:48:00,NaT,CLINIC REFERRAL/PREMATURE,Medicare,WHITE,DYSPNEA,False,WHITE,M,1841-01-09,1,
43362,98647,167391,2156-02-28 17:21:00,NaT,EMERGENCY ROOM ADMIT,Medicare,WHITE,LEFT FEMUR FRACTURE,False,WHITE,M,1856-02-28,1,
43380,98713,105148,2140-04-04 13:11:00,NaT,TRANSFER FROM HOSP/EXTRAM,Medicare,WHITE,ST-SEGMENT ELEVATION MYOCARDIAL INFARCTION\CAR...,False,WHITE,F,1840-04-04,0,
43381,98714,171885,2104-11-12 12:31:00,NaT,TRANSFER FROM HOSP/EXTRAM,Medicare,WHITE,CORONARY ARTERY DISEASE\RIGHT CORONARY ARTERY ...,False,WHITE,M,1804-11-12,1,


In [77]:
admissions_patients_df.dropna(subset=['AGE_YRS'], inplace=True, ignore_index=True)
admissions_patients_df['AGE_YRS'].describe()

count    40900.000000
mean        61.983834
std         17.691616
min          0.000000
25%         51.035616
50%         64.230137
75%         76.394521
max         89.060274
Name: AGE_YRS, dtype: float64

In [78]:
admissions_patients_df

Unnamed: 0,SUBJECT_ID,HADM_ID,ADMITTIME,DEATHTIME,ADMISSION_LOCATION,INSURANCE,ETHNICITY,DIAGNOSIS,DEATH,ETHNICITY_BUCKET,GENDER,DOB,GENDER_BINARY,AGE_YRS
0,22,165315,2196-04-09 12:26:00,NaT,EMERGENCY ROOM ADMIT,Private,WHITE,BENZODIAZEPINE OVERDOSE,False,WHITE,F,2131-05-07,0,64.969863
1,23,124321,2157-10-18 19:34:00,NaT,TRANSFER FROM HOSP/EXTRAM,Medicare,WHITE,BRAIN MASS,False,WHITE,M,2082-07-17,1,75.304110
2,24,161859,2139-06-06 16:14:00,NaT,TRANSFER FROM HOSP/EXTRAM,Private,WHITE,INTERIOR MYOCARDIAL INFARCTION,False,WHITE,M,2100-05-31,1,39.041096
3,25,129635,2160-11-02 02:06:00,NaT,EMERGENCY ROOM ADMIT,Private,WHITE,ACUTE CORONARY SYNDROME,False,WHITE,M,2101-11-21,1,58.989041
4,26,197661,2126-05-06 15:16:00,NaT,TRANSFER FROM HOSP/EXTRAM,Medicare,UNKNOWN/NOT SPECIFIED,V-TACH,False,OTHER,M,2054-05-04,1,72.052055
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
40895,98797,105447,2132-12-24 20:06:00,2132-12-25 12:00:00,EMERGENCY ROOM ADMIT,Medicare,WHITE,ALTERED MENTAL STATUS,True,WHITE,M,2044-12-27,1,88.049315
40896,98800,191113,2131-03-30 21:13:00,NaT,CLINIC REFERRAL/PREMATURE,Private,WHITE,TRAUMA,False,WHITE,F,2111-11-05,0,19.410959
40897,98802,101071,2151-03-05 20:00:00,2151-03-06 09:10:00,CLINIC REFERRAL/PREMATURE,Medicare,WHITE,SAH,True,WHITE,F,2067-09-21,0,83.506849
40898,98813,170407,2128-11-11 02:29:00,NaT,EMERGENCY ROOM ADMIT,Private,WHITE,S/P FALL,False,WHITE,F,2068-02-04,0,60.808219


In [79]:
admissions_patients_df[admissions_patients_df['AGE_YRS'] == 0]

Unnamed: 0,SUBJECT_ID,HADM_ID,ADMITTIME,DEATHTIME,ADMISSION_LOCATION,INSURANCE,ETHNICITY,DIAGNOSIS,DEATH,ETHNICITY_BUCKET,GENDER,DOB,GENDER_BINARY,AGE_YRS
2774,3411,141023,2156-01-02 23:12:00,NaT,TRANSFER FROM HOSP/EXTRAM,Medicaid,UNKNOWN/NOT SPECIFIED,PREMATURITY,False,OTHER,F,2156-01-02,0,0.0
3808,7690,189747,2110-03-16 06:42:00,NaT,TRANSFER FROM HOSP/EXTRAM,Private,UNKNOWN/NOT SPECIFIED,PREMATURITY,False,OTHER,F,2110-03-16,0,0.0
5377,5630,146285,2116-07-17 18:01:00,NaT,TRANSFER FROM HOSP/EXTRAM,Private,WHITE,PREMATURITY,False,WHITE,F,2116-07-17,0,0.0
5960,7832,122473,2183-06-29 05:39:00,2183-06-29 16:58:00,TRANSFER FROM HOSP/EXTRAM,Self Pay,UNKNOWN/NOT SPECIFIED,PREMATURITY,True,OTHER,F,2183-06-29,0,0.0
7510,10977,164955,2139-11-06 12:00:00,NaT,TRANSFER FROM HOSP/EXTRAM,Private,PATIENT DECLINED TO ANSWER,PREMATURITY,False,OTHER,M,2139-11-06,1,0.0
9272,14258,109672,2136-02-03 20:37:00,NaT,TRANSFER FROM HOSP/EXTRAM,Self Pay,UNKNOWN/NOT SPECIFIED,PREMATURITY,False,OTHER,M,2136-02-03,1,0.0
10344,11599,161114,2177-09-05 21:38:00,NaT,TRANSFER FROM HOSP/EXTRAM,Private,UNKNOWN/NOT SPECIFIED,PREMATURITY,False,OTHER,F,2177-09-05,0,0.0
10540,14601,110097,2198-05-12 11:48:00,NaT,PHYS REFERRAL/NORMAL DELI,Self Pay,UNKNOWN/NOT SPECIFIED,RESPIRATORY DISTRESS,False,OTHER,M,2198-05-12,1,0.0
12060,15167,174629,2191-03-18 21:28:00,NaT,TRANSFER FROM HOSP/EXTRAM,Medicaid,UNKNOWN/NOT SPECIFIED,PREMATURITY,False,OTHER,F,2191-03-18,0,0.0
15811,20823,193552,2178-01-19 15:07:00,NaT,TRANSFER FROM HOSP/EXTRAM,Medicaid,UNKNOWN/NOT SPECIFIED,PERINATAL DEPRESSION,False,OTHER,F,2178-01-19,0,0.0


In [82]:
admissions_patients_deathrate = admissions_patients_df['DEATH'].mean()
admissions_patients_deathrate

0.1234718826405868

### Feature extraction

In [83]:
admissions_patients_feature_cols = [
    'SUBJECT_ID',
    'HADM_ID',
    # 'ADMISSION_TYPE', # removed non emergency/urgent admissions
    'ADMISSION_LOCATION',
    'INSURANCE',
    'ETHNICITY',
    # 'DIAGNOSIS' # skip, we'll get diagnostic information from NOTEEVENTS
    # Patients features
    'GENDER_BINARY',
    'AGE_YRS'
]
for column in admissions_patients_feature_cols:
    print(f'{column} unique values: {admissions_patients_df[column].nunique()}')

SUBJECT_ID unique values: 31663
HADM_ID unique values: 40900
ADMISSION_LOCATION unique values: 9
INSURANCE unique values: 5
ETHNICITY unique values: 41
GENDER_BINARY unique values: 2
AGE_YRS unique values: 18887


In [84]:
admissions_df_filtered['ETHNICITY'].value_counts(normalize=True)

ETHNICITY
WHITE                                                       0.698044
BLACK/AFRICAN AMERICAN                                      0.097450
UNKNOWN/NOT SPECIFIED                                       0.081876
HISPANIC OR LATINO                                          0.026977
OTHER                                                       0.022231
UNABLE TO OBTAIN                                            0.016173
ASIAN                                                       0.016034
PATIENT DECLINED TO ANSWER                                  0.006819
HISPANIC/LATINO - PUERTO RICAN                              0.004700
ASIAN - CHINESE                                             0.004354
BLACK/CAPE VERDEAN                                          0.003410
WHITE - RUSSIAN                                             0.003340
BLACK/HAITIAN                                               0.002143
MULTI RACE ETHNICITY                                        0.002050
HISPANIC/LATINO - DOMINI

In [85]:
ethnicity_buckets = [
    'WHITE',
    'BLACK',
    'HISPANIC',
    'ASIAN'
]

def map_ethnicity(ethnicity_str):
    ethnicity_upper = ethnicity_str.upper()
    for bucket in ethnicity_buckets:
        if bucket in ethnicity_upper:
            return bucket
    return 'OTHER'


admissions_patients_df['ETHNICITY_BUCKET'] = admissions_patients_df['ETHNICITY'].apply(map_ethnicity)
admissions_patients_df['ETHNICITY_BUCKET'].value_counts(normalize=True)


ETHNICITY_BUCKET
WHITE       0.695941
OTHER       0.135306
BLACK       0.106895
HISPANIC    0.037335
ASIAN       0.024523
Name: proportion, dtype: float64

In [86]:
admissions_patients_feature_cols.append('ETHNICITY_BUCKET')
admissions_patients_feature_cols.remove('ETHNICITY')

In [87]:
admissions_patients_features_df = admissions_patients_df[admissions_patients_feature_cols].copy()

# Convert string columns to 'category' dtype
for col in admissions_patients_feature_cols:
    if col != 'AGE_YRS':  # Only convert non-numeric columns to category
        admissions_patients_features_df[col] = admissions_patients_features_df[col].astype('category')

admissions_patients_features_df['DEATH'] = admissions_patients_df['DEATH']
admissions_patients_features_df

Unnamed: 0,SUBJECT_ID,HADM_ID,ADMISSION_LOCATION,INSURANCE,GENDER_BINARY,AGE_YRS,ETHNICITY_BUCKET,DEATH
0,22,165315,EMERGENCY ROOM ADMIT,Private,0,64.969863,WHITE,False
1,23,124321,TRANSFER FROM HOSP/EXTRAM,Medicare,1,75.304110,WHITE,False
2,24,161859,TRANSFER FROM HOSP/EXTRAM,Private,1,39.041096,WHITE,False
3,25,129635,EMERGENCY ROOM ADMIT,Private,1,58.989041,WHITE,False
4,26,197661,TRANSFER FROM HOSP/EXTRAM,Medicare,1,72.052055,OTHER,False
...,...,...,...,...,...,...,...,...
40895,98797,105447,EMERGENCY ROOM ADMIT,Medicare,1,88.049315,WHITE,True
40896,98800,191113,CLINIC REFERRAL/PREMATURE,Private,0,19.410959,WHITE,False
40897,98802,101071,CLINIC REFERRAL/PREMATURE,Medicare,0,83.506849,WHITE,True
40898,98813,170407,EMERGENCY ROOM ADMIT,Private,0,60.808219,WHITE,False


In [88]:
admissions_patients_features_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 40900 entries, 0 to 40899
Data columns (total 8 columns):
 #   Column              Non-Null Count  Dtype   
---  ------              --------------  -----   
 0   SUBJECT_ID          40900 non-null  category
 1   HADM_ID             40900 non-null  category
 2   ADMISSION_LOCATION  40900 non-null  category
 3   INSURANCE           40900 non-null  category
 4   GENDER_BINARY       40900 non-null  category
 5   AGE_YRS             40900 non-null  float64 
 6   ETHNICITY_BUCKET    40900 non-null  category
 7   DEATH               40900 non-null  bool    
dtypes: bool(1), category(6), float64(1)
memory usage: 3.3 MB


In [89]:
assert_unique_HADM_ID(admissions_patients_features_df)

### Eval

In [90]:
admissions_patients_feature_importances = train_and_eval_prelim_xgb(admissions_patients_features_df, admissions_patients_deathrate, label_col='DEATH', id_cols=['SUBJECT_ID', 'HADM_ID'])

Accuracy: 0.556601466992665
ROC AUC: 0.5944
Average Precision: 0.1680
              precision    recall  f1-score   support

       False       0.90      0.55      0.69      7176
        True       0.15      0.57      0.24      1004

    accuracy                           0.56      8180
   macro avg       0.53      0.56      0.46      8180
weighted avg       0.81      0.56      0.63      8180

Feature importances:


Unnamed: 0,feature,importance
3,AGE_YRS,0.249284
4,ETHNICITY_BUCKET,0.211507
1,INSURANCE,0.188572
0,ADMISSION_LOCATION,0.183278
2,GENDER_BINARY,0.16736


In [91]:
admissions_patients_keep_features = top_p_features(admissions_patients_feature_importances)['feature'].tolist()
admissions_patients_keep_features

['AGE_YRS', 'ETHNICITY_BUCKET', 'INSURANCE', 'ADMISSION_LOCATION']

## Datetime Events (2 features)

### Feature extraction

In [92]:
datetimeevents_feature_cols = [
    'HADM_ID',
    'ITEMID',
    'CGID'
]
for column in datetimeevents_feature_cols:
    print(f'{column} unique values: {datetimeevents_df_filtered_24h[column].nunique()}')

HADM_ID unique values: 22398
ITEMID unique values: 148
CGID unique values: 922


In [93]:
datetimeevents_features_df_many = datetimeevents_df_filtered_24h[datetimeevents_feature_cols].copy()

# Convert string columns to 'category' dtype
for col in datetimeevents_feature_cols:
    datetimeevents_features_df_many[col] = datetimeevents_features_df_many[col].astype('category')

datetimeevents_features_df_many['DEATH'] = datetimeevents_df_filtered_24h['DEATH']
datetimeevents_features_df_many

Unnamed: 0,HADM_ID,ITEMID,CGID,DEATH
0,112755.0,3411,15484,False
1,112755.0,3411,15484,False
2,112755.0,3411,15713,False
3,112755.0,3411,17604,False
4,112755.0,3411,14437,False
...,...,...,...,...
398318,193894.0,224288,20622,False
398319,193894.0,224290,20622,False
398320,193894.0,224284,18576,False
398321,193894.0,224287,18576,False


### Aggregate features

In [94]:
# Group by HADM_ID and aggregate unique ITEMID and CGID values into lists
datetimeevents_features_df_lists = datetimeevents_features_df_many.groupby('HADM_ID').agg(
    ITEMID_LIST=('ITEMID', lambda x: sorted(list(set(x)))),
    CGID_LIST=('CGID', lambda x: sorted(list(set(x)))),
    DEATH=('DEATH', 'first')  # Keep the first DEATH value (all are the same per HADM_ID)
).reset_index()

# Calculate list lengths for additional features
datetimeevents_features_df_lists['ITEMID_COUNT'] = datetimeevents_features_df_lists['ITEMID_LIST'].apply(len)
datetimeevents_features_df_lists['CGID_COUNT'] = datetimeevents_features_df_lists['CGID_LIST'].apply(len)

print("\nFeature statistics:")
print(f"Average number of unique ITEMIDs per admission: {datetimeevents_features_df_lists['ITEMID_COUNT'].mean():.2f}")
print(f"Average number of unique CGIDs per admission: {datetimeevents_features_df_lists['CGID_COUNT'].mean():.2f}")

datetimeevents_features_df_lists

  datetimeevents_features_df_lists = datetimeevents_features_df_many.groupby('HADM_ID').agg(



Feature statistics:
Average number of unique ITEMIDs per admission: 5.29
Average number of unique CGIDs per admission: 3.28


Unnamed: 0,HADM_ID,ITEMID_LIST,CGID_LIST,DEATH,ITEMID_COUNT,CGID_COUNT
0,100001.0,"[225754, 226515, 226719]","[15815, 18328, 20889, 21290]",False,3,4
1,100003.0,"[224280, 225755, 225756, 226515, 226724]","[15830, 16797, 17175, 17693, 20889]",False,5,5
2,100007.0,"[5684, 6703, 6704]","[15816, 18320]",False,3,2
3,100009.0,"[226515, 226724]",[20889],False,2,1
4,100011.0,"[224284, 224287, 224288, 224290, 225754, 22575...","[18222, 18792, 20884, 20889, 20951]",False,8,5
...,...,...,...,...,...,...
22393,199958.0,"[225755, 226515, 226724]","[17250, 18144, 20889]",False,3,3
22394,199962.0,"[224284, 224287, 224288, 224290, 225756, 226515]","[19710, 20081, 20889, 20890]",False,6,4
22395,199967.0,"[225754, 226515, 226724]","[18846, 20816, 20889]",False,3,3
22396,199984.0,"[225753, 225754, 226515, 226724]","[14431, 17155, 20889]",False,4,3


In [95]:
# Convert lists to strings for encoding
datetimeevents_features_df_lists['ITEMID_LIST_STR'] = datetimeevents_features_df_lists['ITEMID_LIST'].apply(lambda x: ','.join(map(str, x)))
datetimeevents_features_df_lists['CGID_LIST_STR'] = datetimeevents_features_df_lists['CGID_LIST'].apply(lambda x: ','.join(map(str, x)))

# Apply label encoding -- label encoding here not ideal, but best option for high-cardinality categorical data w/o training embeddings
itemid_encoder = LabelEncoder()
cgid_encoder = LabelEncoder()

datetimeevents_features_df_lists['ITEMID_ENCODED'] = itemid_encoder.fit_transform(datetimeevents_features_df_lists['ITEMID_LIST_STR'])
datetimeevents_features_df_lists['CGID_ENCODED'] = cgid_encoder.fit_transform(datetimeevents_features_df_lists['CGID_LIST_STR'])

datetimeevents_features_df_lists['ITEMID_ENCODED'] = datetimeevents_features_df_lists['ITEMID_ENCODED'].astype('category')
datetimeevents_features_df_lists['CGID_ENCODED'] = datetimeevents_features_df_lists['CGID_ENCODED'].astype('category')

# Drop intermediate string columns and original list columns
features_to_keep = ['HADM_ID', 'ITEMID_ENCODED', 'CGID_ENCODED', 'ITEMID_COUNT', 'CGID_COUNT', 'DEATH']
datetimeevents_features_df = datetimeevents_features_df_lists[features_to_keep]

# Print the first few rows
datetimeevents_features_df

Unnamed: 0,HADM_ID,ITEMID_ENCODED,CGID_ENCODED,ITEMID_COUNT,CGID_COUNT,DEATH
0,100001.0,3509,10014,3,4,False
1,100003.0,1766,10149,5,5,False
2,100007.0,3940,10104,3,2,False
3,100009.0,3645,17549,2,1,False
4,100011.0,2251,15531,8,5,False
...,...,...,...,...,...,...
22393,199958.0,3596,14067,3,3,False
22394,199962.0,2366,17054,6,4,False
22395,199967.0,3512,16306,3,3,False
22396,199984.0,3340,2418,4,3,False


In [96]:
datetimeevents_features_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 22398 entries, 0 to 22397
Data columns (total 6 columns):
 #   Column          Non-Null Count  Dtype   
---  ------          --------------  -----   
 0   HADM_ID         22398 non-null  category
 1   ITEMID_ENCODED  22398 non-null  category
 2   CGID_ENCODED    22398 non-null  category
 3   ITEMID_COUNT    22398 non-null  int64   
 4   CGID_COUNT      22398 non-null  int64   
 5   DEATH           22398 non-null  bool    
dtypes: bool(1), category(3), int64(2)
memory usage: 2.0 MB


In [97]:
assert_unique_HADM_ID(datetimeevents_features_df)

### Eval

In [98]:
datetimeevents_feature_importances = train_and_eval_prelim_xgb(datetimeevents_features_df, datetimeevents_deathrate, label_col='DEATH')

Accuracy: 0.8580357142857142
ROC AUC: 0.6224
Average Precision: 0.1691
              precision    recall  f1-score   support

       False       0.87      0.98      0.92      3900
        True       0.20      0.03      0.05       580

    accuracy                           0.86      4480
   macro avg       0.53      0.51      0.49      4480
weighted avg       0.78      0.86      0.81      4480

Feature importances:


Unnamed: 0,feature,importance
1,CGID_ENCODED,0.726136
0,ITEMID_ENCODED,0.187224
2,ITEMID_COUNT,0.054533
3,CGID_COUNT,0.032107


In [99]:
datetimeevents_keep_features = top_p_features(datetimeevents_feature_importances)['feature'].tolist()
datetimeevents_keep_features

['CGID_ENCODED', 'ITEMID_ENCODED']

## Note Events (10 PCA features)

### Extract Embeddings

In [100]:
clinicalBERT = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
clinicalBERT_tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

def get_bert_embedding(text, model, tokenizer):
    inputs = tokenizer(text, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    with torch.no_grad():
        output = model(**inputs)
    return output.last_hidden_state.mean(dim=1).squeeze().numpy()

In [101]:
# batch embeddings and move to GPU
def get_bert_embeddings_batched(texts:list[str], model, tokenizer, batch_size=32, tokenizer_max_length=512) -> np.ndarray:
    if torch.cuda.is_available():
        print('Using CUDA')
        model = model.cuda()

    embeddings = []
    for i in range(0, len(texts), batch_size):
        batch_idx = i // batch_size
        if batch_idx % 20 == 0: # Print every 50 batches
            print(f'Processing batch {batch_idx} of {len(texts) // batch_size}')
            
        batch_texts = texts[i:i+batch_size]
        inputs = tokenizer(batch_texts, padding="max_length", truncation=True, 
                          max_length=tokenizer_max_length, return_tensors="pt")  # Reduced max_length
        
        # Move to GPU if available
        if torch.cuda.is_available():
            inputs = {k: v.cuda() for k, v in inputs.items()}
            
        with torch.no_grad():
            outputs = model(**inputs)
        
        batch_embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
        embeddings.append(batch_embeddings)

    return np.vstack(embeddings)

In [102]:
example_text = "Patient presented with hypotension and low oxygen saturation."
embedding = get_bert_embedding(example_text, clinicalBERT, clinicalBERT_tokenizer)
embedding.shape

(768,)

In [103]:
noteevents_test_tokens = noteevents_df_filtered_24h['TEXT'].sample(1000, random_state=42)\
    .apply(lambda x: clinicalBERT_tokenizer.encode(x, truncation=False))
noteevents_test_tokens

185201    [101, 11138, 14940, 131, 24928, 11955, 131, 18...
13691     [101, 1512, 26063, 192, 7903, 1114, 188, 1513,...
223011    [101, 19192, 1920, 131, 185, 1204, 4120, 1106,...
14461     [101, 182, 1197, 119, 164, 115, 115, 1227, 131...
188972    [101, 14255, 112, 189, 1104, 1920, 131, 15508,...
                                ...                        
11960     [101, 1641, 131, 21685, 24841, 10540, 131, 160...
187069    [101, 1231, 20080, 1920, 117, 185, 1204, 119, ...
59873     [101, 2705, 12522, 131, 1572, 2396, 1958, 131,...
131598    [101, 164, 115, 115, 18615, 1580, 118, 130, 11...
139615    [101, 164, 115, 115, 13075, 1477, 118, 123, 11...
Name: TEXT, Length: 1000, dtype: object

In [104]:
print(f'Minimum token length: {noteevents_test_tokens.apply(len).min()}')
print(f'Maximum token length: {noteevents_test_tokens.apply(len).max()}')
print(f'Mean token length: {noteevents_test_tokens.apply(len).mean()}')
print(f'Median token length: {noteevents_test_tokens.apply(len).median()}')

Minimum token length: 13
Maximum token length: 5069
Mean token length: 673.372
Median token length: 459.0


In [105]:
# 229486 rows took 24 min with batch size of 128
# noteevents_embeddings = get_bert_embeddings_batched( # uncomment to run
#     list(noteevents_df_filtered_24h['TEXT']),
#     clinicalBERT,
#     clinicalBERT_tokenizer,
#     batch_size=256 # Takes about 10GB of GPU memory
# )
# noteevents_embeddings

In [106]:
# noteevents_embeddings.shape

In [107]:
# noteevents_embeddings[0]

In [108]:
# Save the embeddings to disk for future use

# Define a path to save the embeddings
noteevents_embeddings_path = os.path.join('feature-embeddings', 'noteevents_embeddings.npy')

# Save the numpy array to disk
# np.save(noteevents_embeddings_path, noteevents_embeddings)

# print(f"Saved embeddings to {noteevents_embeddings_path}")

### Aggregate Embeddings

In [109]:
noteevents_embeddings = np.load(noteevents_embeddings_path)

In [110]:
noteevents_embeddings.shape

(229486, 768)

In [111]:
noteevents_embeddings_df = pd.DataFrame({
    'HADM_ID': noteevents_df_filtered_24h['HADM_ID'],
    'NOTE_EMBEDDING': list(noteevents_embeddings),
    'DEATH': noteevents_df_filtered_24h['DEATH']
})
noteevents_embeddings_df

Unnamed: 0,HADM_ID,NOTE_EMBEDDING,DEATH
0,136227.0,"[0.12009446, -0.18610138, -0.16754776, 0.16050...",False
1,157324.0,"[-0.10679449, -0.12896663, -0.12561706, 0.2088...",False
2,157324.0,"[-0.10679449, -0.12896663, -0.12561706, 0.2088...",False
3,157324.0,"[-0.10679449, -0.12896663, -0.12561706, 0.2088...",False
4,165003.0,"[0.097343326, 0.021206781, 0.0088621415, -0.01...",False
...,...,...,...
229481,122817.0,"[0.08896427, 0.026162047, -0.28558433, -0.0401...",False
229482,122817.0,"[-0.019218164, 0.046230193, -0.1718328, 0.1286...",False
229483,122817.0,"[-0.043659694, -0.13485357, -0.12930332, 1.557...",False
229484,122817.0,"[-0.19425957, -0.082947776, -0.22138855, 0.201...",False


In [112]:
# Group by HADM_ID and aggregate note embeddings by taking the average
noteevents_embeddings_df_avg = noteevents_embeddings_df.groupby('HADM_ID').agg(
    NOTE_EMBEDDING_AVG=('NOTE_EMBEDDING', lambda x: np.mean(np.vstack(x), axis=0)),
    NOTE_COUNT=('NOTE_EMBEDDING', 'count'),  # Count number of notes per HADM_ID
    DEATH=('DEATH', 'first')  # Keep the first DEATH value (all are the same per HADM_ID)
).reset_index()

print("\nFeature statistics:")
print(f"Average number of unique notes per admission: {noteevents_embeddings_df_avg['NOTE_COUNT'].mean():.2f}")

noteevents_embeddings_df_avg


Feature statistics:
Average number of unique notes per admission: 5.98


Unnamed: 0,HADM_ID,NOTE_EMBEDDING_AVG,NOTE_COUNT,DEATH
0,100001.0,"[0.03339041, -0.06324264, -0.07525222, 0.06214...",1,False
1,100003.0,"[0.050145224, -0.071770415, -0.22003864, 0.094...",13,False
2,100006.0,"[-0.0016322257, -0.14573509, -0.15141164, 0.08...",4,False
3,100007.0,"[0.016669223, -0.018630615, -0.14267147, 0.035...",2,False
4,100009.0,"[0.16083565, -0.16822207, -0.24011362, -0.0431...",2,False
...,...,...,...,...
38360,199993.0,"[0.034504432, -0.093784995, -0.13057417, -0.02...",3,False
38361,199994.0,"[0.03164832, -0.04555378, -0.14422141, 0.09916...",7,False
38362,199995.0,"[0.10537794, -0.16849351, -0.15182771, 0.11871...",4,False
38363,199998.0,"[0.08895928, -0.12588014, -0.12862699, 0.03421...",2,False


### PCA Embeddings

In [113]:
# Stack the embeddings into a numpy array
noteevents_embeddings_array = np.vstack(noteevents_embeddings_df_avg['NOTE_EMBEDDING_AVG'].values)

# Apply PCA to reduce dimensions (10 components)
pca = PCA(n_components=10) # 10 components retain 72% variance
noteevents_reduced_embeddings = pca.fit_transform(noteevents_embeddings_array)

# Create DataFrame with reduced embeddings
noteevents_features_df = pd.DataFrame(
    noteevents_reduced_embeddings, 
    columns=[f'note_pca_{i}' for i in range(noteevents_reduced_embeddings.shape[1])]
)

# Add the DEATH & HADM_ID columns
noteevents_features_df['HADM_ID'] = noteevents_embeddings_df_avg['HADM_ID'].astype('category')
noteevents_features_df['NOTE_COUNT'] = noteevents_embeddings_df_avg['NOTE_COUNT']
noteevents_features_df['DEATH'] = noteevents_embeddings_df_avg['DEATH']

print(f"Variance explained by PCA components: {pca.explained_variance_ratio_.sum():.4f}")
noteevents_features_df


Variance explained by PCA components: 0.8556


Unnamed: 0,note_pca_0,note_pca_1,note_pca_2,note_pca_3,note_pca_4,note_pca_5,note_pca_6,note_pca_7,note_pca_8,note_pca_9,HADM_ID,NOTE_COUNT,DEATH
0,1.708682,-1.868920,-0.248604,-0.667687,-0.611248,-0.426874,0.394282,-0.239990,0.350912,0.432565,100001.0,1,False
1,-1.957087,1.143611,-0.526526,-0.312375,0.018176,-0.420411,-0.005299,-0.244733,-0.007793,-0.486583,100003.0,13,False
2,-2.146841,-0.247298,0.186115,-0.081406,0.558931,-0.003863,0.476413,-0.371807,0.001280,0.174869,100006.0,4,False
3,-4.285731,-1.483348,0.608325,-0.068249,0.411864,-0.114133,-0.188602,0.184932,0.312753,-0.505277,100007.0,2,False
4,3.005071,0.930612,0.307408,0.131238,-0.231828,0.493441,0.064175,0.095169,-0.205803,-0.421829,100009.0,2,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...
38360,-3.545998,0.075124,0.736832,-0.691070,0.300872,0.442168,0.115442,0.060266,-0.099181,-0.164120,199993.0,3,False
38361,-1.172173,-0.525782,-0.142164,0.399935,0.376993,0.033307,0.279896,-0.018568,0.143480,-0.092740,199994.0,7,False
38362,1.982328,-0.179626,-0.432048,-0.100182,0.222691,0.110157,-0.344956,0.113496,0.121876,0.016949,199995.0,4,False
38363,1.875252,-0.587410,-0.712523,-0.417801,-0.010714,0.443398,0.037928,0.533628,-0.221201,-0.071894,199998.0,2,False


### Eval

In [114]:
assert_unique_HADM_ID(noteevents_features_df)

In [115]:
noteevents_deathrate = noteevents_features_df['DEATH'].mean()
noteevents_deathrate

0.12714713931969243

In [116]:
noteevents_feature_importances = train_and_eval_prelim_xgb(noteevents_features_df, noteevents_deathrate, label_col='DEATH')

Accuracy: 0.7283982796820018
ROC AUC: 0.7170
Average Precision: 0.2807
              precision    recall  f1-score   support

       False       0.92      0.75      0.83      6717
        True       0.24      0.54      0.33       956

    accuracy                           0.73      7673
   macro avg       0.58      0.65      0.58      7673
weighted avg       0.84      0.73      0.77      7673

Feature importances:


Unnamed: 0,feature,importance
3,note_pca_3,0.212199
6,note_pca_6,0.096281
7,note_pca_7,0.093039
0,note_pca_0,0.081601
8,note_pca_8,0.079843
10,NOTE_COUNT,0.077883
5,note_pca_5,0.074327
4,note_pca_4,0.072861
9,note_pca_9,0.071846
1,note_pca_1,0.070945


## Prescriptions (1 feature)

### Feature extraction

In [117]:
prescriptions_feature_cols = [
    'HADM_ID',
    'DRUG'
]
prescriptions_features_df_many = prescriptions_df_filtered_24h[prescriptions_feature_cols].copy()

# Convert string columns to 'category' dtype
# for col in prescriptions_feature_cols:
#     prescriptions_features_df_many[col] = prescriptions_features_df_many[col].astype('category')

prescriptions_features_df_many['DEATH'] = prescriptions_df_filtered_24h['DEATH']
prescriptions_features_df_many

Unnamed: 0,HADM_ID,DRUG,DEATH
0,143045,D5W,False
1,143045,Heparin Sodium,False
2,143045,Nitroglycerin,False
3,143045,Docusate Sodium,False
4,143045,Insulin,False
...,...,...,...
1265959,121032,PredniSONE,False
1265960,121032,Ipratropium Bromide Neb,False
1265961,121032,HYDROmorphone (Dilaudid),False
1265962,121032,Docusate Sodium,False


In [118]:
prescriptions_features_df_many.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1265964 entries, 0 to 1265963
Data columns (total 3 columns):
 #   Column   Non-Null Count    Dtype 
---  ------   --------------    ----- 
 0   HADM_ID  1265964 non-null  int64 
 1   DRUG     1265964 non-null  object
 2   DEATH    1265964 non-null  bool  
dtypes: bool(1), int64(1), object(1)
memory usage: 20.5+ MB


### Get ClinicalBERT Embeddings

In [119]:
prescriptions_test_tokens = prescriptions_features_df_many['DRUG'].sample(1000, random_state=42)\
    .apply(lambda x: clinicalBERT_tokenizer.encode(x, truncation=False))
prescriptions_test_tokens

814631       [101, 11437, 10542, 9717, 2042, 178, 1964, 102]
620040                                 [101, 183, 1116, 102]
40060        [101, 182, 1766, 21587, 28117, 9654, 2193, 102]
624717             [101, 185, 10436, 12415, 7880, 8643, 102]
897897     [101, 1899, 4184, 13166, 4063, 28117, 19557, 1...
                                 ...                        
78876              [101, 172, 27969, 2312, 1183, 16430, 102]
800337          [101, 121, 119, 130, 110, 15059, 21256, 102]
603181                    [101, 1143, 3365, 2386, 2042, 102]
1061589            [101, 16516, 3202, 2528, 3810, 1233, 102]
1131730           [101, 20866, 10436, 7889, 7412, 9685, 102]
Name: DRUG, Length: 1000, dtype: object

In [120]:
print(f'Minimum token length: {prescriptions_test_tokens.apply(len).min()}')
print(f'Maximum token length: {prescriptions_test_tokens.apply(len).max()}')
print(f'Mean token length: {prescriptions_test_tokens.apply(len).mean()}')
print(f'Median token length: {prescriptions_test_tokens.apply(len).median()}')

Minimum token length: 3
Maximum token length: 22
Mean token length: 7.182
Median token length: 7.0


In [121]:
# prescriptions_embeddings = get_bert_embeddings_batched( # uncomment to run
#     list(prescriptions_features_df_many['DRUG']),
#     clinicalBERT,
#     clinicalBERT_tokenizer,
#     tokenizer_max_length=32, # Reduce max_length to 32 to match max token length of 22
#     batch_size=256 # Takes about 10GB of GPU memory
# )
# prescriptions_embeddings

In [122]:
# prescriptions_embeddings.shape

In [123]:
prescriptions_embeddings_path = os.path.join(os.getcwd(), 'feature-embeddings', 'prescriptions_embeddings.npy')

# # Save the numpy array to disk
# np.save(prescriptions_embeddings_path, prescriptions_embeddings)

# print(f"Saved embeddings to {prescriptions_embeddings_path}")

In [124]:
prescriptions_embeddings = np.load(prescriptions_embeddings_path)
prescriptions_embeddings.shape

(1265964, 768)

In [125]:
prescriptions_embeddings_df = pd.DataFrame({
    'HADM_ID': prescriptions_features_df_many['HADM_ID'],
    'DRUG_EMBEDDING': list(prescriptions_embeddings),
    'DEATH': prescriptions_features_df_many['DEATH']
})
prescriptions_embeddings_df

Unnamed: 0,HADM_ID,DRUG_EMBEDDING,DEATH
0,143045,"[-0.07757282, -0.052578203, -0.1802096, 0.1119...",False
1,143045,"[-0.07849769, 0.012043957, 0.08178935, -0.0562...",False
2,143045,"[0.06420252, 0.050761342, -0.18494278, -0.1604...",False
3,143045,"[-0.059045363, 0.09230081, -0.036827233, 0.000...",False
4,143045,"[-0.15852846, -0.057911728, -0.07898846, 0.022...",False
...,...,...,...
1265959,121032,"[-0.054809596, -0.04827134, 0.1427296, 0.27137...",False
1265960,121032,"[0.102403425, 0.083332404, 0.15967578, 0.09949...",False
1265961,121032,"[-0.23464333, 0.09761905, 0.08254529, 0.363698...",False
1265962,121032,"[-0.05904512, 0.09230088, -0.03682725, 0.00074...",False


### Aggregate embeddings

In [126]:
# Group by HADM_ID and aggregate drug embeddings by taking the average
prescriptions_embeddings_df_avg = prescriptions_embeddings_df.groupby('HADM_ID').agg(
    DRUG_EMBEDDING_AVG=('DRUG_EMBEDDING', lambda x: np.mean(np.vstack(x), axis=0)),
    DRUG_COUNT=('DRUG_EMBEDDING', 'count'),  # Count number of drugs per HADM_ID
    DEATH=('DEATH', 'first')  # Keep the first DEATH value (all are the same per HADM_ID)
).reset_index()

print("\nFeature statistics:")
print(f"Average number of unique DRUGs per admission: {prescriptions_embeddings_df_avg['DRUG_COUNT'].mean():.2f}")

prescriptions_embeddings_df_avg


Feature statistics:
Average number of unique DRUGs per admission: 31.95


Unnamed: 0,HADM_ID,DRUG_EMBEDDING_AVG,DRUG_COUNT,DEATH
0,100001,"[-0.16638184, -0.09346691, -0.0629181, 0.14109...",54,False
1,100003,"[-0.165796, -0.10662353, -0.030378202, 0.11985...",20,False
2,100006,"[-0.10700698, 0.0104033165, 0.036883276, 0.049...",13,False
3,100007,"[-0.13051079, -0.063218296, -0.03339023, 0.125...",29,False
4,100009,"[-0.11628712, -0.07572043, -0.014120933, 0.072...",64,False
...,...,...,...,...
39614,199992,"[-0.11459154, -0.084150046, -0.05613905, 0.067...",34,False
39615,199993,"[-0.08653705, 0.0022539236, 0.031304993, 0.080...",15,False
39616,199995,"[-0.13258265, -0.0594926, 0.04097059, 0.101542...",6,False
39617,199998,"[-0.011335154, -0.0063688955, 0.054856975, 0.0...",15,False


In [127]:
prescriptions_embeddings_df_avg['DRUG_EMBEDDING_AVG'][0].shape

(768,)

### PCA Embeddings

In [128]:
# Stack the embeddings into a numpy array
prescriptions_embeddings_array = np.vstack(prescriptions_embeddings_df_avg['DRUG_EMBEDDING_AVG'].values)

# Apply PCA to reduce dimensions (10 components)
pca = PCA(n_components=10) # 10 components retain 72% variance
prescriptions_reduced_embeddings = pca.fit_transform(prescriptions_embeddings_array)

# Create DataFrame with reduced embeddings
prescriptions_features_df = pd.DataFrame(
    prescriptions_reduced_embeddings, 
    columns=[f'rx_pca_{i}' for i in range(prescriptions_reduced_embeddings.shape[1])]
)

# Add the DEATH & HADM_ID columns
prescriptions_features_df['HADM_ID'] = prescriptions_embeddings_df_avg['HADM_ID'].astype('category')
prescriptions_features_df['DRUG_COUNT'] = prescriptions_embeddings_df_avg['DRUG_COUNT']
prescriptions_features_df['DEATH'] = prescriptions_embeddings_df_avg['DEATH']

print(f"Variance explained by PCA components: {pca.explained_variance_ratio_.sum():.4f}")
prescriptions_features_df

Variance explained by PCA components: 0.7163


Unnamed: 0,rx_pca_0,rx_pca_1,rx_pca_2,rx_pca_3,rx_pca_4,rx_pca_5,rx_pca_6,rx_pca_7,rx_pca_8,rx_pca_9,HADM_ID,DRUG_COUNT,DEATH
0,-0.395903,-0.066814,-0.041165,0.233145,-0.130740,-0.038979,0.279914,-0.107969,0.034869,0.031378,100001,54,False
1,-1.105019,-0.037961,0.395463,-0.227490,-0.103449,-0.009000,0.068222,0.038178,-0.067452,0.176697,100003,20,False
2,0.080458,0.424011,-0.446228,0.084480,0.111681,0.182918,-0.348981,-0.149870,-0.077495,-0.280991,100006,13,False
3,-0.124133,-0.073028,0.376297,0.528203,0.108141,-0.070648,0.188848,-0.299253,0.181893,0.022235,100007,29,False
4,-0.402664,-0.063882,-0.189173,-0.325317,-0.014030,0.068377,0.228680,0.074844,0.013839,0.190815,100009,64,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...
39614,-0.619789,-0.354167,-0.120729,-0.090109,0.042412,-0.067700,-0.266099,0.002710,0.211213,0.029865,199992,34,False
39615,0.857906,0.461963,-0.158439,0.242260,-0.324705,0.187143,0.038167,0.228518,0.157575,0.242861,199993,15,False
39616,-0.490933,0.385175,0.420620,0.108515,-0.156675,0.305624,-0.099204,0.106822,0.219252,0.063419,199995,6,False
39617,0.535985,0.619083,0.199973,0.089354,-0.145917,-0.195839,0.017735,0.113981,-0.207446,-0.203276,199998,15,False


In [129]:
prescriptions_features_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 39619 entries, 0 to 39618
Data columns (total 13 columns):
 #   Column      Non-Null Count  Dtype   
---  ------      --------------  -----   
 0   rx_pca_0    39619 non-null  float32 
 1   rx_pca_1    39619 non-null  float32 
 2   rx_pca_2    39619 non-null  float32 
 3   rx_pca_3    39619 non-null  float32 
 4   rx_pca_4    39619 non-null  float32 
 5   rx_pca_5    39619 non-null  float32 
 6   rx_pca_6    39619 non-null  float32 
 7   rx_pca_7    39619 non-null  float32 
 8   rx_pca_8    39619 non-null  float32 
 9   rx_pca_9    39619 non-null  float32 
 10  HADM_ID     39619 non-null  category
 11  DRUG_COUNT  39619 non-null  int64   
 12  DEATH       39619 non-null  bool    
dtypes: bool(1), category(1), float32(10), int64(1)
memory usage: 3.3 MB


In [130]:
assert_unique_HADM_ID(prescriptions_features_df)

### Eval

In [131]:
prescriptions_deathrate = prescriptions_features_df['DEATH'].mean()
prescriptions_deathrate

0.12670688306115752

In [132]:
prescriptions_feature_importances = train_and_eval_prelim_xgb(prescriptions_features_df, prescriptions_deathrate, label_col='DEATH')

Accuracy: 0.7455830388692579
ROC AUC: 0.6687
Average Precision: 0.2845
              precision    recall  f1-score   support

       False       0.90      0.79      0.84      6870
        True       0.24      0.43      0.31      1054

    accuracy                           0.75      7924
   macro avg       0.57      0.61      0.58      7924
weighted avg       0.81      0.75      0.77      7924

Feature importances:


Unnamed: 0,feature,importance
2,rx_pca_2,0.155943
1,rx_pca_1,0.103532
4,rx_pca_4,0.093875
5,rx_pca_5,0.088996
3,rx_pca_3,0.085067
9,rx_pca_9,0.08215
10,DRUG_COUNT,0.080662
8,rx_pca_8,0.080209
0,rx_pca_0,0.078174
7,rx_pca_7,0.077176


## Transfers (3 features)

### Feature extraction

In [133]:
# Count the number of transfers for each HADM_ID
transfer_counts = transfers_df_filtered_24h['HADM_ID'].value_counts().to_dict()

# Add the NUM_TRANSFERS column
transfers_df_filtered_24h['NUM_TRANSFERS'] = transfers_df_filtered_24h['HADM_ID'].map(transfer_counts)

# Display the first few rows to verify
transfers_df_filtered_24h

Unnamed: 0,HADM_ID,PREV_WARDID,CURR_WARDID,INTIME,ADMITTIME,DEATH,NUM_TRANSFERS
0,155897,52.0,32.0,2144-07-01 05:19:39,2144-07-01 04:12:00,True,4
1,155897,32.0,52.0,2144-07-01 06:28:29,2144-07-01 04:12:00,True,4
2,155897,52.0,32.0,2144-07-01 08:07:16,2144-07-01 04:12:00,True,4
3,155897,32.0,23.0,2144-07-01 08:13:51,2144-07-01 04:12:00,True,4
4,174105,12.0,3.0,2194-06-14 14:51:17,2194-06-13 18:39:00,False,2
...,...,...,...,...,...,...,...
25908,126800,3.0,7.0,2140-07-14 22:59:38,2140-07-14 18:30:00,False,1
25909,195599,36.0,36.0,2108-10-05 20:45:56,2108-10-05 20:35:00,False,3
25910,195599,36.0,49.0,2108-10-06 11:27:11,2108-10-05 20:35:00,False,3
25911,195599,49.0,33.0,2108-10-06 13:05:57,2108-10-05 20:35:00,False,3


In [134]:
transfers_feature_cols = [
    'HADM_ID',
    'PREV_WARDID',
    'CURR_WARDID',
    'NUM_TRANSFERS'
]
for column in transfers_feature_cols:
    print(f'{column} unique values: {transfers_df_filtered_24h[column].nunique()}')

HADM_ID unique values: 16722
PREV_WARDID unique values: 49
CURR_WARDID unique values: 51
NUM_TRANSFERS unique values: 10


In [135]:
transfers_features_df = transfers_df_filtered_24h[transfers_feature_cols].copy()

transfers_features_df[['PREV_WARDID','CURR_WARDID']] = transfers_features_df[['PREV_WARDID','CURR_WARDID']].astype('category')

transfers_features_df['DEATH'] = transfers_df_filtered_24h['DEATH']
transfers_features_df

Unnamed: 0,HADM_ID,PREV_WARDID,CURR_WARDID,NUM_TRANSFERS,DEATH
0,155897,52.0,32.0,4,True
1,155897,32.0,52.0,4,True
2,155897,52.0,32.0,4,True
3,155897,32.0,23.0,4,True
4,174105,12.0,3.0,2,False
...,...,...,...,...,...
25908,126800,3.0,7.0,1,False
25909,195599,36.0,36.0,3,False
25910,195599,36.0,49.0,3,False
25911,195599,49.0,33.0,3,False


In [136]:
transfers_features_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 25913 entries, 0 to 25912
Data columns (total 5 columns):
 #   Column         Non-Null Count  Dtype   
---  ------         --------------  -----   
 0   HADM_ID        25913 non-null  int64   
 1   PREV_WARDID    25913 non-null  category
 2   CURR_WARDID    25913 non-null  category
 3   NUM_TRANSFERS  25913 non-null  int64   
 4   DEATH          25913 non-null  bool    
dtypes: bool(1), category(2), int64(2)
memory usage: 484.8 KB


### Eval

In [137]:
transfers_feature_importances = train_and_eval_prelim_xgb(transfers_features_df, transfers_deathrate, label_col='DEATH')

Accuracy: 0.5558556820374301
ROC AUC: 0.6021
Average Precision: 0.1431
              precision    recall  f1-score   support

       False       0.93      0.55      0.69      4699
        True       0.12      0.58      0.20       484

    accuracy                           0.56      5183
   macro avg       0.52      0.57      0.44      5183
weighted avg       0.85      0.56      0.65      5183

Feature importances:


Unnamed: 0,feature,importance
0,PREV_WARDID,0.401003
1,CURR_WARDID,0.377337
2,NUM_TRANSFERS,0.22166


In [138]:
transfers_keep_features = top_p_features(transfers_feature_importances)['feature'].tolist()
transfers_keep_features

['PREV_WARDID', 'CURR_WARDID', 'NUM_TRANSFERS']

# <u>Combine feature sets</u>
Just include Admissions-Patients, Note Events, and Prescriptions for now. The rest can be included with more time to embed categorical lists.

## Merge feature sets

In [139]:
admissions_patients_features_df

Unnamed: 0,SUBJECT_ID,HADM_ID,ADMISSION_LOCATION,INSURANCE,GENDER_BINARY,AGE_YRS,ETHNICITY_BUCKET,DEATH
0,22,165315,EMERGENCY ROOM ADMIT,Private,0,64.969863,WHITE,False
1,23,124321,TRANSFER FROM HOSP/EXTRAM,Medicare,1,75.304110,WHITE,False
2,24,161859,TRANSFER FROM HOSP/EXTRAM,Private,1,39.041096,WHITE,False
3,25,129635,EMERGENCY ROOM ADMIT,Private,1,58.989041,WHITE,False
4,26,197661,TRANSFER FROM HOSP/EXTRAM,Medicare,1,72.052055,OTHER,False
...,...,...,...,...,...,...,...,...
40895,98797,105447,EMERGENCY ROOM ADMIT,Medicare,1,88.049315,WHITE,True
40896,98800,191113,CLINIC REFERRAL/PREMATURE,Private,0,19.410959,WHITE,False
40897,98802,101071,CLINIC REFERRAL/PREMATURE,Medicare,0,83.506849,WHITE,True
40898,98813,170407,EMERGENCY ROOM ADMIT,Private,0,60.808219,WHITE,False


In [140]:
noteevents_features_df

Unnamed: 0,note_pca_0,note_pca_1,note_pca_2,note_pca_3,note_pca_4,note_pca_5,note_pca_6,note_pca_7,note_pca_8,note_pca_9,HADM_ID,NOTE_COUNT,DEATH
0,1.708682,-1.868920,-0.248604,-0.667687,-0.611248,-0.426874,0.394282,-0.239990,0.350912,0.432565,100001.0,1,False
1,-1.957087,1.143611,-0.526526,-0.312375,0.018176,-0.420411,-0.005299,-0.244733,-0.007793,-0.486583,100003.0,13,False
2,-2.146841,-0.247298,0.186115,-0.081406,0.558931,-0.003863,0.476413,-0.371807,0.001280,0.174869,100006.0,4,False
3,-4.285731,-1.483348,0.608325,-0.068249,0.411864,-0.114133,-0.188602,0.184932,0.312753,-0.505277,100007.0,2,False
4,3.005071,0.930612,0.307408,0.131238,-0.231828,0.493441,0.064175,0.095169,-0.205803,-0.421829,100009.0,2,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...
38360,-3.545998,0.075124,0.736832,-0.691070,0.300872,0.442168,0.115442,0.060266,-0.099181,-0.164120,199993.0,3,False
38361,-1.172173,-0.525782,-0.142164,0.399935,0.376993,0.033307,0.279896,-0.018568,0.143480,-0.092740,199994.0,7,False
38362,1.982328,-0.179626,-0.432048,-0.100182,0.222691,0.110157,-0.344956,0.113496,0.121876,0.016949,199995.0,4,False
38363,1.875252,-0.587410,-0.712523,-0.417801,-0.010714,0.443398,0.037928,0.533628,-0.221201,-0.071894,199998.0,2,False


In [141]:
prescriptions_features_df

Unnamed: 0,rx_pca_0,rx_pca_1,rx_pca_2,rx_pca_3,rx_pca_4,rx_pca_5,rx_pca_6,rx_pca_7,rx_pca_8,rx_pca_9,HADM_ID,DRUG_COUNT,DEATH
0,-0.395903,-0.066814,-0.041165,0.233145,-0.130740,-0.038979,0.279914,-0.107969,0.034869,0.031378,100001,54,False
1,-1.105019,-0.037961,0.395463,-0.227490,-0.103449,-0.009000,0.068222,0.038178,-0.067452,0.176697,100003,20,False
2,0.080458,0.424011,-0.446228,0.084480,0.111681,0.182918,-0.348981,-0.149870,-0.077495,-0.280991,100006,13,False
3,-0.124133,-0.073028,0.376297,0.528203,0.108141,-0.070648,0.188848,-0.299253,0.181893,0.022235,100007,29,False
4,-0.402664,-0.063882,-0.189173,-0.325317,-0.014030,0.068377,0.228680,0.074844,0.013839,0.190815,100009,64,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...
39614,-0.619789,-0.354167,-0.120729,-0.090109,0.042412,-0.067700,-0.266099,0.002710,0.211213,0.029865,199992,34,False
39615,0.857906,0.461963,-0.158439,0.242260,-0.324705,0.187143,0.038167,0.228518,0.157575,0.242861,199993,15,False
39616,-0.490933,0.385175,0.420620,0.108515,-0.156675,0.305624,-0.099204,0.106822,0.219252,0.063419,199995,6,False
39617,0.535985,0.619083,0.199973,0.089354,-0.145917,-0.195839,0.017735,0.113981,-0.207446,-0.203276,199998,15,False


In [142]:
merged_features_df = pd.merge(
    admissions_patients_features_df,
    noteevents_features_df,
    on=['HADM_ID', 'DEATH'],
    how='inner'
)
merged_features_df

Unnamed: 0,SUBJECT_ID,HADM_ID,ADMISSION_LOCATION,INSURANCE,GENDER_BINARY,AGE_YRS,ETHNICITY_BUCKET,DEATH,note_pca_0,note_pca_1,note_pca_2,note_pca_3,note_pca_4,note_pca_5,note_pca_6,note_pca_7,note_pca_8,note_pca_9,NOTE_COUNT
0,22,165315,EMERGENCY ROOM ADMIT,Private,0,64.969863,WHITE,False,-1.406294,-0.382817,-0.147326,0.230472,0.167391,0.152819,-0.309732,-0.353389,-0.394005,0.136460,5
1,23,124321,TRANSFER FROM HOSP/EXTRAM,Medicare,1,75.304110,WHITE,False,1.619870,-1.008531,-0.442865,-0.313084,-0.373290,0.114110,-0.055887,0.303872,-0.208171,0.170830,1
2,24,161859,TRANSFER FROM HOSP/EXTRAM,Private,1,39.041096,WHITE,False,-2.934710,1.036119,0.831560,-1.108636,-0.387169,0.459587,0.016300,0.599273,0.370220,0.461145,1
3,25,129635,EMERGENCY ROOM ADMIT,Private,1,58.989041,WHITE,False,-2.977587,0.429115,0.214535,-0.363241,-0.118179,-0.022090,0.139223,-0.074032,0.314959,-0.208274,3
4,31,128652,TRANSFER FROM HOSP/EXTRAM,Medicare,1,72.312329,WHITE,True,-3.570315,-1.144792,0.164033,1.071252,0.386312,-0.074751,-0.047760,-0.208025,-0.178810,-0.079752,3
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
36176,98769,141860,CLINIC REFERRAL/PREMATURE,Medicare,0,80.389041,WHITE,False,1.657211,-0.730174,-0.307028,0.155865,0.095740,0.525348,0.141700,0.079765,-0.227316,-0.116063,3
36177,98790,187987,EMERGENCY ROOM ADMIT,Private,1,53.613699,WHITE,False,-2.372871,0.279419,-0.474688,-0.351214,-0.581778,-0.162198,-0.200935,0.384624,0.310095,0.036374,8
36178,98800,191113,CLINIC REFERRAL/PREMATURE,Private,0,19.410959,WHITE,False,3.019539,0.225234,0.719516,0.257329,-0.046806,-0.187671,-0.169384,0.088052,-0.227861,0.040457,6
36179,98813,170407,EMERGENCY ROOM ADMIT,Private,0,60.808219,WHITE,False,2.343940,-0.029092,-0.046828,0.170270,0.190978,-0.081753,-0.578009,0.155968,-0.041452,0.250815,8


In [143]:
merged_features_df = pd.merge(
    merged_features_df,
    prescriptions_features_df,
    on=['HADM_ID', 'DEATH'],
    how='inner'
)
merged_features_df

Unnamed: 0,SUBJECT_ID,HADM_ID,ADMISSION_LOCATION,INSURANCE,GENDER_BINARY,AGE_YRS,ETHNICITY_BUCKET,DEATH,note_pca_0,note_pca_1,...,rx_pca_1,rx_pca_2,rx_pca_3,rx_pca_4,rx_pca_5,rx_pca_6,rx_pca_7,rx_pca_8,rx_pca_9,DRUG_COUNT
0,23,124321,TRANSFER FROM HOSP/EXTRAM,Medicare,1,75.304110,WHITE,False,1.619870,-1.008531,...,0.887703,-0.158029,-0.120938,0.199733,0.100948,0.089554,0.048722,-0.199941,0.020372,17
1,24,161859,TRANSFER FROM HOSP/EXTRAM,Private,1,39.041096,WHITE,False,-2.934710,1.036119,...,0.130689,-0.479938,-0.067860,0.129713,-0.270058,0.200172,0.118863,-0.098211,0.092918,22
2,25,129635,EMERGENCY ROOM ADMIT,Private,1,58.989041,WHITE,False,-2.977587,0.429115,...,-0.699393,-0.030055,-0.172002,-0.008296,-0.268531,-0.098588,0.234991,0.215704,0.031519,52
3,31,128652,TRANSFER FROM HOSP/EXTRAM,Medicare,1,72.312329,WHITE,True,-3.570315,-1.144792,...,-0.642328,0.509045,0.126523,0.247161,0.009521,-0.117968,-0.008234,-0.098021,0.162976,39
4,33,176176,EMERGENCY ROOM ADMIT,Medicare,1,82.446575,OTHER,False,0.422784,0.612354,...,0.263682,0.747391,-0.114952,0.093479,0.238333,-0.306376,-0.070260,0.010318,-0.074579,15
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
33141,98769,141860,CLINIC REFERRAL/PREMATURE,Medicare,0,80.389041,WHITE,False,1.657211,-0.730174,...,-0.695297,-0.072750,-0.190953,-0.077945,0.103924,-0.023573,0.154345,-0.063336,-0.008565,59
33142,98790,187987,EMERGENCY ROOM ADMIT,Private,1,53.613699,WHITE,False,-2.372871,0.279419,...,-0.198250,0.321598,-0.144759,-0.033973,0.096607,0.090049,0.176499,0.130440,0.024848,45
33143,98800,191113,CLINIC REFERRAL/PREMATURE,Private,0,19.410959,WHITE,False,3.019539,0.225234,...,-0.555646,-0.327531,0.417941,-0.141216,0.300900,0.113614,-0.044908,-0.115732,-0.017684,68
33144,98813,170407,EMERGENCY ROOM ADMIT,Private,0,60.808219,WHITE,False,2.343940,-0.029092,...,-0.761428,0.190721,0.258994,0.132533,0.148823,0.006846,0.033431,-0.050321,0.039406,64


In [144]:
assert_unique_HADM_ID(merged_features_df)

In [145]:
merged_features_df['HADM_ID'] = merged_features_df['HADM_ID'].astype('category')

In [146]:
merged_features_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 33146 entries, 0 to 33145
Data columns (total 30 columns):
 #   Column              Non-Null Count  Dtype   
---  ------              --------------  -----   
 0   SUBJECT_ID          33146 non-null  category
 1   HADM_ID             33146 non-null  category
 2   ADMISSION_LOCATION  33146 non-null  category
 3   INSURANCE           33146 non-null  category
 4   GENDER_BINARY       33146 non-null  category
 5   AGE_YRS             33146 non-null  float64 
 6   ETHNICITY_BUCKET    33146 non-null  category
 7   DEATH               33146 non-null  bool    
 8   note_pca_0          33146 non-null  float32 
 9   note_pca_1          33146 non-null  float32 
 10  note_pca_2          33146 non-null  float32 
 11  note_pca_3          33146 non-null  float32 
 12  note_pca_4          33146 non-null  float32 
 13  note_pca_5          33146 non-null  float32 
 14  note_pca_6          33146 non-null  float32 
 15  note_pca_7          33146 non-null  

## Eval

In [147]:
merged_deathrate = merged_features_df['DEATH'].mean()
merged_deathrate

0.11880769926989682

In [148]:
merged_feature_importances = train_and_eval_prelim_xgb(merged_features_df, merged_deathrate, label_col='DEATH', id_cols=['HADM_ID', 'SUBJECT_ID'])

Accuracy: 0.8229260935143288
ROC AUC: 0.7888
Average Precision: 0.3945
              precision    recall  f1-score   support

       False       0.93      0.86      0.90      5834
        True       0.34      0.52      0.41       796

    accuracy                           0.82      6630
   macro avg       0.64      0.69      0.65      6630
weighted avg       0.86      0.82      0.84      6630

Feature importances:


Unnamed: 0,feature,importance
8,note_pca_3,0.113925
3,AGE_YRS,0.065735
18,rx_pca_2,0.051077
12,note_pca_7,0.046707
17,rx_pca_1,0.045847
11,note_pca_6,0.04305
4,ETHNICITY_BUCKET,0.038709
20,rx_pca_4,0.037203
0,ADMISSION_LOCATION,0.035596
5,note_pca_0,0.034976


## Tune and eval XGBoost

In [None]:
# Prepare feature and target columns
feature_cols = [col for col in merged_features_df.columns if col not in ['HADM_ID', 'SUBJECT_ID', 'DEATH']]
X = merged_features_df[feature_cols]
y = merged_features_df['DEATH']

# Define the hyperparameter search space for randomized search
param_dist = {
    'max_depth': randint(3, 10),
    'learning_rate': uniform(0.01, 0.2),
    'n_estimators': randint(100, 500),
    'min_child_weight': randint(1, 10),
    'gamma': uniform(0, 0.5),
    'subsample': uniform(0.6, 0.4),
    'colsample_bytree': uniform(0.6, 0.4),
    'scale_pos_weight': uniform(1, 10),
    'reg_alpha': uniform(0, 1),
    'reg_lambda': uniform(1, 10)
}

# Create XGBoost classifier
xgb_model = XGBClassifier(
    objective='binary:logistic',
    eval_metric='aucpr',
    random_state=42,
    enable_categorical = True,
    tree_method = 'hist',  # Required for categorical feature support
    scale_pos_weight = 1 / merged_deathrate  # Adjust class imbalance
)

# Set up cross-validation
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Set up scoring metrics
scoring = {
    'precision': 'precision',
    'recall': 'recall',
    'f1': 'f1',
    'auc': 'roc_auc'
}

# Set up RandomizedSearchCV with precision as the primary optimization metric
random_search = RandomizedSearchCV(
    estimator=xgb_model,
    param_distributions=param_dist,
    n_iter=30,  # Number of parameter settings sampled
    scoring=scoring,
    refit='precision',  # Optimize for precision
    cv=cv,
    verbose=1,
    random_state=42,
    return_train_score=True
)

# Fit the model
random_search.fit(X, y)

# Get the best parameters and results
best_params = random_search.best_params_
best_score = random_search.best_score_

print("\n====================")
print(f"Best Precision Score: {best_score:.4f}")
print("Best Parameters:")
for param, value in best_params.items():
    print(f"{param}: {value}")

# Get all the cross-validation results for the best model
cv_results = random_search.cv_results_
best_index = random_search.best_index_

print("\nBest Model Metrics:")
print(f"Precision: {cv_results['mean_test_precision'][best_index]:.4f}")
print(f"Recall: {cv_results['mean_test_recall'][best_index]:.4f}")
print(f"F1: {cv_results['mean_test_f1'][best_index]:.4f}")
print(f"AUC: {cv_results['mean_test_auc'][best_index]:.4f}")

# Train final model with best parameters
# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

final_model = XGBClassifier(
    **best_params,
    objective='binary:logistic',
    random_state=42,
    enable_categorical = True,
    # tree_method = 'hist'  # Required for categorical feature support
)
final_model.fit(X, y)

# Evaluate the model
y_pred = final_model.predict(X_test)
print('Accuracy:', accuracy_score(y_test, y_pred))
print(f"ROC AUC: {roc_auc_score(y_test, final_model.predict_proba(X_test)[:,1]):.4f}")
print(f"PR AUC: {average_precision_score(y_test, final_model.predict_proba(X_test)[:,1], average='weighted'):.4f}")
print(classification_report(y_test, y_pred))

# Print feature importances
importances = final_model.feature_importances_
feat_importances = pd.DataFrame({
    'feature': X.columns,
    'importance': importances
}).sort_values(by='importance', ascending=False)
print('Feature importances:')
display(feat_importances)


Fitting 5 folds for each of 30 candidates, totalling 150 fits

Best Precision Score: 0.5272
Best Parameters:
colsample_bytree: 0.7727780074568463
gamma: 0.14561457009902096
learning_rate: 0.1323705789444759
max_depth: 4
min_child_weight: 3
n_estimators: 463
reg_alpha: 0.5142344384136116
reg_lambda: 6.924145688620425
scale_pos_weight: 1.4645041271999772
subsample: 0.8430179407605753

Best Model Metrics:
Precision: 0.5272
Recall: 0.2364
F1: 0.3264
AUC: 0.8023
Accuracy: 0.9291101055806938
ROC AUC: 0.9575
Average Precision: 0.7909
              precision    recall  f1-score   support

       False       0.94      0.99      0.96      5834
        True       0.85      0.50      0.63       796

    accuracy                           0.93      6630
   macro avg       0.89      0.74      0.79      6630
weighted avg       0.92      0.93      0.92      6630

Feature importances:


Unnamed: 0,feature,importance
8,note_pca_3,0.102528
3,AGE_YRS,0.07241
18,rx_pca_2,0.056205
4,ETHNICITY_BUCKET,0.047397
17,rx_pca_1,0.044369
20,rx_pca_4,0.042947
12,note_pca_7,0.041543
1,INSURANCE,0.039738
5,note_pca_0,0.035433
11,note_pca_6,0.03501
