# Create Model Input Data

This file queries the tokenized trajectory tables (all_tokens, vs_tokens), endpoint table (pred_endpt) and ICU-specific input data table (icu_tokens) to create tensorflow record files that can be used to generate embeddings and final models.

In [1]:
import datetime, os, boto3, pickle, h5py, sys
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras as tfk
import matplotlib.pyplot as plt

from contextlib import ExitStack
from itertools import product, repeat
from sklearn.model_selection import train_test_split, KFold
from IPython.display import display, HTML
from collections import Counter, defaultdict

In [3]:
from utils.connections import connection, cursor, gluedatabase, processed_db, upload_file, processed_data_bucket, download_file
from utils.datagen import PROCESSED_DATAPATH, MODEL_INPUT_DATAPATH
from utils.utils import read_data, dump_data

In [4]:
def get_vocabulary(table):
    query = f'select label, count(*) from {processed_db}.{table} group by label'
    vocab = pd.read_sql(query, connection)
    
    vocab_to_int = {v:i for i, v in enumerate(vocab[vocab._col1>=20].sort_values('_col1').label)}
    int_to_vocab = {i:v for i, v in enumerate(vocab[vocab._col1>=20].sort_values('_col1').label)}
    
    rare = max(vocab_to_int.values()) + 1

    def rare_str():
        return 'rare'

    def rare_int():
        return rare
    
    i2v = defaultdict(rare_str, int_to_vocab)
    v2i = defaultdict(rare_int, vocab_to_int)
    return v2i, i2v

In [5]:
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _int_list_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _timestep_feature(value):
    step_list = []
    for v in value:
        step_list.append(tf.train.Feature(int64_list=tf.train.Int64List(value=v)))
    step_feature = tf.train.FeatureList(feature=step_list)
    return tf.train.FeatureLists(feature_list={'data': step_feature})

In [6]:
np.random.seed(42)

def get_train_validation_splits(RELOAD=False):
    
    # we are splitting by subject, not admission, so as not to leak data across training and validation 
    # sets, although in this instance it's likely that would harm sensitivity rather than help, as for 
    # death endpoints, an earlier admission would have definitionally the opposite outcome for the 
    # overlapped data, rather than the same outcome.
    
    if RELOAD:
        query = f'select distinct subject_id from {processed_db}.full_endpoints'
        subjects = np.array(pd.read_sql(query, connection))
        kf = KFold(n_splits=5, shuffle=True)
        kfold_split = list(kf.split(subjects))
        train_ids = [subjects[k[0]] for k in kfold_split]
        valid_ids = [subjects[k[1]] for k in kfold_split]
        with open(os.path.join(PROCESSED_DATAPATH, 'training_validation_splits'), 'wb') as outfile:
            pickle.dump((train_ids, valid_ids), outfile)
    else:
        with open(os.path.join(PROCESSED_DATAPATH, 'training_validation_splits'), 'rb') as infile:
            train_ids, valid_ids = pickle.load(infile)
    return train_ids, valid_ids

In [7]:
vs_v2i, vs_i2v = get_vocabulary('vs_tokens')
clin_v2i, clin_i2v = get_vocabulary('all_tokens')

In [8]:
train_ids, valid_ids = get_train_validation_splits(True)

