Code to reproduce the results of the paper DeepJoint: Robust Survival Modelling Under Clinical Presence Shift. This paper shows how modelling the interaction between patient and the healthcare system in a multi task setting might improve both predictive performance and robustness to change in the observational process.
The model consists in a recurrent neural network with input each new observation. The embedding is then used for modelling the survival outcome using a DeepSurv model S, and in parallel the observation process:
- L: the longitudinal evolution using a neural network with output the mean and variance of the different laboratory tests.
- I: the inter-observation times using a temporal point process network with output the intensity function.
- M: the missingness process using a neural network with a Bernoulli output.
To use the model, one needs to execute:
from models import RNNJoint
model = RNNJoint(inputdim, outputdim, **hyperparameter)
model.fit(covariates, inter_observation, mask, event, time)
model.predict(covariates, inter_observation, mask)To reproduce the paper's results:
- Clone the repository with dependencies:
git clone git@github.com:Jeanselme/ClinicalPresence.git --recursive. - Create a conda environment with all necessary libraries
pytorch,pandas,numpy. - Download the MIMIC III dataset and extracts data following
1. MIMIC - Temporal Lab Extraction.ipynb. - Then sub select the laboratory of interest using
2. MIMIC - Analysis.ipynb. - And finally run the experiments
3. MIMIC - Death - Survival.ipynb, run the notebook with the different split of interest (weekend, weekday or random) --Script.pyallows to run this same set of experiments in command line. - Analyse the results using
4. MIMIC - Analysis Results.ipynb.
- Competing risks.
- MIMIC IV.
All models are in models folder, itself divided in the RNN structure, the Observational and Survival components. All scripts are at the root.
git clone git@github.com:Jeanselme/ClinicalPresence.git --recursive
The model relies on pytorch, pandas, numpy and tqdm.
