# LSTM

**Assumptions:**

- We will consider all ventilator types as if patient was intubated

<span style="color:red;"><b>
In `data_wide`:  
    - Consider including charlson comorbidity index (from `mimiciv.derived.charlson` table)  
    - Get ICU hour data | Calculate mortality at 24h, 48h, 96h 
</span></b>

**Variables to add**:

 - GCS
 - Charlson comorbidity index


**Considerations** 

 - Add valve patients
 - Exclude ICU admissions <4h
 - Exclude ICU admission >30 days

## Importing and Cleaning Data

In [114]:
import pandas as pd
from os import listdir
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import RandomizedSearchCV
from sklearn.base import BaseEstimator
import os

import tensorflow as tf
#from tensorflow import keras

Loading data

In [115]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [116]:
import os

# Get the current Conda environment
conda_env = os.environ.get('CONDA_DEFAULT_ENV')

# Print the environment name
print(f"Current Conda environment: {conda_env}")


Current Conda environment: base


**Importing main data files and concatenating them**

In [117]:
data_path = r'../data/final_data'

# Files to read
# data_files = listdir(data_path)
data_files = [file for file in os.listdir(data_path) if os.path.isfile(os.path.join(data_path, file))]

# Load data
data = pd.concat([pd.read_csv(data_path + '/' + file) for file in data_files])#.drop(columns=['endtime'])

  data = pd.concat([pd.read_csv(data_path + '/' + file) for file in data_files])#.drop(columns=['endtime'])
  data = pd.concat([pd.read_csv(data_path + '/' + file) for file in data_files])#.drop(columns=['endtime'])


In [118]:
data.head(5)

Unnamed: 0,insurance,race,marital_status,deathtime,discharge_location,gender,anchor_age,endtime,value,amount,amountuom,label,Died,Unique Stay,sequence_num
0,Other,WHITE,MARRIED,,SKILLED NURSING FACILITY,M,87,2173-12-01 12:00:00,Drager,1.0,,Ventilator Type,True,1578920124078757,2680
1,Other,WHITE,MARRIED,,SKILLED NURSING FACILITY,M,87,2173-12-01 12:00:00,,83.0,bpm,Heart Rate,True,1578920124078757,2681
2,Other,WHITE,MARRIED,,SKILLED NURSING FACILITY,M,87,2173-12-01 12:00:00,,78.0,mmHg,Arterial Blood Pressure mean,True,1578920124078757,2682
3,Other,WHITE,MARRIED,,SKILLED NURSING FACILITY,M,87,2173-12-01 12:00:00,,78.0,mmHg,Arterial Blood Pressure mean,True,1578920124078757,2683
4,Other,WHITE,MARRIED,,SKILLED NURSING FACILITY,M,87,2173-12-01 12:00:00,,83.0,bpm,Heart Rate,True,1578920124078757,2684


**Verifying that data includes only CABG patients**

In [119]:
CABG_ICDS = ['0210093', '0210098', '0210099', '021009C', '021009F', '021009W',
'02100A3', '02100A8', '02100A9', '02100AC', '02100AF', '02100AW',
'02100J3', '02100J8', '02100J9', '02100JC', '02100JF', '02100JW',
'02100K3', '02100K8', '02100K9', '02100KC', '02100KF', '02100KW',
'02100Z3', '02100Z8', '02100Z9', '02100ZC', '02100ZF',
'0211093', '0211098', '0211099', '021109C', '021109F', '021109W',
'02110A3', '02110A8', '02110A9', '02110AC', '02110AF', '02110AW',
'02110J3', '02110J8', '02110J9', '02110JC', '02110JF', '02110JW',
'02110K3', '02110K8', '02110K9', '02110KC', '02110KF', '02110KW',
'02110Z3', '02110Z8', '02110Z9', '02110ZC', '02110ZF',
'0212093', '0212098', '0212099', '021209C', '021209F', '021209W',
'02120A3', '02120A8', '02120A9', '02120AC', '02120AF', '02120AW',
'02120J3', '02120J8', '02120J9', '02120JC', '02120JF', '02120JW',
'02120K3', '02120K8', '02120K9', '02120KC', '02120KF', '02120KW',
'02120Z3', '02120Z8', '02120Z9', '02120ZC', '02120ZF',
'0213093', '0213098', '0213099', '021309C', '021309F', '021309W',
'02130A3', '02130A8', '02130A9', '02130AC', '02130AF', '02130AW',
'02130J3', '02130J8', '02130J9', '02130JC', '02130JF', '02130JW',
'02130K3', '02130K8', '02130K9', '02130KC', '02130KF', '02130KW',
'02130Z3', '02130Z8', '02130Z9', '02130ZC', '02130ZF',
'3610', '3611', '3612', '3613', '3614', '3615', '3616', '3617', '3619']

In [120]:
#cabg_pts = pd.read_csv('B:/Databases/MIMIC-IV/CABG-filtered/cabg_pts.csv')

In [121]:
# # Filter data to only include CABG patients
# cabg_pts[cabg_pts.icd_code.isin(CABG_ICDS)]

# # Unique patient IDs
# cabg_pts.subject_id.nunique() # 5647 unique patients

# Create unique stay ID
# unique_stay = cabg_pts.subject_id.astype(str) + cabg_pts.hadm_id.astype(str)
# unique_stay = unique_stay.to_list()

# Number of patients in the data
data.shape[0]

1214970

In [122]:
# Number of patients in the data that are CABG patients
#data['Unique Stay'].astype(str).isin(unique_stay).sum()

**Keeping only one ICU admission per patient**

For patients admitted to the ICU multiple times, we will only look at their first admission. This is to ensure that observations are all independent.

In [123]:
# ICU stays
icu_data = pd.read_csv('../data/datasets/icu_stays_filtered.csv')

# Converting to time format
icu_data['intime'] = pd.to_datetime(icu_data['intime'])

icu_data['Unique Stay'] =  icu_data['subject_id'].astype(str) + icu_data['hadm_id'].astype(str)

In [124]:
# Looking at number of ICU readmissions in general
icu_data.groupby(['subject_id', 'hadm_id'])['stay_id'].nunique().value_counts()

1    6000
2     583
3      55
4      10
6       1
Name: stay_id, dtype: int64

In [125]:
# Keeping only 1 ICU admission per patient
icu_data = icu_data.sort_values('intime').groupby(['Unique Stay'], as_index=False).first()

In [126]:
icu_data.groupby(['subject_id', 'hadm_id'])['stay_id'].nunique().value_counts()

1    6649
Name: stay_id, dtype: int64

In [127]:
# Merging data back with original data
data['Unique Stay'] = data['Unique Stay'].astype(str)
data = pd.merge(icu_data[['Unique Stay', 'subject_id', 'stay_id']], data, on = 'Unique Stay', how = 'inner')

In [128]:
data.head()

Unnamed: 0,Unique Stay,subject_id,stay_id,insurance,race,marital_status,deathtime,discharge_location,gender,anchor_age,endtime,value,amount,amountuom,label,Died,sequence_num
0,1000201323581541,10002013,39060235,Medicare,OTHER,SINGLE,,HOME HEALTH CARE,F,53,2160-05-18 10:26:00,,3.3,mmol/L,Lactic Acid,False,1
1,1000201323581541,10002013,39060235,Medicare,OTHER,SINGLE,,HOME HEALTH CARE,F,53,2160-05-18 10:26:00,,421.0,mmHg,Arterial O2 pressure,False,2
2,1000201323581541,10002013,39060235,Medicare,OTHER,SINGLE,,HOME HEALTH CARE,F,53,2160-05-18 11:23:00,,2.8,mmol/L,Lactic Acid,False,3
3,1000201323581541,10002013,39060235,Medicare,OTHER,SINGLE,,HOME HEALTH CARE,F,53,2160-05-18 12:20:00,,3.1,mmol/L,Lactic Acid,False,4
4,1000201323581541,10002013,39060235,Medicare,OTHER,SINGLE,,HOME HEALTH CARE,F,53,2160-05-18 12:20:00,,384.0,mmHg,Arterial O2 pressure,False,5


In [129]:
data.subject_id.nunique()

5346

In [130]:
data.stay_id.nunique()

5361

In [131]:
data['Unique Stay'].nunique()

5361

**Importing Charlson Comorbidities Data**

In [132]:
# Importing data
charlson_df = pd.read_csv('../data/datasets/charlson_cabg.csv')
charlson_df['Unique Stay'] = charlson_df['subject_id'].astype(str) + charlson_df['hadm_id'].astype(str)

# Merging charlson data with main data
data = pd.merge(data, charlson_df.drop(['subject_id', 'hadm_id'], axis =1), on = 'Unique Stay', how = 'left')
data.rename(columns = {'charlson_comorbidity_index': 'charlson'}, inplace = True)

**Importing ventilation data**

In [133]:
data.shape

(1214970, 18)

In [134]:
vent_data = pd.read_csv('../data/datasets/cabg_ventilation.csv')

In [135]:
# Converting start and stop vent time to date/time format
vent_data['vent_starttime'] = pd.to_datetime(vent_data['starttime'])
vent_data['vent_endtime'] = pd.to_datetime(vent_data['endtime'])
# Filtering vent data to only include CABBG patients
patient_id = [row[0:8] for row in data['Unique Stay'].astype('str')]
# Filtering the ventilator data for CABG patients based on identified stay_ids
vent_data = vent_data[vent_data['subject_id'].astype('str').isin(patient_id)].drop(['starttime', 'endtime'], axis = 1)

In [136]:
vent_data.head()

Unnamed: 0,subject_id,stay_id,ventilation_status,vent_starttime,vent_endtime
0,17997568,36100181,HFNC,2148-09-26 08:00:00,2148-09-26 17:23:00
1,17557505,34278743,HFNC,2167-04-23 12:00:00,2167-04-24 07:00:00
2,17557505,34278743,HFNC,2167-04-24 10:00:00,2167-04-24 16:00:00
3,11799303,32117491,HFNC,2183-11-08 10:00:00,2183-11-08 12:00:00
4,17579017,36645193,HFNC,2137-08-30 20:00:00,2137-08-30 21:30:00


**Converting to long format**

In [137]:
def expand_ventilation_data_merge(vent_data):
    # Convert datetime columns to datetime type
    vent_data['vent_starttime'] = pd.to_datetime(vent_data['vent_starttime'])
    vent_data['vent_endtime'] = pd.to_datetime(vent_data['vent_endtime'])
    
    # Create time ranges for each unique stay
    time_ranges = []
    for (subject_id, stay_id), group in vent_data.groupby(['subject_id', 'stay_id']):
        times = pd.date_range(
            start=group['vent_starttime'].min(),
            end=group['vent_endtime'].max(),
            freq='H'
        )
        time_ranges.append(pd.DataFrame({
            'subject_id': subject_id,
            'stay_id': stay_id,
            'time': times
        }))
    
    # Combine all time ranges into a single DataFrame
    time_df = pd.concat(time_ranges, ignore_index=True)
    
    # Perform a merge to filter times within the vent_starttime and vent_endtime intervals
    result = (
        time_df.merge(vent_data, on=['subject_id', 'stay_id'], how='left')
        .query('vent_starttime <= time <= vent_endtime')
        .groupby(['subject_id', 'stay_id', 'time'])
        .first()
        .reset_index()
    ).drop(['vent_starttime', 'vent_endtime'], axis = 1)
    
    return result

