In [1]:
import os
os.chdir('../')

from promptehr import PromptEHR
from promptehr import load_demo_data

# load pytrial demodata, supported by PyTrial package to load the demo EHR data
from pytrial.data.demo_data import load_mimic_ehr_sequence
from pytrial.tasks.trial_simulation.data import SequencePatient

# see the input format
demo = load_mimic_ehr_sequence(n_sample=100)

# build sequence dataset
seqdata = SequencePatient(data={'v':demo['visit'], 'y':demo['mortality'], 'x':demo['feature'],},
    metadata={
        'visit':{'mode':'dense'},
        'label':{'mode':'tensor'}, 
        'voc':demo['voc'],
        'max_visit':20,
        }
    )

print('visit', demo['visit'][0]) # a list of visit events
print('mortality', demo['mortality'][0]) # array of labels
print('feature', demo['feature'][0]) # array of patient baseline features
print('voc', demo['voc']) # dict of dicts containing the mapping from index to the original event names
print('order', demo['order']) # a list of three types of code
print('n_num_feature', demo['n_num_feature']) # int: a number of patient's numerical features
print('cat_cardinalities', demo['cat_cardinalities']) # list: a list of cardinalities of patient's categorical features

visit [[[0, 1, 2, 3, 4, 5, 6, 7], [0, 1, 2], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]], [[8, 9, 10, 7], [3, 4, 1], [0, 1, 2, 3, 5, 4, 6, 7, 8, 9, 10, 11, 13, 15, 16, 17, 18]]]
mortality False
feature [-1.02022055  0.          0.        ]
voc {'diag': <pytrial.tasks.trial_simulation.data.Voc object at 0x7efc6e498340>, 'prod': <pytrial.tasks.trial_simulation.data.Voc object at 0x7efc6e4983a0>, 'med': <pytrial.tasks.trial_simulation.data.Voc object at 0x7efc6e498400>}
order ['diag', 'prod', 'med']
n_num_feature 1
cat_cardinalities [2, 10]




In [2]:
# fit the model
model = PromptEHR(
    code_type=demo['order'],
    n_num_feature=demo['n_num_feature'],
    cat_cardinalities=demo['cat_cardinalities'],
    num_worker=0,
    eval_step=1,
    epoch=1,
    device=[1,2],
)
model.fit(
    train_data=seqdata,
    val_data=seqdata,
)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BartTokenizer'. 
The class this function is called from is 'DataTokenizer'.
***** Running training *****
  Num examples = 100
  Num Epochs = 1
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 1
  Total optimization steps = 1
Token indices sequence length is longer than the specified maximum sequence length for this model (552 > 512). Running this sequence through the model will result in indexing errors


Step,Training Loss,Validation Loss,Ppl Diag,Ppl Prod,Ppl Med
1,6.6957,No log,897.926819,353.614288,110.910278


evaluation for code diag.
***** Running Evaluation *****
  Num examples = 100
  Batch size = 512
evaluation for code prod.
***** Running Evaluation *****
  Num examples = 100
  Batch size = 512
evaluation for code med.
***** Running Evaluation *****
  Num examples = 100
  Batch size = 512


Saving model checkpoint to ./promptEHR_logs/checkpoint-1
Configuration saved in ./promptEHR_logs/checkpoint-1/config.json
Model weights saved in ./promptEHR_logs/checkpoint-1/pytorch_model.bin


Training completed. Do not forget to share your model on huggingface.co/models =)


Loading best model from ./promptEHR_logs/checkpoint-1 (score: 897.9268188476562).


In [3]:
# save the model
model.save_model('./simulation/promptEHR')

Configuration saved in ./simulation/promptEHR/config.json


Save the trained model to: ./simulation/promptEHR


In [4]:
# generate fake records
res = model.predict(seqdata, n_per_sample=10, n=100, verbose=True)

550it [00:47, 11.54it/s]                                                                                                                                                                                                                                                                                          


In [5]:
print(res)

{'visit': [[[[1, 3, 4, 6, 7, 202, 235, 684, 2], [601, 226, 9, 7]], [[0, 2, 71], [153, 3, 175]], [[97, 2, 100, 4, 5, 6, 9, 74, 11, 12, 15, 18, 19, 51, 87, 93], [0, 1, 2, 3, 4, 6, 40, 8, 11, 79, 15, 16, 17, 23, 56, 30]]], [[[2, 4, 5, 6, 202, 235, 621, 530, 2], [8, 9, 202, 7]], [[0, 1, 2], [153, 71, 175]], [[64, 97, 0, 3, 4, 5, 6, 7, 8, 10, 11, 14, 19, 30], [64, 1, 2, 3, 4, 6, 7, 9, 11, 15, 80, 16, 82, 51, 17, 23]]], [[[3, 4, 6, 7, 235, 684, 530, 632, 2], [8, 601, 10, 9]], [[0, 2, 175], [3, 4, 71]], [[2, 3, 4, 8, 10, 11, 12, 13, 14, 47, 15, 82, 51, 19, 23], [64, 1, 3, 4, 6, 74, 11, 10, 13, 14, 16, 82, 51, 18, 94, 30]]], [[[2, 3, 7, 202, 684, 530, 632, 637, 2], [8, 9, 202, 226]], [[2, 19, 71], [153, 71, 175]], [[64, 1, 2, 97, 100, 5, 7, 8, 42, 10, 12, 14, 15, 18], [33, 2, 97, 4, 38, 7, 40, 8, 74, 9, 10, 14, 79, 82, 87, 56, 30]]], [[[0, 1, 3, 6, 7, 621, 530, 637, 2], [601, 202, 10, 7]], [[19, 71, 175], [153, 4, 71]], [[0, 1, 97, 2, 9, 10, 11, 12, 14, 47, 15, 18, 19, 51, 23], [0, 97, 1, 2, 4

In [6]:
print('we are done! :)')

we are done! :)
