# Build Patient Trajectories

This file queries the MIMIC database and generates patient trajectories - ordered lists of tokenised clinical events and associated target endpoints that can be fed into prediction models

Notes:

1. These scripts work largely on parquet files directly for simplicity, except for vital signs, where loading into memory is prohibitive, so you must run a glue crawler to get processed data added to the glue catalogue to proceed with vital signs processing.

2. In order to avoid label leakage into embedding generation, ICU admissions and discharges are tokenized separately to the rest of the patient trajectory

3. We generate two versions of the patient trajectory - one based in clinical and administrative events (pathology, prescriptions, procedures, historical diagnoses, ventilation & patient demography) and one based in the trajectory of charted vital signs observations.

4. For non-discrete events such as numeric pathology results, procedure durations, ventilation durations or numeric chart observations, we generate decile or quartile information to use as discrete labels

In [None]:
from pyathena import connect
from pyathena.util import as_pandas
from pyathena.util import to_sql
import pandas as pd
import awswrangler as wr
import sys,os,os.path, datetime, os, boto3, pickle
import numpy as np
import tensorflow as tf
from itertools import product, repeat
from collections import Counter
import s3fs

In [None]:
s3 = s3fs.S3FileSystem(anon=False)

from utils.connections import connection, cursor, gluedatabase, processed_db, upload_file, processed_data_bucket, download_file, get_or_create_bucket, account_id, get_s3_keys_as_generator
from utils.datagen import PROCESSED_DATAPATH

In [None]:
assert(tf.__version__[0]=='2')

## Trajectory pre-processing

In [None]:
def get_distinct_lab_events(RELOAD=False):    
    
    # We select all distinct test types from the lab events list 
    
    if RELOAD:
        wr.s3.delete_objects(path=f's3://{processed_data_bucket}/distinct_path')
        query=f'select distinct(itemid) from {gluedatabase}.labevents;'
        item_ids = pd.read_sql(query, connection)
        wr.s3.to_parquet(df=item_ids, 
                         database=gluedatabase, 
                         table='distinct_path',
                         dataset=True,
                         path=f's3://{processed_data_bucket}/distinct_path')
    else:
        item_ids = wr.s3.read_parquet(path=f's3://{processed_data_bucket}/distinct_path')
    return item_ids

In [None]:
def get_distinct_chart_events(RELOAD=False):    
    
    # We select all distinct observation types from the chart events list 
    
    if RELOAD:
        wr.s3.delete_objects(path=f's3://{processed_data_bucket}/distinct_chart')
        query=f'select distinct(itemid) from {gluedatabase}.chartevents;'
        item_ids = pd.read_sql(query, connection)
        wr.s3.to_parquet(df=item_ids, 
                         database=gluedatabase, 
                         table='distinct_chart',
                         dataset=True,
                         path=f's3://{processed_data_bucket}/distinct_chart')
    else:
        item_ids = wr.s3.read_parquet(path=f's3://{processed_data_bucket}/distinct_chart')
    return item_ids

In [None]:
def get_path_tokens(item_ids, RELOAD=False):
    
    # Load all distinct results per event type and return relevant decile information
    
    if RELOAD:
        wr.s3.delete_objects(path=f's3://{processed_data_bucket}/path_tokens')
        keys = []
    else:
        try:
            keys = list(wr.s3.read_parquet_table(database=gluedatabase, table='path_tokens').itemid.unique())
        except EntityNotFoundException:
            keys = []
    found_keys = [i for i in item_ids.itemid if i in keys]
    to_find_keys = [i for i in item_ids.itemid if not i in keys]
    print(f'loading {len(found_keys)} keys from file and {len(to_find_keys)} keys from db')
    
    unit_mapper = {'mg/24hr':('mg/24hours', 1), 
                   'ug/dl':('ng/ml', 10), 
                   '#/ul':('#/cu mm', 1), 
                   'ug/l':('ng/ml', 1), 
                   'mosm/kg':('mosm/l', 1), 
                   'mg/l':('mg/dl', 0.1), 
                   'miu/ml':('miu/l', 1000), 
                   'iu/ml':('i.u.', 1), 
                   'sec':('seconds', 1), 
                   'eu/dl':('mg/dl', 0.037427), 
                   'uiu/ml':('uu/ml', 1)}

    for i, item in enumerate(to_find_keys):
        # for all test types that have numeric data, we select all test results for that given test type
        query=f'select subject_id, hadm_id, itemid, charttime, valuenum, valueuom from {gluedatabase}.labevents where itemid={item};'
        item_range = pd.read_sql(query, connection)
        item_range = item_range.rename(columns={'valuenum': 'val', 'valueuom': 'units', 'charttime': 'time'})
        item_range = item_range.dropna(subset=['hadm_id'])
        item_range = item_range.astype({'subject_id': 'int64', 
                                        'hadm_id': 'int64', 
                                        'itemid': 'int64', 
                                        'val': 'double', 
                                        'units': 'string'})
        item_range.units = item_range.units.str.lower()
        try:
            unit_count = item_range.units.value_counts()
            if len(unit_count) > 1:
                for unit in unit_count.index:
                    if unit in unit_mapper.keys():
                        multiplier = unit_mapper[unit][1]
                        replacement_unit = unit_mapper[unit][0]
                        item_range.loc[item_range.units == unit, 'val'] *= multiplier
                        item_range.loc[item_range.units == unit, 'units'] = replacement_unit
            item_range['decile'] = pd.qcut(item_range.val, q=10, labels=list(range(10)))        
            item_range['label'] = item_range.itemid.astype(str) + '_' + item_range.decile.astype(str)
            item_range = item_range.rename(columns={})
        except:         
            # will fail if non-numeric so no deciles for this selection - just include that 
            # this non-numeric test was ordered (presumably still has some signal value even 
            # without the result value, as it can be a proxy for clinician's intuition/which 
            # variables were viewed as in-scope)
            item_range['label'] = item_range.itemid.astype(str)
            
        # for every token type, we retain the label only if there are more than 20 instances
        # which means for numeric tests, we do not include them in the trajectory if there are 
        # fewer than 200 valid numeric results in the database
        retain_labels = list(item_range.label.value_counts()[item_range.label.value_counts()>20].index)
        d = item_range[item_range.label.isin(retain_labels)]
        
        if len(d) > 0:
            wr.s3.to_parquet(df=d[['subject_id', 'hadm_id', 'time', 'label', 'itemid']], 
                             database=gluedatabase, 
                             table='path_tokens',
                             dataset=True,
                             partition_cols=['itemid'],
                             path=f's3://{processed_data_bucket}/path_tokens')
        else:
            print(f'empty dataframe for itemid={item}')
        if i % 20 == 0:
            print(f'{i} of {len(to_find_keys)}')