# Assuming vent_data is your DataFrame with ventilation data
vent_data_long = expand_ventilation_data_merge(vent_data)


In [138]:
# Merging ventilator data with aggregated data!
vent_data_long = vent_data_long.rename(columns = {'time': 'time_bucket'})
vent_data_long['time_bucket'] = vent_data_long['time_bucket'].dt.round('H')

In [139]:
vent_data_long

Unnamed: 0,subject_id,stay_id,time_bucket,ventilation_status
0,10002013,39060235,2160-05-18 14:00:00,InvasiveVent
1,10002013,39060235,2160-05-18 15:00:00,InvasiveVent
2,10002013,39060235,2160-05-18 16:00:00,InvasiveVent
3,10002013,39060235,2160-05-18 17:00:00,InvasiveVent
4,10002013,39060235,2160-05-18 18:00:00,InvasiveVent
...,...,...,...,...
388074,19995790,34995866,2185-02-03 08:00:00,SupplementalOxygen
388075,19995790,34995866,2185-02-03 09:00:00,SupplementalOxygen
388076,19995790,34995866,2185-02-03 10:00:00,SupplementalOxygen
388077,19995790,34995866,2185-02-03 11:00:00,SupplementalOxygen


We don't need `ventilator type` and `ventilator mode` from the columns data.label anymore. So we can remove these rows are they are rendered useless with the new variable `ventilation_status`. This will also reduce the number of columns when we reshape the data from long to wide. 

In [140]:
# Deleting ventilator type and ventilator mode rows from dataset
data = data[~(data.label == 'Ventilator Type')]
data = data[~(data.label == 'Ventilator Mode')]

**Importing Vasoactive data**

In [141]:
vasoactive_data = pd.read_csv('../data/datasets/vasoactive_cabg.csv')

In [142]:
vasoactive_data.shape

(111538, 10)

In [143]:
# Merging with Unique Stay
vasoactive_data = pd.merge(vasoactive_data, icu_data[['stay_id', 'Unique Stay']], on = 'stay_id', how = 'inner')

# Converting start and stop vasoactive time to date/time format
vasoactive_data['vaso_starttime'] = pd.to_datetime(vasoactive_data['starttime'])
vasoactive_data['vaso_endtime'] = pd.to_datetime(vasoactive_data['endtime'])

# Filtering the vasoactive data for CABG patients based on identified stay_ids
vasoactive_data = vasoactive_data[vasoactive_data['stay_id'].astype('str').isin(data.stay_id.astype('str').to_list())].drop(['starttime', 'endtime'], axis = 1)

# Reordering columns
vasoactive_data = vasoactive_data[['stay_id', 'vaso_starttime', 'vaso_endtime', 'epinephrine', 'norepinephrine', 'phenylephrine', 'dobutamine', 'milrinone', 'dopamine']]

Converting vasoactive data it into an hourly time series, where each row represents one hour and shows the doses of all vasoactive medications (epinephrine, norepinephrine, phenylephrine, dobutamine, milrinone, and dopamine) that a patient was receiving during that hour.

In [144]:
def expand_vasoactive_data_merge(vasoactive_data):
    # Convert datetime columns
    vasoactive_data['vaso_starttime'] = pd.to_datetime(vasoactive_data['vaso_starttime'])
    vasoactive_data['vaso_endtime'] = pd.to_datetime(vasoactive_data['vaso_endtime'])
    
    # Create time ranges for each stay
    time_ranges = []
    for stay_id, group in vasoactive_data.groupby('stay_id'):
        times = pd.date_range(
            start=group['vaso_starttime'].min(),
            end=group['vaso_endtime'].max(),
            freq='H'
        )
        time_ranges.append(pd.DataFrame({
            'stay_id': stay_id,
            'time': times
        }))
    
    time_df = pd.concat(time_ranges, ignore_index=True)
    
    # Merge with original data using interval join logic
    result = time_df.merge(vasoactive_data, on='stay_id', how='left')
    
    # Filter for rows where time falls within the start and end times
    result = result[
        (result['time'] >= result['vaso_starttime']) & 
        (result['time'] < result['vaso_endtime'])
    ]
    
    # Get the latest dose for each medication at each hour
    vasoactive_agents = ['epinephrine', 'norepinephrine', 'phenylephrine', 
                        'dobutamine', 'milrinone', 'dopamine']
    
    # Group by stay_id and time, and get the last non-null value for each medication
    result = result.groupby(['stay_id', 'time'])[vasoactive_agents].last().reset_index()
    
    # Fill NaN values with 0 to indicate no medication at that time
    result[vasoactive_agents] = result[vasoactive_agents].fillna(0)
    
    return result

# Create the dataset
vasoactive_long = expand_vasoactive_data_merge(vasoactive_data)

# Merging ventilator data with aggregated data!
vasoactive_long = vasoactive_long.rename(columns = {'time': 'time_bucket'})
vasoactive_long['time_bucket'] = vasoactive_long['time_bucket'].dt.floor('H')

In [145]:
# Deleting vasoactive data from label variable since these will be superceded by the vasoactive_long table
data = data[~data.label.isin(['Norepinephrine', 'Dobutamine', 'Dopamine', 'Epinephrine', 'Epinephrine.'])]

In [146]:
data.label.value_counts()

Heart Rate                      526542
Arterial Blood Pressure mean    336641
Inspired O2 Fraction             71942
Arterial O2 pressure             61075
Platelet Count                   36745
Creatinine (serum)               34050
Lactic Acid                      31799
Total Bilirubin                   4365
Name: label, dtype: int64

**Merging vasoactive with ventilator data**  

This will create on long format table, where for each hour, and for each patient, we have information regarding what drips they were on and what were the vent settings

In [147]:
# Merging ventilator with vasoactive data
rx_data_long = pd.merge(vasoactive_long, vent_data_long, on=['stay_id', 'time_bucket'], how='outer', suffixes=('', '_vaso'))


# Dropping subject_id column
rx_data_long = rx_data_long.drop('subject_id', axis = 1)

We will replace missing values for drips with 0 - we assume that if patients did not have any drips recorded at that time, that they were not receiving anything. 

Same goes for ventilation status: if there is no documented ventilation/oxygenation treatment, we assume that patients were not receiving treatment and thus we set the value to be `none`. But we will do that later, because we need ventilation status to impute missing fio2 status

In [148]:
# Replace missing values with 0
rx_data_long[['epinephrine', 'norepinephrine', 'phenylephrine', 'dobutamine', 'dopamine', 'milrinone']] = rx_data_long[['epinephrine', 'norepinephrine', 'phenylephrine', 'dobutamine', 'dopamine', 'milrinone']].fillna(0)

# Rounding to 3 digits
rx_data_long[['epinephrine', 'norepinephrine', 'phenylephrine', 'dobutamine', 'dopamine', 'milrinone']] = rx_data_long[['epinephrine', 'norepinephrine', 'phenylephrine', 'dobutamine', 'dopamine', 'milrinone']].round(3)

In [149]:
rx_data_long.stay_id.nunique()

6791

In [150]:
rx_data_long['stay_id'].nunique()

6791

In [151]:
rx_data_long.head()

Unnamed: 0,stay_id,time_bucket,epinephrine,norepinephrine,phenylephrine,dobutamine,milrinone,dopamine,ventilation_status
0,30001148,2156-08-30 14:00:00,0.0,0.0,0.4,0.0,0.0,0.0,
1,30001148,2156-08-30 15:00:00,0.0,0.0,0.0,0.0,0.0,0.0,InvasiveVent
2,30001148,2156-08-30 16:00:00,0.0,0.0,0.0,0.0,0.0,0.0,InvasiveVent
3,30001148,2156-08-30 17:00:00,0.0,0.0,0.0,0.0,0.0,0.0,InvasiveVent
4,30001148,2156-08-30 18:00:00,0.0,0.0,0.5,0.0,0.0,0.0,InvasiveVent


### Data Preprocessing

**Checking accuracy in coding of the variable `Died`**. 

Many patients who were coded as 'Died' were in fact discharged from the hospital. This is because this variable records whether patients died, even after discharge from the hospital. We are only interested in in-hospital mortality, so we need to create a new variable for that, which we will call `mortality`

In [152]:
data[data.Died == True].discharge_location.value_counts()

DIED                            75521
CHRONIC/LONG TERM ACUTE CARE    70267
SKILLED NURSING FACILITY        51304
REHAB                           32346
HOME HEALTH CARE                26007
HOSPICE                          6374
HOME                              363
OTHER FACILITY                    281
Name: discharge_location, dtype: int64

In [153]:
# Creating New Mortality Data
data['mortality'] = np.where(data['discharge_location']  == 'DIED', True, False)

In [154]:
# Dropping variable Died and discharge_location
data.drop('Died', axis = 1, inplace = True)
data.drop('discharge_location', axis = 1, inplace = True)

**Creating variable that codes for TIME OF DEATH, as a time-to-event outcome**

I will do that later - for now, we're interested in looking at patterns that predict in-hospital mortality in general. It's a classification problem. Later, we will build a survival model so that we can predict time-to-mortality.

In [155]:
# Recoding variable death time
data['deathtime'] = pd.to_datetime(data['deathtime'])

In [156]:
data[~data.deathtime.isna()].head()

Unnamed: 0,Unique Stay,subject_id,stay_id,insurance,race,marital_status,deathtime,gender,anchor_age,endtime,value,amount,amountuom,label,sequence_num,charlson,mortality
8726,1010445023157316,10104450,31524085,Medicare,WHITE,MARRIED,2137-10-13 00:01:00,M,82,2137-10-10 06:20:00,,203.0,K/uL,Platelet Count,1,13,True
8727,1010445023157316,10104450,31524085,Medicare,WHITE,MARRIED,2137-10-13 00:01:00,M,82,2137-10-10 06:20:00,,2.9,mg/dL,Creatinine (serum),2,13,True
8728,1010445023157316,10104450,31524085,Medicare,WHITE,MARRIED,2137-10-13 00:01:00,M,82,2137-10-10 07:46:00,,99.0,bpm,Heart Rate,3,13,True
8729,1010445023157316,10104450,31524085,Medicare,WHITE,MARRIED,2137-10-13 00:01:00,M,82,2137-10-10 07:48:00,,94.0,bpm,Heart Rate,4,13,True
8730,1010445023157316,10104450,31524085,Medicare,WHITE,MARRIED,2137-10-13 00:01:00,M,82,2137-10-10 07:48:00,,94.0,bpm,Heart Rate,5,13,True


