In [None]:
from datetime import datetime
import os
import math
import pickle
import pandas as pd
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from bidict import bidict
import scipy
from scipy.signal import savgol_filter

matplotlib.rcParams['figure.dpi'] = 300

In [None]:
# If True, prepare the training data including filtering for common values etc. 
PREPARE_TRAINING_DATA = True

# If True, prepare the validation and testing datasets (using bidicts generated using training data)
PREPARE_VALI_TEST_DATA = True

# If True, make separate NOTEEVENTS.csv subset csvs for training, test, and vali
MAKE_SEPARATE_NOTE_CSVS = True

# If True, regenerate the dictionary of keywords for clinical notes (takes ~30-40 minutes)
BUILD_NOTES_DICTIONARY = True

# Combine the actual data into mixehr formatted text files
COMBINE_DATA = True

# Helper functions

In [None]:
# Helper functions

def only_these_subj_and_hadm_ids(df, subj_ids, hadm_ids):
    df = df[df['SUBJECT_ID'].isin(subj_ids)]
    df = df[df['HADM_ID'].isin(hadm_ids)]
    return df
              

# Restrict to one hadm_id for each patient. 
def only_one_hadm_per_subj(df):
    # Group by 'subject_id' and select one random 'hadm_id'
    np.random.seed(0)
    random_hadm_ids = df.groupby('SUBJECT_ID')['HADM_ID'].apply(lambda x: np.random.choice(x))

    # Filter the dataframe to keep only the rows with the selected 'hadm_id' for each 'subject_id'
    filtered_df = df[df['HADM_ID'].isin(random_hadm_ids)]
    
    return filtered_df


# Keep only items within "secs" seconds of the start of the patient's ICU stay. 
def within_x_secs_of_admit(df, secs, icustays, col_name="CHARTTIME"):
    df = df[df[col_name].notna()]
    
    rgt = '%Y-%m-%d %H:%M:%S'
    df["keep"] = df.apply(lambda row: \
        (datetime.strptime(row[col_name], rgt) \
        - datetime.strptime(icustays[icustays["HADM_ID"] == row["HADM_ID"]].iloc[0]["INTIME"], rgt)).total_seconds() < secs, axis=1)
    
    df = df[df["keep"]]
    df = df.drop(labels=["keep"], axis=1)
    
    return df


# Each df should have a patientId, typeId, pheId, and stateId column. 
# This function adds a "freq" column with a count, for each patientId and typeId, 
# of how many times a given pheId occurred along with a given stateId. 
# Essentially, duplicates are removed, and each unique row now has a 
# count of how many duplicates were in the dataframe before. 
def get_mixehr_format_with_freqs(df):
    grpd = df.groupby(["patientId", "typeId", "pheId", "stateId"]).size().to_frame("freq").reset_index()
    return grpd[["patientId", "typeId", "pheId", "stateId", "freq"]]


def prepare_vali_test_sets(df, mapper, pheIdCol, idset):
    df["pheId"] = df[pheIdCol].map(mapper.inverse)
    df = df.dropna(subset=["pheId"])
    df["pheId"] = df["pheId"].astype(int)
    return df

# Create a dataframe for the metadata file needed by mixehr. 
# Input mixehr_df should have columns "patientId", "typeId", "pheId", "stateId", and "freq"
def get_mixehr_metadata(mixehr_df):
    meta_df = mixehr_df[["typeId", "pheId", "stateId"]].drop_duplicates().reset_index(drop=True)
    meta_df = meta_df.groupby(["typeId", "pheId"]).size().to_frame("stateCnt").reset_index()
    for typeId in meta_df["typeId"].unique():
        meta_df.loc[meta_df["typeId"] == typeId, ["stateCnt"]] = meta_df[meta_df["typeId"] == typeId]["stateCnt"].max()
    return meta_df


# Write a df to a .txt file in mixehr format.  
# Must specify if it is a metadata file, otherwise
# it is assumed to be a data file. 
def mixehr_df_to_txt(mixehr_df, file_name, metadata=False, prepend_n_patients=False):
    try: os.remove(file_name)
    except OSError: pass
    with open(file_name, 'a') as f:
        if not metadata and prepend_n_patients:
            # The first line is always the number of patients for some reason
            f.write(str(mixehr_df["patientId"].nunique()) + "\n")
        np.savetxt(f, mixehr_df.to_numpy(), delimiter=' ', fmt='%s')

In [None]:
# More helper functions (for filtering out common values with inflection points)

