In [1]:
# autoreload
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

In [3]:
mimic_iv_path = "/data/wang/junh/datasets/physionet.org/files/mimiciv/2.2"
mm_dir = "/data/wang/junh/datasets/multimodal"

output_dir = os.path.join(mm_dir, "preprocessing")

In [4]:
restrict_48_hours = True
include_notes = True
include_cxr = True
include_ecg = True
standard_scale = True
include_missing = True

In [5]:
ireg_vitals_ts_df = pd.read_pickle(os.path.join(output_dir, "ts_labs_vitals_new.pkl"))
# ireg_vitals_ts_df = ireg_vitals_ts_df.drop(columns=["hosp_time_delta"])
# ireg_vitals_ts_df.rename(columns={'icu_time_delta': 'timedelta'}, inplace=True)
imputed_vitals = pd.read_pickle(os.path.join(output_dir, "imputed_ts_labs_vitals_new_hosp.pkl"))

# ireg_vitals_ts_df = pd.read_pickle(os.path.join(output_dir, "ts_labs_vitals_icu.pkl"))
# imputed_vitals = pd.read_pickle(os.path.join(output_dir, "imputed_ts_labs_vitals_icu.pkl"))
print(ireg_vitals_ts_df.columns)
print(imputed_vitals.columns)

Index(['subject_id', 'hadm_id', 'stay_id', 'hosp_time_delta', 'icu_time_delta',
       'Anion Gap', 'Bicarbonate', 'Calcium, Total', 'Chloride', 'Creatinine',
       'Diastolic BP', 'GCS - Eye Opening', 'GCS - Motor Response',
       'GCS - Verbal Response', 'Glucose', 'Heart Rate', 'Hematocrit',
       'Hemoglobin', 'MCH', 'MCHC', 'MCV', 'Magnesium', 'Mean BP',
       'Neutrophils', 'O2 Saturation', 'Phosphate', 'Platelet Count', 'RDW',
       'Red Blood Cells', 'Respiratory Rate', 'Sodium', 'Systolic BP',
       'Urea Nitrogen', 'Vancomycin', 'White Blood Cells'],
      dtype='object')
Index(['subject_id', 'hadm_id', 'stay_id', 'hosp_time_delta', 'icu_time_delta',
       'Anion Gap', 'Bicarbonate', 'Calcium, Total', 'Chloride', 'Creatinine',
       'Diastolic BP', 'GCS - Eye Opening', 'GCS - Motor Response',
       'GCS - Verbal Response', 'Glucose', 'Heart Rate', 'Hematocrit',
       'Hemoglobin', 'MCH', 'MCHC', 'MCV', 'Magnesium', 'Mean BP',
       'Neutrophils', 'O2 Saturation', '

In [6]:
print(len(ireg_vitals_ts_df.columns))

35


In [7]:
ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['hosp_time_delta'] >= 0]
imputed_vitals = imputed_vitals[imputed_vitals['hosp_time_delta'] >= 0]

if restrict_48_hours:
    ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['hosp_time_delta'] <= 48]
    imputed_vitals = imputed_vitals[imputed_vitals['hosp_time_delta'] <= 48]

In [33]:
if include_notes:
    #notes_df = pd.read_pickle(os.path.join(output_dir, "icu_notes_text_embeddings.pkl"))
    notes_df = pd.read_pickle(os.path.join(output_dir, "merge_notes_text.pkl"))
    print(notes_df.columns)
    # notes_df = pd.read_pickle(os.path.join(output_dir, "notes_text.pkl"))
    notes_df = notes_df[notes_df['stay_id'].notnull()]

    notes_df = notes_df[notes_df['hosp_time_delta'] >= 0]
    if restrict_48_hours:
        notes_df = notes_df[notes_df['hosp_time_delta'] <= 48]

if include_cxr:
    cxr_df = pd.read_pickle(os.path.join(output_dir, "cxr_embeddings_stay.pkl"))
    print(cxr_df.columns)
    cxr_df = cxr_df[cxr_df['hosp_time_delta'] >= 0]
    if restrict_48_hours:
        cxr_df = cxr_df[cxr_df['hosp_time_delta'] <= 48]

if include_ecg:
    ecg_df = pd.read_pickle(os.path.join(output_dir, "ecg_embeddings_hosp.pkl"))
    ecg_df = ecg_df[ecg_df['hosp_time_delta'] >= 0]
    if restrict_48_hours:
        ecg_df = ecg_df[ecg_df['hosp_time_delta'] <= 48]

Index(['note_id', 'subject_id', 'hadm_id', 'note_type', 'note_seq',
       'charttime', 'storetime', 'clinic_text', 'stay_id', 'icu_time_delta',
       'hosp_time_delta', 'text'],
      dtype='object')