In [157]:
# Making sure that all patients who died have a time of death recorded (should return True)
data[~data.deathtime.isna()].mortality.value_counts()[0] == data[~data.deathtime.isna()].shape[0]

True

<span style = 'color:maroon'> <b> ... to be continued later </span> </b>

**Recoding `race` data**

In [158]:
# Recoding race data
data['race_new']  = data.race.replace({
                              'WHITE': 'White',
                              'WHITE - OTHER EUROPEAN': 'White',
                              'WHITE - RUSSIAN': 'White',
                              'WHITE - EASTERN EUROPEAN': 'White',
                              'PORTUGUESE': 'White',
                              'WHITE - BRAZILIAN': 'White',
                              # Black
                              'BLACK/AFRICAN AMERICAN': 'Black',
                              'BLACK/AFRICAN': 'Black',
                              'HISPANIC/LATINO - GUATEMALAN': 'Black',
                              # Asian
                              'ASIAN - ASIAN INDIAN': 'Asian',
                              'ASIAN': 'Asian',
                              'ASIAN - CHINESE': 'Asian',
                              'ASIAN - SOUTH EAST ASIAN': 'Asian',
                              'ASIAN - KOREAN': 'Asian',
                              # Hispanic
                              'HISPANIC/LATINO - PUERTO RICAN': 'Hispanic',
                              'HISPANIC/LATINO - DOMINICAN': 'Hispanic',
                              'HISPANIC OR LATINO': 'Hispanic',
                              'HISPANIC': 'Hispanic',
                              'HISPANIC/LATINO - CUBAN': 'Hispanic',
                              'HISPANIC/LATINO - SALVADORAN': 'Hispanic',
                              'SOUTH AMERICAN': 'Hispanic',
                              'HISPANIC/LATINO - COLUMBIAN': 'Hispanic',
                              'HISPANIC/LATINO - CUBAN': 'Hispanic',
                              'HISPANIC/LATINO - HONDURAN': 'Hispanic',
                              'HISPANIC/LATINO - CENTRAL AMERICAN': 'Hispanic',
                              'HISPANIC/LATINO - MEXICAN': 'Hispanic',
                              # Native American/Pacific Islander
                              'NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER': 'Native American',
                              'AMERICAN INDIAN/ALASKA NATIVE': 'Native American',
                              'BLACK/CAPE VERDEAN': 'Black',
                              'BLACK/CARIBBEAN ISLAND' : 'Black',
                              # Other
                              'OTHER': 'Other',
                              'MULTIPLE RACE/ETHNICITY': 'Other',
                              # Unknown
                              'UNABLE TO OBTAIN': np.nan,
                              'UNKNOWN': np.nan,
                              'PATIENT DECLINED TO ANSWER': np.nan,
                              })

data.race = data.race_new
data.drop(['race_new'], axis = 1, inplace = True)

data.race.value_counts()

White              790246
Black               45499
Hispanic            31957
Other               30206
Asian               22902
Native American      6440
Name: race, dtype: int64

In [159]:
# Checking for discrepancies in the race data

# Creating data for patient ID
data['subject_id'] = data['Unique Stay'].astype(str).str[0:8]

# Checking if all race data is the same for each patient
data.groupby('subject_id').race.nunique().value_counts()

1    4533
0     813
Name: race, dtype: int64

In [160]:
# **Cleaning up the group of variables `value`, `amount`, `amountuom`, `label`.**

In [161]:
data.label.value_counts()

Heart Rate                      526542
Arterial Blood Pressure mean    336641
Inspired O2 Fraction             71942
Arterial O2 pressure             61075
Platelet Count                   36745
Creatinine (serum)               34050
Lactic Acid                      31799
Total Bilirubin                   4365
Name: label, dtype: int64

In [162]:
# CODE BELOW NO LONGER NEEDED SINCE I AM RETRIEVING VASOACTIVE DATA DIRECTLY FROM THE MIMICIV DERIVED VASOACTIVE TABLE

# # Remove the '.' from 'Epinephrine.' in the 'label' column
# data.loc[data.label == 'Epinephrine.', 'label'] = 'Epinephrine'

#There are 2 label values for epinephrine: 'Epinephrine' and 'Epinephrine.'. We will combine these into one label value.
#Also, there are 2 different units of measure for epinephrine: 'mcg' and 'mg'. We will convert all values to 'mg'. But we need to make sure that the values that are coded as mcg are actually mcg and not mg (the values for mcg and mg should different by a factor of 1000).
# # Convert 'amount' from mcg to mg and change 'amountuom' to 'mg'
# data.loc[(data.label == 'Epinephrine') & (data.amountuom == 'mcg'), 'amount'] /= 1000
# data.loc[(data.label == 'Epinephrine') & (data.amountuom == 'mcg'), 'amountuom'] = 'mg'

# # Making sure that units and values are consistent
# print('Heart rate: ', data[data.label == 'Heart Rate'].amountuom.value_counts(), '\n')

# print('MAP: ', data[data.label == 'Arterial Blood Pressure mean'].amountuom.value_counts(), '\n')

# print('FiO2: ', data[data.label == 'Inspired O2 Fraction'].amountuom.value_counts(), '\n')

# print('PaO2: ', data[data.label == 'Arterial O2 pressure'].amountuom.value_counts(), '\n')

# print('Platelet Count: ', data[data.label == 'Platelet Count'].amountuom.value_counts(), '\n')

# print('Creatinine (serum): ', data[data.label == 'Creatinine (serum)'].amountuom.value_counts(), '\n')

# print('Lactic Acid: ', data[data.label == 'Lactic Acid'].amountuom.value_counts(), '\n')

# print('Norepinephrine: ', data[data.label == 'Norepinephrine'].amountuom.value_counts(), '\n')

# print('Epinephrine: ', data[data.label == 'Epinephrine'].amountuom.value_counts(), '\n')

# print('Total Bilirubin: ', data[data.label == 'Total Bilirubin'].amountuom.value_counts(), '\n')

# print('Dobutamine: ', data[data.label == 'Dobutamine'].amountuom.value_counts(), '\n')

# print('Dopamine: ', data[data.label == 'Dopamine'].amountuom.value_counts(), '\n')

**Pivoting Data**

In [163]:
data.head()

Unnamed: 0,Unique Stay,subject_id,stay_id,insurance,race,marital_status,deathtime,gender,anchor_age,endtime,value,amount,amountuom,label,sequence_num,charlson,mortality
0,1000201323581541,10002013,39060235,Medicare,Other,SINGLE,NaT,F,53,2160-05-18 10:26:00,,3.3,mmol/L,Lactic Acid,1,7,False
1,1000201323581541,10002013,39060235,Medicare,Other,SINGLE,NaT,F,53,2160-05-18 10:26:00,,421.0,mmHg,Arterial O2 pressure,2,7,False
2,1000201323581541,10002013,39060235,Medicare,Other,SINGLE,NaT,F,53,2160-05-18 11:23:00,,2.8,mmol/L,Lactic Acid,3,7,False
3,1000201323581541,10002013,39060235,Medicare,Other,SINGLE,NaT,F,53,2160-05-18 12:20:00,,3.1,mmol/L,Lactic Acid,4,7,False
4,1000201323581541,10002013,39060235,Medicare,Other,SINGLE,NaT,F,53,2160-05-18 12:20:00,,384.0,mmHg,Arterial O2 pressure,5,7,False


In [164]:
# Pivot data from long to wide
data_wide = data.pivot_table(
    index=['insurance', 'race', 'endtime', 'marital_status', 'gender', 'anchor_age', 'mortality', 'Unique Stay', 'stay_id', 'sequence_num', 'charlson'],
    columns='label',
    values='amount',
    aggfunc='first'
).reset_index()


In [165]:
pd.set_option('display.max_columns', None)
data_wide.head(5)

label,insurance,race,endtime,marital_status,gender,anchor_age,mortality,Unique Stay,stay_id,sequence_num,charlson,Arterial Blood Pressure mean,Arterial O2 pressure,Creatinine (serum),Heart Rate,Inspired O2 Fraction,Lactic Acid,Platelet Count,Total Bilirubin
0,Medicaid,Asian,2113-06-29 16:57:00,MARRIED,M,69,False,1356278426556587,30830290,1,4,,332.0,,,,,,
1,Medicaid,Asian,2113-06-29 16:57:00,MARRIED,M,69,False,1356278426556587,30830290,2,4,,,,,,4.6,,
2,Medicaid,Asian,2113-06-29 16:57:00,MARRIED,M,69,False,1356278426556587,30830290,3,4,,,,,,4.6,,
3,Medicaid,Asian,2113-06-29 18:00:00,MARRIED,M,69,False,1356278426556587,30830290,5,4,,,,,50.0,,,
4,Medicaid,Asian,2113-06-29 18:00:00,MARRIED,M,69,False,1356278426556587,30830290,6,4,,,,,50.0,,,


<b>For each patient, need to create buckets of 1 hour, and then project values of all the measurements for that hour</b>

In [166]:
# Making sure endtime is in correct date time format
data_wide['endtime'] = pd.to_datetime(data_wide['endtime'])

# Create a new column with the rounded down hourly bucket
data_wide['time_bucket'] = data_wide['endtime'].dt.floor('H')

In [167]:
# Sample aggregation functions
agg_funcs = {
    'stay_id': 'first', 
    'anchor_age': 'first', 
    'insurance': 'first', 
    'race': 'first',
    'charlson': 'first',
    'marital_status': 'first',
    'gender': 'first',
    'mortality': 'max',        
    'sequence_num': 'min',
    'Arterial Blood Pressure mean': 'mean',
    'Arterial O2 pressure': 'mean',
    'Creatinine (serum)': 'mean',
    'Heart Rate': 'mean',
    'Inspired O2 Fraction': 'mean',
    'Lactic Acid': 'mean',
    'Platelet Count': 'mean',
}

In [168]:
# Sort data by 'Unique Stay' and 'time_bucket'
data_wide = data_wide.sort_values(by=['Unique Stay', 'time_bucket'])

# Group by 'Unique Stay' and 'time_bucket' and apply aggregation
data_wide_agg = data_wide.drop('endtime', axis=1).groupby(['Unique Stay', 'time_bucket']).agg(agg_funcs).reset_index()
data_wide_agg.head(7)

