In [6]:
# autoreload
%load_ext autoreload
%autoreload 2

In [7]:
import os
import sys
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm

In [8]:
mimic_iv_path = "/data/wang/junh/datasets/physionet.org/files/mimiciv/2.2"
mm_dir = "/data/wang/junh/datasets/multimodal"

output_dir = os.path.join(mm_dir, "preprocessing")

In [9]:
include_notes = True
include_cxr = True
standard_scale = True
include_missing = True
include_ecg = True

In [10]:
ireg_vitals_ts_df = pd.read_pickle(os.path.join(output_dir, "ts_labs_vitals_new.pkl"))
ireg_vitals_ts_df = ireg_vitals_ts_df.drop(columns=["hosp_time_delta"])
ireg_vitals_ts_df = ireg_vitals_ts_df.rename(columns={"icu_time_delta": "timedelta"})
imputed_vitals = pd.read_pickle(os.path.join(output_dir, "imputed_ts_labs_vitals_new.pkl"))

# ireg_vitals_ts_df = pd.read_pickle(os.path.join(output_dir, "ts_labs_vitals_icu.pkl"))
# imputed_vitals = pd.read_pickle(os.path.join(output_dir, "imputed_ts_labs_vitals_icu.pkl"))

ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['timedelta'] >= 0]
imputed_vitals = imputed_vitals[imputed_vitals['timedelta'] >= 0]

ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['timedelta'] <= 48]
imputed_vitals = imputed_vitals[imputed_vitals['timedelta'] <= 48]


In [11]:
if include_notes:
    notes_df = pd.read_pickle(os.path.join(output_dir, "icu_notes_text_embeddings.pkl"))
    # notes_df = pd.read_pickle(os.path.join(output_dir, "notes_text.pkl"))
    notes_df = notes_df[notes_df['stay_id'].notnull()]

    notes_df = notes_df[notes_df['icu_time_delta'] >= 0]
    notes_df = notes_df[notes_df['icu_time_delta'] <= 48]

if include_cxr:
    cxr_df = pd.read_pickle(os.path.join(output_dir, "cxr_embeddings_stay.pkl"))
    cxr_df = cxr_df[cxr_df['icu_time_delta'] >= 0]
    cxr_df = cxr_df[cxr_df['icu_time_delta'] <= 48]

if include_ecg:
    ecg_df = pd.read_pickle(os.path.join(output_dir, "ecg_embeddings_icu.pkl"))
    ecg_df = ecg_df[ecg_df['icu_time_delta'] >= 0]
    ecg_df = ecg_df[ecg_df['icu_time_delta'] <= 48]

In [12]:
icustays_df = pd.read_csv(os.path.join(mimic_iv_path, "icu", "icustays.csv.gz"), low_memory=False)
icustays_df['intime'] = pd.to_datetime(icustays_df['intime'])
icustays_df['outtime'] = pd.to_datetime(icustays_df['outtime'])

icustays_df = icustays_df[icustays_df['los'] >= 2]

In [13]:
valid_stay_ids = icustays_df['stay_id'].unique()

ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'].isin(valid_stay_ids)]
imputed_vitals = imputed_vitals[imputed_vitals['stay_id'].isin(valid_stay_ids)]

if include_notes:
    notes_df = notes_df[notes_df['stay_id'].isin(valid_stay_ids)]

if include_cxr:
    cxr_df = cxr_df[cxr_df['stay_id'].isin(valid_stay_ids)]

if include_ecg:
    ecg_df = ecg_df[ecg_df['stay_id'].isin(valid_stay_ids)]

In [14]:
admissions_df = pd.read_csv(os.path.join(mimic_iv_path, "hosp", "admissions.csv.gz"))
admissions_df = admissions_df.rename(columns={"hospital_expire_flag": "died"})
admissions_df = admissions_df[["subject_id", "hadm_id", "died"]]

In [15]:
if not include_missing:
    unique_stays = ireg_vitals_ts_df['stay_id'].unique()
    print(f"Number of stays with vitals: {len(unique_stays)}")

    if include_notes:
        unique_stays = np.intersect1d(unique_stays, notes_df['stay_id'].unique())
        print(f"Number of stays with notes: {len(unique_stays)}")

    if include_cxr:
        unique_stays = np.intersect1d(unique_stays, cxr_df['stay_id'].unique())
        print(f"Number of stays with cxr: {len(unique_stays)}")

    if include_ecg:
        unique_stays = np.intersect1d(unique_stays, ecg_df['stay_id'].unique())
        print(f"Number of stays with ecg: {len(unique_stays)}")        