# Get inflection point from frequency values.
# IMPORTANT NOTE: the array "arr" MUST be sorted in descending order. 
def get_inflection_point(arr, plotlabel="", plot=True, deriv=1, sg_window=31, sg_order=2, n_pts=1, rare_ind=None):
    
    # Smooth the curve using a Savitzky-Golay filter
    smoothed = savgol_filter(arr, window_length=sg_window, polyorder=sg_order, mode='interp')

    if deriv == 1:
        # Calculate 1st derivative of smoothed curve
        smoothed_1st_deriv = savgol_filter(arr, window_length=sg_window, polyorder=sg_order, mode='interp', deriv=1)
        
        # Find "inflection point" based on a heuristic using the gradient
        try: maxi = arr[0]
        except KeyError: maxi = arr.iloc[0]
        infl_slope = -maxi/(0.5*len(arr))
        inflection_ind = np.where(smoothed_1st_deriv > infl_slope)[0][0] + 1
        inflection_val = smoothed[inflection_ind]
    elif deriv == 2:
        # Calculate second derivative of smoothed curve
        smoothed_second_deriv = savgol_filter(arr, window_length=sg_window, polyorder=sg_order, mode='interp', deriv=2)

        # Find inflection point where second derivative of smoothed curve is near zero
        inflection_ind = np.where(np.abs(smoothed_second_deriv) <= 1e-3)[0][0] + 1
        inflection_val = smoothed[inflection_ind]
    else:
        raise ValueError("deriv must be 1 or 2")
        
    if plot:
        plt.figure(figsize=(3,2))
        x = np.arange(len(arr), dtype=float)
        plt.plot(x, arr/n_pts, label="Raw frequencies")
        plt.plot(x, smoothed/n_pts, "--", label="Savitzky-Golay")
        plt.plot([0, len(arr)], [inflection_val/n_pts, inflection_val/n_pts], label="Inflection")
        plt.plot([rare_ind, rare_ind], [0, 0.2*(arr.max()/n_pts)], label="Rarity cutoff")
        plt.legend()
        plt.title(plotlabel.lower() + " freqs")
        plt.tight_layout()
        plt.savefig(plotlabel.lower() + "_inflection.png")
        plt.show()
        
    return inflection_val, inflection_ind


# Given a dataframe with a column of interest (string 'col') (e.g. ICD-9 codes), 
# Generate a bidirectional dictionary (bidict) mapping from integers to the actual value. 
# If remove_common is True, remove the most common values using the frequency curve inflection point
# (similar to supplementary Figure 26 in original MixEHR paper)
def unique_mapping(df, col, remove_common=True, remove_rare=True, rare_k=3, plot=True, deriv=1, sg_window=31, sg_order=2):
    vcounts = df.groupby(col)["SUBJECT_ID"].nunique().sort_values(ascending=False)
    
    if remove_common or remove_rare:
        
        # Find out how many rare items we'll remove, for the plots
        rare_ind = np.where(vcounts < rare_k)[0][0]
        
        # Remove common values based on "inflection point" of frequency distribution
        if remove_common:
            inflection_val, inflection_ind = get_inflection_point(vcounts, plotlabel=col, sg_window=sg_window, sg_order=sg_order, n_pts=df["SUBJECT_ID"].nunique(), rare_ind=rare_ind)
            len_before = len(vcounts)
            vcounts = vcounts[inflection_ind:]
            print(len_before - len(vcounts), "common", col, "values removed from the original set of", len_before)

        # Remove rare values that occur in fewer than k patients.
        if remove_rare:
            rare_ind = np.where(vcounts < rare_k)[0][0]
            len_before = len(vcounts)
            vcounts = vcounts[:rare_ind]
            print(len_before - len(vcounts), "rare", col, "values removed from the original set of", len_before)
            
        keep = vcounts.index.tolist()
        orig_len = len(df)
        df = df[df[col].isin(keep)]
        print(orig_len - len(df), "rows with common or rare", col, "values removed of the original", orig_len, "rows")
        
    else:
        keep = vcounts.index.tolist()

    # Make bidict, a bidirectional dictionary
    mapping = bidict(dict(zip(range(1, len(keep)+1), keep)))
    
    df["pheId"] = df[col].map(mapping.inverse)
    
    return mapping, df

# Train/test/vali splits

In [None]:
### Train-vali-test split on subject_id

import random

# If the pickled lists of subject and training ids are available, just load them from pickle
all_files_found = True
for idtype in ("subj", "hadm"):
    for idset in ("train", "vali", "test"):
        filename = idset + "_" + idtype + "_ids_list.pkl"
        if os.path.isfile(filename): 
            exec(idset + "_" + idtype + "_ids = pickle.load(open('" + filename + "', 'rb'))")
        else:
            all_files_found = False

if all_files_found:
    print("Train, vali, and test ids loaded from pickled lists")
else: 
    random.seed(0)

    # Get our subject IDs from trajectories
    trajectories = pd.read_csv("./../../trajectories.csv")
    trajectories["SUBJECT_ID"] = trajectories["subject_id"]
    trajectories["HADM_ID"] = trajectories["hadm_id"]
    
    
    # Only keep icu stays 24 hours or longer
    icustays = pd.read_csv("./../../ICUSTAYS.csv")
    icustays = icustays[icustays["LOS"] >= 1.5] # Must have stayed at least 36 hours
    trajectories = trajectories[trajectories["HADM_ID"].isin(icustays["HADM_ID"])]

    trajectories = only_one_hadm_per_subj(trajectories)
    subj_ids = list(trajectories["SUBJECT_ID"].unique())
    
    # Remove a few very specific patients
    # (enough data for some conditions but not others, causes dimensionality issues)
    remove_these_patients = [64523, 41408, 4064, 18818, 2148, 19980, 6256, 86193, 59762, 886, 19872, 13437, 73833, 14469]
    for pt in remove_these_patients:
        try:
            subj_ids.remove(pt)
        except ValueError:
            print("Subj ID", pt, "not found")

    # Randomly select training, test, vali sets
    random.shuffle(subj_ids)
    train_frac = 0.8
    vali_frac = 0.1
    test_frac = 0.1
    train_subj_ids = subj_ids[0:int(train_frac*len(subj_ids))]
    vali_subj_ids = subj_ids[int(train_frac*len(subj_ids)):int((train_frac+vali_frac)*len(subj_ids))]
    test_subj_ids = subj_ids[int((train_frac+vali_frac)*len(subj_ids)):]

    train_subj_ids.sort()
    vali_subj_ids.sort()
    test_subj_ids.sort()

    with open("train_subj_ids_list.pkl", 'wb') as file: pickle.dump(train_subj_ids, file)
    with open("vali_subj_ids_list.pkl", 'wb') as file: pickle.dump(vali_subj_ids, file)
    with open("test_subj_ids_list.pkl", 'wb') as file: pickle.dump(test_subj_ids, file)

    train_hadm_ids = trajectories[trajectories["SUBJECT_ID"].isin(train_subj_ids)]["HADM_ID"].unique().tolist()
    vali_hadm_ids = trajectories[trajectories["SUBJECT_ID"].isin(vali_subj_ids)]["HADM_ID"].unique().tolist()
    test_hadm_ids = trajectories[trajectories["SUBJECT_ID"].isin(test_subj_ids)]["HADM_ID"].unique().tolist()

    train_hadm_ids.sort()
    vali_hadm_ids.sort()
    test_hadm_ids.sort()
    
    with open("train_hadm_ids_list.pkl", 'wb') as file: pickle.dump(train_hadm_ids, file)
    with open("vali_hadm_ids_list.pkl", 'wb') as file: pickle.dump(vali_hadm_ids, file)
    with open("test_hadm_ids_list.pkl", 'wb') as file: pickle.dump(test_hadm_ids, file)
    