In [None]:
def check_vs_to_process():   
    
    # Quantity of vital signs data is very large so this function returns the list 
    # of chart-event types still requiring pre-processing.  We only include chart
    # obs for the first 6 hours of the ICU admission.
    
    query = f'select distinct(itemid), count(*) from {processed_db}.vs_windows group by itemid;'
    processed_keys = pd.read_sql(query, connection)

    query = f'select distinct(itemid), count(*) FROM {gluedatabase}.icustays i INNER JOIN {gluedatabase}.chartevents c '\
            f'ON i.icustay_id = c.icustay_id WHERE '

    condition1 = f'((c.charttime > c.storetime) AND (c.charttime > i.intime) AND ((to_unixtime(c.charttime) - to_unixtime(i.intime))/3600 < 6))'
    condition2 = f'((c.storetime > c.charttime) AND (c.storetime > i.intime) AND ((to_unixtime(c.storetime) - to_unixtime(i.intime))/3600 < 6))'  

    vs_window1 = pd.read_sql(f'{query} {condition1} group by itemid', connection)
    vs_window2 = pd.read_sql(f'{query} {condition2} group by itemid', connection)

    processed_keys.itemid = processed_keys.itemid.astype(int)

    to_process_keys = pd.merge(vs_window1, vs_window2, how='outer', on='itemid')
    to_process_keys.columns = ['itemid', 'query1', 'query2']
    to_process_keys = to_process_keys.fillna(0)

    processing_diff = pd.merge(to_process_keys, processed_keys, how='left', on='itemid')
    processing_diff.columns = ['itemid', 'query1', 'query2', 'done']

    processing_diff['todo'] = processing_diff.query1 + processing_diff.query2
    processing_diff = processing_diff.fillna(0)
    print(f'number of v.s. types still to process: {len(processing_diff[processing_diff.done < processing_diff.todo])}')
    return list(processing_diff[processing_diff.done < processing_diff.todo].itemid)


def make_vital_signs_windows(todo_items):
    
    # this function filters the vital signs data so that it only includes data captured within 
    # the first 6 hours of each ICU admission, and also partitions the data by itemid to speed up
    # tokenization steps.
    
    query = f'SELECT i.icustay_id, i.intime, i.outtime, c.itemid, c.charttime, c.storetime, c.value, '\
            f'c.valuenum, c.valueuom FROM {gluedatabase}.icustays i INNER JOIN {gluedatabase}.chartevents c '\
            f'ON i.icustay_id = c.icustay_id WHERE '

    # chart time and store time are inconsistent, so we select the latest time as the time of 
    # availability in the record to ensure it represents only data that we can be confident
    # was without question available for real-time prediction
    
    condition1 = f'((c.charttime > c.storetime) AND (c.charttime > i.intime) AND ((to_unixtime(c.charttime) - to_unixtime(i.intime))/3600 < 6))'
    condition2 = f'((c.storetime > c.charttime) AND (c.storetime > i.intime) AND ((to_unixtime(c.storetime) - to_unixtime(i.intime))/3600 < 6))'  

    i = 0

    for condition, col in zip([condition1, condition2], ['charttime', 'storetime']):
        print(col)
        for item in todo_items:
            i += 1
            vs_window = pd.read_sql(f'{query} {condition} and itemid = {item}', connection, chunksize=5000)
            for vs in vs_window:
                vs = vs.rename(columns={col: 'time'})
                vs.valuenum = vs.valuenum.astype(float)
                vs.value = vs.value.astype(str)
                vs.valueuom = vs.valueuom.astype(str)
                wr.s3.to_parquet(df=vs[['icustay_id', 'intime', 'outtime', 'itemid', 
                                       'time', 'value', 'valuenum', 'valueuom']],
                                 database=gluedatabase, 
                                 table='vs_windows',
                                 dataset=True,
                                 partition_cols=['itemid'],
                                 path=f's3://{processed_data_bucket}/vs_windows')
            if i % 200 == 0:
                print(i)                
                