else:
    unique_stays = ireg_vitals_ts_df['stay_id'].unique()
    print(f"Number of stays with vitals: {len(unique_stays)}")

    if include_notes:
        # Get stays with either TS or notes
        unique_stays = np.union1d(unique_stays, notes_df['stay_id'].unique())
        print(f"Number of stays with either TS or notes: {len(unique_stays)}")
    
    if include_cxr:
        unique_stays = np.union1d(unique_stays, cxr_df['stay_id'].unique())
        print(f"Number of stays with either TS, notes, cxr: {len(unique_stays)}")

    if include_ecg:
        unique_stays = np.union1d(unique_stays, ecg_df['stay_id'].unique())
        print(f"Number of stays with either TS, notes, cxr, ecg: {len(unique_stays)}")

Number of stays with vitals: 35131
Number of stays with either TS or notes: 35131
Number of stays with either TS, notes, cxr: 35131
Number of stays with either TS, notes, cxr, ecg: 35131


In [16]:
# stays_with_labels = []
# stay_to_hadm_mapping = dict(zip(ireg_vitals_ts_df['stay_id'], ireg_vitals_ts_df['hadm_id']))

# for stay in unique_stays:
#     curr_stay_dict = {}
#     curr_stay_dict['stay_id'] = stay

#     # Retrieve hadm_id using the mapping
#     if stay in stay_to_hadm_mapping:
#         curr_stay_dict['hadm_id'] = stay_to_hadm_mapping[stay]
#         # Retrieve the label (e.g., 'died') from admissions_df
#         curr_stay_dict['label'] = admissions_df[
#             admissions_df['hadm_id'] == curr_stay_dict['hadm_id']
#         ]['died'].iloc[0]

#     stays_with_labels.append(curr_stay_dict)

# stays_df = pd.DataFrame(stays_with_labels)

In [17]:
# groups = stays_df.groupby('label')
# train_stays, val_stays, test_stays = [], [], []

# np.random.seed(0)  # Ensure reproducibility

# for label, group in groups:
#     shuffled_group = group.sample(frac=1, random_state=0)  # Shuffle
#     n = len(shuffled_group)
    
#     # Calculate split indices
#     train_end = int(0.7 * n)
#     val_end = int(0.85 * n)
    
#     # Append stratified splits
#     train_stays.extend(shuffled_group.iloc[:train_end].to_dict('records'))
#     val_stays.extend(shuffled_group.iloc[train_end:val_end].to_dict('records'))
#     test_stays.extend(shuffled_group.iloc[val_end:].to_dict('records'))

In [18]:
#Create train, val, test splits
np.random.seed(0)
np.random.shuffle(unique_stays)
train_stays = unique_stays[:int(0.7*len(unique_stays))]
val_stays = unique_stays[int(0.7*len(unique_stays)):int(0.85*len(unique_stays))]
test_stays = unique_stays[int(0.85*len(unique_stays)):]

In [19]:
# # Convert train_stays, val_stays, and test_stays dictionaries back to unique_stays format
# train_stays = [stay['stay_id'] for stay in train_stays]
# val_stays = [stay['stay_id'] for stay in val_stays]
# test_stays = [stay['stay_id'] for stay in test_stays]


In [20]:
train_ireg_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'].astype(int).isin(train_stays)].copy()
train_imputed_df = imputed_vitals[imputed_vitals['stay_id'].astype(int).isin(train_stays)].copy()

cols = train_ireg_ts_df.columns.tolist()
cols = [col for col in cols if col not in ['subject_id', 'hadm_id', 'stay_id', 'timedelta']]

if standard_scale:
    for col in cols:
        scaler = StandardScaler()
        scaler.fit(train_ireg_ts_df[[col]])
        ireg_vitals_ts_df[col] = scaler.transform(ireg_vitals_ts_df[[col]])

        scaler = StandardScaler()
        scaler.fit(train_imputed_df[[col]])
        imputed_vitals[col] = scaler.transform(imputed_vitals[[col]])