print("N patients in training set: ", len(train_subj_ids))
print("N patients in validation set: ", len(vali_subj_ids))
print("N patients in test set: ", len(test_subj_ids))

In [None]:
# Sanity checks
assert len(set(train_hadm_ids)) == len(train_hadm_ids)
assert len(set(vali_hadm_ids)) == len(vali_hadm_ids)
assert len(set(test_hadm_ids)) == len(test_hadm_ids)

assert set(train_hadm_ids).isdisjoint(set(test_hadm_ids))
assert set(train_hadm_ids).isdisjoint(set(vali_hadm_ids))
assert set(vali_hadm_ids).isdisjoint(set(test_hadm_ids))

In [None]:
# from datetime import datetime

# icustays = only_these_subj_and_hadm_ids(pd.read_csv("./../ICUSTAYS.csv"), train_subj_ids, train_hadm_ids)
# icustays["intime_dt"] = icustays.apply(lambda row: datetime.strptime(row["INTIME"], '%Y-%m-%d %H:%M:%S'), axis=1)

# icustays["los_delta"] = icustays.apply(lambda row: (datetime.strptime(row["OUTTIME"], '%Y-%m-%d %H:%M:%S') - datetime.strptime(row["INTIME"], '%Y-%m-%d %H:%M:%S')).total_seconds()/(60*60*24), axis=1)

# plt.hist(icustays["LOS"], bins=500)
# plt.xlim([0, 5])

# ICD Codes

In [None]:
# ICD codes (both diagnoses and procedures)

icd_diag = pd.read_csv("./../../MIMIC_III_DIAGNOSES_ICD.csv")
icd_proc = pd.read_csv("./../../PROCEDURES_ICD.csv")

icds = pd.concat([icd_diag, icd_proc])
icds["patientId"] = icds["SUBJECT_ID"]
icds["typeId"] = 1
icds["stateId"] = 1 # ICD codes are simply marked "present"

In [None]:
# ICD codes for training set

if PREPARE_TRAINING_DATA:
    icds_subset = only_these_subj_and_hadm_ids(icds, train_subj_ids, train_hadm_ids)

    icds_mapper, icds_subset = unique_mapping(icds_subset, "ICD9_CODE", remove_common=True, remove_rare=True, plot=True)
    with open("icds_mapper.pkl", 'wb') as file:
        pickle.dump(icds_mapper, file)

    icds_subset_mixehr = get_mixehr_format_with_freqs(icds_subset)
    icds_subset_mixehr.to_csv("train_icds_mixehr.csv", index=False)

In [None]:
# ICD codes for vali and test sets

for idset in ["vali", "test"]:
    subj_ids = eval(idset + "_subj_ids")
    hadm_ids = eval(idset + "_hadm_ids")
    icds_subset = only_these_subj_and_hadm_ids(icds, subj_ids, hadm_ids)
    icds_mapper = pickle.load(open("icds_mapper.pkl", 'rb'))
    icds_subset = prepare_vali_test_sets(icds_subset, icds_mapper, "ICD9_CODE", idset)
    icds_subset_mixehr = get_mixehr_format_with_freqs(icds_subset)
    icds_subset_mixehr.to_csv(idset + "_icds_mixehr.csv", index=False)

# Prescriptions

In [None]:
# Prescriptions

# I make these simplifications to routes of administration: 
# All oral and gastric tube administration routes are denoted "PO/GT" without distinction
# All duodenal routes (e.g. nasoduodenal) are marked as "PO/GT", same as gastric routes
# All intraocular drugs are "OU" (no left vs right vs both eyes etc)
# All intraaural drugs are "EAR" (no left vs right vs both ears etc)
route_dict = {
    "PO": "PO/GT", # By mouth or gastric tube
    "ORAL": "PO/GT", 
    "PO/NG": "PO/GT",
    "PO/OG": "PO/GT",
    "NG": "PO/GT",
    "OG": "PO/GT",
    "PO OR ENTERAL TUBE": "PO/GT",
    "G TUBE": "PO/GT",
    "ND": "PO/GT",
    "PO OR ENTERAL TUBE": "PO/GT",
    "ENTERAL TUBE ONLY ? NOT ORAL": "PO/GT",
    "RIGHT EYE": "OU", # Ocular/eye
    "LEFT EYE": "OU",
    "BOTH EYES": "OU",
    "OS": "OU",
    "OD": "OU",
    "AD": "EAR",
    "AU": "EAR",
    "AS": "EAR",
    "BOTH EARS": "EAR",
    "LEFT EAR": "EAR",
    "RIGHT EAR": "EAR",
    "NAS": "NU", # Nasal
    "NS": "NU",
    "IN": "NU"
}