def get_vs_tokens(RELOAD=False):
    # Load all distinct results per event type and return relevant decile information
    query = f'select distinct(itemid), count(*) from {processed_db}.vs_windows group by itemid;'
    processed_keys = pd.read_sql(query, connection)
    processed_keys.columns=['itemid', 'c']
    
    distinct_vs = [v for v in processed_keys[processed_keys.c > 100].itemid]
    
    RELOAD = True
    if RELOAD:
        wr.s3.delete_objects(path=f's3://{processed_data_bucket}/vs_tokens')
        done_keys = []
    else:
        query=f'select distinct(itemid) from {processed_db}.vs_tokens'
        done_keys = list(pd.read_sql(query, connection).itemid)
        
    found_keys = [i for i in distinct_vs if i in done_keys]
    to_find_keys = [int(i) for i in distinct_vs if not i in done_keys]
    print(f'loading {len(found_keys)} keys from file and {len(to_find_keys)} keys from db')
    
    query = f'select subject_id, hadm_id, icustay_id from {gluedatabase}.icustays'
    subj_stays = pd.read_sql(query, connection)
    
    # unit harmonisation for vital signs is slightly different than for test results,
    # because there are a few tests with non-convertible units (i.e. cm-->kg) therefore instead of 
    # using a mapper, we split those tests into two sub-types, as they still contain useful
    # numerical results.

    units = {}

    for i, item in enumerate(to_find_keys):
        query = f'select icustay_id, intime, outtime, itemid, time, value, valuenum, valueuom from '\
                f'{processed_db}.vs_windows where itemid = \'{item}\''
        item_range = pd.read_sql(query, connection)
        item_range.valueuom = item_range.valueuom.str.lower()
        try:
            unit_count = item_range.valueuom.value_counts()
            # if units are either x or '.' or 'none', assume all should be x
            unit_list = [x for x in unit_count.index if x != 'none' and x != '.']
            if len(unit_count) > 1: 
                print(f'splitting results for test {item}')
                it_list = []
                for j, unit in enumerate(unit_list):
                    it = item_range[item_range.valueuom == unit].copy()
                    it['decile'] = pd.qcut(it.valuenum, q=50, labels=list(range(50)), duplicates='drop')        
                    it['decile_str'] = item_range.decile.astype('category').cat.codes.astype(str)
                    it['label'] = it.itemid.astype(str) + '_' + str(j) + '_' + it.decile_str.astype(str)
                    it_list.append(it)
                item_range = pd.concat(it_list)
            else:
                item_range['decile'] = pd.qcut(item_range.valuenum, q=50, duplicates='drop')   
                item_range['decile_str'] = item_range.decile.astype('category').cat.codes.astype(str)
                item_range['label'] = item_range.itemid.astype(str) + '_' + item_range.decile_str.astype(str)
        except:
            item_range['label'] = item_range.value.astype('category').cat.codes.astype(str)
        it = pd.merge(item_range[['icustay_id', 'time', 'label', 'itemid']], subj_stays, how='left', on='icustay_id')
        
        retain_labels = list(it.label.value_counts()[it.label.value_counts()>20].index)
        d = it[it.label.isin(retain_labels)]
        
        if len(d) > 0:
            wr.s3.to_parquet(df=d[['subject_id', 'hadm_id', 'time', 'label', 'itemid']], 
                                 database=gluedatabase, 
                                 table='vs_tokens',
                                 dataset=True,
                                 partition_cols=['itemid'],
                                 path=f's3://{processed_data_bucket}/vs_tokens')
        if i % 500 == 0:
            print(i)

