Skip to content

Commit

Permalink
Ehr transformer ICU example (#245)
Browse files Browse the repository at this point in the history
  • Loading branch information
ellabarkan committed Jan 13, 2023
1 parent 2f12ba7 commit 3390fda
Show file tree
Hide file tree
Showing 14 changed files with 1,498 additions and 2 deletions.
4 changes: 2 additions & 2 deletions README.md
Expand Up @@ -213,12 +213,12 @@ $ pip install fuse-med-ml[all]
* [**MNIST**](./examples/fuse_examples/imaging/classification/mnist/) - a simple example, including training, inference and evaluation over [MNIST dataset](http://yann.lecun.com/exdb/mnist/)
* [**STOIC**](./examples/fuse_examples/imaging/classification/stoic21/) - severe COVID-19 classifier baseline given a Computed-Tomography (CT), age group and gender. [Challenge description](https://stoic2021.grand-challenge.org/)


* [**KNIGHT Challenge**](./examples/fuse_examples/imaging/classification/knight) - preoperative prediction of risk class for patients with renal masses identified in clinical Computed Tomography (CT) imaging of the kidneys. Including data pre-processing, baseline implementation and evaluation pipeline for the challenge.
* [**Multimodality tutorial**](https://colab.research.google.com/github/BiomedSciAI/fuse-med-ml/blob/master/examples/fuse_examples/multimodality/image_clinical/multimodality_image_clinical.ipynb) - demonstration of two popular simple methods integrating imaging and clinical data (tabular) using FuseMedML
* [**Skin Lesion**](./examples/fuse_examples/imaging/classification/isic/) - skin lesion classification , including training, inference and evaluation over the public dataset introduced in [ISIC challenge](https://challenge.isic-archive.com/landing/2019)
* [**Breast Cancer Lesion Classification**](./examples/fuse_examples/imaging/classification/cmmd) - lesions classification of tumor ( benign, malignant) in breast mammography over the public dataset introduced in [The Chinese Mammography Database (CMMD)](https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=70230508)

* [**Mortality prediction for ICU patients**](./examples/fuse_examples/multimodality/ehr_transformer) - Example of EHR transformer applied to the data of Intensive Care Units patients for in-hospital mortality prediction. The dataset is from [PhysioNet Computing in Cardiology Challenge (2012)](https://physionet.org/content/challenge-2012/1.0.0/)

## Walkthrough template
* [**Walkthrough Template**](./fuse/dl/templates/walkthrough_template.py) - includes several TODO notes, marking the minimal scope of code required to get your pipeline up and running. The template also includes useful explanations and tips.

Expand Down
Empty file.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
65 changes: 65 additions & 0 deletions examples/fuse_examples/multimodality/ehr_transformer/README.md
@@ -0,0 +1,65 @@
# Predicting Mortality of ICU Patients with Fuse-Med-ML

## Introduction

The aim of this example is to demonstrate Fuse-Med-ML abilities for prediction tasks involving Electronic Health Records (EHR). In this example, Fuse will be applied for the prediction of In-hospital patients’ mortality after spending 48 hours in Intensive Care Unit.

<br/>


## Dataset
The dataset used in the example was taken from the PhysioNet Computing in Cardiology Challenge that was held in 2012 (https://physionet.org/content/challenge-2012/1.0.0/). The challenge data was extracted from MIMICC II Clinical Database. We used EHR records of 8000 patients that are available for public access in the above link. More details on the dataset can be found at https://ieeexplore.ieee.org/abstract/document/6420376

<br/>

## Task and Modeling
In this example, we will see Fuse-Med-ML applied to the prediction of in-hospital mortality, which is one of the tasks that was included in the PhysioNet challenge.

Figure 1 presents the structure of patient data in this data set. It can be seen that patients presented with both general information and time series of observations during the first 48 hours in the hospital. The architecture can be seen in Figure 2. We used a transformer-based deep-learning model that utilizes the BERT network design with additional adjustments.

<br/>

<figure>
<img src="./Fig1.jpg" alt="Trulli" style="width:100%">
<figcaption>Figure 1</figcaption>
</figure>

<br/>

The network was implemented in a manner that allows several output ”heads”, each responsible for a separate task. The network consists of several encoder layers, followed by task-specific heads (smaller neural network classifiers).

<br/>



<figure>
<img src="./Fig2.jpg" alt="Trulli" style="width:100%">
<figcaption>Figure 2</figcaption>
</figure>

<br/>

The patients’ data is passing a set of preprocessing steps before entering the modeling stage, which includes the exclusion of illegal patients, mapping continuous values of observation to categorical values by applying digitization, and combining both static data (general information) and time series in one input patient trajectory sequence.

As mentioned, the network has multiple prediction heads, related to multiple prediction tasks. Using multiple heads aims to improve patients' representation. We trained and tested multiple networks consisting of a combination of the following tasks:

<br/>- **Next Visit Prediction** A multi-label classification task of predicting the observation values that will appear in the visit following the last visit of the trajectory.
<br/>- **Gender Prediction**: A binary classifier that predicts the biologically determined gender of the
patient.
<br/>- **In-hospital mortality** prediction: A binary classifier that predicts will patient die or not during his hospital stay after the first 48 hours.

<br/>


## Usage Instructions
Download the following data to a designated folder in your system:
- set-a and set-b folders of the patient’s input
- Outcome-a.txt, Outcome-b.txt files

In the configuration file config.yaml hydra file:
1. Define environment variable CINC_DATA_PATH to the data folder.
2. There is a possibility to define an environment variable CINC_DATA_PKL of a path to the pickle file that will manage the loaded data for the next runs. Use 'null' (that is the default) in configuration if you are not using this option
3. Configure network hyperparameters

Run: python3 main_train.py

Empty file.
149 changes: 149 additions & 0 deletions examples/fuse_examples/multimodality/ehr_transformer/config.yaml
@@ -0,0 +1,149 @@
name: ehr_transformer
root: "."

target_key: "Target" # key that points to labels in sample_dict
max_len_seq: 350 # maximal number of tokens in the trajectory
aux_gender_classification: True # set to True to enable gender auxiliary head
aux_next_vis_classification: True # set to True to enable next visit auxiliary head

data:
dataset_cfg: # See PhysioNetCinC.dataset for details about the arguments
raw_data_pkl: null # in case we want to use pickle for managing raw data use ${oc.env:CINC_DATA_PKL} to define the
#path to the pickle file
raw_data_path: ${oc.env:CINC_DATA_PATH}
split_filename: None
num_folds: 5
train_folds: [ 0, 1, 2 ]
validation_folds: [ 3 ]
test_folds: [ 4 ]
seed: 2580
reset_split: True
num_percentiles : 4 #number of bins (percentiles) for converting floating lab/vital measurements to categorical values
categorical_max_num_of_values : 5 #max number of uniq values for categorical variable for not to be digitized
min_hours_in_hospital: 46
min_number_of_visits: 10
max_len_seq: ${max_len_seq}
static_variables_to_embed: ['Age','ICUType','Height','Weight','BMI',]
embed_static_in_all_visits: 0


batch_size: 128
target_key: ${target_key}

data_loader_train: # Dataloader constructor parameters
num_workers: 8

data_loader_valid: # Dataloader constructor parameters
num_workers: 8
batch_size: ${data.batch_size}


model:
encoder_type: "bert" # supported values: "bert" "transformer"

transformer_encoder: # TransformerEncoder constructor arguments - used when encoder type is "transformer"
num_tokens: ${max_len_seq}
token_dim: ${model.embed.emb_dim}
depth: 4
heads: 10
mlp_dim: 50
dropout: 0.0
emb_dropout: 0.0
num_cls_tokens: 1

bert_config_kwargs: # BertConfig constructor arguments - used when encoder type is "bert"
hidden_size: ${model.z_dim} # word embedding and seg embedding hidden size (needs to be a multiple of attention heads)
num_hidden_layers: 6 # number of multi-head attention layers required
num_attention_heads: 24 # number of attention heads
intermediate_size: 512 # the size of the "intermediate" layer in the transformer encoder
hidden_act: gelu # activation function ("gelu", 'relu', 'swish')
hidden_dropout_prob: 0.2 # dropout rate
attention_probs_dropout_prob: 0.22 # multi-head attention dropout rate
initializer_range: 0.02 # parameter weight initializer range

embed: # Embed constructor arguments
emb_dim: ${model.z_dim}

classifier_head: # HeadD1 constructor arguments - used for main classifer head
num_outputs: 2
layers_description: [256]

classifier_gender_head: # HeadD1 constructor arguments - used for gender classifer head
num_outputs: 2
layers_description: [256]

classifier_next_vis_head: # HeadD1 constructor arguments - used for next visit classifer head
layers_description: [256]

aux_gender_classification: ${aux_gender_classification}
aux_next_vis_classification: ${aux_next_vis_classification}

z_dim: 48

# train
train: # arguments for train() in classifiers_main_train.py
model_dir: ${root}/${name}
target_key: ${target_key}
target_loss_weight: 0.8
aux_gender_classification: ${aux_gender_classification}
gender_loss_weight: 0.1
aux_next_vis_classification: ${aux_next_vis_classification}
next_vis_loss_weight: 0.1

# uncomment to track in clearml
# track_clearml:
# project_name: "ehr_transformer"
# task_name: ${name}
# tags: "fuse_example"
# reuse_last_task_id: True
# continue_last_task: False


# uncomment for SGD
# opt:
# _partial_: true
# _target_: torch.optim.SGD
# momentum: 0.99
# nesterov: True
# lr: 0.001
# # weight_decay: 1e-5

# AdamW
op`t:
_partial_: true
_target_: torch.optim.AdamW
lr: 1e-3

# linear_schedule_with_warmup
lr_scheduler:
_partial_: True
_target_: transformers.get_linear_schedule_with_warmup
num_warmup_steps: 500
num_training_steps: 50000

# uncomment for lr sch ReduceLROnPlateau
# lr_scheduler:
# _partial_: true
# _target_: torch.optim.lr_scheduler.ReduceLROnPlateau

# uncomment for lr sch CosineAnnealingLR
# lr_scheduler:
# _target_: torch.optim.lr_scheduler.CosineAnnealingLR
# T_max: ${train.trainer_kwargs.max_epochs}
# eta_min: 1e-7
# last_epoch: -1

trainer_kwargs: # arguments for pl.Trainer
default_root_dir: ${train.model_dir}
max_epochs: 100
accelerator: "gpu"
devices: 1
strategy: null
auto_select_gpus: True
num_sanity_val_steps: 0

hydra:
run:
dir: ${root}/${name}
job:
chdir: False

0 comments on commit 3390fda

Please sign in to comment.