In [61]:
# autoreload
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [62]:
import os
import sys
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

In [63]:
mimic_iv_path = "/data/wang/junh/datasets/physionet.org/files/mimiciv/2.2"
mm_dir = "/data/wang/junh/datasets/multimodal"

output_dir = os.path.join(mm_dir, "preprocessing")

In [64]:
restrict_48_hours = True
include_notes = True
include_cxr = True
include_ecg = False
standard_scale = True
include_missing = False

In [65]:
ireg_vitals_ts_df = pd.read_pickle(os.path.join(output_dir, "ts_labs_vitals.pkl"))
ireg_vitals_ts_df = ireg_vitals_ts_df.drop(columns=["hosp_time_delta"])
ireg_vitals_ts_df.rename(columns={'icu_time_delta': 'timedelta'}, inplace=True)
imputed_vitals = pd.read_pickle(os.path.join(output_dir, "imputed_ts_labs_vitals.pkl"))

# ireg_vitals_ts_df = pd.read_pickle(os.path.join(output_dir, "ts_labs_vitals_icu.pkl"))
# imputed_vitals = pd.read_pickle(os.path.join(output_dir, "imputed_ts_labs_vitals_icu.pkl"))
print(ireg_vitals_ts_df.columns)
print(imputed_vitals.columns)

Index(['subject_id', 'hadm_id', 'stay_id', 'timedelta', 'Diastolic BP',
       'GCS - Eye Opening', 'GCS - Motor Response', 'GCS - Verbal Response',
       'Heart Rate', 'Mean BP', 'O2 Saturation', 'Respiratory Rate',
       'Systolic BP'],
      dtype='object')
Index(['subject_id', 'hadm_id', 'stay_id', 'timedelta', 'Diastolic BP',
       'GCS - Eye Opening', 'GCS - Motor Response', 'GCS - Verbal Response',
       'Heart Rate', 'Mean BP', 'O2 Saturation', 'Respiratory Rate',
       'Systolic BP'],
      dtype='object')


In [66]:

ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['timedelta'] >= 0]
imputed_vitals = imputed_vitals[imputed_vitals['timedelta'] >= 0]

if restrict_48_hours:
    ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['timedelta'] <= 48]
    imputed_vitals = imputed_vitals[imputed_vitals['timedelta'] <= 48]

In [67]:
if include_notes:
    notes_df = pd.read_pickle(os.path.join(output_dir, "icu_notes_text_embeddings.pkl"))
    print(notes_df.columns)
    # notes_df = pd.read_pickle(os.path.join(output_dir, "notes_text.pkl"))
    notes_df = notes_df[notes_df['stay_id'].notnull()]

    notes_df = notes_df[notes_df['icu_time_delta'] >= 0]
    if restrict_48_hours:
        notes_df = notes_df[notes_df['icu_time_delta'] <= 48]

if include_cxr:
    cxr_df = pd.read_pickle(os.path.join(output_dir, "cxr_embeddings_stay.pkl"))
    print(cxr_df.columns)
    cxr_df = cxr_df[cxr_df['icu_time_delta'] >= 0]
    if restrict_48_hours:
        cxr_df = cxr_df[cxr_df['icu_time_delta'] <= 48]

if include_ecg:
    ecg_df = pd.read_pickle(os.path.join(output_dir, "ecg_embeddings_icu.pkl"))
    ecg_df = ecg_df[ecg_df['icu_time_delta'] >= 0]
    if restrict_48_hours:
        ecg_df = ecg_df[ecg_df['icu_time_delta'] <= 48]

Index(['note_id', 'subject_id', 'hadm_id', 'note_type', 'note_seq',
       'charttime', 'storetime', 'text', 'stay_id', 'icu_time_delta',
       'hosp_time_delta', 'biobert_embeddings'],
      dtype='object')
Index(['dicom_id', 'subject_id', 'study_id',
       'PerformedProcedureStepDescription', 'ViewPosition',
       'ProcedureCodeSequence_CodeMeaning', 'ViewCodeSequence_CodeMeaning',
       'PatientOrientationCodeSequence_CodeMeaning', 'densefeatures',
       'predictions', 'cxrtime', 'hadm_id', 'stay_id', 'icu_time_delta',
       'hosp_time_delta'],
      dtype='object')


In [68]:
icustays_df = pd.read_csv(os.path.join(mimic_iv_path, "icu", "icustays.csv.gz"), low_memory=False)
icustays_df['intime'] = pd.to_datetime(icustays_df['intime'])
icustays_df['outtime'] = pd.to_datetime(icustays_df['outtime'])

if restrict_48_hours:
    icustays_df = icustays_df[icustays_df['los'] >= 2]

In [69]:
valid_stay_ids = icustays_df['stay_id'].unique()

ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'].isin(valid_stay_ids)]
imputed_vitals = imputed_vitals[imputed_vitals['stay_id'].isin(valid_stay_ids)]

if include_notes:
    notes_df = notes_df[notes_df['stay_id'].isin(valid_stay_ids)]

if include_cxr:
    cxr_df = cxr_df[cxr_df['stay_id'].isin(valid_stay_ids)]