In [None]:
def get_proc_tokens(RELOAD=False):
    if RELOAD:
        wr.s3.delete_objects(path=f's3://{processed_data_bucket}/proc_tokens')
        query = f'select subject_id, hadm_id, starttime, endtime, itemid, value, valueuom, storetime, ordercategoryname from {gluedatabase}.procedureevents_mv'
        proc_df = pd.read_sql(query, connection)

        timed_ranges = []
        unit_mapper = {'day':1440, 
                       'hour':60, 
                       'min':1}
        timed_categories = proc_df[(proc_df.valueuom == 'hour') | 
                                   (proc_df.valueuom == 'min') | 
                                   (proc_df.valueuom == 'day')].ordercategoryname.unique()
        # convert all durations to minutes
        for category in timed_categories:
            timed_category = proc_df[proc_df.ordercategoryname == category].copy()
            units = timed_category.valueuom.unique()
            for unit in units:
                if unit in unit_mapper.keys():
                    multiplier = unit_mapper[unit]
                    timed_category.loc[timed_category.valueuom == unit, 'value'] *= multiplier
                    timed_category.loc[timed_category.valueuom == unit, 'valueuom'] = 'min'
                else:
                    raise ValueError(unit)
            timed_category['quartile'] = pd.qcut(timed_category.value, 4, labels=list(range(4)))
            timed_ranges.append(timed_category)

        timed_tokens = []
        for t in timed_ranges:
            t['label'] = t.ordercategoryname + '_' + t.quartile.astype(str)
            timed_tokens.append(t[['subject_id', 'hadm_id', 'storetime', 'label']])

        all_timed_tokens = pd.concat(timed_tokens)
        all_timed_tokens = all_timed_tokens.rename(columns={'storetime': 'time'})

        # for untimed categories, we simply use the itemid at the time that it was stored in the database

        untimed_categories = proc_df[(proc_df.valueuom != 'hour') &
                                     (proc_df.valueuom != 'min') &
                                     (proc_df.valueuom != 'day')].ordercategoryname.unique()

        untimed_tokens = proc_df[proc_df.ordercategoryname.isin(untimed_categories)][['subject_id', 'hadm_id', 'storetime', 'itemid']].copy()
        untimed_tokens.itemid = untimed_tokens.itemid.astype(str)
        untimed_tokens = untimed_tokens.rename(columns={'storetime': 'time', 'itemid': 'label'})

        proc_tokens = pd.concat([untimed_tokens, all_timed_tokens])

        retain_labels = list(proc_tokens.label.value_counts()[proc_tokens.label.value_counts()>20].index)
        d = proc_tokens[proc_tokens.label.isin(retain_labels)]
        
        wr.s3.to_parquet(df=d[['subject_id', 'hadm_id', 'time', 'label']], 
                         database=gluedatabase, 
                         table='proc_tokens',
                         dataset=True,
                         path=f's3://{processed_data_bucket}/proc_tokens')
        
    else:
        proc_tokens = wr.s3.read_parquet_table(database=gluedatabase, table='proc_tokens')
        
    return proc_tokens

In [None]:
import warnings

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    get_vs_tokens(RELOAD=True)

In [None]:
def diagnosis_splitter(diagnosis_tokens, to_split):
    d = diagnosis_tokens['label'].str.split(to_split, expand=True)
    diagnosis_tokens_split = pd.concat([diagnosis_tokens, d], axis=1)
    split_list = []
    for i in range(len(d.columns)):
        split_diag = diagnosis_tokens_split[~diagnosis_tokens_split[i].isna()][['subject_id', 'hadm_id', 'time', i]].copy()
        split_diag = split_diag.rename(columns = {i:'label'})
        split_list.append(split_diag)
    return pd.concat(split_list)

In [None]:
def getage(admtime, dob):
    age = over18(admtime, dob)
    if age != None:
        return f'age_{age}'
    else:
        return None

def over18(admtime, dob):
    age = admtime.year - dob.year
    if admtime.month < dob.month:
        age -= 1
    if age < 18:
        return None
    return age

def get_demography_tokens(RELOAD=False):
    if RELOAD:
        wr.s3.delete_objects(path=f's3://{processed_data_bucket}/demographics_tokens')
        query = f'SELECT p.subject_id, a.hadm_id, a.admittime, a.dischtime, a.diagnosis, p.gender, p.dob '\
                f'FROM {gluedatabase}.admissions a inner join {gluedatabase}.patients p on a.subject_id = p.subject_id;'
        demography_df = pd.read_sql(query, connection)
        demography_df['admission_age'] = demography_df.apply(lambda x: getage(x.admittime, x.dob), axis=1)
        # dropping patients < 18 years of age
        demography_df = demography_df.dropna(axis=0, subset=['admission_age'])
        admission_tokens = demography_df[['subject_id', 'hadm_id', 'admittime']].copy()
        admission_tokens['label'] = 'admission'
        admission_tokens = admission_tokens.rename(columns={'admittime':'time'})
        discharge_tokens = demography_df[['subject_id', 'hadm_id', 'dischtime']].copy()
        discharge_tokens['label'] = 'discharge'
        discharge_tokens = discharge_tokens.rename(columns={'dischtime':'time'})
        gender_tokens = demography_df[['subject_id', 'hadm_id', 'gender', 'admittime']].copy()
        gender_tokens = gender_tokens.rename(columns={'admittime':'time', 'gender':'label'})
        age_tokens = demography_df[['subject_id', 'hadm_id', 'admission_age', 'admittime']].copy()
        age_tokens = age_tokens.rename(columns={'admittime':'time', 'admission_age':'label'})

        diagnosis_tokens = demography_df[['subject_id', 'hadm_id', 'diagnosis', 'dischtime']].copy()
        diagnosis_tokens = diagnosis_tokens.rename(columns={'diagnosis': 'label', 'dischtime': 'time'})
        split1 = diagnosis_splitter(diagnosis_tokens, '\\')
        split2 = diagnosis_splitter(diagnosis_tokens, ';')
        diagnosis_tokens = pd.concat([split1, split2])
    
        # add 3 days to every diagnosis that is available, on the assumption that this reflects 
        # a realistic delay in diagnosis availability.
        
        diagnosis_tokens.time += np.timedelta64(3, 'D')
        demographics_tokens = pd.concat([admission_tokens, discharge_tokens, gender_tokens, age_tokens, diagnosis_tokens], sort=False)
        demographics_tokens = demographics_tokens.dropna(axis=0)
        
        retain_labels = list(demographics_tokens.label.value_counts()[demographics_tokens.label.value_counts()>20].index)
        d = demographics_tokens[demographics_tokens.label.isin(retain_labels)]
        
        wr.s3.to_parquet(df=d[['subject_id', 'hadm_id', 'time', 'label']], 
                 database=gluedatabase, 
                 table='demographics_tokens',
                 dataset=True,
                 path=f's3://{processed_data_bucket}/demographics_tokens')
    else:
        demographics_tokens = wr.s3.read_parquet_table(database=gluedatabase, table='demographics_tokens')
    return demographics_tokens