label,Unique Stay,time_bucket,stay_id,anchor_age,insurance,race,charlson,marital_status,gender,mortality,sequence_num,Arterial Blood Pressure mean,Arterial O2 pressure,Creatinine (serum),Heart Rate,Inspired O2 Fraction,Lactic Acid,Platelet Count
0,1000201323581541,2160-05-18 10:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,1,,421.0,,,,3.3,
1,1000201323581541,2160-05-18 11:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,3,,,,,,2.8,
2,1000201323581541,2160-05-18 12:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,4,,384.0,,,,3.1,
3,1000201323581541,2160-05-18 13:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,7,,311.0,,,,2.6,252.0
4,1000201323581541,2160-05-18 14:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,11,94.0,332.0,1.1,80.0,100.0,,254.0
5,1000201323581541,2160-05-18 15:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,21,66.0,,,,50.0,,
6,1000201323581541,2160-05-18 16:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,27,,,,92.0,,,


**I only kept the first ICU admission for each patient that has multiple ICU admission**

In [169]:
data_wide_agg.stay_id.nunique()

4451

In [170]:
data_wide_agg['Unique Stay'].nunique()

4451

In [171]:
# Creating index for grouping patients
group_idx = ['Unique Stay', 'time_bucket']

In [172]:
data_wide_agg.shape

(313940, 18)

**Merging data with ventilator and vasoactive data**

In [173]:
rx_data_long.shape

(411096, 9)

In [174]:
data_merged = pd.merge(data_wide_agg.drop(['Unique Stay'], axis = 1), rx_data_long, on = ['stay_id', 'time_bucket'], how = 'left')

In [175]:
data_merged['stay_id'].nunique()

4451

In [176]:
data_merged['ventilation_status'].value_counts()

SupplementalOxygen    118242
InvasiveVent           76386
HFNC                    2422
Tracheostomy            2295
NonInvasiveVent         1585
None                      60
Name: ventilation_status, dtype: int64

In [177]:
# Imputing missing ventilation status as 'None' (see explanation above)
data_merged['ventilation_status'] = data_merged['ventilation_status'].fillna('None')

In [178]:
# Reconding ventilation into yes/no binary variable
data_merged['vent'] = data_merged['ventilation_status'].apply(lambda x: 1 if x in ['InvasiveVent', 'Tracheostomy'] else 0)

# Will drop the old ventilation_status variable later

In [179]:
data_merged

Unnamed: 0,time_bucket,stay_id,anchor_age,insurance,race,charlson,marital_status,gender,mortality,sequence_num,Arterial Blood Pressure mean,Arterial O2 pressure,Creatinine (serum),Heart Rate,Inspired O2 Fraction,Lactic Acid,Platelet Count,epinephrine,norepinephrine,phenylephrine,dobutamine,milrinone,dopamine,ventilation_status,vent
0,2160-05-18 10:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,1,,421.0,,,,3.3,,,,,,,,,0
1,2160-05-18 11:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,3,,,,,,2.8,,,,,,,,,0
2,2160-05-18 12:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,4,,384.0,,,,3.1,,,,,,,,,0
3,2160-05-18 13:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,7,,311.0,,,,2.6,252.0,,,,,,,,0
4,2160-05-18 14:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,11,94.0,332.0,1.1,80.0,100.0,,254.0,0.0,0.0,0.701,0.0,0.0,0.0,InvasiveVent,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
319026,2185-02-03 10:00:00,34995866,66,Medicare,White,3,DIVORCED,M,False,76,,,,104.0,,,,0.0,0.0,0.000,0.0,0.0,0.0,SupplementalOxygen,0
319027,2185-02-03 11:00:00,34995866,66,Medicare,White,3,DIVORCED,M,False,78,,,,82.0,,,,0.0,0.0,0.000,0.0,0.0,0.0,SupplementalOxygen,0
319028,2185-02-03 12:00:00,34995866,66,Medicare,White,3,DIVORCED,M,False,81,,,,85.0,,,,0.0,0.0,0.000,0.0,0.0,0.0,SupplementalOxygen,0
319029,2185-02-03 13:00:00,34995866,66,Medicare,White,3,DIVORCED,M,False,83,,,,79.0,,,,,,,,,,,0


**Forward/Backward filling missing values for _FIXED_ variables**

In [180]:
# List of columns to forward fill
columns_to_fill = ['anchor_age', 'insurance', 'race', 'marital_status', 'gender', 'mortality']

# Forward fill specific columns within each group
data_merged[columns_to_fill] = data_merged.groupby('stay_id')[columns_to_fill].ffill().bfill()

In [181]:
# Recoding Sequence_number variable
data_merged['seq_num'] = data_merged.groupby(['stay_id']).cumcount()+1

data_merged.drop('sequence_num', axis = 1, inplace = True)

In [182]:
data_merged.head()

Unnamed: 0,time_bucket,stay_id,anchor_age,insurance,race,charlson,marital_status,gender,mortality,Arterial Blood Pressure mean,Arterial O2 pressure,Creatinine (serum),Heart Rate,Inspired O2 Fraction,Lactic Acid,Platelet Count,epinephrine,norepinephrine,phenylephrine,dobutamine,milrinone,dopamine,ventilation_status,vent,seq_num
0,2160-05-18 10:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,,421.0,,,,3.3,,,,,,,,,0,1
1,2160-05-18 11:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,,,,,,2.8,,,,,,,,,0,2
2,2160-05-18 12:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,,384.0,,,,3.1,,,,,,,,,0,3
3,2160-05-18 13:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,,311.0,,,,2.6,252.0,,,,,,,,0,4
4,2160-05-18 14:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,94.0,332.0,1.1,80.0,100.0,,254.0,0.0,0.0,0.701,0.0,0.0,0.0,InvasiveVent,1,5


In [183]:
# Renaming columns => easier to work with lower case names
data_merged = data_merged.rename(columns={'Arterial Blood Pressure mean': 'map',
                                      'Arterial O2 pressure': 'pao2',
                                      'Creatinine (serum)': 'creatinine',
                                      'Heart Rate': 'hr',
                                      'Inspired O2 Fraction': 'fio2',
                                      'Lactic Acid': 'lactate',
                                      'Platelet Count': 'platelets'})

**Doing the same thing as before: filling missing treatment observations with 0**

In [184]:
data_merged[['epinephrine', 'norepinephrine', 'phenylephrine', 'dobutamine', 'dopamine', 'milrinone', 'vent']] = data_merged[['epinephrine', 'norepinephrine', 'phenylephrine', 'dobutamine', 'dopamine', 'milrinone', 'vent']].fillna(0)

In [185]:
data_merged.head()

Unnamed: 0,time_bucket,stay_id,anchor_age,insurance,race,charlson,marital_status,gender,mortality,map,pao2,creatinine,hr,fio2,lactate,platelets,epinephrine,norepinephrine,phenylephrine,dobutamine,milrinone,dopamine,ventilation_status,vent,seq_num
0,2160-05-18 10:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,,421.0,,,,3.3,,0.0,0.0,0.0,0.0,0.0,0.0,,0,1
1,2160-05-18 11:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,,,,,,2.8,,0.0,0.0,0.0,0.0,0.0,0.0,,0,2
2,2160-05-18 12:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,,384.0,,,,3.1,,0.0,0.0,0.0,0.0,0.0,0.0,,0,3
3,2160-05-18 13:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,,311.0,,,,2.6,252.0,0.0,0.0,0.0,0.0,0.0,0.0,,0,4
4,2160-05-18 14:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,94.0,332.0,1.1,80.0,100.0,,254.0,0.0,0.0,0.701,0.0,0.0,0.0,InvasiveVent,1,5


**Merging data with OXYGEN DELIVERY FIO2 data from mimic IV derived table** 

In [186]:
# Importing FiO2 data from Mimic IV derived data - this specifically includes the Fio2 for when patients were on the vent. 
o2_delivery = pd.read_csv('../data/datasets/oxygen_delivery_cabg.csv')
o2_delivery['time_bucket'] = pd.to_datetime(o2_delivery['charttime']).dt.floor('H') # create 1hour time buckets
o2_delivery.drop('charttime', axis = 1, inplace = True)

# Keeping only rows where there is an oxygen delivery device 
o2_delivery = o2_delivery[~ (o2_delivery['o2_delivery_device_1'] == 'None')]
o2_delivery = o2_delivery[~(o2_delivery.o2_delivery_device_1.isna())]

# Dropping endotrachel tube
o2_delivery = o2_delivery[~ (o2_delivery['o2_delivery_device_1'] == 'Endotracheal tube')]

# Defining maximal flow rates for each device
max_flow_rates = {
    'Nasal cannula': 6,
    'Aerosol-cool': 15,
    'Face tent': 15,
    'High flow nasal cannula': 60,
    'Trach mask': 15,
    'High flow neb': 15,
    'CPAP mask': 60,
    'Bipap mask': 60,
    'Non-rebreather': 15,
    'Oxymizer': 15,
    'Medium conc mask': 10,
    'Other': np.nan,  # Variable; keep as NaN or assign a default if known
    'Venti mask': 15,
    'T-piece': 15,
    'Tracheostomy tube': 60,
    'Ultrasonic neb': 15}

# Capping flow rates with values capped by max flow rates
o2_delivery['o2_flow'] = o2_delivery.apply(lambda row: np.nan if row['o2_flow'] > max_flow_rates.get(row['o2_delivery_device_1'], np.nan) else row['o2_flow'], axis=1)
o2_delivery['o2_flow_additional'] = o2_delivery.apply(lambda row: np.nan if row['o2_flow_additional'] > max_flow_rates.get(row['o2_delivery_device_1'], np.nan) else row['o2_flow_additional'], axis=1)

# Imputing Fio2 based on flow rate and oxygen delivery device. Based on this paper: https://static-content.springer.com/esm/art%3A10.1038%2Fs41598-019-38491-0/MediaObjects/41598_2019_38491_MOESM1_ESM.pdf
def calculate_fio2(row):
    # Use the maximum of the two flow rates, or one if the other is NaN
    o2_flow = row['o2_flow'] if not np.isnan(row['o2_flow']) else 0
    o2_flow_additional = row['o2_flow_additional'] if not np.isnan(row['o2_flow_additional']) else 0
    max_flow = max(o2_flow, o2_flow_additional)
    
    # Apply FiO2 formulas based on device type
    device = row['o2_delivery_device_1']
    if device == 'Aerosol-cool':
        return min(60, 21 + (max_flow * 4))
    if device == 'Face tent':
        return min(40, 21 + (max_flow * 4) if max_flow > 0 else 25)
    elif device == 'Nasal cannula':
        return min(40, 21 + (max_flow * 4))
    elif device == 'High flow nasal cannula':
        return min(100, 48 + ((max_flow - 6) * 2) if max_flow >= 6 else 48)
    elif device == 'CPAP mask':
        return 40
    elif device == 'BIPAP mask':
        return 40    
    elif device == 'Non-rebreather':
        return min(100, 80 + (min((max_flow - 10), 2) * 10))
    elif device == 'Venti mask':
        return min(55, 26 + ((max_flow - 4) * 2.5) if max_flow >= 4 else 26)
    elif device == 'Oxymizer':
        return 40  # Default FiO2 for Oxymizer
    elif device == 'T-piece':
        return 40  # Default FiO2 for Trach mask
    elif device == 'Tracheostomy tube':
        return 40  # Default FiO2 for Trach mask
    else:
        return np.nan  # Return NaN if device type is unknown or not specified