In [21]:
def get_stay_list(stays):
    stays_list = []

    for curr_stay in tqdm(stays):
        curr_stay_ireg = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'] == curr_stay].copy()
        curr_stay_imputed = imputed_vitals[imputed_vitals['stay_id'] == curr_stay].copy()
        curr_icustay = icustays_df[icustays_df['stay_id'] == curr_stay].copy()

        try:
            curr_hadm_id = curr_stay_ireg['hadm_id'].iloc[0]
            died = admissions_df[admissions_df['hadm_id'] == curr_hadm_id]['died'].iloc[0]
        except:
            print("error!")
            continue

        intime = icustays_df[icustays_df['stay_id'] == curr_stay]['intime'].iloc[0]
        outtime = icustays_df[icustays_df['stay_id'] == curr_stay]['outtime'].iloc[0]
        timedelta = (outtime - intime).total_seconds() / 3600 / 24  # Calculate the duration in days

        # Assign labels based on the duration of stay
        if timedelta <= 3:
            label = 0
        elif timedelta <= 7:
            label = 1
        elif timedelta <= 14:
            label = 2
        else:
            label = 3

        # intime = icustays_df[icustays_df['stay_id'] == curr_stay]['intime'].iloc[0]
        # outtime = icustays_df[icustays_df['stay_id'] == curr_stay]['outtime'].iloc[0]
        # timedelta = (outtime - intime).total_seconds() / 3600

        # if died == 1:
        #     pass

        # if (timedelta < 96) & (died == 0):
        #     label = 1
        # else:
        #     label = 0

        if len(curr_stay_ireg) == 0:
            continue

        if include_notes:
            curr_stay_notes = notes_df[notes_df['stay_id'] == curr_stay].copy()

        if include_cxr:
            curr_stay_cxr = cxr_df[cxr_df['stay_id'] == curr_stay].copy()

        curr_stay_dict = {}
        curr_stay_dict['name'] = curr_stay_ireg['subject_id'].iloc[0]
        curr_stay_dict['hadm_id'] = curr_stay_ireg['hadm_id'].iloc[0]
        curr_stay_dict['stay_id'] = curr_stay
        curr_stay_dict['ts_tt'] = curr_stay_ireg['timedelta'].values

        # Save the feature names before dropping unnecessary columns
        feature_names = curr_stay_ireg.drop(columns=['subject_id', 'hadm_id', 'stay_id', 'timedelta']).columns.tolist()
        curr_stay_ireg.drop(columns=['subject_id', 'hadm_id', 'stay_id', 'timedelta'], inplace=True)
        ireg_ts_mask = curr_stay_ireg.notnull()
        curr_stay_ireg.fillna(0, inplace=True)
        curr_stay_dict['irg_ts'] = curr_stay_ireg.values
        curr_stay_dict['irg_ts_mask'] = ireg_ts_mask.values.astype(int)

        curr_stay_imputed.drop(columns=['subject_id', 'hadm_id', 'stay_id', 'timedelta'], inplace=True)
        curr_stay_dict['reg_ts'] = curr_stay_imputed.values

        if include_notes:
            curr_stay_notes = notes_df[notes_df['stay_id'] == curr_stay].copy()
            if len(curr_stay_notes) == 0:
                curr_stay_dict['text_data'] = []
                curr_stay_dict['text_time_to_end'] = []
                curr_stay_dict['text_embeddings'] = []
                curr_stay_dict['text_missing'] = 1
            else:
                curr_stay_dict['text_data'] = curr_stay_notes['text'].tolist()
                curr_stay_dict['text_time_to_end'] = curr_stay_notes['icu_time_delta'].values
                curr_stay_dict['text_embeddings'] = [emb[0][0] for emb in curr_stay_notes['biobert_embeddings']]
                curr_stay_dict['text_missing'] = 0

        if include_cxr:
            curr_stay_cxr = cxr_df[cxr_df['stay_id'] == curr_stay].copy()
            if curr_stay_cxr.empty:
                curr_stay_dict['cxr_metadata'] = []
                curr_stay_dict['image_paths'] = []
                curr_stay_dict['cxr_missing'] = 1
            else:
                curr_stay_dict['cxr_feats'] = curr_stay_cxr['densefeatures'].tolist()
                metadata = curr_stay_cxr[['dicom_id', 'PerformedProcedureStepDescription', 'ViewPosition', 'ProcedureCodeSequence_CodeMeaning', 'ViewCodeSequence_CodeMeaning', 'PatientOrientationCodeSequence_CodeMeaning']].to_dict('records')
                curr_stay_dict['cxr_metadata'] = metadata
                image_paths = curr_stay_cxr.apply(lambda row: os.path.join('/data/wang/junh/datasets/physionet.org/files/mimic-cxr-jpg/2.0.0/mimic-cxr-jpg-2.1.0.physionet.org/files', f"p{str(row['subject_id'])[:2]}", f"p{row['subject_id']}", f"s{row['study_id']}", f"{row['dicom_id']}.jpg"), axis=1).tolist()
                curr_stay_dict['image_paths'] = image_paths
                curr_stay_dict['cxr_time'] = curr_stay_cxr['icu_time_delta'].values
                curr_stay_dict['cxr_missing'] = 0

        if include_ecg and ecg_df is not None:
            curr_stay_ecg = ecg_df[ecg_df['stay_id'] == curr_stay].copy()
            if len(curr_stay_ecg) == 0:
                curr_stay_dict['ecg_feats'] = []
                curr_stay_dict['ecg_time'] = []
                curr_stay_dict['ecg_missing'] = 1
            else:
                curr_stay_dict['ecg_feats'] = curr_stay_ecg['embeddings'].tolist()
                curr_stay_dict['ecg_time'] = curr_stay_ecg['icu_time_delta'].values
                curr_stay_dict['ecg_missing'] = 0
        
        curr_stay_dict['label'] = label
        stays_list.append(curr_stay_dict)

    return stays_list