In [None]:
def get_vent_tokens(RELOAD=False):
    if RELOAD:
        wr.s3.delete_objects(path=f's3://{processed_data_bucket}/vent_tokens')

        query = f'select i.subject_id, i.hadm_id, i.icustay_id, '\
                f'v.ventnum, v.starttime, v.endtime, v.duration_hours '\
                f'from {gluedatabase}.icustays as i inner join '\
                f'{gluedatabase}.ventdurations as v on i.icustay_id = v.icustay_id'
        vent_df = pd.read_sql(query, connection)

        start_vent = vent_df[['subject_id', 'hadm_id', 'starttime']].copy()
        start_vent['label'] = 'start_vent'
        end_vent = vent_df[['subject_id', 'hadm_id', 'endtime']].copy()
        end_vent['label'] = 'end_vent'

        vent_df['dur_quant'] = pd.qcut(vent_df.duration_hours, 10, labels=list(range(10)))
        vent_df['label'] = 'ventdur_' + vent_df.dur_quant.astype(str)

        vent_df = vent_df[['subject_id', 'hadm_id', 'endtime', 'label']]
        vent_df = vent_df.rename(columns={'endtime': 'time'})
        start_vent = start_vent.rename(columns={'starttime': 'time'})
        end_vent = end_vent.rename(columns={'endtime': 'time'})

        vent_df = pd.concat([vent_df, start_vent, end_vent])
    
        retain_labels = list(vent_df.label.value_counts()[vent_df.label.value_counts()>20].index)
        d = vent_df[vent_df.label.isin(retain_labels)]
    
        wr.s3.to_parquet(df=d, 
                         database=gluedatabase, 
                         table='vent_tokens',
                         dataset=True,
                         path=f's3://{processed_data_bucket}/vent_tokens')
    else:
        vent_df = wr.s3.read_parquet_table(database=gluedatabase, table='vent_tokens')
    return vent_df

In [None]:
def get_med_tokens(RELOAD=False):
    if RELOAD:
        wr.s3.delete_objects(path=f's3://{processed_data_bucket}/med_tokens')
        query = f'select subject_id, hadm_id, startdate, formulary_drug_cd, form_val_disp from {gluedatabase}.prescriptions'
        med_tokens = pd.read_sql(query, connection)
        med_tokens['label'] = med_tokens.formulary_drug_cd + '-' + med_tokens.form_val_disp
        med_tokens = med_tokens.rename(columns={'startdate':'time'})
        med_tokens = med_tokens[['subject_id', 'hadm_id', 'time', 'label']]
        
        retain_labels = list(med_tokens.label.value_counts()[med_tokens.label.value_counts()>20].index)
        d = med_tokens[med_tokens.label.isin(retain_labels)]
        
        wr.s3.to_parquet(df=d[['subject_id', 'hadm_id', 'time', 'label']], 
                 database=gluedatabase, 
                 table='med_tokens',
                 dataset=True,
                 path=f's3://{processed_data_bucket}/med_tokens')
    else:
        med_tokens = wr.s3.read_parquet_table(database=gluedatabase, table='med_tokens')
    return med_tokens

In [None]:
def is_new_ICU_adm(row):
    if (not isinstance(row.icustay_id, np.float)):
        return False
    if (row.curr_careunit != None):
        if ('ICU' in row.curr_careunit):
            return (row.prev_careunit == None) or (not ('ICU' in row.prev_careunit))
        if ('CCU' in row.curr_careunit):
            return (row.prev_careunit == None) or (not ('CCU' in row.prev_careunit))
        if ('CSRU' in row.curr_careunit):
            return (row.prev_careunit == None) or (not ('CSRU' in row.prev_careunit))
    return False

In [None]:

def get_endpoints_df(RELOAD=False):
    
    # For simplicity, the prediction model considers only the first ICU admission 
    # of 6 or more hours duration for each hospital admission. 

    # 4 endpoints are considered:
    #
    # 1. at 6 hours after ICU admission, can we predict death within this ICU admission?
    # 2. at 6 hours after ICU admission, can we predict death within this hospital admission?
    # 3. at 6 hours after ICU admission, can we predict readmission to ICU within this hospital admission?
    # 4. at 6 hours after ICU admission, can we predict icu duration > 48hr?

    if RELOAD:
        wr.s3.delete_objects(path=f's3://{processed_data_bucket}/endpoint_df')
        query = f'SELECT a.subject_id, a.hadm_id, a.hospital_expire_flag, a.admittime, a.dischtime,'\
                f' p.dob, a.deathtime, i.intime, i.outtime FROM {gluedatabase}.admissions '\
                f' a INNER JOIN {gluedatabase}.icustays i'\
                f' ON a.hadm_id = i.hadm_id INNER JOIN {gluedatabase}.patients p on a.subject_id = p.subject_id'

        endpoint_df = pd.read_sql(query, connection)
        endpoint_df['duration'] = (endpoint_df.outtime - endpoint_df.intime)
        endpoint_df = endpoint_df.dropna(axis=0, subset=['duration'])
        endpoint_df['admission_age'] = endpoint_df.apply(lambda x: over18(x.admittime, x.dob), axis=1)
        endpoint_df.duration = endpoint_df.duration.astype(int)
        endpoint_df.hospital_expire_flag = endpoint_df.hospital_expire_flag.astype(int)
        endpoint_df = endpoint_df.sort_values(['hadm_id', 'intime'])

        wr.s3.to_parquet(df=endpoint_df, 
                         database=gluedatabase, 
                         table='endpoint_df',
                         dataset=True,
                         path=f's3://{processed_data_bucket}/endpoint_df')
    else:
        endpoint_df = wr.s3.read_parquet_table(database=gluedatabase, table='endpoint_df')
    return endpoint_df