# Apply the function to calculate FiO2
o2_delivery['fio2'] = o2_delivery.apply(calculate_fio2, axis=1)

o2_delivery.drop(['subject_id', 'o2_flow', 'o2_flow_additional'], axis = 1 ,inplace = True)
o2_delivery.sort_values(by=['stay_id','time_bucket'], inplace = True)

# Group by 'stay_id' and 'time_bucket' and apply aggregation
o2_delivery = o2_delivery.groupby(['stay_id', 'time_bucket']).agg({'fio2': 'mean', 'o2_delivery_device_1': 'first'}).reset_index()

# Dropping missing rows
o2_delivery = o2_delivery[~o2_delivery.fio2.isna()]

o2_delivery

Unnamed: 0,stay_id,time_bucket,fio2,o2_delivery_device_1
0,30001148,2156-08-30 18:00:00,40.0,Face tent
1,30001148,2156-08-30 20:00:00,40.0,Face tent
2,30001148,2156-08-31 00:00:00,37.0,Nasal cannula
3,30001148,2156-08-31 04:00:00,29.0,Nasal cannula
4,30004530,2165-08-01 05:00:00,60.0,Aerosol-cool
...,...,...,...,...
74601,39995735,2124-08-21 16:00:00,29.0,Nasal cannula
74602,39995735,2124-08-21 20:00:00,29.0,Nasal cannula
74603,39995735,2124-08-22 00:00:00,33.0,Nasal cannula
74604,39995735,2124-08-22 04:00:00,33.0,Nasal cannula


In [187]:
# Merging dataset
data_merged = pd.merge(data_merged, o2_delivery, on = ['stay_id', 'time_bucket'], how = 'left', suffixes = ('_original', '_delivery'))

# Replace missing values in 'fio2_original' with 'fio2_delivery' if 'fio2_delivery' is not missing
data_merged['fio2'] = np.where(data_merged['fio2_original'].isna() & ~data_merged['fio2_delivery'].isna(), data_merged['fio2_delivery'], data_merged['fio2_original'])

# Dropping columns
data_merged.drop(['fio2_delivery', 'fio2_original'], axis = 1, inplace = True)

In [188]:
data_merged

Unnamed: 0,time_bucket,stay_id,anchor_age,insurance,race,charlson,marital_status,gender,mortality,map,pao2,creatinine,hr,lactate,platelets,epinephrine,norepinephrine,phenylephrine,dobutamine,milrinone,dopamine,ventilation_status,vent,seq_num,o2_delivery_device_1,fio2
0,2160-05-18 10:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,,421.0,,,3.3,,0.0,0.0,0.000,0.0,0.0,0.0,,0,1,,
1,2160-05-18 11:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,,,,,2.8,,0.0,0.0,0.000,0.0,0.0,0.0,,0,2,,
2,2160-05-18 12:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,,384.0,,,3.1,,0.0,0.0,0.000,0.0,0.0,0.0,,0,3,,
3,2160-05-18 13:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,,311.0,,,2.6,252.0,0.0,0.0,0.000,0.0,0.0,0.0,,0,4,,
4,2160-05-18 14:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,94.0,332.0,1.1,80.0,,254.0,0.0,0.0,0.701,0.0,0.0,0.0,InvasiveVent,1,5,,100.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
319026,2185-02-03 10:00:00,34995866,66,Medicare,White,3,DIVORCED,M,False,,,,104.0,,,0.0,0.0,0.000,0.0,0.0,0.0,SupplementalOxygen,0,20,,
319027,2185-02-03 11:00:00,34995866,66,Medicare,White,3,DIVORCED,M,False,,,,82.0,,,0.0,0.0,0.000,0.0,0.0,0.0,SupplementalOxygen,0,21,,
319028,2185-02-03 12:00:00,34995866,66,Medicare,White,3,DIVORCED,M,False,,,,85.0,,,0.0,0.0,0.000,0.0,0.0,0.0,SupplementalOxygen,0,22,Nasal cannula,33.0
319029,2185-02-03 13:00:00,34995866,66,Medicare,White,3,DIVORCED,M,False,,,,79.0,,,0.0,0.0,0.000,0.0,0.0,0.0,,0,23,,


**Merging data with ventilator FiO2 data**

In [189]:
# Importing FiO2 data from Mimic IV derived data - this specifically includes the Fio2 for when patients were on the vent. 
fio2 = pd.read_csv('../data/datasets/fio2_vent.csv')
fio2['time_bucket'] = pd.to_datetime(fio2['charttime']).dt.floor('H') # create 1hour time buckets
fio2.drop(['subject_id', 'flow_rate','ventilator_mode', 'ventilator_mode_hamilton','ventilator_type', 'charttime'], axis = 1 ,inplace = True)
fio2.sort_values(by=['stay_id','time_bucket'], inplace = True)

# Group by 'stay_id' and 'time_bucket' and apply aggregation
fio2 = fio2.groupby(['stay_id', 'time_bucket']).mean().reset_index()

# Dropping missing rows
fio2 = fio2[~fio2.fio2.isna()]

In [190]:
# Merging
data_merged = pd.merge(data_merged, fio2, on = ['stay_id', 'time_bucket'], how = 'left', suffixes=('_original', '_vent'))

# Replace old FiO2 data with newly merged data - This will overwrite old fio2 values (even those from oxygen delivery table)
data_merged['fio2_original'] = np.where(~data_merged['fio2_vent'].isna(), data_merged['fio2_vent'], data_merged['fio2_original'])
data_merged.drop('fio2_vent', axis = 1, inplace = True)
data_merged.rename(columns = {'fio2_original': 'fio2'}, inplace = True)
data_merged

Unnamed: 0,time_bucket,stay_id,anchor_age,insurance,race,charlson,marital_status,gender,mortality,map,pao2,creatinine,hr,lactate,platelets,epinephrine,norepinephrine,phenylephrine,dobutamine,milrinone,dopamine,ventilation_status,vent,seq_num,o2_delivery_device_1,fio2
0,2160-05-18 10:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,,421.0,,,3.3,,0.0,0.0,0.000,0.0,0.0,0.0,,0,1,,
1,2160-05-18 11:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,,,,,2.8,,0.0,0.0,0.000,0.0,0.0,0.0,,0,2,,
2,2160-05-18 12:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,,384.0,,,3.1,,0.0,0.0,0.000,0.0,0.0,0.0,,0,3,,
3,2160-05-18 13:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,,311.0,,,2.6,252.0,0.0,0.0,0.000,0.0,0.0,0.0,,0,4,,
4,2160-05-18 14:00:00,39060235,53,Medicare,Other,7,SINGLE,F,False,94.0,332.0,1.1,80.0,,254.0,0.0,0.0,0.701,0.0,0.0,0.0,InvasiveVent,1,5,,100.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
319026,2185-02-03 10:00:00,34995866,66,Medicare,White,3,DIVORCED,M,False,,,,104.0,,,0.0,0.0,0.000,0.0,0.0,0.0,SupplementalOxygen,0,20,,
319027,2185-02-03 11:00:00,34995866,66,Medicare,White,3,DIVORCED,M,False,,,,82.0,,,0.0,0.0,0.000,0.0,0.0,0.0,SupplementalOxygen,0,21,,
319028,2185-02-03 12:00:00,34995866,66,Medicare,White,3,DIVORCED,M,False,,,,85.0,,,0.0,0.0,0.000,0.0,0.0,0.0,SupplementalOxygen,0,22,Nasal cannula,33.0
319029,2185-02-03 13:00:00,34995866,66,Medicare,White,3,DIVORCED,M,False,,,,79.0,,,0.0,0.0,0.000,0.0,0.0,0.0,,0,23,,


Checking correctness of ventilation_status by cross-tabulating with O2 delivery device variable

In [191]:
pd.crosstab(data_merged['ventilation_status'].fillna('Missing'), data_merged['o2_delivery_device_1'].fillna('Missing'), dropna=False, margins = True).T

ventilation_status,HFNC,InvasiveVent,NonInvasiveVent,None,SupplementalOxygen,Tracheostomy,All
o2_delivery_device_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
Aerosol-cool,1,1355,7,211,3077,18,4669
Bipap mask,0,0,0,1,0,0,1
CPAP mask,0,1,0,0,4,0,5
Face tent,1,508,8,93,1528,0,2138
High flow nasal cannula,824,4,12,39,36,0,915
High flow neb,0,1,0,0,2,0,3
Missing,1528,73384,1454,116195,84854,2260,279675
Nasal cannula,55,1118,102,1527,28544,3,31349
Non-rebreather,2,3,2,11,105,0,123
Oxymizer,11,0,0,0,89,0,100


In [192]:
# Need to recode rows in `ventilation_status` from None to SupplementalOxygen if O2 deliver_device_1 is not missing
data_merged['ventilation_status'] = np.where((data_merged['ventilation_status'] == 'None') & (~data_merged['o2_delivery_device_1'].isna()), 'SupplementalOxygen', data_merged['ventilation_status'])

**Recoding Outlier Values**

In [193]:
# Recoding MAP 
def recode_map(value):
    if value <= 0:
        return np.NaN
    elif value>=220:
        return np.NaN
    else:
        return value
    
data_merged['map'] = data_merged['map'].apply(recode_map)

# Recoding PAO2 
def recode_pao2(value):
    if value <= 0:
        return np.NaN
    elif value>=500:
        return np.NaN
    else:
        return value
    
data_merged['pao2'] = data_merged['pao2'].apply(recode_pao2)

# Recoding creatinine 
def recode_cr(value):
    if value <= 0:
        return np.NaN
    elif value>=999999:
        return np.NaN
    else:
        return value
    
data_merged['creatinine'] = data_merged['creatinine'].apply(recode_cr)

# Recoding FiO2 
def recode_fio2(value):
    if value <= 21:
        return 21
    else:
        return value
    
data_merged['fio2'] = data_merged['fio2'].apply(recode_fio2)

# No Need to recode HR

# Recoding Platelet Count & lactic acid (can use the same code as creatinine)
data_merged['platelets'] = data_merged['platelets'].apply(recode_cr)
data_merged['lactate'] = data_merged['lactate'].apply(recode_cr)

In [194]:
# Sort by Unique Stay and sequence number to get events in correct order
data_merged.sort_values(by=['stay_id', 'time_bucket', 'seq_num'], inplace=True)

In [195]:
data_merged.head()