prescrips = pd.read_csv("./../../PRESCRIPTIONS.csv")
prescrips["patientId"] = prescrips["SUBJECT_ID"]
prescrips["typeId"] = 2
prescrips["stateId"] = 1 # Prescriptions are simply marked "present"

prescrips["ROUTE"] = prescrips["ROUTE"].fillna("UNSPECIFIED")

# Simplify the route abbreviations - e.g., for this cardiac mixehr, assume anything oral/gastric/intraduodenal is the same
prescrips["simplified_route"] = prescrips["ROUTE"].map(route_dict).fillna(prescrips["ROUTE"]).astype(str)

# Concatenate "DRUG" column with simplified route column
prescrips["drug_id"] = prescrips["DRUG"].astype(str) + "-" + prescrips["simplified_route"]

In [None]:
# Prescriptions for training set

if PREPARE_TRAINING_DATA:
    prescrips_subset = only_these_subj_and_hadm_ids(prescrips, train_subj_ids, train_hadm_ids)

    prescrips_subset = within_x_secs_of_admit(
        prescrips_subset, 
        3600*24, 
        only_these_subj_and_hadm_ids(pd.read_csv("./../../ICUSTAYS.csv"), train_subj_ids, train_hadm_ids),
        col_name="STARTDATE"
    )

    prescrips_mapper, prescrips_subset = unique_mapping(prescrips_subset, "drug_id", remove_common=True, remove_rare=True, plot=True)
    with open("prescrips_mapper.pkl", 'wb') as file:
        pickle.dump(prescrips_mapper, file)

    prescrips_subset_mixehr = get_mixehr_format_with_freqs(prescrips_subset)
    prescrips_subset_mixehr.to_csv("train_prescrips_mixehr.csv", index=False)

In [None]:
# Prescriptions for vali and test sets

for idset in ["vali", "test"]:
    subj_ids = eval(idset + "_subj_ids")
    hadm_ids = eval(idset + "_hadm_ids")
    prescrips_subset = only_these_subj_and_hadm_ids(prescrips, subj_ids, hadm_ids)

    prescrips_subset = within_x_secs_of_admit(
        prescrips_subset, 
        3600*24, 
        only_these_subj_and_hadm_ids(pd.read_csv("./../../ICUSTAYS.csv"), subj_ids, hadm_ids),
        col_name="STARTDATE"
    )
    
    prescrips_mapper = pickle.load(open("prescrips_mapper.pkl", 'rb'))
    prescrips_subset = prepare_vali_test_sets(prescrips_subset, prescrips_mapper, "drug_id", idset)
    prescrips_subset_mixehr = get_mixehr_format_with_freqs(prescrips_subset)
    prescrips_subset_mixehr.to_csv(idset + "_prescrips_mixehr.csv", index=False)

# Lab tests

In [None]:
# LAB TESTS

labs = pd.read_csv("./../../LABEVENTS.csv")
labs["patientId"] = labs["SUBJECT_ID"]
labs["typeId"] = 3

# For some reason, lots of missing HADM_IDs in the labs
# They could mostly not be recovered from ICUSTAYS.csv
# using SUBJECT_ID and timestamps. 
labs = labs[labs["HADM_ID"].notna()]

labs["lab_id"] = labs["ITEMID"]

In [None]:
# Lab tests for training set

if PREPARE_TRAINING_DATA:
    labs_subset = only_these_subj_and_hadm_ids(labs, train_subj_ids, train_hadm_ids)
    
    labs_subset = within_x_secs_of_admit(
        labs_subset, 
        3600*24, 
        only_these_subj_and_hadm_ids(pd.read_csv("./../../ICUSTAYS.csv"), train_subj_ids, train_hadm_ids)
    )
    
    # 0 for nan or "DELTA" value in FLAG column, 1 for "abnormal"
    labs_subset["stateId"] = labs_subset.apply(lambda row: 1 if row["FLAG"] == "abnormal" else 0, axis=1)

    labs_mapper, labs_subset = unique_mapping(labs_subset, "lab_id", remove_common=True, remove_rare=True, plot=True, sg_window=51)
    with open("labs_mapper.pkl", 'wb') as file:
        pickle.dump(labs_mapper, file)

    labs_subset_mixehr = get_mixehr_format_with_freqs(labs_subset)
    labs_subset_mixehr.to_csv("train_labs_mixehr.csv", index=False)

In [None]:
# Labs for vali and test sets

