Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ehr transformer ICU example #245

Merged
merged 53 commits into from Jan 13, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
b37dd0b
Initial example of EHR Transformer based on BERT from Hugging Face
simona-rc Oct 2, 2022
d1149c6
Fix lightening trainer parameters
simona-rc Oct 3, 2022
7cd0410
Add preamble to python files
simona-rc Oct 6, 2022
1441b73
Preparation for moving OrigBertFuse to fuse core
simona-rc Nov 3, 2022
e298cad
changes
ellabarkan Dec 25, 2022
b143fe6
changes
ellabarkan Dec 25, 2022
1ee6960
changes
ellabarkan Dec 26, 2022
f82f344
changes
ellabarkan Dec 26, 2022
be3f0aa
changes
ellabarkan Dec 27, 2022
a4c07c0
changes
ellabarkan Dec 27, 2022
376f70e
changes
ellabarkan Dec 27, 2022
6b6d6b7
changes (fixed percentile generation to include both dynamic and stat…
ellabarkan Jan 4, 2023
2ba0162
added digitization Op
ellabarkan Jan 4, 2023
64671e1
fixed bugs (running large dataset) and added dropping of short patients
ellabarkan Jan 5, 2023
1e92d8c
added trajectories generation
ellabarkan Jan 5, 2023
7fe75b2
added trajectories generation + comments
ellabarkan Jan 8, 2023
0a2d128
added corpus
ellabarkan Jan 8, 2023
0951d0b
added corpus
ellabarkan Jan 8, 2023
b35af44
black formatter applyied
ellabarkan Jan 8, 2023
b46b8ad
black formatter applied
ellabarkan Jan 8, 2023
cbec12a
Revert "black formatter applied"
ellabarkan Jan 8, 2023
29f6e37
Revert "black formatter applyied"
ellabarkan Jan 8, 2023
f44df46
black formatter applied
ellabarkan Jan 8, 2023
56f048c
new black version formatter applied
ellabarkan Jan 8, 2023
1710c6a
fixes following pull request
ellabarkan Jan 9, 2023
087acdf
updated Op of generation trajectories, added postions, and indexes
ellabarkan Jan 9, 2023
e7051e8
updated Op of generation trajectories, added postions, and indexes
ellabarkan Jan 9, 2023
19cb648
bug fixes in generating trajectory of visits Op
ellabarkan Jan 10, 2023
128774e
black reformatting, adding TODO related to the heads management
ellabarkan Jan 10, 2023
7fb6c60
Merge branch 'master' of github.com:IBM/fuse-med-ml into ehr_transfor…
Jan 10, 2023
77736c7
less
ellabarkan Jan 10, 2023
8b5477d
fixes
ellabarkan Jan 10, 2023
e4219f0
Merge branch 'ehr_transformer_example' of https://github.com/BiomedSc…
ellabarkan Jan 10, 2023
24e61c5
moved WordVocab to local "utils.py"; added path to data pkl to config…
ellabarkan Jan 10, 2023
6d60137
fixed bug in passing percentile arg
ellabarkan Jan 10, 2023
60747b7
misc fixes
ellabarkan Jan 10, 2023
d35d525
Merge branch 'ehr_transformer_example' of github.com:IBM/fuse-med-ml …
Jan 10, 2023
bfd4d31
single head script with vanila transformer
Jan 10, 2023
b038bd8
added option of adding static details as a first special visit
ellabarkan Jan 11, 2023
8f8efcf
added Readme
ellabarkan Jan 11, 2023
fcd324a
deleted code of the old version of transforment implementation, added…
ellabarkan Jan 11, 2023
ea04710
more comments and flake8 fixes added
ellabarkan Jan 12, 2023
b094f3e
add auxilary heads
Jan 12, 2023
eecc052
add bert support
Jan 12, 2023
c2b17ff
fixes of Readme files and better configuration of pickle file
ellabarkan Jan 12, 2023
1e48f70
black reformatting
ellabarkan Jan 12, 2023
2fe806c
updating figures
ellabarkan Jan 12, 2023
332dae8
document
Jan 12, 2023
808d6f4
flake8 fix
Jan 12, 2023
c6e569e
add transformers to dependnecy lst
Jan 12, 2023
12bece7
Merge branch 'master' into ehr_transformer_example
mosheraboh Jan 12, 2023
440d7d0
cleanup
Jan 12, 2023
b8e2689
change default to bert
Jan 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Empty file added examples/__init__.py
Empty file.
Empty file.
@@ -0,0 +1 @@
__version__ = "0.0.1"
@@ -0,0 +1,130 @@
dataset:
raw_data_path: 'C:/D_Drive/Projects/EHR_Transformer/PhysioNet/predicting-mortality-of-icu-patients-the-physionetcomputing-in-cardiology-challenge-2012-1.0.0/predicting-mortality-of-icu-patients-the-physionet-computing-in-cardiology-challenge-2012-1.0.0'
#raw_data_pickle: 'raw_data.pkl'
split_filename: None
num_folds: 5
train_folds: [ 0, 1, 2 ]
validation_folds: [ 3 ]
test_folds: [ 4 ]
seed: 2580
reset_split: True
num_percentiles : 4 #number of bins (percentiles) for converting floating lab/vital measurements to categorical values
categorical_max_num_of_values : 5 #max number of uniq values for categorical variable for not to be digitized
min_hours_in_hospital: 46
visit_to_embed_static_variables: FIRST #embed to first (FIRST) or to all visits (ALL) or null (not to embed)
static_variables_to_embed: ['Age','ICUType','Height','Weight','BMI',]

data:
#### Test/Val set defs
days_to_inddate: 60 #90 #None #180 #only visits before days_to_ind to index date (or after index date, if negative) are considered for test/val
days_to_inddate_start: null #only the visits between ind_date-days_to_inddate_start and ind_date-days_to_inddate are considered for test/val
event_prediction_window_days: null #90
treatment_event_prediction_window_days: null #60
#### Train set defs
# in case of train we take all visits (we want to train on as much data as we can), so days_to_inddate_tr and days_to_inddate_start_tr should be both null
days_to_inddate_tr: 60 #90 #None #180 #only visits before days_to_ind to index date (or after index date, if negative) are considered for train
days_to_inddate_start_tr: null #only the visits between days_to_inddate_start and days_to_inddate are considered for train
event_prediction_window_days_tr: null #90

visit_days_resolution: 10 #consecutive visits that are less than visit_days_resolution days apart are merged into one visit.
#Note that merging stops only when number of days between consecutive visits is more than visit_days_resolution, so
#within a group of merged visits there may be some more than visit_days_resolution days apart.
#E.g. for visit days 1, 3, 5, 9, and for visit_days_resolution=2, days 1, 3, and 5 will be merged into a single visit, and day 10 will be a separate visit

limit_visits: null #210000 #null #50000000 #10000 #None #don't take too few (e.g. less than 10k), otherwise you'll get only a single visit per patient, which is not enough in case of next visit prediction
out_type: EVENT #EVENT or LABEL - determines the type of outcome GT, whether it is a list of events, or a binary label
task: CKD #RC #"RC24" #"PD"#"BC"#'EP' #The task is only used output path naming and in db access (ehrtransformers.data_access.db2_data), to choose which dataset to read
subtask: ALL_SERV
task_type: outcome
output_name: admdate_visit
use_procedures: false #true #false
num_loader_workers: 1
input_column_names: [CCS, ]
min_visits_per_patient: 10 #patients with less visits are ignored

# When predicting next visit, this will map regular visit dictionary (codes to indices, index per code) to reduced
# visit dictionary (codes to indices, code groups mapped to (fewer) indices). When no mapping is used, this is null.
reduce_pred_visit_vocab_mapping_path: null #/projects/msieve_dev3/usr/vadim/my_work/Code/BEHRT/ehrtransformers/data_access/icd_to_ccs.json #null


data_source: csv #sql_db #csv
data_source_str: '/home/simona/Dev/data/ckd5/ckd5_all_transformer.csv' #/gpfs/haifa/projects/m/msieve3/usr/Capable/MS_renal/data/rc_for_ehr_final.csv
#categorial_codes_str: /gpfs/haifa/projects/m/msieve3/usr/Capable/MS_renal/data/categories_codes_v2.csv
gt_source_str: null
sample_data_source_str: '/home/simona/Dev/data/ckd5/ckd5_all_transformer_10000.csv' #/data/usr/vadim/EHR/rc_for_ehr_sample.csv

learning:
batch_size: 100 #64 300, 480, 420
device: cuda:0 #cpu
is_data_parallel: false
gradient_accumulation_steps: 1
save_model: true
stop_file: stop.txt
optimization:
lr: 1.0e-05 #1.0e-5 #3.0e-06 #1e-4, #3e-6,
warmup_proportion: 0.1
weight_decay: 0.01
global:
age_symbol: null
debug_mode: true #if false, outputs to a log file instead of the main output
global_stat_file: global_stats.csv
max_age: 110
max_len_seq: 100 #How many visits in a BEHRT input vector
max_label_len: 10 #How many classes in the multiclass task TODO: remove, this is no longer used
min_visit: 5 #Visit trajectories with fewer visits are ignored
month: 1 # choice whether to use months for age (=1) or years (=12)
uber_base_path: '/home/simona/Dev/EHR_experiments/ckd5' #/data/usr/simona/EHR/ #path utils to all EHR experiments
model:
age_vocab_size: null # number of vocab for age embedding
attention_probs_dropout_prob: 0.22 # multi-head attention dropout rate
hidden_act: gelu # activation function ("gelu", 'relu', 'swish')
hidden_dropout_prob: 0.2 # dropout rate
hidden_size: 240 #48 #240 #72 #120 #288, # word embedding and seg embedding hidden size (needs to be a multiple of attention heads)
initializer_range: 0.02 # parameter weight initializer range
intermediate_size: 512 #128 #512 # the size of the "intermediate" layer in the transformer encoder
max_position_embedding: 100 # maximum number of tokens
num_attention_heads: 24 # number of attention heads
num_hidden_layers: 6 # number of multi-head attention layers required
# pretrain_model: /data/usr/vadim/EHR/PD_ALL_SERV_270_after_ind_90_to_event_tr_90_to_event_merge_5_days/EVENT/models/PD/outcome_admdate_visit/run_139/Logs/model_dir/checkpoint_last_epoch.pth # pretrained model path, None if not to continue from pretrained
# pretrain_model: /data/usr/vadim/EHR/PD_ALL_SERV_270_after_ind_90_to_event_tr_90_to_event_merge_5_days/EVENT/models/PD/outcome_admdate_visit/run_136/Logs/model_dir/checkpoint_last_epoch.pth # pretrained model path, None if not to continue from pretrained
# pretrain_model: /data/usr/vadim/EHR/PD_ALL_SERV_270_after_ind_90_to_event_tr_90_to_event_merge_10_days/EVENT/models/PD/outcome_admdate_visit/run_104/Logs/model_dir/checkpoint_last_epoch.pth # pretrained model path, None if not to continue from pretrained
# pretrain_model: /data/usr/vadim/EHR/PD_ALL_SERV_270_after_ind_90_to_event_tr_90_to_event_merge_10_days_with_procedures/EVENT/models/PD/outcome_admdate_visit/run_113/Logs/model_dir/checkpoint_last_epoch.pth # pretrained model path, None if not to continue from pretrained
# pretrain_model: /data/usr/vadim/EHR/PD_ALL_SERV_270_after_ind_90_to_event_tr_90_to_event_merge_5_days_with_procedures/EVENT/models/PD/outcome_admdate_visit/run_105/Logs/model_dir/checkpoint_last_epoch.pth
# pretrain_model: /data/usr/vadim/EHR/RC_ALL_SERV_90_to_ind_tr_90_to_ind_90_to_event_tr_90_to_event_merge_10_days/EVENT/models/RC/outcome_admdate_visit/run_129/Logs/model_dir/checkpoint_last_epoch.pth
# pretrain_model: /home/simona/Dev/EHR_experiments/ckd5/CKD_ALL_SERV_60_to_ind_tr_60_to_ind_merge_10_days/EVENT/models/CKD/outcome_admdate_visit/run_101/Logs/model_dir/checkpoint_last_epoch.pth
# pretrain_model: /home/simona/Dev/EHR_experiments/ckd5/CKD_ALL_SERV_60_to_ind_tr_60_to_ind_merge_10_days/EVENT/models/CKD/outcome_admdate_visit/run_177/Logs/model_dir/last.ckpt
pretrain_model: null
seg_vocab_size: null # number of vocab for seg embedding
train_epochs: 2
vocab_size: null # number of disease + symbols for word embedding
reverse_input_direction: true
heads: [gender, next_vis, disease_prediction] #, response_profile, gender] #[event, next_vis, gender, cluster, contrastive, response_profile]
head_weights: [0.005, 32.0, 0.02]
sampler_weights: [0.9, 0.1]
contrastive_instances: 1 #number of patient visits used in contrastive loss (reducing distance between embeddings of the same patient)

naming_conventions:
age_key: AGE
age_month_key: AGE_MON
date_key: ADMDATE
index_date_key: INDDATE
svc_date_key: SVCDATE
adm_date_key: ADMDATE
date_birth_key: DOBYR
dxver_key: DXVER
diagnosis_vec_key: DX
gender_key: SEX
male_val: 1 #'1'
female_val: 2 #'2'
label_key: label
outcome_key: CKD #RC #Disease column
event_key: EVENTS #event column
treatment_event_key: TREATMENT_EVENTS
patient_id_key: ENROLID
disease_key: CKD #RC #this is not used
healthy_val: 0
sick_val: 1
split_key: split
fold_key: fold
next_visit_key: DXNEXTVIS