Unnamed: 0,time_bucket,stay_id,anchor_age,insurance,race,charlson,marital_status,gender,mortality,map,pao2,creatinine,hr,lactate,platelets,epinephrine,norepinephrine,phenylephrine,dobutamine,milrinone,dopamine,ventilation_status,vent,seq_num,o2_delivery_device_1,fio2
76574,2165-07-31 12:00:00,30004530,63,Medicare,White,5,DIVORCED,M,False,,305.0,,,,,0.0,0.0,0.0,0.0,0.0,0.0,,0,1,,
76575,2165-07-31 13:00:00,30004530,63,Medicare,White,5,DIVORCED,M,False,73.333333,300.0,,71.0,,,0.0,0.0,1.0,0.0,0.0,0.0,InvasiveVent,1,2,,100.0
76576,2165-07-31 14:00:00,30004530,63,Medicare,White,5,DIVORCED,M,False,,,,84.0,,,0.0,0.0,0.0,0.0,0.0,0.0,InvasiveVent,1,3,,40.0
76577,2165-07-31 15:00:00,30004530,63,Medicare,White,5,DIVORCED,M,False,,95.0,,83.0,,,0.0,0.0,0.0,0.0,0.0,0.0,InvasiveVent,1,4,,40.0
76578,2165-07-31 16:00:00,30004530,63,Medicare,White,5,DIVORCED,M,False,65.0,96.0,,71.0,,,0.0,0.0,0.0,0.0,0.0,0.0,InvasiveVent,1,5,,


**Importing GCS Data**

In [196]:
gcs = pd.read_csv('../data/datasets/gcs_cabg.csv')
gcs['time_bucket'] = pd.to_datetime(gcs['charttime']).dt.floor('H')
gcs.drop('charttime', axis = 1, inplace = True)
gcs = gcs.groupby(['stay_id', 'time_bucket']).min().reset_index() # Keeping lowest GCS in an hour

# Merging with main data
data_merged = pd.merge(data_merged, gcs, on = ['stay_id', 'time_bucket'], how = 'left')

**HANDLING MISSING LAB/VITAL SIGNS VALUES**

*Forward/Backward filling of variables*: 

* Average between the preceding and succeeding value. 
* If no precedine value, apply backward filling. 
* If no next/succeeding value, apply forward filling

In [197]:
df = data_merged[['stay_id','map','pao2','creatinine','hr','fio2','lactate','platelets', 'gcs']]
df.head()

Unnamed: 0,stay_id,map,pao2,creatinine,hr,fio2,lactate,platelets,gcs
0,30004530,,305.0,,,,,,
1,30004530,73.333333,300.0,,71.0,100.0,,,
2,30004530,,,,84.0,40.0,,,15.0
3,30004530,,95.0,,83.0,40.0,,,
4,30004530,65.0,96.0,,71.0,,,,15.0


In [198]:
def handle_missing_values_optimized(df, id_column='stay_id'):
    """
    Optimized version of missing value handler for lab/vital signs data.
    Uses vectorized operations for better performance.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        Input dataframe containing the time series data
    id_column : str, default='stay_id'
        Column name containing the patient/stay identifier
    
    Returns:
    --------
    pandas.DataFrame
        DataFrame with imputed values
    """
    # Create a copy to avoid modifying the original dataframe
    df_processed = df.copy()
    
    # Get all columns except the ID column
    value_columns = [col for col in df.columns if col != id_column]
    
    # Process each group using vectorized operations
    def process_group_vectorized(group):
        for col in value_columns:
            # Skip if column has no missing values
            if not group[col].isna().any():
                continue
            
            # Get series of non-null values and their indices
            non_null = group[col].dropna()
            
            if len(non_null) == 0:
                continue
                
            # Get indices of missing and non-null values
            missing_idx = group[col].isna()
            
            # Calculate forward and backward fill values
            ffill_values = group[col].ffill()
            bfill_values = group[col].bfill()
            
            # Create arrays for interpolation
            x = np.arange(len(group))
            x_non_null = x[~missing_idx]
            y_non_null = group[col][~missing_idx]
            
            if len(x_non_null) > 1:
                # Use linear interpolation for values between known points
                interpolated = np.interp(
                    x[missing_idx],
                    x_non_null,
                    y_non_null
                )
                
                # Assign interpolated values
                group.loc[missing_idx, col] = interpolated
            else:
                # If only one non-null value exists, use forward/backward fill
                group.loc[missing_idx, col] = np.where(
                    pd.notna(ffill_values[missing_idx]),
                    ffill_values[missing_idx],
                    bfill_values[missing_idx]
                )
        
        return group
    
    # Apply the vectorized processing to each group
    df_processed = df_processed.groupby(id_column, group_keys=False).apply(process_group_vectorized)
    
    return df_processed

# Creating imputed dataset for select variables
df_imputed = handle_missing_values_optimized(df)

# Replace in original dataset
data_imputed = data_merged.copy()
data_imputed[['stay_id','map','pao2','creatinine','hr','fio2','lactate','platelets', 'gcs']] = df_imputed

In [199]:
# Imputing FiO2 values based on ventilation status (fio2 = 21% if no ventilation)
data_imputed['fio2'] = np.where(data_imputed['ventilation_status'] == 'None', 21, data_imputed['fio2'])

# Dropping ventilation variable
data_imputed.drop(['ventilation_status', 'o2_delivery_device_1'], axis = 1, inplace = True)

In [200]:
data_imputed

Unnamed: 0,time_bucket,stay_id,anchor_age,insurance,race,charlson,marital_status,gender,mortality,map,pao2,creatinine,hr,lactate,platelets,epinephrine,norepinephrine,phenylephrine,dobutamine,milrinone,dopamine,vent,seq_num,fio2,gcs
0,2165-07-31 12:00:00,30004530,63,Medicare,White,5,DIVORCED,M,False,73.333333,305.0,1.000000,71.0,1.3,141.0,0.0,0.0,0.0,0.0,0.0,0.0,0,1,21.0,15.0
1,2165-07-31 13:00:00,30004530,63,Medicare,White,5,DIVORCED,M,False,73.333333,300.0,1.000000,71.0,1.3,141.0,0.0,0.0,1.0,0.0,0.0,0.0,1,2,100.0,15.0
2,2165-07-31 14:00:00,30004530,63,Medicare,White,5,DIVORCED,M,False,70.555556,197.5,1.000000,84.0,1.3,141.0,0.0,0.0,0.0,0.0,0.0,0.0,1,3,40.0,15.0
3,2165-07-31 15:00:00,30004530,63,Medicare,White,5,DIVORCED,M,False,67.777778,95.0,1.000000,83.0,1.3,141.0,0.0,0.0,0.0,0.0,0.0,0.0,1,4,40.0,15.0
4,2165-07-31 16:00:00,30004530,63,Medicare,White,5,DIVORCED,M,False,65.000000,96.0,1.000000,71.0,1.3,141.0,0.0,0.0,0.0,0.0,0.0,0.0,1,5,40.0,15.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
319026,2124-08-27 15:00:00,39995735,62,Other,White,5,MARRIED,M,False,63.000000,98.0,1.197297,70.0,0.8,269.0,0.0,0.0,0.0,0.0,0.0,0.0,0,261,21.0,15.0
319027,2124-08-27 16:00:00,39995735,62,Other,White,5,MARRIED,M,False,65.000000,98.0,1.200000,70.0,0.8,269.0,0.0,0.0,0.0,0.0,0.0,0.0,0,262,21.0,15.0
319028,2124-08-27 17:00:00,39995735,62,Other,White,5,MARRIED,M,False,60.000000,98.0,1.200000,70.0,0.8,269.0,0.0,0.0,0.0,0.0,0.0,0.0,0,263,21.0,15.0
319029,2124-08-27 18:00:00,39995735,62,Other,White,5,MARRIED,M,False,60.000000,98.0,1.200000,70.0,0.8,269.0,0.0,0.0,0.0,0.0,0.0,0.0,0,264,21.0,15.0


If patents still have missing fio2 data, even after grouping, averaging, forward, and backward filling, we will recode Fio2 to be equal to 21% if patient is not of ventilator

Spot checking imputed dataset - looking more closely at rows/values that were not imputed

In [201]:
# Missing values for each column
df_imputed.isnull().sum().to_frame().T

Unnamed: 0,stay_id,map,pao2,creatinine,hr,fio2,lactate,platelets,gcs
0,0,821,430,3425,17,45,14921,1299,819


In [202]:
df_imputed.groupby('stay_id').apply(lambda x: x.isna().all()).sum().to_frame().T

Unnamed: 0,stay_id,map,pao2,creatinine,hr,fio2,lactate,platelets,gcs
0,0,23,9,119,1,3,368,42,10


In [203]:
# Get the number of patients that have missing values for ALL of the columns map, pao2, creatinine, hr, lactate, and platelet
df_imputed[df_imputed.loc[:, df_imputed.columns != 'stay_id'].isnull().all(axis=1)]['stay_id'].nunique()

0

In [204]:
# Get the number of patients that have missing values for ANY the columns map, pao2, creatinine, hr, lactate, and platelet
df_imputed[df_imputed.loc[:, df_imputed.columns != 'stay_id'].isnull().any(axis=1)]['stay_id'].nunique()

535

In [205]:
df_imputed['stay_id'].nunique()

4451

In [206]:
# Rearranging columns in a way that makes sense
data_imputed = data_imputed[
    [
        'stay_id', 'time_bucket', 'seq_num',   # Identifiers (add 'subject_id' if available)
        'anchor_age', 'gender', 'race', 'marital_status', 'insurance',  # Demographics
        'vent', 'charlson', # Clinical status
        'map', 'hr', 'pao2', 'fio2', 'creatinine', 'lactate', 'platelets', 'gcs', # Vital signs & lab results
        'epinephrine', 'norepinephrine', 'phenylephrine', 'dobutamine', 'milrinone', 'dopamine',  # Medications
        'mortality' # Outcomes
    ]
]


I will create 2 datasets:   
    1. First dataset will exclude patients that have missing values for ANY of the columns map, pao2, creatinine, hr, lactate, and platelet  
    2. The second dataset will impute missing values for these variables with normal physiologic values

In [207]:
# Complete data DF
complete_final_df = data_imputed.copy()
complete_final_df = complete_final_df[~df_imputed.loc[:, df_imputed.columns != 'stay_id'].isnull().any(axis=1)]
complete_final_df