In [22]:
train_stays_list = get_stay_list(train_stays)
val_stays_list = get_stay_list(val_stays)
test_stays_list = get_stay_list(test_stays)

100%|██████████| 24591/24591 [1:04:00<00:00,  6.40it/s]
100%|██████████| 5270/5270 [13:43<00:00,  6.40it/s]
100%|██████████| 5270/5270 [13:44<00:00,  6.39it/s]


In [28]:
print(train_stays_list[0].keys())
print(train_stays_list[0]['irg_ts'].shape)

dict_keys(['name', 'hadm_id', 'stay_id', 'ts_tt', 'irg_ts', 'irg_ts_mask', 'reg_ts', 'text_data', 'text_time_to_end', 'text_embeddings', 'text_missing', 'cxr_metadata', 'image_paths', 'cxr_missing', 'ecg_feats', 'ecg_time', 'ecg_missing', 'label'])
(116, 30)


In [29]:
# Save the data
import pickle

output_dir = os.path.join(mm_dir, "multiclass/")

restrict_48_hours = True

base_name = "los"
if restrict_48_hours:
    base_name += "-48"
else:
    base_name += "-all"

if include_cxr:
    if include_notes:
        base_name += "-cxr-notes"
    else:
        base_name += "-cxr"

if include_ecg:
    base_name += "-ecg"

if include_missing:
    base_name += "-missingInd"

f_path = os.path.join(output_dir, f"train_{base_name}_stays.pkl")
with open(f_path, 'wb') as f:
    print(f"Saving train stays to {f_path}")
    pickle.dump(train_stays_list, f)

f_path = os.path.join(output_dir, f"val_{base_name}_stays.pkl")
with open(f_path, 'wb') as f:
    print(f"Saving val stays to {f_path}")
    pickle.dump(val_stays_list, f)

f_path = os.path.join(output_dir, f"test_{base_name}_stays.pkl")
with open(f_path, 'wb') as f:
    print(f"Saving test stays to {f_path}")
    pickle.dump(test_stays_list, f)


Saving train stays to /data/wang/junh/datasets/multimodal/multiclass/train_los-48-cxr-notes-ecg-missingInd_stays.pkl
Saving val stays to /data/wang/junh/datasets/multimodal/multiclass/val_los-48-cxr-notes-ecg-missingInd_stays.pkl
Saving test stays to /data/wang/junh/datasets/multimodal/multiclass/test_los-48-cxr-notes-ecg-missingInd_stays.pkl


In [30]:
# Save the data
import pickle
file_path = '/data/wang/junh/datasets/multimodal/multiclass/train_los-48-cxr-notes-ecg-missingInd_stays.pkl'
with open(file_path, 'rb') as f:
    train_stays_list = pickle.load(f)
print(len(train_stays_list))

24591


In [32]:
labels = []
for data in train_stays_list:
    labels.append(data['label'])

labels.count(0), labels.count(1), labels.count(2), labels.count(3)

(8392, 10618, 3719, 1862)

In [3]:
train_stays_list[0]['TS_weight']

KeyError: 'TS_weights'