The goal is to predict if the patient will survive from its stay. This script runs all experiments and save it in a format that allows further analysis.

In [None]:
import pandas as pd
import numpy as np

In [None]:
labs = pd.read_csv('data/mimic/labs_first_day_subselection.csv', index_col = [0, 1], header = [0])
outcomes = pd.read_csv('data/mimic/outcomes_first_day_subselection.csv', index_col = 0)

In [None]:
outcomes['Death'] = ~outcomes.Death.isna()

# Split 

In [None]:
mode = "random" # "random", "weekday", "weekend"Split on date - Weekend vs weekdays 

In [None]:
if mode == "weekend":
    # Train only on weekends but test on both
    training = outcomes.Day > 4
    results = 'results/mimic/weekends/'
    ratio = (1-training).sum() / training.sum() # Oversample
elif mode == "weekday":
    # Train only on weekends but test on both
    training = outcomes.Day <= 4
    results = 'results/mimic/weekdays/'
    ratio = 0. # Do not oversample
else:
    # Random split
    training = pd.Series(outcomes.index.isin(outcomes.sample(frac = 0.8, random_state = 0).index), index = outcomes.index)
    results = 'results/mimic/random/'
    ratio = 0. # Do not oversample
results += 'survival_'

In [None]:
test = True

In [None]:
if test:
    outcomes = outcomes.iloc[:500]
    labs = labs[labs.index.get_level_values(0).isin(outcomes.index)]
    training = training[training.index.isin(outcomes.index)]

In [None]:
print('Total patients: {}'.format(len(training)))
print('Training patients: {}'.format(training.sum()))

# Models

In [None]:
from experiment import ShiftExperiment

In [None]:
def process(data, labels):
    """
        Extracts mask and interevents
        Preprocesses the time of event and event
    """
    cov = data.copy().astype(float)
    cov = cov.groupby('Patient').ffill() # Forward last value
    
    patient_mean = data.astype(float).groupby('Patient').mean()
    cov.fillna(patient_mean, inplace=True) # Impute by patient mean

    pop_mean = patient_mean.mean()
    cov.fillna(pop_mean, inplace=True) # Impute by population mean => There is at least one value otherwise test wouldn't be in dataset

    ie_time = data.groupby("Patient").apply(lambda x: x.index.get_level_values('Time').to_series().diff().fillna(0))
    mask = ~data.isna() # 0 if not observed
    time_event = pd.DataFrame((labels.LOS.loc[data.index.get_level_values(0)] - data.index.get_level_values(1)).values, index = data.index)

    return cov, ie_time, mask, time_event, labels.Death

### DeepSurv

In [None]:
layers = [[], [50], [50, 50], [50, 50, 50]]

As a baseline, we build a DeepSurv on the last carried forward observations

##### Last Carried Forward

In [None]:
last = labs.groupby('Patient').ffill().groupby('Patient').last().fillna(labs.groupby('Patient').mean().mean()) # Impute if last is na by population mean

In [None]:
se = ShiftExperiment.create(model = 'deepsurv', 
                     hyper_grid = {"survival_args": [{"layers": l} for l in layers],
                        "lr" : [1e-3, 1e-4],
                        "batch": [100, 250]
                     }, 
                     path = results + 'deepsurv_last', force = test, save = not(test))

In [None]:
se.train(last, outcomes.Remaining, outcomes.Death, training, oversampling_ratio = ratio)

##### Count

In [None]:
last = labs.groupby('Patient').ffill().groupby('Patient').last().fillna(labs.groupby('Patient').mean().mean()) # Impute if last is na by population mean
count = (~labs.isna()).groupby('Patient').sum() # Compute counts

In [None]:
se = ShiftExperiment.create(model = 'deepsurv', 
                     hyper_grid = {"survival_args": [{"layers": l} for l in layers],
                        "lr" : [1e-3, 1e-4],
                        "batch": [100, 250]
                     }, 
                     path = results + 'deepsurv_count', force = test, save = not(test))

In [None]:
se.train(pd.concat([last, count], axis = 1), outcomes.Remaining, outcomes.Death, training, oversampling_ratio = ratio)

## LSTM

In [None]:
hyper_grid = {
        "layers": [1, 2, 3],
        "hidden": [10, 30],
        "survival_args": [{"layers": l} for l in layers],

        "lr" : [0.01],
        "batch": [100, 250]
    }

#### Value data only

Replace missing data and use time to predict

In [None]:
cov, ie, mask, time, event = process(labs.copy(), outcomes)

In [None]:
se = ShiftExperiment.create(model = 'joint', 
                     hyper_grid = hyper_grid,
                     path = results + 'lstm_value', force = test, save = not(test), n_iter = 12)