Unnamed: 0,stay_id,time_bucket,seq_num,anchor_age,gender,race,marital_status,insurance,vent,charlson,map,hr,pao2,fio2,creatinine,lactate,platelets,gcs,epinephrine,norepinephrine,phenylephrine,dobutamine,milrinone,dopamine,mortality
0,30004530,2165-07-31 12:00:00,1,63,M,White,DIVORCED,Medicare,0,5,73.333333,71.0,305.0,21.0,1.000000,1.3,141.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,False
1,30004530,2165-07-31 13:00:00,2,63,M,White,DIVORCED,Medicare,1,5,73.333333,71.0,300.0,100.0,1.000000,1.3,141.0,15.0,0.0,0.0,1.0,0.0,0.0,0.0,False
2,30004530,2165-07-31 14:00:00,3,63,M,White,DIVORCED,Medicare,1,5,70.555556,84.0,197.5,40.0,1.000000,1.3,141.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,False
3,30004530,2165-07-31 15:00:00,4,63,M,White,DIVORCED,Medicare,1,5,67.777778,83.0,95.0,40.0,1.000000,1.3,141.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,False
4,30004530,2165-07-31 16:00:00,5,63,M,White,DIVORCED,Medicare,1,5,65.000000,71.0,96.0,40.0,1.000000,1.3,141.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
319026,39995735,2124-08-27 15:00:00,261,62,M,White,MARRIED,Other,0,5,63.000000,70.0,98.0,21.0,1.197297,0.8,269.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,False
319027,39995735,2124-08-27 16:00:00,262,62,M,White,MARRIED,Other,0,5,65.000000,70.0,98.0,21.0,1.200000,0.8,269.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,False
319028,39995735,2124-08-27 17:00:00,263,62,M,White,MARRIED,Other,0,5,60.000000,70.0,98.0,21.0,1.200000,0.8,269.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,False
319029,39995735,2124-08-27 18:00:00,264,62,M,White,MARRIED,Other,0,5,60.000000,70.0,98.0,21.0,1.200000,0.8,269.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,False


In [208]:
# Imputed data
imputed_final_df = data_imputed.copy()

# Fill missing values with fixed physiologic values
imputed_final_df['map'] = imputed_final_df['map'].fillna(75)
imputed_final_df['hr'] = imputed_final_df['hr'].fillna(70)
imputed_final_df['pao2'] = imputed_final_df['pao2'].fillna(100)
imputed_final_df['creatinine'] = imputed_final_df['creatinine'].fillna(0.8)
imputed_final_df['lactate'] = imputed_final_df['lactate'].fillna(0.7)
imputed_final_df['gcs'] = imputed_final_df['gcs'].fillna(15)

imputed_final_df

Unnamed: 0,stay_id,time_bucket,seq_num,anchor_age,gender,race,marital_status,insurance,vent,charlson,map,hr,pao2,fio2,creatinine,lactate,platelets,gcs,epinephrine,norepinephrine,phenylephrine,dobutamine,milrinone,dopamine,mortality
0,30004530,2165-07-31 12:00:00,1,63,M,White,DIVORCED,Medicare,0,5,73.333333,71.0,305.0,21.0,1.000000,1.3,141.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,False
1,30004530,2165-07-31 13:00:00,2,63,M,White,DIVORCED,Medicare,1,5,73.333333,71.0,300.0,100.0,1.000000,1.3,141.0,15.0,0.0,0.0,1.0,0.0,0.0,0.0,False
2,30004530,2165-07-31 14:00:00,3,63,M,White,DIVORCED,Medicare,1,5,70.555556,84.0,197.5,40.0,1.000000,1.3,141.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,False
3,30004530,2165-07-31 15:00:00,4,63,M,White,DIVORCED,Medicare,1,5,67.777778,83.0,95.0,40.0,1.000000,1.3,141.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,False
4,30004530,2165-07-31 16:00:00,5,63,M,White,DIVORCED,Medicare,1,5,65.000000,71.0,96.0,40.0,1.000000,1.3,141.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
319026,39995735,2124-08-27 15:00:00,261,62,M,White,MARRIED,Other,0,5,63.000000,70.0,98.0,21.0,1.197297,0.8,269.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,False
319027,39995735,2124-08-27 16:00:00,262,62,M,White,MARRIED,Other,0,5,65.000000,70.0,98.0,21.0,1.200000,0.8,269.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,False
319028,39995735,2124-08-27 17:00:00,263,62,M,White,MARRIED,Other,0,5,60.000000,70.0,98.0,21.0,1.200000,0.8,269.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,False
319029,39995735,2124-08-27 18:00:00,264,62,M,White,MARRIED,Other,0,5,60.000000,70.0,98.0,21.0,1.200000,0.8,269.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,False


In [209]:
print(f"Total number of unique patients in imputed dataset: {imputed_final_df.stay_id.nunique()}, and number of deaths: {imputed_final_df.groupby('stay_id').mortality.any().sum()} ({round((100*imputed_final_df.groupby('stay_id').mortality.any().sum())/imputed_final_df.stay_id.nunique(),2)}%)")
print(f"Total number of unique patients in compelte dataset: {complete_final_df.stay_id.nunique()}, and number of deaths: {complete_final_df.groupby('stay_id').mortality.any().sum()} ({round((100*complete_final_df.groupby('stay_id').mortality.any().sum())/complete_final_df.stay_id.nunique(),2)}%)")

Total number of unique patients in imputed dataset: 4451, and number of deaths: 58 (1.3%)
Total number of unique patients in compelte dataset: 3916, and number of deaths: 57 (1.46%)


## Data Augmentation

In [227]:
import deepecho
from sdv.sequential import PARSynthesizer
from sdv.metadata import Metadata

In [339]:
df_complete = complete_final_df.drop('seq_num', axis = 1).copy()

# Creating metadata

# Auto-detecting metdata
metadata = Metadata.detect_from_dataframes(data = {'df': df_complete})
metadata.set_sequence_index(column_name='time_bucket')
metadata.update_column(column_name='mortality', sdtype='boolean')
metadata.update_column(column_name='stay_id', sdtype='id',  regex_format = "[0-9]{8}")
metadata.update_column(column_name='gcs', sdtype = 'numerical', computer_representation='Int64')
metadata.update_column(column_name='charlson', sdtype='numerical', computer_representation='Int64')
metadata.set_sequence_key(column_name='stay_id')
metadata.validate()

# Creating synethetizer
synthesizer = PARSynthesizer(
    metadata,
    enforce_min_max_values = True,
    enforce_rounding = True,
    context_columns=['anchor_age', 'gender', 'race', 'marital_status', 'insurance', 'charlson', 'mortality'],
    verbose = True,
    cuda = True,
    epochs= 512,
)

# # Adding constraints
# charlson_constraint = {
#     'constraint_class': 'ScalarRange',
#     'constraint_parameters': {
#         'column_name': 'charlson',
#         'low_value': 0,
#         'high_value': 24,
#         'strict_boundaries': False}
# }

# synthesizer.add_constraints(constraints=[charlson_constraint])



In [None]:
synthesizer.fit(df_complete)

Loss (-0.073):  87%|████████▋ | 447/512 [8:40:44<1:16:31, 70.63s/it]

In [255]:
data_types = {
    'stay_id': 'categorical',           # Changed from set to string
    #'time_bucket': 'datetime',
    'anchor_age': 'continuous',
    'gender': 'categorical',
    'race': 'categorical',
    'marital_status': 'categorical',
    'insurance': 'categorical',
    'vent': 'categorical',
    'charlson': 'categorical',
    'mortality': 'categorical',
    'map': 'continuous',
    'hr': 'continuous',
    'pao2': 'continuous',
    'fio2': 'continuous',
    'creatinine': 'continuous',
    'platelets': 'continuous',
    'lactate': 'continuous',
    'gcs': 'continuous',
    'epinephrine': 'continuous',
    'norepinephrine': 'continuous',
    'phenylephrine': 'continuous',
    'dobutamine': 'continuous',
    'milrinone': 'continuous',
    'dopamine': 'continuous'
}

# Adding constraints
model.add_constraints(
    constraints=[""]
)

model = deepecho.PARModel(epochs=10, cuda=True)
model.fit(
    data=df_complete,
    data_types=data_types,
    entity_columns=['stay_id'],        # Changed to list
    context_columns=['anchor_age', 'gender', 'race', 'marital_status', 'insurance','charlson', 'mortality'],
    sequence_index='time_bucket'
)


Loss (0.023): 100%|██████████| 10/10 [11:16<00:00, 67.64s/it]


In [257]:
synthetic_data = model.sample(1000)
synthetic_data.head()

100%|██████████| 1000/1000 [03:27<00:00,  4.82it/s]


Unnamed: 0,stay_id,anchor_age,gender,race,marital_status,insurance,vent,charlson,map,hr,pao2,fio2,creatinine,lactate,platelets,gcs,epinephrine,norepinephrine,phenylephrine,dobutamine,milrinone,dopamine,mortality
0,0,75,M,White,MARRIED,Medicare,0,9,74.047609,83.169117,125.931684,28.27613,3.041878,28.635222,174.337555,14.741282,-0.007124,0.039957,0.091275,0.302379,0.01438,0.032075,False
1,0,75,M,White,MARRIED,Medicare,1,9,91.953187,83.169117,136.063508,34.687484,2.513777,28.635222,174.337555,14.254432,0.002187,-0.055327,-0.271954,0.039425,0.01438,0.032075,False
2,0,75,M,White,MARRIED,Medicare,0,9,85.063773,83.169117,136.063508,27.336356,2.376842,28.635222,150.721411,15.242582,-0.009501,0.008452,0.535625,-0.034628,0.196989,-0.08919,False
3,0,75,M,White,MARRIED,Medicare,1,9,65.963619,83.169117,86.559377,34.687484,1.568552,28.635222,174.337555,16.256231,0.002187,0.008649,0.108718,0.657287,-0.01217,0.032075,False
4,0,75,M,White,MARRIED,Medicare,0,9,74.047609,83.169117,219.234408,34.687484,1.568552,3574.394318,79.161357,14.694727,0.002187,-0.020644,-0.135681,0.039425,-0.002262,0.192302,False


In [277]:
# Number of deaths in synthetic data
synthetic_data.groupby('stay_id').mortality.any().sum()

58

In [282]:
complete_final_df.race.describe()

count     298389
unique         6
top        White
freq      254212
Name: race, dtype: object

## Model Building

In [107]:
# Creating a copy of the complete data to work with
model_df_complete = complete_final_df.copy()

Encoding categorical columns

In [108]:
label_encoder = LabelEncoder()

# Encode categorical columns
categorical_columns = ['insurance', 'gender', 'race', 'marital_status', 'vent', 'charlson']
for col in categorical_columns:
    model_df_complete[col] = label_encoder.fit_transform(model_df_complete[col].astype(str))

In [109]:
# Encode target variable `Died` (binary classification: 0 = Alive, 1 = Died)
model_df_complete['mortality'] = label_encoder.fit_transform(model_df_complete['mortality'])

In [110]:
model_df_complete