def get_icu_adm(RELOAD=False):
    
    # We must load ICU admissions from the transfers_df so that we can account for ICU admissions on the same day.
    # Note that in the original MIMIC data, any ICU readmissions on the same day are given the same ICUSTAY_ID.
    # We combine all transfer rows where a patient is transferred, but remains within the ICU as a single
    # ICU admission.  All transfer rows where a patient is transferred from ICU onto a general ward and then back 
    # to ICU are counted as a new ICU admission (ICU readmission).
    
    if RELOAD:
        wr.s3.delete_objects(path=f's3://{processed_data_bucket}/icu_adm')
        
        query = f'select t.subject_id, t.hadm_id, t.icustay_id, t.eventtype, t.prev_careunit, t.curr_careunit,'\
                f' t.intime, t.outtime, t.los from {gluedatabase}.transfers t'
        transfers_df = pd.read_sql(query, connection)

        # locate all transfers that mark the start of a new ICU admission
        transfers_df['new_admin'] = transfers_df.apply(lambda x: is_new_ICU_adm(x), axis=1)

        # for all transfers that do not mark the start of a new ICU admin, but are within-ICU transfers, we 
        # must amalgamate the overall intime/outime to get the true full ICU admission
        transfers_df = transfers_df[~transfers_df.icustay_id.isna()]
        transfers_df = transfers_df.sort_values(['subject_id', 'hadm_id', 'intime'])

        icu_list = []
        new_id = 0

        for idx, subject in enumerate(transfers_df.subject_id.unique()):
            for hadm_id in transfers_df[transfers_df.subject_id==subject].hadm_id.unique():
                transfers = transfers_df[transfers_df.hadm_id == hadm_id]
                for icustay in transfers.icustay_id.unique():
                    icurecords = transfers[(transfers.icustay_id == icustay) & ~(transfers.los.isna())]
                    if len(icurecords[icurecords.new_admin]) > 0:
                        intimes = list(icurecords[icurecords.new_admin].intime) + [max(icurecords.outtime)]

                        for i, o in zip(intimes[:-1], intimes[1:]):
                            full_icu_stay = icurecords[(icurecords.intime >= i) & (icurecords.intime < o)]
                            new_id += 1
                            try:
                                icu_list.append({'subject_id': subject, 
                                                 'hadm_id': hadm_id, 
                                                 'icu_stay': icustay,
                                                 'new_icu_stay': new_id,
                                                 'intime': min(full_icu_stay.intime),
                                                 'outtime': max(full_icu_stay.outtime)
                                                })
                            except ValueError:
                                print(hadm_id, icustay)
            if idx % 5000 == 0:
                print(idx)
        icu_admission_endpoints = pd.DataFrame(icu_list)

        wr.s3.to_parquet(df=icu_admission_endpoints, 
                         database=gluedatabase, 
                         table='icu_adm',
                         dataset=True,
                         path=f's3://{processed_data_bucket}/icu_adm')
    else:
        icu_admission_endpoints = wr.s3.read_parquet_table(database=gluedatabase, table='icu_adm')
    return icu_admission_endpoints

In [None]:
def get_icu_tokens(full_endpoint_df, RELOAD=False):
    if RELOAD:
        wr.s3.delete_objects(path=f's3://{processed_data_bucket}/icu_tokens')
        icu_admit_tokens = full_endpoint_df[['subject_id', 'hadm_id', 'intime']].copy()
        icu_admit_tokens['label'] = 'icu_admit'
        icu_admit_tokens = icu_admit_tokens.rename(columns={'intime': 'time'})
        
        # tokens stating this admission's ordinality for the subject in question are known 
        # at the time of admission to ICU
        icu_ordinal_tokens = full_endpoint_df[['subject_id', 'hadm_id', 'intime', 'nth']].copy()
        icu_ordinal_tokens['nth'] = 'icuadm_' + icu_ordinal_tokens['nth'].astype('str')
        icu_ordinal_tokens = icu_ordinal_tokens.rename(columns={'intime': 'time', 'nth': 'label'})
        
        # duration tokens are only available at discharge
        icu_dur_tokens = full_endpoint_df[['subject_id', 'hadm_id', 'outtime', 'duration']].copy()
        icu_dur_tokens['duration'] = pd.qcut(icu_dur_tokens.duration, q=10, labels=list(range(10)))  
        icu_dur_tokens['duration'] = 'icudur_' + icu_dur_tokens['duration'].astype('str')
        icu_dur_tokens = icu_dur_tokens.rename(columns={'outtime': 'time', 'duration': 'label'})
        icu_tokens = pd.concat([icu_admit_tokens, icu_ordinal_tokens, icu_dur_tokens]).sort_values(['subject_id', 'hadm_id', 'time'])

        wr.s3.to_parquet(df=icu_tokens, 
                         database=gluedatabase, 
                         table='icu_tokens',
                         dataset=True,
                         path=f's3://{processed_data_bucket}/icu_tokens')
        
    else:
        icu_tokens = wr.s3.read_parquet_table(database=gluedatabase, table='icu_tokens')
    return icu_tokens

