Code to reproduce the results of the paper DeepJoint: Robust Survival Modelling Under Clinical Presence Shift. This paper shows how modelling the interaction between patient and the healthcare system in a multi task setting might improve both predictive performance and robustness to change in the observational process.
The model consists in a recurrent neural network with input each new observation. The embedding is then used for modelling the survival outcome using a DeepSurv model S, and in parallel the observation process:
- L: the longitudinal evolution using a neural network with output the mean and variance of the different laboratory tests.
- I: the inter-observation times using a temporal point process network with output the intensity function.
- M: the missingness process using a neural network with a Bernoulli output.
To use the model, one needs to execute:
from models import RNNJoint
model = RNNJoint(inputdim, outputdim, **hyperparameter)
model.fit(covariates, inter_observation, mask, event, time)
model.predict(covariates, inter_observation, mask)
To reproduce the paper's results:
- Clone the repository with dependencies:
git clone git@github.com:Jeanselme/ClinicalPresence.git --recursive
. - Create a conda environment with all necessary libraries
pytorch
,pandas
,numpy
. - Download the MIMIC III dataset and extracts data following
1. Temporal Lab Extraction.ipynb
. - Then sub select the laboratory of interest using
2. Analysis.ipynb
. - And finally run the experiments
3. Death - Survival.ipynb
, run the notebook with the different split of interest (weekend, weekday or random) --Script.py
allows to run this same set of experiments in command line. - Analyse the results using
4. Analysis Results.ipynb
.
- Competing risks.
- MIMIC IV.
All models are in models
folder, itself divided in the RNN structure, the Observational and Survival components. All scripts are at the root.
git clone git@github.com:Jeanselme/ClinicalPresence.git --recursive
The model relies on pytorch
, pandas
, numpy
and tqdm
.