Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2f12ba7
commit 3390fda
Showing
14 changed files
with
1,498 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
65 changes: 65 additions & 0 deletions
65
examples/fuse_examples/multimodality/ehr_transformer/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Predicting Mortality of ICU Patients with Fuse-Med-ML | ||
|
||
## Introduction | ||
|
||
The aim of this example is to demonstrate Fuse-Med-ML abilities for prediction tasks involving Electronic Health Records (EHR). In this example, Fuse will be applied for the prediction of In-hospital patients’ mortality after spending 48 hours in Intensive Care Unit. | ||
|
||
<br/> | ||
|
||
|
||
## Dataset | ||
The dataset used in the example was taken from the PhysioNet Computing in Cardiology Challenge that was held in 2012 (https://physionet.org/content/challenge-2012/1.0.0/). The challenge data was extracted from MIMICC II Clinical Database. We used EHR records of 8000 patients that are available for public access in the above link. More details on the dataset can be found at https://ieeexplore.ieee.org/abstract/document/6420376 | ||
|
||
<br/> | ||
|
||
## Task and Modeling | ||
In this example, we will see Fuse-Med-ML applied to the prediction of in-hospital mortality, which is one of the tasks that was included in the PhysioNet challenge. | ||
|
||
Figure 1 presents the structure of patient data in this data set. It can be seen that patients presented with both general information and time series of observations during the first 48 hours in the hospital. The architecture can be seen in Figure 2. We used a transformer-based deep-learning model that utilizes the BERT network design with additional adjustments. | ||
|
||
<br/> | ||
|
||
<figure> | ||
<img src="./Fig1.jpg" alt="Trulli" style="width:100%"> | ||
<figcaption>Figure 1</figcaption> | ||
</figure> | ||
|
||
<br/> | ||
|
||
The network was implemented in a manner that allows several output ”heads”, each responsible for a separate task. The network consists of several encoder layers, followed by task-specific heads (smaller neural network classifiers). | ||
|
||
<br/> | ||
|
||
|
||
|
||
<figure> | ||
<img src="./Fig2.jpg" alt="Trulli" style="width:100%"> | ||
<figcaption>Figure 2</figcaption> | ||
</figure> | ||
|
||
<br/> | ||
|
||
The patients’ data is passing a set of preprocessing steps before entering the modeling stage, which includes the exclusion of illegal patients, mapping continuous values of observation to categorical values by applying digitization, and combining both static data (general information) and time series in one input patient trajectory sequence. | ||
|
||
As mentioned, the network has multiple prediction heads, related to multiple prediction tasks. Using multiple heads aims to improve patients' representation. We trained and tested multiple networks consisting of a combination of the following tasks: | ||
|
||
<br/>- **Next Visit Prediction** A multi-label classification task of predicting the observation values that will appear in the visit following the last visit of the trajectory. | ||
<br/>- **Gender Prediction**: A binary classifier that predicts the biologically determined gender of the | ||
patient. | ||
<br/>- **In-hospital mortality** prediction: A binary classifier that predicts will patient die or not during his hospital stay after the first 48 hours. | ||
|
||
<br/> | ||
|
||
|
||
## Usage Instructions | ||
Download the following data to a designated folder in your system: | ||
- set-a and set-b folders of the patient’s input | ||
- Outcome-a.txt, Outcome-b.txt files | ||
|
||
In the configuration file config.yaml hydra file: | ||
1. Define environment variable CINC_DATA_PATH to the data folder. | ||
2. There is a possibility to define an environment variable CINC_DATA_PKL of a path to the pickle file that will manage the loaded data for the next runs. Use 'null' (that is the default) in configuration if you are not using this option | ||
3. Configure network hyperparameters | ||
|
||
Run: python3 main_train.py | ||
|
Empty file.
149 changes: 149 additions & 0 deletions
149
examples/fuse_examples/multimodality/ehr_transformer/config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
name: ehr_transformer | ||
root: "." | ||
|
||
target_key: "Target" # key that points to labels in sample_dict | ||
max_len_seq: 350 # maximal number of tokens in the trajectory | ||
aux_gender_classification: True # set to True to enable gender auxiliary head | ||
aux_next_vis_classification: True # set to True to enable next visit auxiliary head | ||
|
||
data: | ||
dataset_cfg: # See PhysioNetCinC.dataset for details about the arguments | ||
raw_data_pkl: null # in case we want to use pickle for managing raw data use ${oc.env:CINC_DATA_PKL} to define the | ||
#path to the pickle file | ||
raw_data_path: ${oc.env:CINC_DATA_PATH} | ||
split_filename: None | ||
num_folds: 5 | ||
train_folds: [ 0, 1, 2 ] | ||
validation_folds: [ 3 ] | ||
test_folds: [ 4 ] | ||
seed: 2580 | ||
reset_split: True | ||
num_percentiles : 4 #number of bins (percentiles) for converting floating lab/vital measurements to categorical values | ||
categorical_max_num_of_values : 5 #max number of uniq values for categorical variable for not to be digitized | ||
min_hours_in_hospital: 46 | ||
min_number_of_visits: 10 | ||
max_len_seq: ${max_len_seq} | ||
static_variables_to_embed: ['Age','ICUType','Height','Weight','BMI',] | ||
embed_static_in_all_visits: 0 | ||
|
||
|
||
batch_size: 128 | ||
target_key: ${target_key} | ||
|
||
data_loader_train: # Dataloader constructor parameters | ||
num_workers: 8 | ||
|
||
data_loader_valid: # Dataloader constructor parameters | ||
num_workers: 8 | ||
batch_size: ${data.batch_size} | ||
|
||
|
||
model: | ||
encoder_type: "bert" # supported values: "bert" "transformer" | ||
|
||
transformer_encoder: # TransformerEncoder constructor arguments - used when encoder type is "transformer" | ||
num_tokens: ${max_len_seq} | ||
token_dim: ${model.embed.emb_dim} | ||
depth: 4 | ||
heads: 10 | ||
mlp_dim: 50 | ||
dropout: 0.0 | ||
emb_dropout: 0.0 | ||
num_cls_tokens: 1 | ||
|
||
bert_config_kwargs: # BertConfig constructor arguments - used when encoder type is "bert" | ||
hidden_size: ${model.z_dim} # word embedding and seg embedding hidden size (needs to be a multiple of attention heads) | ||
num_hidden_layers: 6 # number of multi-head attention layers required | ||
num_attention_heads: 24 # number of attention heads | ||
intermediate_size: 512 # the size of the "intermediate" layer in the transformer encoder | ||
hidden_act: gelu # activation function ("gelu", 'relu', 'swish') | ||
hidden_dropout_prob: 0.2 # dropout rate | ||
attention_probs_dropout_prob: 0.22 # multi-head attention dropout rate | ||
initializer_range: 0.02 # parameter weight initializer range | ||
|
||
embed: # Embed constructor arguments | ||
emb_dim: ${model.z_dim} | ||
|
||
classifier_head: # HeadD1 constructor arguments - used for main classifer head | ||
num_outputs: 2 | ||
layers_description: [256] | ||
|
||
classifier_gender_head: # HeadD1 constructor arguments - used for gender classifer head | ||
num_outputs: 2 | ||
layers_description: [256] | ||
|
||
classifier_next_vis_head: # HeadD1 constructor arguments - used for next visit classifer head | ||
layers_description: [256] | ||
|
||
aux_gender_classification: ${aux_gender_classification} | ||
aux_next_vis_classification: ${aux_next_vis_classification} | ||
|
||
z_dim: 48 | ||
|
||
# train | ||
train: # arguments for train() in classifiers_main_train.py | ||
model_dir: ${root}/${name} | ||
target_key: ${target_key} | ||
target_loss_weight: 0.8 | ||
aux_gender_classification: ${aux_gender_classification} | ||
gender_loss_weight: 0.1 | ||
aux_next_vis_classification: ${aux_next_vis_classification} | ||
next_vis_loss_weight: 0.1 | ||
|
||
# uncomment to track in clearml | ||
# track_clearml: | ||
# project_name: "ehr_transformer" | ||
# task_name: ${name} | ||
# tags: "fuse_example" | ||
# reuse_last_task_id: True | ||
# continue_last_task: False | ||
|
||
|
||
# uncomment for SGD | ||
# opt: | ||
# _partial_: true | ||
# _target_: torch.optim.SGD | ||
# momentum: 0.99 | ||
# nesterov: True | ||
# lr: 0.001 | ||
# # weight_decay: 1e-5 | ||
|
||
# AdamW | ||
op`t: | ||
_partial_: true | ||
_target_: torch.optim.AdamW | ||
lr: 1e-3 | ||
|
||
# linear_schedule_with_warmup | ||
lr_scheduler: | ||
_partial_: True | ||
_target_: transformers.get_linear_schedule_with_warmup | ||
num_warmup_steps: 500 | ||
num_training_steps: 50000 | ||
|
||
# uncomment for lr sch ReduceLROnPlateau | ||
# lr_scheduler: | ||
# _partial_: true | ||
# _target_: torch.optim.lr_scheduler.ReduceLROnPlateau | ||
|
||
# uncomment for lr sch CosineAnnealingLR | ||
# lr_scheduler: | ||
# _target_: torch.optim.lr_scheduler.CosineAnnealingLR | ||
# T_max: ${train.trainer_kwargs.max_epochs} | ||
# eta_min: 1e-7 | ||
# last_epoch: -1 | ||
|
||
trainer_kwargs: # arguments for pl.Trainer | ||
default_root_dir: ${train.model_dir} | ||
max_epochs: 100 | ||
accelerator: "gpu" | ||
devices: 1 | ||
strategy: null | ||
auto_select_gpus: True | ||
num_sanity_val_steps: 0 | ||
|
||
hydra: | ||
run: | ||
dir: ${root}/${name} | ||
job: | ||
chdir: False |
Oops, something went wrong.