if include_ecg:
    ecg_df = ecg_df[ecg_df['stay_id'].isin(valid_stay_ids)]

In [70]:
admissions_df = pd.read_csv(os.path.join(mimic_iv_path, "hosp", "admissions.csv.gz"))
admissions_df = admissions_df.rename(columns={"hospital_expire_flag": "died"})
admissions_df = admissions_df[["subject_id", "hadm_id", "died"]]

In [71]:

if not include_missing:
    unique_stays = ireg_vitals_ts_df['stay_id'].unique()
    print(f"Number of stays with vitals: {len(unique_stays)}")

    if include_notes:
        unique_stays = np.intersect1d(unique_stays, notes_df['stay_id'].unique())
        print(f"Number of stays with notes: {len(unique_stays)}")

    if include_cxr:
        unique_stays = np.intersect1d(unique_stays, cxr_df['stay_id'].unique())
        print(f"Number of stays with cxr: {len(unique_stays)}")

    if include_ecg:
        unique_stays = np.intersect1d(unique_stays, ecg_df['stay_id'].unique())
        print(f"Number of stays with ecg: {len(unique_stays)}")
else:
    unique_stays = ireg_vitals_ts_df['stay_id'].unique()
    print(f"Number of stays with vitals: {len(unique_stays)}")

    if include_notes:
        # Get stays with either TS or notes
        unique_stays = np.union1d(unique_stays, notes_df['stay_id'].unique())
        print(f"Number of stays with either TS or notes: {len(unique_stays)}")
    
    if include_cxr:
        unique_stays = np.union1d(unique_stays, cxr_df['stay_id'].unique())
        print(f"Number of stays with either TS, notes, cxr: {len(unique_stays)}")
    
    if include_ecg:
        unique_stays = np.union1d(unique_stays, ecg_df['stay_id'].unique())
        print(f"Number of stays with either TS, notes, cxr, ecg: {len(unique_stays)}")

Number of stays with vitals: 35078
Number of stays with notes: 31987
Number of stays with cxr: 8770


In [72]:
# Create train, val, test splits
np.random.seed(0)
np.random.shuffle(unique_stays)
train_stays = unique_stays[:int(0.7*len(unique_stays))]
val_stays = unique_stays[int(0.7*len(unique_stays)):int(0.85*len(unique_stays))]
test_stays = unique_stays[int(0.85*len(unique_stays)):]

In [73]:
train_ireg_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'].isin(train_stays)].copy()
train_imputed_df = imputed_vitals[imputed_vitals['stay_id'].isin(train_stays)].copy()

cols = train_ireg_ts_df.columns.tolist()
cols = [col for col in cols if col not in ['subject_id', 'hadm_id', 'stay_id', 'timedelta']]

if standard_scale:
    for col in cols:
        scaler = StandardScaler()
        scaler.fit(train_ireg_ts_df[[col]])
        ireg_vitals_ts_df[col] = scaler.transform(ireg_vitals_ts_df[[col]])

        scaler = StandardScaler()
        scaler.fit(train_imputed_df[[col]])
        imputed_vitals[col] = scaler.transform(imputed_vitals[[col]])


In [74]:
def get_stay_list(stays):
    stays_list = []

    for curr_stay in tqdm(stays):
        
        curr_stay_ireg = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'] == curr_stay].copy()
        #print(f"Initial irregular time series shape (before dropping columns): {curr_stay_ireg.shape}")
        curr_stay_imputed = imputed_vitals[imputed_vitals['stay_id'] == curr_stay].copy()

        if len(curr_stay_ireg) == 0:
            continue

        curr_stay_dict = {}
        curr_stay_dict['name'] = curr_stay_ireg['subject_id'].iloc[0]
        curr_stay_dict['hadm_id'] = curr_stay_ireg['hadm_id'].iloc[0]
        curr_stay_dict['stay_id'] = curr_stay
        curr_stay_dict['ts_tt'] = curr_stay_ireg['timedelta'].values

        curr_stay_ireg.drop(columns=['subject_id', 'hadm_id', 'stay_id', 'timedelta'], inplace=True)
        #print(f"Irregular time series shape (after dropping columns): {curr_stay_ireg.shape}")
        ireg_ts_mask = curr_stay_ireg.notnull()
        curr_stay_ireg.fillna(0, inplace=True)
        #print(f"Number of features in irregular time series: {curr_stay_ireg.shape[1]}")
        curr_stay_dict['irg_ts'] = curr_stay_ireg.values
        curr_stay_dict['irg_ts_mask'] = ireg_ts_mask.values.astype(int)

        curr_stay_imputed.drop(columns=['subject_id', 'hadm_id', 'stay_id', 'timedelta'], inplace=True)
        curr_stay_dict['reg_ts'] = curr_stay_imputed.values

        if include_notes:
            curr_stay_notes = notes_df[notes_df['stay_id'] == curr_stay].copy()

            if len(curr_stay_notes) == 0:
                curr_stay_dict['text_data'] = []
                curr_stay_dict['text_time_to_end'] = []
                curr_stay_dict['text_embeddings'] = []
                curr_stay_dict['text_missing'] = 1
            else:
                curr_stay_dict['text_data'] = curr_stay_notes['text'].tolist()
                curr_stay_dict['text_time_to_end'] = curr_stay_notes['icu_time_delta'].values
                curr_stay_dict['text_embeddings'] = [emb[0][0] for emb in curr_stay_notes['biobert_embeddings']]
                curr_stay_dict['text_missing'] = 0

        if include_cxr:
            curr_stay_cxr = cxr_df[cxr_df['stay_id'] == curr_stay].copy()
            
            if len(curr_stay_cxr) == 0:
                curr_stay_dict['cxr_feats'] = []
                curr_stay_dict['cxr_time'] = []
                curr_stay_dict['cxr_missing'] = 1
            else:
                curr_stay_dict['cxr_feats'] = curr_stay_cxr['densefeatures'].tolist()
                curr_stay_dict['cxr_time'] = curr_stay_cxr['icu_time_delta'].values
                curr_stay_dict['cxr_missing'] = 0
                #print(f"Number of CXR features: {len(curr_stay_dict['cxr_feats'])}")

        if include_ecg:
            curr_stay_ecg = ecg_df[ecg_df['stay_id'] == curr_stay].copy()
            if len(curr_stay_ecg) == 0:
                curr_stay_dict['ecg_feats'] = []
                curr_stay_dict['ecg_time'] = []
                curr_stay_dict['ecg_missing'] = 1
            else:
                curr_stay_dict['ecg_feats'] = curr_stay_ecg['embeddings'].tolist()
                curr_stay_dict['ecg_time'] = curr_stay_ecg['icu_time_delta'].values
                curr_stay_dict['ecg_missing'] = 0

        curr_stay_dict['label'] = admissions_df[admissions_df['hadm_id'] == curr_stay_dict['hadm_id']]['died'].iloc[0]

        stays_list.append(curr_stay_dict)

    return stays_list