for idset in ["vali", "test"]:
    subj_ids = eval(idset + "_subj_ids")
    hadm_ids = eval(idset + "_hadm_ids")
    labs_subset = only_these_subj_and_hadm_ids(labs, subj_ids, hadm_ids)
    
    labs_subset = within_x_secs_of_admit(
        labs_subset, 
        3600*24, 
        only_these_subj_and_hadm_ids(pd.read_csv("./../../ICUSTAYS.csv"), subj_ids, hadm_ids)
    )
    
    # 0 for nan or "DELTA" value in FLAG column, 1 for "abnormal"
    labs_subset["stateId"] = labs_subset.apply(lambda row: 1 if row["FLAG"] == "abnormal" else 0, axis=1)
    
    labs_mapper = pickle.load(open("labs_mapper.pkl", 'rb'))
    labs_subset = prepare_vali_test_sets(labs_subset, labs_mapper, "lab_id", idset)
    labs_subset_mixehr = get_mixehr_format_with_freqs(labs_subset)
    labs_subset_mixehr.to_csv(idset + "_labs_mixehr.csv", index=False)

# Clinical notes

In [None]:
# Set up NLTK for clinical notes

import re
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from string import punctuation
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer

nltk.download('stopwords')
nltk.download('punkt')

def get_tokens(text, stop_words, punctuation, restricted_dict=None):
    # Lowercase then tokenize, after removing ** wherever it appears (common around dates etc in MIMIC notes)
    word_tokens = word_tokenize(text.replace("**", "").lower()) 
    return [
        token for token in word_tokens 
        if token not in stop_words 
        and ((restricted_dict is None) or (token in restricted_dict)) # Keep only tokens in a set list
    ]

stop_words = set(stopwords.words('english'))
punctuation = set(punctuation)
punctuation.update(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])

In [None]:
# Save separate note files for training, vali, and test notes

if MAKE_SEPARATE_NOTE_CSVS:
    notes = pd.read_csv("./../../NOTEEVENTS.csv")

    notes_train = only_these_subj_and_hadm_ids(notes, train_subj_ids, train_hadm_ids)
    notes_train.to_csv("TRAIN_NOTEEVENTS.csv")

    notes_vali = only_these_subj_and_hadm_ids(notes, vali_subj_ids, vali_hadm_ids)
    notes_vali.to_csv("VALI_NOTEEVENTS.csv")

    notes_test = only_these_subj_and_hadm_ids(notes, test_subj_ids, test_hadm_ids)
    notes_test.to_csv("TEST_NOTEEVENTS.csv")

In [None]:
# Build dictionary from clinical notes training set (takes about 40 minutes)
from tqdm import tqdm

if BUILD_NOTES_DICTIONARY: 
    notes = pd.read_csv("./TRAIN_NOTEEVENTS.csv", usecols=["SUBJECT_ID", "HADM_ID", "CHARTTIME", "TEXT"])
    num_notes = len(notes)
    notes = only_these_subj_and_hadm_ids(notes, train_subj_ids, train_hadm_ids)
    assert num_notes == len(notes), "Inconsistent selection of training set!"
    
    notes = within_x_secs_of_admit(
        notes, 
        3600*24, 
        only_these_subj_and_hadm_ids(pd.read_csv("./../../ICUSTAYS.csv"), train_subj_ids, train_hadm_ids)
    )

    # Initialize an empty FreqDist
    fdist = nltk.FreqDist()

    for pt in tqdm(train_subj_ids):
        pt_notes = notes[notes["SUBJECT_ID"] == pt]
        concat_notes = " ".join(pt_notes["TEXT"])
        tokens = get_tokens(concat_notes, stop_words, punctuation)
        fdist.update(set(tokens))

    word_freq_df = pd.DataFrame(fdist.items(), columns=['word', 'frequency'])
    word_freq_df = word_freq_df.sort_values(by=["frequency"], ascending=False).reset_index(drop=True)
    word_freq_df.to_csv("train_notes_word_freqs.csv", index=False)

In [None]:
# Do some filtering on the extracted dictionary and make a bidirectional mapping

if BUILD_NOTES_DICTIONARY:
    word_freq_df = pd.read_csv("train_notes_word_freqs.csv")
    word_freq_df = word_freq_df.sort_values(by=["frequency"], ascending=False).reset_index(drop=True)
    word_freq_df.to_csv("train_notes_word_freqs.csv", index=False)
    
    # Only keep words that occur for at least k patients
    k = 3
    length_before = len(word_freq_df)
    word_freq_df = word_freq_df[word_freq_df["frequency"] >= k]
    print(length_before-len(word_freq_df), "of", length_before, "words removed because they occurred for less than", k, "patients")

    # Remove any words that contain punctuation or numbers
    length_before = len(word_freq_df)
    word_freq_df["nopunc"] = word_freq_df.apply(lambda row: set(str(row["word"])).isdisjoint(punctuation), axis=1)
    word_freq_df = word_freq_df[word_freq_df["nopunc"]].drop(["nopunc"], axis=1)
    print(length_before-len(word_freq_df), "of", length_before, "words removed because they contained numbers or punctuation")

    # Remove common words based on inflection point
    length_before = len(word_freq_df)
    inflection_val, inflection_ind = get_inflection_point(word_freq_df["frequency"], plotlabel="Notes word")
    word_freq_df = word_freq_df[inflection_ind:].reset_index(drop=True)
    print(length_before-len(word_freq_df), "common words removed from the original set of", length_before)

    word_freq_df.to_csv("cleaned_notes_dict.csv", index=False)

    # Make bidict, a bidirectional dictionary
    note_words_mapper = bidict(dict(zip(range(1, len(word_freq_df)+1), word_freq_df["word"].tolist())))
    with open("note_words_mapper.pkl", 'wb') as file:
        pickle.dump(note_words_mapper, file)

In [None]:
# Make the final dataframe(s) for clinical notes (takes about 10 minutes for training set)
from tqdm import tqdm

idsets = []
if PREPARE_TRAINING_DATA:
    idsets.append("train")