In [None]:
def get_full_endpoints_df(icu_admission_endpoints, endpoint_df, RELOAD=False):
    if RELOAD:
        wr.s3.delete_objects(path=f's3://{processed_data_bucket}/icu_adm')
        icu_admission_endpoints['duration'] = icu_admission_endpoints.outtime - icu_admission_endpoints.intime
        print(f'data input contains {len(icu_admission_endpoints)} ICU admissions, across {len(icu_admission_endpoints.hadm_id.unique())} hospital admissions and {len(icu_admission_endpoints.subject_id.unique())} patients')

        endpoint_df = endpoint_df.sort_values(['subject_id', 'hadm_id', 'intime'])


        full_endpoint_df = pd.merge(icu_admission_endpoints[['hadm_id', 'new_icu_stay', 'intime', 'outtime', 'duration']], 
                                    endpoint_df[['subject_id','hadm_id','hospital_expire_flag','admittime',
                                                 'dischtime','deathtime', 'admission_age']], 
                                    how='left', left_on=['hadm_id'], right_on=['hadm_id'])
        
        # Exclude rows that do not meet criteria:
        # if the age at admissin is < 18 or 
        # if the first ICU stay in a hospital admission is either:
        #              * shorter than 6 hours
        #              * or the subject dies in the first 6 hours of ICU admission
        # then that hospital admission is excluded from the model

        # ordinality of this ICU admission within the current hospital admission
        full_endpoint_df['nth'] = full_endpoint_df.groupby('hadm_id').cumcount()+1

        excluded_hadm_id = full_endpoint_df[(full_endpoint_df.nth==1) &                           # first admission
                                       ((full_endpoint_df.duration < np.timedelta64(6, 'h')) |    # and shorter than 6hr
                                        (full_endpoint_df.admission_age < 18) |                   # or < 18yo
                                        ((full_endpoint_df.hospital_expire_flag == 1) &           # or died and death time within 6hr
                                         (full_endpoint_df.deathtime < full_endpoint_df.intime + np.timedelta64(6, 'h'))))].hadm_id.unique()

        excluded = full_endpoint_df[full_endpoint_df.hadm_id.isin(excluded_hadm_id)]
        included = full_endpoint_df[~full_endpoint_df.hadm_id.isin(excluded_hadm_id)]
        
        print(f'Excluded rows: {len(excluded)}, included rows: {len(included)}')
        print(f'Excluded hospital admissions: {len(excluded.hadm_id.unique())}, hospital admissions: {len(included.hadm_id.unique())}')
        print(f'Excluded patients: {len(excluded.subject_id.unique())}, included patients: {len(included.subject_id.unique())}')

        full_endpoint_df = full_endpoint_df[~full_endpoint_df.hadm_id.isin(excluded_hadm_id)].copy()
        full_endpoint_df = full_endpoint_df.sort_values(['subject_id', 'hadm_id', 'intime'])
        full_endpoint_df.duration = full_endpoint_df.duration.astype(int)
        full_endpoint_df.hospital_expire_flag = full_endpoint_df.hospital_expire_flag.fillna(0).astype(int)

        full_endpoint_df = full_endpoint_df.dropna(axis=0, subset=['subject_id'])
        full_endpoint_df.subject_id = full_endpoint_df.subject_id.astype(int)
        
        wr.s3.to_parquet(df=full_endpoint_df, 
                 database=gluedatabase, 
                 table='full_endpoints',
                 dataset=True,
                 path=f's3://{processed_data_bucket}/full_endpoints')
    else:
        full_endpoint_df = wr.s3.read_parquet_table(database=gluedatabase, table='full_endpoints')
    return full_endpoint_df

In [None]:
def join_tokens(full_endpoint_df, RELOAD=False):

    # This function grabs all of the tokenized data (words) of different clinical types 
    # and groups them into a single table to be turned into a patient trajectory (document) 
    # and be fed into the network
    
    wr.s3.delete_objects(path=f's3://{processed_data_bucket}/all_tokens')
    pt = get_proc_tokens(RELOAD)
    print(f'got {len(pt)} procedure tokens')
    dt = get_demography_tokens(RELOAD)
    print(f'got {len(dt)} demography tokens')
    mt = get_med_tokens(RELOAD)
    print(f'got {len(mt)} medication tokens')
    vt = get_vent_tokens(RELOAD)
    print(f'got {len(vt)} ventilation tokens')
    
    # Note that ICU tokens are NOT included in the patient trajectory, so that an embedding can be learned
    # that does not leak data from any endpoints.  The ICU tokens will instead be summarised and fed into 
    # the network separately from the rest of the tokenized trajectory.
    
    # Due to the volume and frequency of the vital signs data, it is also treated separately to 
    # the other tokenized data types.
    
    ptht = wr.s3.read_parquet_table(database=gluedatabase, table='path_tokens')
    print(f'got {len(ptht)} pathology tokens')
    
    for token_list in [pt, dt, mt, vt, ptht]:
        wr.s3.to_parquet(df=token_list[['subject_id', 'hadm_id', 'time', 'label']], 
                         database=gluedatabase, 
                         table='all_tokens',
                         dataset=True,
                         path=f's3://{processed_data_bucket}/all_tokens')