Index(['dicom_id', 'subject_id', 'study_id',
       'PerformedProcedureStepDescription', 'ViewPosition',
       'ProcedureCodeSequence_CodeMeaning', 'ViewCodeSequence_CodeMeaning',
       'PatientOrientationCodeSequence_CodeMeaning', 'densefeatures',
       'predictions', 'cxrtime', 'hadm_id', 'stay_id', 'icu_time_delta',
       'hosp_time_delta'],
      dtype='object')


In [34]:
print("Number of stays in ecg: ", ecg_df['stay_id'].nunique())
print("Number of stays in cxr: ", cxr_df['stay_id'].nunique())
print("Number of stays in notes: ", notes_df['stay_id'].nunique())

Number of stays in ecg:  25560
Number of stays in cxr:  10232
Number of stays in notes:  2783


In [35]:
print("Number of subjects in ecg: ", ecg_df['subject_id'].nunique())
print("Number of subjects in cxr: ", cxr_df['subject_id'].nunique())
print("Number of subjects in notes: ", notes_df['subject_id'].nunique())

Number of subjects in ecg:  22077
Number of subjects in cxr:  21727
Number of subjects in notes:  2671


In [11]:
icustays_df = pd.read_csv(os.path.join(mimic_iv_path, "icu", "icustays.csv.gz"), low_memory=False)
icustays_df['intime'] = pd.to_datetime(icustays_df['intime'])
icustays_df['outtime'] = pd.to_datetime(icustays_df['outtime'])

if restrict_48_hours:
    icustays_df = icustays_df[icustays_df['los'] <= 2]

In [12]:
valid_stay_ids = icustays_df['stay_id'].unique()

ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'].isin(valid_stay_ids)]
imputed_vitals = imputed_vitals[imputed_vitals['stay_id'].isin(valid_stay_ids)]

if include_notes:
    notes_df = notes_df[notes_df['stay_id'].isin(valid_stay_ids)]

if include_cxr:
    cxr_df = cxr_df[cxr_df['stay_id'].isin(valid_stay_ids)]

if include_ecg:
    ecg_df = ecg_df[ecg_df['stay_id'].isin(valid_stay_ids)]

In [13]:
print("Number of stays in ecg: ", ecg_df['stay_id'].nunique())
print("Number of stays in cxr: ", cxr_df['stay_id'].nunique())
print("Number of stays in notes: ", notes_df['stay_id'].nunique())

Number of stays in ecg:  11996
Number of stays in cxr:  3945
Number of stays in notes:  2299


In [14]:
print(ireg_vitals_ts_df.shape)

(1431987, 35)


In [15]:
admissions_df = pd.read_csv(os.path.join(mimic_iv_path, "hosp", "admissions.csv.gz"))
admissions_df = admissions_df.rename(columns={"hospital_expire_flag": "died"})
admissions_df = admissions_df[["subject_id", "hadm_id", "died"]]

In [16]:

if not include_missing:
    unique_stays = ireg_vitals_ts_df['stay_id'].unique()
    print(f"Number of stays with vitals: {len(unique_stays)}")

    if include_notes:
        unique_stays = np.intersect1d(unique_stays, notes_df['stay_id'].unique())
        print(f"Number of stays with notes: {len(unique_stays)}")

    if include_cxr:
        unique_stays = np.intersect1d(unique_stays, cxr_df['stay_id'].unique())
        print(f"Number of stays with cxr: {len(unique_stays)}")

    if include_ecg:
        unique_stays = np.intersect1d(unique_stays, ecg_df['stay_id'].unique())
        print(f"Number of stays with ecg: {len(unique_stays)}")
else:
    unique_stays = ireg_vitals_ts_df['stay_id'].unique()
    print(f"Number of stays with vitals: {len(unique_stays)}")

    if include_notes:
        # Get stays with either TS or notes
        unique_stays = np.union1d(unique_stays, notes_df['stay_id'].unique())
        print(f"Number of stays with either TS or notes: {len(unique_stays)}")
    
    if include_cxr:
        unique_stays = np.union1d(unique_stays, cxr_df['stay_id'].unique())
        print(f"Number of stays with either TS, notes, cxr: {len(unique_stays)}")
    
    if include_ecg:
        unique_stays = np.union1d(unique_stays, ecg_df['stay_id'].unique())
        print(f"Number of stays with either TS, notes, cxr, ecg: {len(unique_stays)}")

Number of stays with vitals: 31836
Number of stays with either TS or notes: 31836
Number of stays with either TS, notes, cxr: 31837
Number of stays with either TS, notes, cxr, ecg: 31837


In [32]:
unique_stays = ireg_vitals_ts_df['stay_id'].unique()
unique_stays = np.intersect1d(unique_stays, notes_df['stay_id'].unique())
print(f"Number of stays with both vitals and notes: {len(unique_stays)}")