if PREPARE_VALI_TEST_DATA:
    idsets.extend(["vali", "test"])

note_words_mapper = pickle.load(open("note_words_mapper.pkl", 'rb'))    

for idset in idsets:
    notes = pd.read_csv("./" + idset.upper() + "_NOTEEVENTS.csv", usecols=["SUBJECT_ID", "HADM_ID", "CHARTTIME", "TEXT"])
    num_notes = len(notes)
    subj_ids = eval(idset + "_subj_ids")
    hadm_ids = eval(idset + "_hadm_ids")
    notes = only_these_subj_and_hadm_ids(notes, subj_ids, hadm_ids)
    assert num_notes == len(notes), "Inconsistent selection of " + idset + " set!"
    
    notes = within_x_secs_of_admit(
        notes, 
        3600*24, 
        only_these_subj_and_hadm_ids(pd.read_csv("./../../ICUSTAYS.csv"), subj_ids, hadm_ids)
    )

    cleaned_notes_dict = set(pd.read_csv("cleaned_notes_dict.csv")["word"].tolist())

    # Iterate over each patient in the training, vali, or test DataFrame
    # ASSUME ONLY ONE HADM PER PATIENT!
    pt_dfs = []
    for pt in tqdm(subj_ids):
        pt_notes = notes[notes["SUBJECT_ID"] == pt]
        concat_notes = " ".join(pt_notes["TEXT"])
        tokens = get_tokens(concat_notes, stop_words, punctuation, restricted_dict=cleaned_notes_dict)
        fdist = nltk.FreqDist(tokens)
        pt_df = pd.DataFrame(fdist.items(), columns=['word', 'freq'])
        pt_df["patientId"] = pt
        pt_dfs.append(pt_df)

    note_words_mixehr = pd.concat(pt_dfs)

    note_words_mixehr["typeId"] = 4
    note_words_mixehr["pheId"] = note_words_mixehr.apply(lambda row: note_words_mapper.inverse[row["word"]], axis=1)
    note_words_mixehr["stateId"] = 1

    note_words_mixehr = note_words_mixehr[["patientId", "typeId", "pheId", "stateId", "freq"]]

    note_words_mixehr.to_csv(idset + "_note_words_mixehr.csv", index=False)

In [None]:
# Used this to view some words with low/uncommon frequencies to decide on a cutoff

# word_freq_df = pd.read_csv("notes_word_freqs.csv")
# print(len(word_freq_df))
# pd.set_option('display.max_rows', None)
# print(len(word_freq_df[word_freq_df["frequency"] == 2]))
# word_freq_df[word_freq_df["frequency"] == 2].iloc[0:200]

# ECG

In [None]:
traj = pd.read_csv("./../../trajectories.csv")
print("Mean number of HADM_IDs per patient:", traj.groupby("subject_id")["hadm_id"].nunique().mean())

In [None]:
ecg = pd.read_csv("./../../trajectories_with_features.csv")

# Define which ecg features we are using
misc_features = ["ibi", "bpm", "sdnn", "sdsd", "rmssd", "pnn50", "pnn20"]
cols_list = [col for col in list(ecg.columns.values) if (
    "full_" in col or "swt_" in col or "HRV_" in col or col in misc_features
)]

# Remove features with 100% missingness
ecg = ecg.replace([np.inf, -np.inf], np.nan)
ecg_features = []
for col in cols_list:
    prop_missing = ecg[col].isnull().sum()/len(ecg.index)
    if prop_missing < 1:
        ecg_features.append(col)
ecg_features.remove("full_waveform_duration") # This variable probably doesn't make sense as a feature

ecg = ecg.rename(columns={"subject_id": "SUBJECT_ID", "hadm_id": "HADM_ID"})[["SUBJECT_ID", "HADM_ID", "start_hr", "ts_idx"] + ecg_features]

idsets = []
if PREPARE_TRAINING_DATA:
    idsets.append("train")
if PREPARE_VALI_TEST_DATA:
    idsets.extend(["vali", "test"])
    
# If we're preparing the training data, we are generating the quantiles. 
# For vali and test data, we use the quantile values drawn from the training set. 
if not PREPARE_TRAINING_DATA and PREPARE_VALI_TEST_DATA:
    quantiles_dict = pickle.load(open("ecg_quantiles_dict.pkl", 'rb'))
else:
    quantiles_dict = {}
    