In [None]:
se.train(cov, time, event, training, ie, mask, oversampling_ratio = 0.)

#### Values and time and mask

In [None]:
labs_selection = pd.concat([labs.copy(), labs.isna().add_suffix('_mask').astype(float)], axis = 1)
labs_selection['Time'] = labs_selection.index.to_frame().reset_index(drop = True).groupby('Patient').diff().fillna(0).values
cov, ie, mask, time, event = process(labs_selection, outcomes)

In [None]:
se = ShiftExperiment.create(model = 'joint', 
                     hyper_grid = hyper_grid,
                     path = results + 'lstm_value+time+mask')

In [None]:
se.train(cov, time, event, training, ie, mask, oversampling_ratio = ratio)

#### Values resampled

In [None]:
labs_resample = labs.copy()
labs_resample = labs_resample.set_index(pd.to_datetime(labs_resample.index.get_level_values('Time'), unit = 'D'), append = True) # Add time index
labs_resample = labs_resample.groupby('Patient').resample('1H', level = 2).mean() # Mean resampling
labs_resample.index = labs_resample.index.map(lambda x: (x[0], x[1].hour / 24)) # Reupdate index and match other data 

cov, ie, mask, time, event = process(labs_resample, outcomes) # Time is slightly different in that case as the last hour is rounded

In [None]:
se = ShiftExperiment.create(model = 'joint', 
                     hyper_grid = hyper_grid,
                     path = results + 'lstm+resampled')

In [None]:
se.train(cov, time, event, training, ie, mask, oversampling_ratio = ratio)

### GRU - D

In [None]:
hyper_grid_gru = hyper_grid.copy()
hyper_grid_gru["typ"] = ['GRUD']

In [None]:
labs_selection = pd.concat([labs.copy(), labs.isna().add_suffix('_mask').astype(float)], axis = 1)
cov, ie, mask, time, event = process(labs_selection, outcomes)

In [None]:
se = ShiftExperiment.create(model = 'joint', 
                     hyper_grid = hyper_grid_gru,
                     path = results + 'gru_d+mask')

In [None]:
se.train(cov, time, event, training, ie, mask, oversampling_ratio = ratio)

### Latent ODE

In [None]:
hyper_grid_ode = hyper_grid.copy()
hyper_grid_ode["typ"] = ['ODE']

In [None]:
se = ShiftExperiment.create(model = 'joint', 
                     hyper_grid = hyper_grid_ode,
                     path = results + 'ode')

In [None]:
se.train(cov, time, event, training, ie, mask, oversampling_ratio = ratio)

# Proposed approach

In [None]:
hyper_grid_joint = hyper_grid.copy()
hyper_grid_joint.update(
    {
        "weight": [0.1, 0.5],
        "temporal": ["point"], 
        "temporal_args": [{"layers": l} for l in layers],
        "longitudinal": ["neural"], 
        "longitudinal_args": [{"layers": l} for l in layers],
        "missing": ["neural"], 
        "missing_args": [{"layers": l} for l in layers],
    }
)

### Joint model on value only

In [None]:
labs_selection = labs.copy()
cov, ie, mask, time, event = process(labs_selection, outcomes)

In [None]:
se = ShiftExperiment.create(model = 'joint', 
                     hyper_grid = hyper_grid_joint,
                     path = results + 'joint+value', save = False, force = True)

In [None]:
se.train(cov, time, event, training, ie, mask, oversampling_ratio = ratio)

In [None]:
se.train(cov, time, event, training, ie, mask, oversampling_ratio = ratio)

### Joint model on value, mask and time

In [None]:
labs_selection = pd.concat([labs.copy(), labs.isna().add_suffix('_mask').astype(float)], 1)
labs_selection['Time'] = labs_selection.index.to_frame().reset_index(drop = True).groupby('Patient').diff().fillna(0).values
cov, ie, mask, time, event = process(labs_selection, outcomes)

In [None]:
mask_mixture = np.full(len(cov.columns), False)
mask_mixture[:len(labs.columns)] = True

hyper_grid_joint['mixture_mask'] = [mask_mixture] # Avoids to compute the observational process on the additional dimensions

In [None]:
se = ShiftExperiment.create(model = 'joint', 
                     hyper_grid = hyper_grid_joint,
                     path = results + 'joint_value+time+mask')

In [None]:
se.train(cov, time, event, training, ie, mask, oversampling_ratio = ratio)

### Full fine tuning of the network

In [None]:
hyper_grid_joint['full_finetune'] = [True] 


In [None]:
se = ShiftExperiment.create(model = 'joint', 
                     hyper_grid = hyper_grid_joint,
                     path = results + 'joint_full_finetune_value+time+mask')

In [None]:
se.train(cov, time, event, training, ie, mask, oversampling_ratio = ratio)