In [None]:
def get_endpt_per_prediction(full_endpoint_df, RELOAD=False):
    
    # This function gets the selected prediction times and the associated relevant endpoints. 
    # Stores these in a table that can be used to build datasets.

    if RELOAD:
        wr.s3.delete_objects(path=f's3://{processed_data_bucket}/pred_endpt')
    query = f'select distinct(subject_id, hadm_id) from {processed_db}.all_tokens'
    subjects = pd.read_sql(query, connection)._col0.str.strip('{}field0=').str.split(', field1=', expand=True)
    subjects.columns=['subject_id', 'hadm_id']
    
    subject_list = list(subjects.subject_id.unique())
    print(f'loaded {len(subject_list)} subjects')
    
    six_hours = np.timedelta64(6, 'h')/np.timedelta64(1, 'ns')
    seven_days = np.timedelta64(7, 'D')/np.timedelta64(1, 'ns')
    
    input_df_list = []
    for i, subject in enumerate(subject_list):#token_split.subject_id.unique():
        admission_list = subjects[subjects.subject_id == subject].hadm_id.unique()
        for hadm_id in admission_list:
            ef = full_endpoint_df[(full_endpoint_df.hadm_id == int(hadm_id)) & 
                                  (full_endpoint_df.duration > six_hours)].sort_values('intime')
            # set prediction time as 6 hours after first ICU admission lasting at least 6 hours
            if len(ef) > 0:
                predict_time = (ef.intime + np.timedelta64(6, 'h')).iloc[0]
                inhosp_death = ef.hospital_expire_flag.iloc[0] == 1
                if inhosp_death:
                    time_to_death = (ef.deathtime.iloc[0] - predict_time)/np.timedelta64(1, 'h')
                else:
                    time_to_death = -1
                inicu_death = inhosp_death & (ef.deathtime.iloc[0] <= ef.outtime.iloc[0])
                long_icu = ef.duration.iloc[0] >= seven_days
                icu_len = ef.duration.iloc[0]
                later_admissions = full_endpoint_df[(full_endpoint_df.hadm_id == hadm_id) & (full_endpoint_df.intime > predict_time)]
                icu_readm = len(later_admissions) > 0
                if icu_readm:
                    time_to_readm = (later_admissions.sort_values('intime').iloc[0].intime - predict_time)/np.timedelta64(1, 'h')
                else:
                    time_to_readm = -1
                input_df_list.append({'subj': subject, 'adm': hadm_id, 
                                      'time': predict_time,           # 6 hours after ICU admission time 
                                      'inhosp_death': inhosp_death,   # true if hospital expire flag is true for this hospital admission
                                      'inicu_death': inicu_death,     # true if death time is before or the same as time of ICU discharge
                                      'long_icu': long_icu,           # true if this ICU admission is longer than 48hr
                                      'icu_readm': icu_readm,         # true if there is at least one ICU admission after the prediction target within this hospital admission
                                      'time_to_death': time_to_death, # we aren't predicting time to death, but use this to weight augmentation strategy
                                      'time_to_readm': time_to_readm, # as for time_to_death
                                      'duration': icu_len,            # as for time_to_death
                                      'predict_time': predict_time    # time that these prediction endpoints are valid for this subject
                                     })
        if (i+1) % 1000 == 0:
            print(f'prediction endpoints generated for {i} subjects')
            data = pd.DataFrame(input_df_list)
            wr.s3.to_parquet(df=data, database=gluedatabase, table=f'pred_endpt',
                             dataset=True, path=f's3://{processed_data_bucket}/pred_endpt')
            input_df_list = []

In [None]:
# vital signs pre-processing
vs_todo = check_vs_to_process()
make_vital_signs_windows(vs_todo)
chart_event_ids = get_distinct_chart_events(RELOAD=False)
# run the glue crawler after this function

In [None]:
# pathology token pre-processing
item_ids = get_distinct_lab_events(RELOAD=False)

In [None]:
get_path_tokens(item_ids, RELOAD=False)

In [None]:
get_vs_tokens(False)

In [None]:
# generate all endpoiunts and tokens for ICU admissions and discharges
endpoint_df = get_endpoints_df(True)
icu_admission_endpoints = get_icu_adm(True)
full_endpoint_df = get_full_endpoints_df(icu_admission_endpoints, endpoint_df, True)
it = get_icu_tokens(full_endpoint_df, True)

In [None]:
# process all clinical tokens other than vital signs
join_tokens(full_endpoint_df, True)

In [None]:
# run the glue crawler before you get to this function
get_endpt_per_prediction(full_endpoint_df, True)