for idset in idsets:
    subj_ids = eval(idset + "_subj_ids")
    hadm_ids = eval(idset + "_hadm_ids")
    ecg_subset = only_these_subj_and_hadm_ids(ecg, subj_ids, hadm_ids)
    
    # Keep only trajectories that start within the first 12 hours of ICU stay
    ecg_subset = ecg_subset[ecg_subset["start_hr"] <= 12]

    # Prepare quantile features - i.e., "True" if the feature (e.g. ECG standard deviation in a one-hour segment)
    # is in some "extreme" quantile (e.g. 25th or 75th percentile)
    quantile_features = []
    for feat in ecg_features:
        if PREPARE_TRAINING_DATA and idset == "train":
            quantiles_dict[feat] = [ecg_subset[feat].quantile(0.25), ecg_subset[feat].quantile(0.75)]
        low_quantile_feat_col = feat + "-low"
        high_quantile_feat_col = feat + "-high"
        ecg_subset["value@" + low_quantile_feat_col] = ecg_subset[feat] < quantiles_dict[feat][0]
        ecg_subset["value@" + high_quantile_feat_col] = ecg_subset[feat] > quantiles_dict[feat][1]
        quantile_features.extend([low_quantile_feat_col, high_quantile_feat_col])
    ecg_subset = ecg_subset.drop(ecg_features, axis=1)
    
    if PREPARE_TRAINING_DATA and idset == "train":
        with open("ecg_quantiles_dict.pkl", 'wb') as file:
            pickle.dump(quantiles_dict, file)

    # This is the key reshaping function - from "wide" format with one column per quantile feature, 
    # to each instance of any feature having its row in a df with lots and lots of rows. 
    long_ecg = pd.wide_to_long(ecg_subset, stubnames="value", i=["SUBJECT_ID", "HADM_ID", "ts_idx"], 
                               j="feature", sep="@", suffix="(" + "|".join(quantile_features) + ")").reset_index()
    
    # Keep only instances where the quantiles were exceeded
    long_ecg = long_ecg[long_ecg["value"]]

    long_ecg["patientId"] = long_ecg["SUBJECT_ID"]
    long_ecg["typeId"] = 5 # ecg data designated as type 5
    long_ecg["stateId"] = 1 # ecg quantile features are simply marked "present"

    # Get mapping dictionary from feature names to mixehr integers
    if PREPARE_TRAINING_DATA and idset == "train":
        ecg_quantile_features_mapper = bidict(dict(zip(range(1, len(quantile_features)+1), quantile_features)))
        with open("ecg_quantile_features_mapper.pkl", 'wb') as file:
            pickle.dump(ecg_quantile_features_mapper, file)
    else:
        ecg_quantile_features_mapper = pickle.load(open("ecg_quantile_features_mapper.pkl", 'rb'))
        
    long_ecg["pheId"] = long_ecg["feature"].map(ecg_quantile_features_mapper.inverse)

    ecg_quantiles_mixehr = get_mixehr_format_with_freqs(long_ecg)
    ecg_quantiles_mixehr.to_csv(idset + "_ecg_quantiles_mixehr.csv", index=False)

## Combining modalities for mixehr training

In [None]:
# Identify patients who are included in some conditions but not others
# (modify manually as new conditions are added)
for idset in ["train", "vali", "test"]:

    dfs = []
    dfs.extend([
        #pd.read_csv(idset + "_icds_mixehr.csv"),
        pd.read_csv(idset + "_prescrips_mixehr.csv"),
        pd.read_csv(idset + "_labs_mixehr.csv"),
    ])

    no_notes_no_waveforms = pd.concat(dfs)

    dfs.extend([
        pd.read_csv(idset + "_note_words_mixehr.csv"),
        pd.read_csv(idset + "_ecg_quantiles_mixehr.csv"),
    ])
    
    everything = pd.concat(dfs)
    
    ecg_quantiles_only = pd.read_csv(idset + "_ecg_quantiles_mixehr.csv")
    
    print("------------")
    print(idset)
    print(no_notes_no_waveforms["patientId"].nunique())
    print(everything["patientId"].nunique())
    print(ecg_quantiles_only["patientId"].nunique())
    
    print((everything.apply(set) - no_notes_no_waveforms.apply(set))["patientId"])
    print((everything.apply(set) - ecg_quantiles_only.apply(set))["patientId"])

In [None]:
# Combine all of the mixehr-formatted csvs from prescriptions, labs, note words, and ecg feature quantiles

if COMBINE_DATA:
    for idset in ["train", "vali", "test"]:
        mixehr_data = pd.concat([
            #pd.read_csv(idset + "_icds_mixehr.csv"),
            pd.read_csv(idset + "_prescrips_mixehr.csv"),
            pd.read_csv(idset + "_labs_mixehr.csv"),
            pd.read_csv(idset + "_note_words_mixehr.csv"),
            pd.read_csv(idset + "_ecg_quantiles_mixehr.csv"),
        ])

        mixehr_data = mixehr_data.sort_values(by=["patientId", "typeId", "pheId", "stateId"])
        mixehr_metadata = get_mixehr_metadata(mixehr_data)

        mixehr_data.to_csv(idset + "_mixehr_early_with_ecg_quantiles.csv", index=False)
        mixehr_df_to_txt(mixehr_data, idset + "_mixehr_early_with_ecg_quantiles.txt", metadata=False)

        if idset == "train":
            mixehr_metadata.to_csv(idset + "_mixehr_metadata_early_with_ecg_quantiles.csv", index=False)
            mixehr_df_to_txt(mixehr_metadata, idset + "_mixehr_metadata_early_with_ecg_quantiles.txt", metadata=True)

In [None]:
# Set up to train mixehr with only ecg feature quantiles

if COMBINE_DATA:
    for idset in ["train", "vali", "test"]:
        mixehr_data = pd.concat([
            pd.read_csv(idset + "_ecg_quantiles_mixehr.csv"),
        ])

        mixehr_data = mixehr_data.sort_values(by=["patientId", "typeId", "pheId", "stateId"])
        mixehr_metadata = get_mixehr_metadata(mixehr_data)

        mixehr_data.to_csv(idset + "_mixehr_early_only_ecg_quantiles.csv", index=False)
        mixehr_df_to_txt(mixehr_data, idset + "_mixehr_early_only_ecg_quantiles.txt", metadata=False)

        if idset == "train":
            mixehr_metadata.to_csv(idset + "_mixehr_metadata_early_only_ecg_quantiles.csv", index=False)
            mixehr_df_to_txt(mixehr_metadata, idset + "_mixehr_metadata_early_only_ecg_quantiles.txt", metadata=True)