train_stays_list = get_stay_list(train_stays)
val_stays_list = get_stay_list(val_stays)
test_stays_list = get_stay_list(test_stays)

  0%|          | 0/6139 [00:00<?, ?it/s]

100%|██████████| 6139/6139 [01:13<00:00, 83.76it/s]
100%|██████████| 1315/1315 [00:15<00:00, 83.11it/s]
100%|██████████| 1316/1316 [00:15<00:00, 83.14it/s]


In [75]:
# Example to check the first patient stay
first_stay = train_stays[18]
first_stay_data = get_stay_list([first_stay])
print(f"Data for the first stay: {first_stay_data}")


  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00, 76.61it/s]

Data for the first stay: [{'name': 10459005, 'hadm_id': 25159727, 'stay_id': 32245271, 'ts_tt': array([4.778888888888889, 4.812222222222222, 4.828888888888889,
       4.845555555555555, 5.9622222222222225, 8.72888888888889,
       8.745555555555555, 8.778888888888888, 8.895555555555555,
       9.028888888888888, 9.045555555555556, 9.278888888888888,
       9.528888888888888, 9.662222222222223, 9.778888888888888,
       9.812222222222223, 10.028888888888888, 10.278888888888888,
       10.312222222222223, 11.278888888888888, 12.278888888888888,
       13.278888888888888, 13.778888888888888, 14.278888888888888,
       15.278888888888888, 16.27888888888889, 17.27888888888889,
       17.862222222222222, 18.27888888888889, 19.27888888888889,
       20.27888888888889, 21.27888888888889, 21.52888888888889,
       21.77888888888889, 22.02888888888889, 22.27888888888889,
       22.52888888888889, 22.77888888888889, 23.02888888888889,
       23.27888888888889, 24.27888888888889, 25.27888888888889




In [76]:
# Save the data
import pickle

base_name = "ihm"
if restrict_48_hours:
    base_name += "-48"
else:
    base_name += "-all"

if include_cxr:
    if include_notes:
        base_name += "-cxr-notes"
    else:
        base_name += "-cxr"

if include_ecg:
    base_name += "-ecg"

if include_missing:
    base_name += "-missingInd"

f_path = os.path.join(output_dir, f"train_{base_name}_stays.pkl")
with open(f_path, 'wb') as f:
    print(f"Saving train stays to {f_path}")
    pickle.dump(train_stays_list, f)

f_path = os.path.join(output_dir, f"val_{base_name}_stays.pkl")
with open(f_path, 'wb') as f:
    print(f"Saving val stays to {f_path}")
    pickle.dump(val_stays_list, f)

f_path = os.path.join(output_dir, f"test_{base_name}_stays.pkl")
with open(f_path, 'wb') as f:
    print(f"Saving test stays to {f_path}")
    pickle.dump(test_stays_list, f)


Saving train stays to /data/wang/junh/datasets/multimodal/preprocessing/train_ihm-48-cxr-notes_stays.pkl
Saving val stays to /data/wang/junh/datasets/multimodal/preprocessing/val_ihm-48-cxr-notes_stays.pkl
Saving test stays to /data/wang/junh/datasets/multimodal/preprocessing/test_ihm-48-cxr-notes_stays.pkl