In [9]:
def get_oversample_dist():
    
    # this function figures out the required multipliers to generate the correct 
    # oversample distributions (both simple oversample and time to event oversample)
    # for all included endpoints

    query = f'select * from {processed_db}.full_endpoints'
    full_endpoint_df = pd.read_sql(query, connection)

    six_hours = np.timedelta64(6, 'h')/np.timedelta64(1, 'ns')
    seven_days = np.timedelta64(7, 'D')/np.timedelta64(1, 'ns')
    long_enough = full_endpoint_df[(full_endpoint_df.nth==1)&
                                   (full_endpoint_df.duration >= six_hours) & # duration is in ns in this df
                                   (((full_endpoint_df.intime + np.timedelta64(6, 'h')) < full_endpoint_df.deathtime)
                                    |(full_endpoint_df.hospital_expire_flag == 0))]
    in_hosp_death_rate = len(long_enough[long_enough.hospital_expire_flag==1])/len(long_enough)
    in_icu_death_rate = len(long_enough[(long_enough.hospital_expire_flag==1) & (long_enough.deathtime <= long_enough.outtime)])/len(long_enough)
    long_icu_stay_rate = len(long_enough[long_enough.duration >= seven_days])/len(long_enough)
    readmitted = full_endpoint_df[full_endpoint_df.hadm_id.isin(long_enough.hadm_id.unique()) & (full_endpoint_df.nth==2)]
    readmitted_rate = len(readmitted)/len(long_enough)
    print(f'Death in hospital: {in_hosp_death_rate}, Death in this ICU: {in_icu_death_rate}, Long ICU stay: {long_icu_stay_rate}, ICU readmission: {readmitted_rate}')

    # we are targeting oversample to 50% of data, so what multiplier gets us to approximately
    # that distribution for each endpoint?
    
    oversampling_rates = {'inhosp_death': round(0.5/in_hosp_death_rate, 0), 
                          'inicu_death': round(0.5/in_icu_death_rate, 0), 
                          'long_icu': round(0.5/long_icu_stay_rate, 0), 
                          'icu_readm': round(0.5/readmitted_rate, 0)}
    print(oversampling_rates)

    first_icu_adms = full_endpoint_df[full_endpoint_df.nth == 1].copy()
    first_icu_adms['predtime'] = first_icu_adms.intime + np.timedelta64(6, 'h')

    # get distribution of time to event values in hours for weighting strategy
    death_times = (first_icu_adms.deathtime - first_icu_adms.predtime)[first_icu_adms.hospital_expire_flag == 1]/np.timedelta64(1, 'h')
    in_icu_death_times = (first_icu_adms.deathtime - first_icu_adms.predtime)[(first_icu_adms.hospital_expire_flag == 1) & (first_icu_adms.deathtime <= first_icu_adms.outtime)]/np.timedelta64(1, 'h')

    # distribution of admission time in weeks if > 1 week
    duration_dist = np.histogram([d//seven_days if d//seven_days < 8 else 8 for d in first_icu_adms.duration], bins=9)[0][1:]

    # getting time to readmission 
    full_endpoint_df = full_endpoint_df[['hadm_id', 'new_icu_stay', 'intime', 'outtime', 'duration',
                                         'subject_id', 'hospital_expire_flag', 'admittime', 'dischtime',
                                         'deathtime', 'admission_age']].drop_duplicates()

    first_admissions = full_endpoint_df.sort_values('intime').groupby(['hadm_id']).first().reset_index()
    subsequent = full_endpoint_df[~full_endpoint_df.new_icu_stay.isin(first_admissions.new_icu_stay)]

    readmissions = subsequent[['hadm_id', 'intime']].sort_values('intime').groupby(['hadm_id']).first().reset_index()
    readmissions = pd.merge(full_endpoint_df[['hadm_id','outtime']], readmissions, how='left')

    full_endpoint_df['readm'] = np.where(full_endpoint_df.hadm_id.isin(subsequent.hadm_id), 1, 0)
    full_endpoint_df['time_to_readm'] = np.where((full_endpoint_df.readm==1)&(full_endpoint_df.new_icu_stay.isin(first_admissions.new_icu_stay)), (readmissions.intime - readmissions.outtime)//np.timedelta64(24, 'h'), -1)

    readmit_times = full_endpoint_df.time_to_readm
    
    # we make the assumption that once the time to event is more than 1 week in the future, the relationship between 
    # current patient status and endpoint time (not endpoint occurrance, just timing) becomes less deterministic
    # so focus on over-weighting those subjects with the endpoint within 1 week most highly
    
    death_time_dist = np.histogram([d//24 if d <= 168 else 7 for d in death_times], bins=8)[0]
    in_icu_death_time_dist = np.histogram([d//24 if d <= 168 else 7 for d in in_icu_death_times], bins=8)[0]
    readmit_time_dist = np.histogram([d//24 if d <= 168 else 7 for d in readmit_times], bins=8)[0]

    # targeting 5* oversampling for in-hosp death
    target = 5*sum(death_time_dist)
    in_hosp_multiplier = [9,8,7,6,5,4,3,2]
    print(sum(in_hosp_multiplier*death_time_dist)/target)

    # targeting 8* oversampling for in-ICU death
    target = 8*sum(death_time_dist)
    in_icu_multiplier = [16,15,14,13,12,11,10,9]
    print(sum(in_icu_multiplier*in_icu_death_time_dist)/target)

    # targeting 3* oversampling for long-ICU but there is no 'time-to' component, so we 
    # instead overweight by duration - oversampling the longest stays by the most
    target = 3*sum(duration_dist)
    long_stay_multiplier = [2,3,4,5,6,7,8,9]
    print(sum(long_stay_multiplier*duration_dist)/target)

    target = 8*sum(death_time_dist)
    in_icu_multiplier = [16,15,14,13,12,11,10,9]
    print(sum(in_icu_multiplier*in_icu_death_time_dist)/target)

    # targeting 4* oversampling for readmission
    target = 4*sum(readmit_time_dist)
    readmit_multiplier = [4,3,3,3,3,3,3,3]
    print(sum(readmit_multiplier*readmit_time_dist)/target)
    
    return oversampling_rates, in_hosp_multiplier, in_icu_multiplier, readmit_multiplier, long_stay_multiplier


In [10]:
oversampling_rates, in_hosp_multiplier, in_icu_multiplier, readmit_multiplier, long_stay_multiplier = get_oversample_dist()
# dictionary of {endpoint : associated time weighting input} variables
endpoint_lookup = {'inhosp_death': 'tt_dth', 'inicu_death': 'tt_dth', 'long_icu': 'duration', 'icu_readm': 'tt_readm'}
oversampling_multipliers = {'inhosp_death': in_hosp_multiplier, 'inicu_death': in_icu_multiplier, 'icu_readm': readmit_multiplier, 'long_icu': long_stay_multiplier}

Death in hospital: 0.10052308392085513, Death in this ICU: 0.06341444924569782, Long ICU stay: 0.152850428322341, ICU readmission: 0.11617769691456296
{'inhosp_death': 5.0, 'inicu_death': 8.0, 'long_icu': 3.0, 'icu_readm': 4.0}
0.9855580693815988
1.0122784690799396
1.0707170903079148
1.0122784690799396
0.9747103386809269


In [11]:
def get_stepped_data(data, times):
    # converts a serial list of events into steps of events that all occur
    # within the same 1-hr span
    
    evt_list_stepped = []
    timesteps = list(reversed(sorted(np.unique(times))))[-MAX_TIMESTEPS:]

    for step in timesteps:
        x_stepped = data[np.where(times==step)[0]][-TIMESTEP_WIDTH:]
        evt_list_stepped.append(np.pad(x_stepped, (TIMESTEP_WIDTH-len(x_stepped), 0), 'constant'))
    if len(evt_list_stepped) > 0:
        return np.vstack(evt_list_stepped), timesteps
    else:
        return [], timesteps

def augment(serial, stepped, aug_count):
    
    # randomly creates 10*aug_count augmentated versions of the input trajectory
    # by truncating / masking / shuffling combinations of events within the list
    try:
        z = [shuffle_stepped(stepped, len(serial)) for _ in range(aug_count)]
        augmentation_selector = np.random.choice([0, 1], 10*aug_count, p=[0.7, 0.3], replace=True)
        first_data = min(np.nonzero(serial)[0])
        data_elems = len(serial[first_data:])
        x = [mask_serial(z[i//10], first_data, data_elems) for i, x in enumerate(augmentation_selector) if x == 0]
        y = [truncate_serial(z[i//10], first_data, data_elems) for i, x in enumerate(augmentation_selector) if x == 1] 
        return x + y
    except:
        return serial*(aug_count*20)

def mask_serial(serial, first_data, data_elems):
    # removes somewhere between 1 and half the number of elements in the list - random mask
    mask_num = np.random.randint(1, max(2, data_elems//2))
    mask = np.random.choice(list(range(first_data, len(serial))), data_elems - mask_num, replace=False)
    return np.hstack([[0]*(mask_num+first_data), serial[sorted(mask)]]).astype(np.int64)   

def shuffle_stepped(stepped, n):
    # shuffles events that occur within the same 1hr period and reassembles into a serial trajectory
    try:
        for i in range(stepped.shape[0]):
            nz = np.nonzero(stepped[i])[0]
            if len(nz) > 0:
                first_data = min(nz)
                np.random.shuffle(stepped[i][first_data:])
        s = stepped[np.nonzero(stepped)].flatten()
        return np.hstack([[0]*(max(0, n-len(s))), s[-n:]]).astype(np.int64)
    except:
        return np.array(stepped)

def truncate_serial(serial, first_data, data_elems):
    # removes somewhere between 1 and a third of the number of elements in the list - drops oldest events
    truncate_num = np.random.randint(1, max(2, data_elems//3))
    return np.hstack([[0]*(first_data+truncate_num), serial[first_data + truncate_num:]]).astype(np.int64)

In [12]:
def get_vocab(x, cat):
    if cat=='clin':
        return int(clin_v2i[x])
    return int(vs_v2i[x])

In [13]:
def get_data(traj, predtime, cat):
    
    # gets portion of trajectory that was available for prediction at a given prediction time
    
    traj = traj[traj.time < predtime].sort_values('time')
    serial_data = np.array([get_vocab(x, cat) for x in traj.label])
    serial_times = np.array([round((predtime-t)/np.timedelta64(1, 'h'), 0) for t in traj.time])
    data_stepped, times_stepped = get_stepped_data(serial_data, serial_times)
    
    serial_data = np.pad(serial_data[-SERIAL_TIMESTEPS:], (SERIAL_TIMESTEPS-len(serial_data[-SERIAL_TIMESTEPS:]), 0), mode='constant')
    serial_times = np.pad(serial_times[-SERIAL_TIMESTEPS:], (SERIAL_TIMESTEPS-len(serial_times[-SERIAL_TIMESTEPS:]), 0), mode='constant')

    return serial_data, serial_times, data_stepped, times_stepped

In [14]:
def get_feature_dict(ct, vt, it, endpoints):
    
    # creates feature dictionary for tfrecord file from clincal trajectory, vital signs
    # trajectory and endpoints at a given prediction time
    
    predtime = endpoints.predtime
    clin_data, clin_times, clin_stepped_data, clin_stepped_times = get_data(ct, predtime, 'clin')
    vs_data, vs_times, vs_stepped_data, vs_stepped_times = get_data(vt, predtime, 'vs')
    hist_icu = len(it[(it.time<predtime)&(it.label.str.contains('admit'))])
    durations = it[(it.time<predtime)&(it.label.str.contains('icudur'))]
    try:
        d = durations.label.str.split('_', expand=True)[1]
        icu_ave_dur = np.mean(d)
        icu_time_since = int((predtime - max(durations.time))//np.timedelta64(1, 'h'))
    except:
        icu_av_dur = -1
        icu_time_since = -1

    feature_dict = {'hosp_death': _int_feature(endpoints.hospital_expire_flag),
                    'icu_death': _int_feature(endpoints.icu_death),
                    'icu_readm': _int_feature(endpoints.readm),
                    'long_icu': _int_feature(endpoints.long_icu), 
                    'subject': _int_feature(endpoints.subject_id),
                    'hosp_adm': _int_feature(endpoints.hadm_id),
                    'tt_dth': _int_feature(int(endpoints.time_to_death)),
                    'tt_readm': _int_feature(int(endpoints.time_to_readm)),
                    'duration': _int_feature(endpoints.duration_wk),
                    'hist_icu': _int_feature(hist_icu),
                    'ave_dur': _int_feature(icu_av_dur),
                    'time_since': _int_feature(icu_time_since),
                    'data_clin': tf.train.Feature(int64_list=tf.train.Int64List(value=clin_data.astype(int))),
                    'times_clin': tf.train.Feature(int64_list=tf.train.Int64List(value=clin_times.astype(int))),
                    'data_vs': tf.train.Feature(int64_list=tf.train.Int64List(value=vs_data.astype(int))),
                    'times_vs': tf.train.Feature(int64_list=tf.train.Int64List(value=vs_times.astype(int)))}
    
    return feature_dict, clin_data, vs_data, clin_stepped_data, vs_stepped_data

In [15]:
def get_icu_endpt():
    
    # selects endpoints from database ready to combine with trajetories 
    # into tf record files
    
    query = f'select * from {processed_db}.full_endpoints'
    full_endpoint_df = pd.read_sql(query, connection)

    
    full_endpoint_df = full_endpoint_df[['hadm_id', 'new_icu_stay', 'intime', 'outtime', 'duration',
                                     'subject_id', 'hospital_expire_flag', 'admittime', 'dischtime',
                                     'deathtime', 'admission_age']].drop_duplicates()

    
    query_icu = f'select * from {processed_db}.icu_tokens'
    icu_traj = pd.read_sql(query_icu, connection)
    
    seven_days = np.timedelta64(7, 'D')/np.timedelta64(1, 'ns')
    full_endpoint_df['predtime'] = full_endpoint_df.intime + np.timedelta64(6, 'h')
    full_endpoint_df['long_icu'] = np.where(full_endpoint_df.duration >= seven_days, 1, 0)
    full_endpoint_df['icu_death'] = np.where((full_endpoint_df.hospital_expire_flag==1)&(full_endpoint_df.deathtime <= full_endpoint_df.outtime), 1, 0)
    
    first_admissions = full_endpoint_df.sort_values('intime').groupby(['hadm_id']).first().reset_index()
    subsequent = full_endpoint_df[~full_endpoint_df.new_icu_stay.isin(first_admissions.new_icu_stay)]

    readmissions = subsequent[['hadm_id', 'intime']].sort_values('intime').groupby(['hadm_id']).first().reset_index()
    readmissions = pd.merge(full_endpoint_df[['hadm_id','outtime']], readmissions, how='left')

    full_endpoint_df['readm'] = np.where(full_endpoint_df.hadm_id.isin(subsequent.hadm_id), 1, 0)
    full_endpoint_df['time_to_readm'] = np.where((full_endpoint_df.readm==1)&(full_endpoint_df.new_icu_stay.isin(first_admissions.new_icu_stay)), (readmissions.intime - readmissions.outtime)//np.timedelta64(24, 'h'), -1)

    full_endpoint_df['time_to_death'] = np.where(full_endpoint_df.hospital_expire_flag==1, 
                                                 (full_endpoint_df.deathtime - full_endpoint_df.predtime)//np.timedelta64(24, 'h'), -1)
    full_endpoint_df['duration_wk'] = full_endpoint_df.duration//(np.timedelta64(7, 'D')//np.timedelta64(1, 'ns'))-1
    
    return icu_traj, full_endpoint_df

In [16]:
SERIAL_TIMESTEPS = 500
MAX_TIMESTEPS = 200
TIMESTEP_WIDTH = 100
import copy

def serialize_data():
    # combines data from db into basic feature_dict that is ready for data augmentation processes
    offset = 1000
    icu_traj, full_endpoint_df = get_icu_endpt()
    
    subject_list = full_endpoint_df.subject_id.unique()
    # batching data into managable chunks of length (offset) subjects
    for i in range(0, len(subject_list), offset):
        data_list = {}
        subject_subset = tuple(subject_list[i:i+offset])
        query_clin = f'select * from {processed_db}.all_tokens where subject_id in {subject_subset}'
        query_vs = f'select * from {processed_db}.vs_tokens where subject_id in {subject_subset}'
        clin_tokenized_traj = pd.read_sql(query_clin, connection)
        vs_tokenized_traj = pd.read_sql(query_vs, connection)
        print(f'selected data batch {i//offset} ({i} of {len(subject_list)})')
        for j, admission in enumerate(clin_tokenized_traj.hadm_id.unique()):
            ct = clin_tokenized_traj[clin_tokenized_traj.hadm_id == admission]
            vt = vs_tokenized_traj[vs_tokenized_traj.hadm_id == admission]
            it = icu_traj[icu_traj.hadm_id == admission]
            endpoints = full_endpoint_df[full_endpoint_df.hadm_id == admission].sort_values('intime')
            if len(endpoints)>0:
                endpoints = endpoints.iloc[0]
                data_list[admission] = (get_feature_dict(ct, vt, it, endpoints))
        dump_data(os.path.join(PROCESSED_DATAPATH, f'serialized_{i}'), data_list)
        
def get_feature_value(feature, label):
    # for convenience because tfrecord feature dictionaries are annoying
    return feature[label].int64_list.value[0]

In [17]:
def get_multiplier(value, target, strategy):
    
    # lookup appropriate multiplier according to endpoint value, endpoint target
    # and weighting strategy
    
    if value < 0:
        return 1
    if strategy == 'basic':
        return oversampling_rates[target]
    lookup = min(value, len(oversampling_multipliers[target]) - 1)
    return oversampling_multipliers[target][lookup]

def do_oversampling(strategy):
    
    # strategy is either 'basic' or 'tte' - used to select the multiplier for positive
    # class members - this can be either weighted to time to event or a single rate
    # for all minority class
    
    serialized_filelist = [f for f in os.listdir(PROCESSED_DATAPATH) if 'serialized' in f]
    for filename in serialized_filelist:
        basic_data = read_data(os.path.join(PROCESSED_DATAPATH, filename))
        oversampled_data = {t:{'train':{}, 'valid':{}} for t in oversampling_rates.keys()}        
        augmented_data = {t: {'train':{}, 'valid':{}} for t in oversampling_rates.keys()}
        
        for admission in basic_data.keys():
            feature_dict, clin_data, vs_data, clin_stepped_data, vs_stepped_data = basic_data[admission]

            for target in oversampling_rates.keys():
                # insert original data into the oversampled and augmented lists
                oversampled_data[target]['train'][admission] = [feature_dict]
                oversampled_data[target]['valid'][admission] = [feature_dict]

                # what is the value of the target feature for this record?
                value = get_feature_value(feature_dict, endpoint_lookup[target])
                multiplier = get_multiplier(value, target, strategy)

                for _ in range(int(multiplier) - 1):
                    # add rate-1 more copies of this record to the oversampling list
                    oversampled_data[target]['train'][admission].append(feature_dict)
                
                # make augmentation shuffles for both clinical and vital sign data
                # note: we can augment for validation set and take average prediction as, 
                # final prediction value, but for validation we augment the same number
                # of times, regardless of the feature value, whereas for the training
                # set, we set the rate according to the multiplication factor as per
                # oversampling strategy
                
                clin_augmented_data = {'train': augment(clin_data, clin_stepped_data, multiplier),
                                       'valid': augment(clin_data, clin_stepped_data, 1)}
                vs_augmented_data = {'train': augment(vs_data, vs_stepped_data, multiplier),
                                     'valid': augment(vs_data, vs_stepped_data, 1)}
                for phase in ['train', 'valid']:
                    augmented_data[target][phase][admission] = [feature_dict]
                    for a, v in zip(clin_augmented_data[phase], vs_augmented_data[phase]):
                        augmented_dict = copy.deepcopy(feature_dict)
                        try:
                            augmented_dict['data_clin'] = tf.train.Feature(int64_list=tf.train.Int64List(value=a))
                            augmented_dict['data_vs'] = tf.train.Feature(int64_list=tf.train.Int64List(value=v))
                            augmented_data[target][phase][admission].append(augmented_dict)
                        except TypeError:
                            # this will except if the augmented trajectory is a single event i.e. not iterable
                            pass
                                    
        suffix = filename.split('_')[1]
        print(f'basic_rate_{suffix}')
        dump_data(os.path.join(PROCESSED_DATAPATH, f'{strategy}_rate_oversample_{suffix}'), oversampled_data)
        dump_data(os.path.join(PROCESSED_DATAPATH, f'{strategy}_rate_augment_{suffix}'), augmented_data)

In [18]:
def make_data_files_weighted_distribution(weighting, strategy):
    
    # now that we have the data all in the right format and appropriately oversampled, combine
    # them according to their k-fold assignment into the final tfrecord file that will be fed 
    # to the models
    
    for w in weighting:
        for s in strategy:
            print(w, s)
            files = [f for f in os.listdir(PROCESSED_DATAPATH) if s in f and w in f]
            train_file_names = [os.path.join(MODEL_INPUT_DATAPATH, f'train_{s}_{w}_{fold}') for fold in range(5)]
            valid_file_names = [os.path.join(MODEL_INPUT_DATAPATH, f'valid_{s}_{w}_{fold}') for fold in range(5)]
            with ExitStack() as stack:
                train_files = {endpoint: [stack.enter_context(tf.io.TFRecordWriter(f'{t}_{endpoint}')) for t in train_file_names] for endpoint in endpoint_lookup.keys()}
                valid_files = {endpoint: [stack.enter_context(tf.io.TFRecordWriter(f'{v}_{endpoint}')) for v in valid_file_names] for endpoint in endpoint_lookup.keys()}

                for f in files:
                    print(f)
                    datafile = read_data(os.path.join(PROCESSED_DATAPATH, f))
                    for target, target_data in datafile.items():
                        for phase, phase_data in target_data.items():
                            for visit, visit_data in phase_data.items():
                                if len(phase_data) > 0:
                                    for fold in range(5):
                                        for d in visit_data:
                                            subject = get_feature_value(d, 'subject')
                                            traj_feat = tf.train.Features(feature=d)
                                            example = tf.train.Example(features=traj_feat)
                                            if subject in train_ids[fold]:                        
                                                train_files[target][fold].write(example.SerializeToString())
                                            else:
                                                valid_files[target][fold].write(example.SerializeToString())

In [19]:
def make_data_files_original_distribution():
    
    # as per make_data_files_weighted_distribution but without weightings
    
    train_file_names = [os.path.join(MODEL_INPUT_DATAPATH, f'train_original_{fold}') for fold in range(5)]
    valid_file_names = [os.path.join(MODEL_INPUT_DATAPATH, f'valid_original_{fold}') for fold in range(5)]
    with ExitStack() as stack:
        train_files = [stack.enter_context(tf.io.TFRecordWriter(t)) for t in train_file_names]
        valid_files = [stack.enter_context(tf.io.TFRecordWriter(v)) for v in valid_file_names]
        serialized_filelist = [f for f in os.listdir(PROCESSED_DATAPATH) if 'serialized' in f]
        for filename in serialized_filelist:
            print(filename)
            basic_data = read_data(os.path.join(PROCESSED_DATAPATH, filename))
            for admission in basic_data.keys():
                feature_dict, clin_data, vs_data, clin_stepped_data, vs_stepped_data = basic_data[admission]
                subject = get_feature_value(feature_dict, 'subject')
                traj_feat = tf.train.Features(feature=feature_dict)
                example = tf.train.Example(features=traj_feat)
                for fold in range(5):
                    if subject in train_ids[fold]:                        
                        train_files[fold].write(example.SerializeToString())
                    else:
                        valid_files[fold].write(example.SerializeToString())

In [20]:
serialize_data() # this step produces the endpoints tt_dth and tt_readm in full days and duration in full weeks

selected data batch 0 (0 of 42207)
selected data batch 1 (1000 of 42207)
selected data batch 2 (2000 of 42207)
selected data batch 3 (3000 of 42207)
selected data batch 4 (4000 of 42207)
selected data batch 5 (5000 of 42207)
selected data batch 6 (6000 of 42207)
selected data batch 7 (7000 of 42207)
selected data batch 8 (8000 of 42207)
selected data batch 9 (9000 of 42207)
selected data batch 10 (10000 of 42207)
selected data batch 11 (11000 of 42207)
selected data batch 12 (12000 of 42207)
selected data batch 13 (13000 of 42207)
selected data batch 14 (14000 of 42207)
selected data batch 15 (15000 of 42207)
selected data batch 16 (16000 of 42207)
selected data batch 17 (17000 of 42207)
selected data batch 18 (18000 of 42207)
selected data batch 19 (19000 of 42207)
selected data batch 20 (20000 of 42207)
selected data batch 21 (21000 of 42207)
selected data batch 22 (22000 of 42207)
selected data batch 23 (23000 of 42207)
selected data batch 24 (24000 of 42207)
selected data batch 25 

In [21]:
for strategy in ['basic', 'tte']:
    do_oversampling(strategy) # this step takes serialised data and produces augmented / oversampled data according to the endpoint

basic_rate_8000
basic_rate_27000
basic_rate_28000
basic_rate_24000
basic_rate_9000
basic_rate_39000
basic_rate_29000
basic_rate_30000
basic_rate_18000
basic_rate_6000
basic_rate_14000
basic_rate_35000
basic_rate_41000
basic_rate_13000
basic_rate_20000
basic_rate_5000
basic_rate_23000
basic_rate_36000
basic_rate_42000
basic_rate_21000
basic_rate_0
basic_rate_3000
basic_rate_22000
basic_rate_38000
basic_rate_33000
basic_rate_40000
basic_rate_37000
basic_rate_17000
basic_rate_11000
basic_rate_16000
basic_rate_4000
basic_rate_25000
basic_rate_15000
basic_rate_34000
basic_rate_1000
basic_rate_12000
basic_rate_32000
basic_rate_7000
basic_rate_19000
basic_rate_10000
basic_rate_31000
basic_rate_26000
basic_rate_2000
basic_rate_8000
basic_rate_27000
basic_rate_28000
basic_rate_24000
basic_rate_9000
basic_rate_39000
basic_rate_29000
basic_rate_30000
basic_rate_18000
basic_rate_6000
basic_rate_14000
basic_rate_35000
basic_rate_41000
basic_rate_13000
basic_rate_20000
basic_rate_5000
basic_rate_230

In [22]:
weighting = ['tte', 'basic']
strategy = ['augment', 'oversample']
make_data_files_weighted_distribution(weighting, strategy)

tte augment
tte_rate_augment_40000
tte_rate_augment_41000
tte_rate_augment_27000
tte_rate_augment_7000
tte_rate_augment_20000
tte_rate_augment_15000
tte_rate_augment_12000
tte_rate_augment_30000
tte_rate_augment_18000
tte_rate_augment_23000
tte_rate_augment_1000
tte_rate_augment_3000
tte_rate_augment_13000
tte_rate_augment_19000
tte_rate_augment_26000
tte_rate_augment_28000
tte_rate_augment_22000
tte_rate_augment_10000
tte_rate_augment_21000
tte_rate_augment_5000
tte_rate_augment_16000
tte_rate_augment_24000
tte_rate_augment_32000
tte_rate_augment_25000
tte_rate_augment_34000
tte_rate_augment_2000
tte_rate_augment_6000
tte_rate_augment_31000
tte_rate_augment_0
tte_rate_augment_17000
tte_rate_augment_33000
tte_rate_augment_29000
tte_rate_augment_37000
tte_rate_augment_8000
tte_rate_augment_11000
tte_rate_augment_42000
tte_rate_augment_9000
tte_rate_augment_39000
tte_rate_augment_38000
tte_rate_augment_4000
tte_rate_augment_35000
tte_rate_augment_14000
tte_rate_augment_36000
tte oversamp