Unnamed: 0,stay_id,time_bucket,seq_num,anchor_age,gender,race,marital_status,insurance,vent,charlson,map,hr,pao2,fio2,creatinine,lactate,platelets,gcs,epinephrine,norepinephrine,phenylephrine,dobutamine,milrinone,dopamine,mortality
0,30004530,2165-07-31 12:00:00,1,63,1,5,0,1,0,10,73.333333,71.0,305.0,21.0,1.000000,1.3,141.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,0
1,30004530,2165-07-31 13:00:00,2,63,1,5,0,1,1,10,73.333333,71.0,300.0,100.0,1.000000,1.3,141.0,15.0,0.0,0.0,1.0,0.0,0.0,0.0,0
2,30004530,2165-07-31 14:00:00,3,63,1,5,0,1,1,10,70.555556,84.0,197.5,40.0,1.000000,1.3,141.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,0
3,30004530,2165-07-31 15:00:00,4,63,1,5,0,1,1,10,67.777778,83.0,95.0,40.0,1.000000,1.3,141.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,0
4,30004530,2165-07-31 16:00:00,5,63,1,5,0,1,1,10,65.000000,71.0,96.0,40.0,1.000000,1.3,141.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
319026,39995735,2124-08-27 15:00:00,261,62,1,5,1,2,0,10,63.000000,70.0,98.0,21.0,1.197297,0.8,269.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,0
319027,39995735,2124-08-27 16:00:00,262,62,1,5,1,2,0,10,65.000000,70.0,98.0,21.0,1.200000,0.8,269.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,0
319028,39995735,2124-08-27 17:00:00,263,62,1,5,1,2,0,10,60.000000,70.0,98.0,21.0,1.200000,0.8,269.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,0
319029,39995735,2124-08-27 18:00:00,264,62,1,5,1,2,0,10,60.000000,70.0,98.0,21.0,1.200000,0.8,269.0,15.0,0.0,0.0,0.0,0.0,0.0,0.0,0


In [111]:
model_df_complete.drop(columns = ['time_bucket'], inplace = True)

In [112]:
# Create sequences of features for each Unique Stay
sequence_data = []
sequence_labels = []

unique_stays = model_df_complete['stay_id'].unique()

In [113]:
Stop

NameError: name 'Stop' is not defined

Iterate over unique stays

In [None]:
for stay in unique_stays:
    stay_data = model_df_complete[model_df_complete['stay_id'] == stay]
    features = stay_data[["anchor_age", "gender", "race", "marital_status", "insurance", "charlson", "vent", "map", "hr", "pao2", "fio2", "gcs", "creatinine", "lactate", "platelets", "epinephrine", "norepinephrine", "phenylephrine", "dobutamine", "milrinone", "dopamine"]].values
    label = stay_data['mortality'].values[-1]  # Use the last event to define the label (it doesn't really matter if we're looking to build classification model, but for survival model, it does!)
    
    # Add the sequence and its corresponding label
    sequence_data.append(features)
    sequence_labels.append(label)

Pad sequences to ensure uniform input length

In [None]:
import keras

In [None]:
# Summary of number of rows per patient
model_df_complete.groupby('stay_id').seq_num.count().quantile([0.1, 0.25, 0.5, 0.75, 0.90, 0.99])

0.10     24.0
0.25     27.0
0.50     45.0
0.75     77.0
0.90    143.5
0.99    584.1
Name: seq_num, dtype: float64

In [None]:
# from tensorflow.keras.utils import pad_sequences
# Padding sequences based on 90th percentile of stay duration
# Padding with values of -1
sequence_data = keras.utils.pad_sequences(sequence_data, padding='post', maxlen=150, value = -1)  # Adjust maxlen as needed

In [None]:
# Convert labels to numpy array
sequence_labels = np.array(sequence_labels)

Train-test split

In [None]:
X_train, X_test, y_train, y_test = train_test_split(sequence_data, sequence_labels, test_size=0.2, random_state=42)

Define custom wrapper for the Keras model to use with RandomizedSearchCV

In [None]:
from sklearn.base import BaseEstimator, ClassifierMixin
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM, Masking
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import Precision, Recall, AUC
import numpy as np

In [None]:
class KerasModelWrapper(BaseEstimator, ClassifierMixin):
    def __init__(self, learning_rate=0.001, lstm_units=64, dropout_rate=0.2, epochs=10, batch_size=32):
        self.learning_rate = learning_rate
        self.lstm_units = lstm_units
        self.dropout_rate = dropout_rate
        self.epochs = epochs
        self.batch_size = batch_size
        self.model = None
        self.classes_ = np.array([0, 1])  # Binary classification
        
    def create_model(self):
        model = Sequential([
            layers.InputLayer(shape=(X_train.shape[1], X_train.shape[2])),
            layers.Masking(mask_value=-1),
            layers.LSTM(self.lstm_units, return_sequences=False, dropout=self.dropout_rate),
            layers.Dense(1, activation='sigmoid')
        ])
        
        metrics = [
            'accuracy',
            Precision(name='precision'),
            Recall(name='recall'),
            AUC(name='auc')
        ]
        
        model.compile(optimizer=Adam(learning_rate=self.learning_rate), 
                     loss='binary_crossentropy', 
                     metrics=metrics)
        return model
        
    def fit(self, X, y):
        # Store unique classes
        self.classes_ = np.unique(y)
        
        self.model = self.create_model()
        history = self.model.fit(
            X, y,
            epochs=self.epochs,
            batch_size=self.batch_size,
            verbose=0
        )
        return self
        
    def predict(self, X):
        """Return predicted class labels"""
        probas = self.model.predict(X, verbose=0)
        return (probas > 0.5).astype(int).ravel()
    
    def predict_proba(self, X):
        """Return probability estimates for both classes"""
        probas = self.model.predict(X, verbose=0)
        # Return probabilities for both classes (negative and positive)
        return np.hstack([1 - probas, probas])

    def score(self, X, y):
        """Return accuracy score"""
        scores = self.model.evaluate(X, y, verbose=0)
        return scores[1]  # Return accuracy
    
#     def predict(self, X):
#         return self.model.predict(X)
        
#     def score(self, X, y):
#         # Convert y to float32 here as well
#         # y = tf.cast(y)#, tf.float32)
#         scores = self.model.evaluate(X, y, verbose=0)
#         return dict(zip(self.model.metrics_names, scores))


Instantiate the model wrapper, define the parameters for grid tuning, and conduct RandomizedSearchCV to tune hyperparameters

In [None]:
from sklearn.model_selection import RandomizedSearchCV

param_distributions = {
    'learning_rate': [0.001, 0.01, 0.1],
    'lstm_units': [32, 64, 128],
    'dropout_rate': [0.2, 0.3, 0.5],
    'epochs': [10, 20, 50],
    'batch_size': [16, 32, 64]
}

# Create the wrapper model
model = KerasModelWrapper()

# Create RandomizedSearchCV with specific scoring metrics
random_search = RandomizedSearchCV(
    estimator=model,
    param_distributions=param_distributions,
    n_iter=10,
    cv=3,
    scoring={
        'accuracy': 'accuracy',
        'precision': 'precision',
        'recall': 'recall',
        'auc': 'roc_auc'
    },
    refit='accuracy'  # Choose which metric to use for selecting the best model
)

In [None]:
# Fit the random search
random_search_result = random_search.fit(X_train, y_train)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Best parameters and score

In [None]:
print("Best Hyperparameters:", random_search_result.best_params_)
print("Best Accuracy:", random_search_result.best_score_)

Best Hyperparameters: {'lstm_units': 128, 'learning_rate': 0.001, 'epochs': 50, 'dropout_rate': 0.2, 'batch_size': 32}
Best Accuracy: 0.9881864623243933


Evaluate the best model on the test set

In [None]:
best_model = random_search_result.best_estimator_
test_acc = best_model.score(X_test, y_test)
print(f'Test Accuracy: {test_acc:.4f}')

Test Accuracy: 0.9911


In [None]:
# Get all the scores for the best model
best_model_scores = random_search_result.cv_results_['mean_test_auc'][random_search_result.best_index_]
print(f"Best model AUC: {best_model_scores}")

# You can also get all metrics for the best model:
best_metrics = {
    'accuracy': random_search_result.cv_results_['mean_test_accuracy'][random_search_result.best_index_],
    'precision': random_search_result.cv_results_['mean_test_precision'][random_search_result.best_index_],
    'recall': random_search_result.cv_results_['mean_test_recall'][random_search_result.best_index_],
    'auc': random_search_result.cv_results_['mean_test_auc'][random_search_result.best_index_]
}

print("\nBest model metrics:")
for metric, value in best_metrics.items():
    print(f"{metric}: {value:.4f}")

# If you want to get the parameters that achieved these scores:
print("\nBest parameters:", random_search_result.best_params_)

Best model AUC: 0.8871190012970168

Best model metrics:
accuracy: 0.9882
precision: 0.9333
recall: 0.2500
auc: 0.8871

Best parameters: {'lstm_units': 128, 'learning_rate': 0.001, 'epochs': 50, 'dropout_rate': 0.2, 'batch_size': 32}


In [None]:
pd.set_option('display.max_colwidth', None)  # Display full content in each column

In [None]:
# Get all parameter combinations and their corresponding AUC scores
all_results = pd.DataFrame(random_search_result.cv_results_)
results_summary = all_results[['params', 'mean_test_auc', 'mean_test_accuracy']]
results_summary = results_summary.sort_values('mean_test_auc', ascending=False)
print("\nAll parameter combinations and their AUC scores:")
print(results_summary)


All parameter combinations and their AUC scores:
                                                                                             params  mean_test_auc  mean_test_accuracy
2  {'lstm_units': 128, 'learning_rate': 0.001, 'epochs': 50, 'dropout_rate': 0.2, 'batch_size': 32}       0.887119            0.988186
6   {'lstm_units': 128, 'learning_rate': 0.01, 'epochs': 10, 'dropout_rate': 0.3, 'batch_size': 32}       0.701625            0.984674
9    {'lstm_units': 32, 'learning_rate': 0.01, 'epochs': 20, 'dropout_rate': 0.2, 'batch_size': 32}       0.662988            0.984674
4    {'lstm_units': 32, 'learning_rate': 0.01, 'epochs': 20, 'dropout_rate': 0.5, 'batch_size': 32}       0.651406            0.984674
8    {'lstm_units': 64, 'learning_rate': 0.01, 'epochs': 50, 'dropout_rate': 0.3, 'batch_size': 16}       0.616438            0.984355
0   {'lstm_units': 128, 'learning_rate': 0.01, 'epochs': 10, 'dropout_rate': 0.2, 'batch_size': 16}       0.612070            0.984355
3    

Save the best model

In [None]:
best_model = random_search_result

Save the best model to a file

In [None]:
model_save_path = 'lstm_best_model.keras'
best_model.model.save(model_save_path)

AttributeError: 'RandomizedSearchCV' object has no attribute 'model'

In [None]:
print(f"Best model saved at: {model_save_path}")

Evaluate the best model on the test set

In [None]:
test_acc = best_model.score(X_test, y_test)
print(f'Test Accuracy: {test_acc:.4f}')

Load the saved model and re-evaluate on the test set

In [None]:
loaded_model = load_model(model_save_path)
loaded_test_loss, loaded_test_acc = loaded_model.evaluate(X_test, y_test, verbose=0)

In [None]:
print(f"Loaded Model Test Accuracy: {loaded_test_acc:.4f}")