Number of stays with both vitals and notes: 2299


In [30]:
# Create train, val, test splits
np.random.seed(0)
np.random.shuffle(unique_stays)
train_stays = unique_stays[:int(0.7*len(unique_stays))]
val_stays = unique_stays[int(0.7*len(unique_stays)):int(0.85*len(unique_stays))]
test_stays = unique_stays[int(0.85*len(unique_stays)):]

In [31]:
print(len(train_stays), len(val_stays), len(test_stays))

22285 4775 4776


In [25]:
train_ireg_ts_df = irregular_vitals_df[irregular_vitals_df['stay_id'].isin(train_stays)].copy()
train_imputed_df = imputed_vitals[imputed_vitals['stay_id'].isin(train_stays)].copy()

cols = train_ireg_ts_df.columns.tolist()
cols = [col for col in cols if col not in ['subject_id', 'hadm_id', 'stay_id', 'timedelta']]

if standard_scale:
    for col in cols:
        scaler = StandardScaler()
        scaler.fit(train_ireg_ts_df[[col]])
        irregular_vitals_df[col] = scaler.transform(irregular_vitals_df[[col]])

        scaler = StandardScaler()
        scaler.fit(train_imputed_df[[col]])
        imputed_vitals[col] = scaler.transform(imputed_vitals[[col]])


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  irregular_vitals_df[col] = scaler.transform(irregular_vitals_df[[col]])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  irregular_vitals_df[col] = scaler.transform(irregular_vitals_df[[col]])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  irregular_vitals_df[col] = scaler.transform(irregular_vitals_

In [26]:
def get_stay_list(stays):
    stays_list = []

    for curr_stay in tqdm(stays):
        
        curr_stay_ireg = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'] == curr_stay].copy()
        #print(f"Initial irregular time series shape (before dropping columns): {curr_stay_ireg.shape}")
        curr_stay_imputed = imputed_vitals[imputed_vitals['stay_id'] == curr_stay].copy()

        if len(curr_stay_ireg) == 0:
            continue

        curr_stay_dict = {}
        curr_stay_dict['name'] = curr_stay_ireg['subject_id'].iloc[0]
        curr_stay_dict['hadm_id'] = curr_stay_ireg['hadm_id'].iloc[0]
        curr_stay_dict['stay_id'] = curr_stay
        curr_stay_dict['ts_tt'] = curr_stay_ireg['hosp_time_delta'].values

        curr_stay_ireg.drop(columns=['subject_id', 'hadm_id', 'stay_id', 'hosp_time_delta'], inplace=True)
        ireg_ts_mask = curr_stay_ireg.notnull()
        curr_stay_ireg.fillna(0, inplace=True)
        curr_stay_dict['irg_ts'] = curr_stay_ireg.values
        curr_stay_dict['irg_ts_mask'] = ireg_ts_mask.values.astype(int)

        curr_stay_imputed.drop(columns=['subject_id', 'hadm_id', 'stay_id', 'hosp_time_delta'], inplace=True)
        curr_stay_dict['reg_ts'] = curr_stay_imputed.values

        if include_notes:
            curr_stay_notes = notes_df[notes_df['stay_id'] == curr_stay].copy()

            if len(curr_stay_notes) == 0:
                curr_stay_dict['text_data'] = []
                curr_stay_dict['text_time_to_end'] = []
                curr_stay_dict['text_embeddings'] = []
                curr_stay_dict['text_missing'] = 1
            else:
                curr_stay_dict['text_data'] = curr_stay_notes['text'].tolist()
                curr_stay_dict['text_time_to_end'] = curr_stay_notes['hosp_time_delta'].values
                curr_stay_dict['text_embeddings'] = [emb[0][0] for emb in curr_stay_notes['biobert_embeddings']]
                curr_stay_dict['text_missing'] = 0

        if include_cxr:
            curr_stay_cxr = cxr_df[cxr_df['stay_id'] == curr_stay].copy()
            
            if len(curr_stay_cxr) == 0:
                curr_stay_dict['cxr_feats'] = []
                curr_stay_dict['cxr_time'] = []
                curr_stay_dict['cxr_missing'] = 1
            else:
                curr_stay_dict['cxr_feats'] = curr_stay_cxr['densefeatures'].tolist()
                curr_stay_dict['cxr_time'] = curr_stay_cxr['hosp_time_delta'].values
                curr_stay_dict['cxr_missing'] = 0
                #print(f"Number of CXR features: {len(curr_stay_dict['cxr_feats'])}")

        if include_ecg:
            curr_stay_ecg = ecg_df[ecg_df['stay_id'] == curr_stay].copy()
            if len(curr_stay_ecg) == 0:
                curr_stay_dict['ecg_feats'] = []
                curr_stay_dict['ecg_time'] = []
                curr_stay_dict['ecg_missing'] = 1
            else:
                curr_stay_dict['ecg_feats'] = curr_stay_ecg['embeddings'].tolist()
                curr_stay_dict['ecg_time'] = curr_stay_ecg['hosp_time_delta'].values
                curr_stay_dict['ecg_missing'] = 0

        curr_stay_dict['label'] = admissions_df[admissions_df['hadm_id'] == curr_stay_dict['hadm_id']]['died'].iloc[0]

        stays_list.append(curr_stay_dict)

    return stays_list



In [27]:
train_stays_list = get_stay_list(train_stays)
val_stays_list = get_stay_list(val_stays)
test_stays_list = get_stay_list(test_stays)

  0%|          | 0/1609 [00:00<?, ?it/s]


KeyError: 'biobert_embeddings'

In [33]:
# Example to check the first patient stay
first_stay = train_stays[18]
first_stay_data = get_stay_list([first_stay])


100%|██████████| 1/1 [00:00<00:00,  6.98it/s]


In [39]:
print(first_stay_data)

[{'name': 15257559, 'hadm_id': 20095131.0, 'stay_id': 39871108, 'ts_tt': array([0.0033333333333333335, 0.02, 0.08666666666666667, 0.27, 0.47, 0.77,
       1.77, 1.7866666666666666, 2.77, 2.7866666666666666, 3.77,
       3.7866666666666666, 3.8033333333333332, 4.1866666666666665, 4.77,
       5.77, 6.77, 6.786666666666667, 7.77, 7.803333333333334, 8.77, 9.77,
       9.786666666666667, 10.77, 10.786666666666667, 11.77,
       11.786666666666667, 11.803333333333333, 12.77, 12.786666666666667,
       12.936666666666667, 12.97, 13.77, 13.786666666666667, 14.77,
       14.803333333333333, 15.77, 15.786666666666667, 16.77,
       16.786666666666665, 17.186666666666667, 17.27, 17.586666666666666,
       17.77, 17.786666666666665, 18.77, 18.786666666666665, 19.77,
       19.786666666666665, 20.77, 20.786666666666665, 21.77,
       21.786666666666665, 22.77, 22.786666666666665, 23.77,
       23.786666666666665, 24.236666666666668, 24.77, 24.803333333333335,
       25.77, 25.786666666666665, 26.7

In [34]:
# Save the data
import pickle

base_name = "clinic_ihm"
if restrict_48_hours:
    base_name += "-48"
else:
    base_name += "-all"

if include_cxr:
    if include_notes:
        base_name += "-cxr-notes"
    else:
        base_name += "-cxr"

if include_ecg:
    base_name += "-ecg"

if include_missing:
    base_name += "-missingInd"

f_path = os.path.join(output_dir, f"train_{base_name}_stays.pkl")
with open(f_path, 'wb') as f:
    print(f"Saving train stays to {f_path}")
    pickle.dump(train_stays_list, f)

f_path = os.path.join(output_dir, f"val_{base_name}_stays.pkl")
with open(f_path, 'wb') as f:
    print(f"Saving val stays to {f_path}")
    pickle.dump(val_stays_list, f)

f_path = os.path.join(output_dir, f"test_{base_name}_stays.pkl")
with open(f_path, 'wb') as f:
    print(f"Saving test stays to {f_path}")
    pickle.dump(test_stays_list, f)


Saving train stays to /data/wang/junh/datasets/multimodal/preprocessing/train_clinic_ihm-48-cxr-notes-ecg-missingInd_stays.pkl
Saving val stays to /data/wang/junh/datasets/multimodal/preprocessing/val_clinic_ihm-48-cxr-notes-ecg-missingInd_stays.pkl
Saving test stays to /data/wang/junh/datasets/multimodal/preprocessing/test_clinic_ihm-48-cxr-notes-ecg-missingInd_stays.pkl


In [36]:
# Save the data
import pickle
file_path = '/data/wang/junh/datasets/multimodal/preprocessing/train_clinic_ihm-48-cxr-notes-ecg-missingInd_stays.pkl'
with open(file_path, 'rb') as f:
    train_stays_list = pickle.load(f)
print(len(train_stays_list))

672


In [37]:
print(train_stays_list[0].keys())
print(train_stays_list[0]['irg_ts'].shape)

dict_keys(['name', 'hadm_id', 'stay_id', 'ts_tt', 'irg_ts', 'irg_ts_mask', 'reg_ts', 'text_data', 'text_time_to_end', 'text_embeddings', 'text_missing', 'cxr_feats', 'cxr_time', 'cxr_missing', 'ecg_feats', 'ecg_time', 'ecg_missing', 'label'])
(69, 30)