In [None]:
# Combine all of the mixehr-formatted csvs from prescriptions, labs, and note words
# NO WAVEFORMS

if COMBINE_DATA:
    for idset in ["train", "vali", "test"]:
        mixehr_data = pd.concat([
            #pd.read_csv(idset + "_icds_mixehr.csv"),
            pd.read_csv(idset + "_prescrips_mixehr.csv"),
            pd.read_csv(idset + "_labs_mixehr.csv"),
            pd.read_csv(idset + "_note_words_mixehr.csv"),
        ])

        mixehr_data = mixehr_data.sort_values(by=["patientId", "typeId", "pheId", "stateId"])
        mixehr_metadata = get_mixehr_metadata(mixehr_data)

        mixehr_data.to_csv(idset + "_mixehr_early_no_waveforms.csv", index=False)
        mixehr_df_to_txt(mixehr_data, idset + "_mixehr_early_no_waveforms.txt", metadata=False)

        if idset == "train":
            mixehr_metadata.to_csv(idset + "_mixehr_metadata_early_no_waveforms.csv", index=False)
            mixehr_df_to_txt(mixehr_metadata, idset + "_mixehr_metadata_early_no_waveforms.txt", metadata=True)

In [None]:
# Combine all of the mixehr-formatted csvs from prescriptions, and labs
# NO WAVEFORMS AND NO NOTES

if COMBINE_DATA: 
    for idset in ["train", "vali", "test"]:
        mixehr_data = pd.concat([
            #pd.read_csv(idset + "_icds_mixehr.csv"),
            pd.read_csv(idset + "_prescrips_mixehr.csv"),
            pd.read_csv(idset + "_labs_mixehr.csv"),
        ])

        mixehr_data = mixehr_data.sort_values(by=["patientId", "typeId", "pheId", "stateId"])
        mixehr_metadata = get_mixehr_metadata(mixehr_data)

        mixehr_data.to_csv(idset + "_mixehr_early_no_notes_no_waveforms.csv", index=False)
        mixehr_df_to_txt(mixehr_data, idset + "_mixehr_early_no_notes_no_waveforms.txt", metadata=False)

        if idset == "train":
            mixehr_metadata.to_csv(idset + "_mixehr_metadata_early_no_notes_no_waveforms.csv", index=False)
            mixehr_df_to_txt(mixehr_metadata, idset + "_mixehr_metadata_early_no_notes_no_waveforms.txt", metadata=True)

In [None]:
# Print some stats about mixehr-formatted dataframe (compare to mixehr printouts for debugging)

print("Df length:", len(mixehr_data))
print("Num of typeId types (including labs): ", mixehr_data["typeId"].nunique())
phetypes = mixehr_data[mixehr_data["typeId"] != 3][["typeId", "pheId"]].groupby(["typeId"]).nunique().reset_index().rename(columns={0:'count'})
print("Num of non-lab pheId types: ", phetypes["pheId"].sum())
print("Num of lab tests: ", mixehr_data[mixehr_data["typeId"] == 3]["pheId"].nunique())
print("Num of patients: ", mixehr_data["patientId"].nunique())
print("Max pt ID: ", mixehr_data["patientId"].max())

# Miscellaneous visualizations etc

In [None]:
ecg = pd.read_csv("./../../trajectories_with_features.csv")

plt.style.use('ggplot')
matplotlib.rcParams['figure.dpi'] = 300

# Define which ecg features we are using
misc_features = ["ibi", "bpm", "sdnn", "sdsd", "rmssd", "pnn50", "pnn20"]
cols_list = [col for col in list(ecg.columns.values) if (
    "full_" in col or "swt_" in col or "HRV_" in col or col in misc_features
)]

# Remove features with 100% missingness
ecg = ecg.replace([np.inf, -np.inf], np.nan)
ecg_features = []
for col in cols_list:
    prop_missing = ecg[col].isnull().sum()/len(ecg.index)
    if prop_missing < 1:
        ecg_features.append(col)
ecg_features.remove("full_waveform_duration") # This variable probably doesn't make sense as a feature
matplotlib.rcParams['figure.dpi'] = 100

savelist = ["HRV_CVI", "HRV_pNN20", "HRV_pNN50"]

for feat in ecg_features:
    print(feat)
    plt.figure(figsize=(4,2))
    most_of_dist = ecg[(ecg[feat] > ecg[feat].median() - 3*ecg[feat].std()) & (ecg[feat] < ecg[feat].median() + 3*ecg[feat].std())]
    plt.hist(most_of_dist[feat], bins=50, label="hist", color="dodgerblue", density=True)
    plt.xlim(ecg[feat].mean() - 3*ecg[feat].std(), ecg[feat].mean() + 3*ecg[feat].std())
    plt.axvline(ecg[feat].mean(), color="red", label="mean")
    plt.axvline(ecg[feat].median(), color="green", label="median")
    #print(ecg[feat].quantile(0.1))
    plt.axvline(ecg[feat].quantile(0.25), color="k", label="quantile")
    plt.axvline(ecg[feat].quantile(0.75), color="k")
    plt.legend(loc="upper left")
    plt.title(feat)
    plt.tight_layout()
    if feat in savelist:
        plt.savefig(feat + ".png", dpi=300)
    plt.show()
    # Variables with interesting bimodal distributions: 
    # HRV_pNN50, HRV_pNN20, HRV_HTI, HRV_HF