In [72]:
from pytrial.tasks.indiv_outcome.data import SequencePatient

# build the train data
seqdata = SequencePatient(
    data={'v':data['visit'][:-200], 'y':data['y'][:-200], 'x':data['feature'][:-200]},
    metadata={
        'visit':{'mode':'dense', 'order': data['order']},
        'label':{'mode':'tensor'},
        'voc':data['voc'],
        'max_visit':20,
        }
    )


In [73]:
# build the test data
val_seqdata = SequencePatient(
    data={'v':data['visit'][-200:], 'y':data['y'][-200:], 'x':data['feature'][-200:]},
    metadata={
        'visit':{'mode':'dense', 'order': data['order']},
        'label':{'mode':'tensor'},
        'voc':data['voc'],
        'max_visit':20,
        }
    )

In [74]:
from pytrial.data.patient_data import SeqPatientCollator # we need a collation function to process the input SequencePatient dataset
from pytrial.tasks.indiv_outcome.sequence import RNN

In [75]:
model = RNN(
    vocab_size=[len(data['voc'][o]) for o in data['order']], # get the vocab size for each type of events to build the event embedding layer
    orders=data['order'], # similar, we need an order
    mode='binary',
    max_visit=20,
    bidirectional=True,
    epochs=20,
    batch_size=16,
    device='cpu',
)
model.fit(train_data=seqdata, valid_data=val_seqdata)

{'lr': 0.0001, 'weight_decay': 0.0001}
***** Running training *****
  Num examples = 800
  Num Epochs = 20
  Total optimization steps = 1000


Training Epoch:   0%|          | 0/20 [00:00<?, ?it/s]
Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  12%|█▏        | 6/50 [00:00<00:00, 58.14it/s][A
Iteration:  24%|██▍       | 12/50 [00:00<00:00, 49.00it/s][A
Iteration:  36%|███▌      | 18/50 [00:00<00:00, 51.00it/s][A
Iteration:  48%|████▊     | 24/50 [00:00<00:00, 50.80it/s][A
Iteration:  60%|██████    | 30/50 [00:00<00:00, 51.37it/s][A
Iteration:  72%|███████▏  | 36/50 [00:00<00:00, 51.36it/s][A
Iteration:  84%|████████▍ | 42/50 [00:00<00:00, 50.72it/s][A
Iteration: 100%|██████████| 50/50 [00:00<00:00, 51.45it/s][A
Training Epoch:   5%|▌         | 1/20 [00:00<00:18,  1.03it/s]
Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  12%|█▏        | 6/50 [00:00<00:00, 47.50it/s][A
Iteration:  24%|██▍       | 12/50 [00:00<00:00, 53.17it/s][A
Iteration:  36%|███▌      | 18/50 [00:00<00:00, 48.56it/s][A
Iteration:  48%|████▊     | 24/50 [00:00<00:00, 52.28it/s][A
Iteration:  60%|██████    | 30/50 


######### Train Loss 100 #########
0 0.6519 


######### Eval 100 #########
auc: 0.4077
New best score: from -inf to 0.4077299945563418
Best checkpoint is updated at 100 with auc 0.4077299945563418.



Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  14%|█▍        | 7/50 [00:00<00:00, 66.96it/s][A
Iteration:  28%|██▊       | 14/50 [00:00<00:00, 59.76it/s][A
Iteration:  42%|████▏     | 21/50 [00:00<00:00, 54.01it/s][A
Iteration:  56%|█████▌    | 28/50 [00:00<00:00, 56.51it/s][A
Iteration:  68%|██████▊   | 34/50 [00:00<00:00, 52.67it/s][A
Iteration:  80%|████████  | 40/50 [00:00<00:00, 54.13it/s][A
Iteration: 100%|██████████| 50/50 [00:00<00:00, 51.51it/s][A
Training Epoch:  15%|█▌        | 3/20 [00:03<00:17,  1.01s/it]
Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  10%|█         | 5/50 [00:00<00:00, 47.87it/s][A
Iteration:  22%|██▏       | 11/50 [00:00<00:00, 49.30it/s][A
Iteration:  34%|███▍      | 17/50 [00:00<00:00, 50.55it/s][A
Iteration:  46%|████▌     | 23/50 [00:00<00:00, 52.92it/s][A
Iteration:  58%|█████▊    | 29/50 [00:00<00:00, 53.32it/s][A
Iteration:  70%|███████   | 35/50 [00:00<00:00, 49.32it/s][A
Iteration:  82%|████████▏ 


######### Train Loss 200 #########
0 0.5338 


######### Eval 200 #########
auc: 0.4275
New best score: from 0.4077299945563418 to 0.4275086191253856
Best checkpoint is updated at 200 with auc 0.4275086191253856.



Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  12%|█▏        | 6/50 [00:00<00:00, 46.88it/s][A
Iteration:  22%|██▏       | 11/50 [00:00<00:01, 38.30it/s][A
Iteration:  34%|███▍      | 17/50 [00:00<00:00, 45.92it/s][A
Iteration:  44%|████▍     | 22/50 [00:00<00:00, 47.07it/s][A
Iteration:  60%|██████    | 30/50 [00:00<00:00, 53.28it/s][A
Iteration:  72%|███████▏  | 36/50 [00:00<00:00, 50.64it/s][A
Iteration:  84%|████████▍ | 42/50 [00:00<00:00, 48.51it/s][A
Iteration: 100%|██████████| 50/50 [00:01<00:00, 49.34it/s][A
Training Epoch:  25%|██▌       | 5/20 [00:05<00:15,  1.03s/it]
Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  12%|█▏        | 6/50 [00:00<00:00, 57.02it/s][A
Iteration:  24%|██▍       | 12/50 [00:00<00:00, 50.46it/s][A
Iteration:  36%|███▌      | 18/50 [00:00<00:00, 49.55it/s][A
Iteration:  48%|████▊     | 24/50 [00:00<00:00, 50.30it/s][A
Iteration:  60%|██████    | 30/50 [00:00<00:00, 50.69it/s][A
Iteration:  72%|███████▏  


######### Train Loss 300 #########
0 0.5031 


######### Eval 300 #########
auc: 0.4580
New best score: from 0.4275086191253856 to 0.45799310469969157
Best checkpoint is updated at 300 with auc 0.45799310469969157.



Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  12%|█▏        | 6/50 [00:00<00:00, 48.60it/s][A
Iteration:  24%|██▍       | 12/50 [00:00<00:00, 51.55it/s][A
Iteration:  36%|███▌      | 18/50 [00:00<00:00, 50.00it/s][A
Iteration:  48%|████▊     | 24/50 [00:00<00:00, 52.91it/s][A
Iteration:  60%|██████    | 30/50 [00:00<00:00, 54.87it/s][A
Iteration:  72%|███████▏  | 36/50 [00:00<00:00, 54.36it/s][A
Iteration:  84%|████████▍ | 42/50 [00:00<00:00, 50.31it/s][A
Iteration: 100%|██████████| 50/50 [00:00<00:00, 51.02it/s][A
Training Epoch:  35%|███▌      | 7/20 [00:07<00:13,  1.02s/it]
Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  14%|█▍        | 7/50 [00:00<00:00, 62.75it/s][A
Iteration:  28%|██▊       | 14/50 [00:00<00:00, 58.23it/s][A
Iteration:  40%|████      | 20/50 [00:00<00:00, 58.44it/s][A
Iteration:  52%|█████▏    | 26/50 [00:00<00:00, 57.57it/s][A
Iteration:  64%|██████▍   | 32/50 [00:00<00:00, 57.04it/s][A
Iteration:  76%|███████▌  


######### Train Loss 400 #########
0 0.4766 


######### Eval 400 #########
auc: 0.5035
New best score: from 0.45799310469969157 to 0.5035383777898748
Best checkpoint is updated at 400 with auc 0.5035383777898748.



Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  10%|█         | 5/50 [00:00<00:01, 41.29it/s][A
Iteration:  22%|██▏       | 11/50 [00:00<00:00, 44.58it/s][A
Iteration:  32%|███▏      | 16/50 [00:00<00:00, 44.69it/s][A
Iteration:  44%|████▍     | 22/50 [00:00<00:00, 49.92it/s][A
Iteration:  56%|█████▌    | 28/50 [00:00<00:00, 52.21it/s][A
Iteration:  70%|███████   | 35/50 [00:00<00:00, 55.90it/s][A
Iteration:  82%|████████▏ | 41/50 [00:00<00:00, 53.08it/s][A
Iteration: 100%|██████████| 50/50 [00:00<00:00, 51.63it/s][A
Training Epoch:  45%|████▌     | 9/20 [00:09<00:11,  1.00s/it]
Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  10%|█         | 5/50 [00:00<00:01, 44.81it/s][A
Iteration:  22%|██▏       | 11/50 [00:00<00:00, 50.79it/s][A
Iteration:  34%|███▍      | 17/50 [00:00<00:00, 45.32it/s][A
Iteration:  44%|████▍     | 22/50 [00:00<00:00, 42.17it/s][A
Iteration:  54%|█████▍    | 27/50 [00:00<00:00, 43.67it/s][A
Iteration:  64%|██████▍   


######### Train Loss 500 #########
0 0.4389 


######### Eval 500 #########
auc: 0.5527
New best score: from 0.5035383777898748 to 0.5527127563055707
Best checkpoint is updated at 500 with auc 0.5527127563055707.



Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  10%|█         | 5/50 [00:00<00:01, 43.81it/s][A
Iteration:  22%|██▏       | 11/50 [00:00<00:00, 51.01it/s][A
Iteration:  34%|███▍      | 17/50 [00:00<00:00, 49.77it/s][A
Iteration:  44%|████▍     | 22/50 [00:00<00:00, 48.11it/s][A
Iteration:  54%|█████▍    | 27/50 [00:00<00:00, 46.79it/s][A
Iteration:  64%|██████▍   | 32/50 [00:00<00:00, 47.26it/s][A
Iteration:  78%|███████▊  | 39/50 [00:00<00:00, 52.18it/s][A
Iteration: 100%|██████████| 50/50 [00:01<00:00, 49.07it/s][A
Training Epoch:  55%|█████▌    | 11/20 [00:11<00:09,  1.05s/it]
Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  10%|█         | 5/50 [00:00<00:00, 49.36it/s][A
Iteration:  20%|██        | 10/50 [00:00<00:00, 44.18it/s][A
Iteration:  32%|███▏      | 16/50 [00:00<00:00, 46.31it/s][A
Iteration:  42%|████▏     | 21/50 [00:00<00:00, 42.08it/s][A
Iteration:  52%|█████▏    | 26/50 [00:00<00:00, 39.10it/s][A
Iteration:  60%|██████   


######### Train Loss 600 #########
0 0.3809 


######### Eval 600 #########
auc: 0.5830
New best score: from 0.5527127563055707 to 0.583015786608601
Best checkpoint is updated at 600 with auc 0.583015786608601.



Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  10%|█         | 5/50 [00:00<00:01, 42.30it/s][A
Iteration:  20%|██        | 10/50 [00:00<00:00, 46.44it/s][A
Iteration:  32%|███▏      | 16/50 [00:00<00:00, 51.45it/s][A
Iteration:  44%|████▍     | 22/50 [00:00<00:00, 50.41it/s][A
Iteration:  56%|█████▌    | 28/50 [00:00<00:00, 48.66it/s][A
Iteration:  66%|██████▌   | 33/50 [00:00<00:00, 47.61it/s][A
Iteration:  76%|███████▌  | 38/50 [00:00<00:00, 47.31it/s][A
Iteration:  86%|████████▌ | 43/50 [00:00<00:00, 46.24it/s][A
Iteration: 100%|██████████| 50/50 [00:01<00:00, 48.07it/s][A
Training Epoch:  65%|██████▌   | 13/20 [00:13<00:07,  1.09s/it]
Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  10%|█         | 5/50 [00:00<00:00, 49.17it/s][A
Iteration:  20%|██        | 10/50 [00:00<00:00, 42.19it/s][A
Iteration:  30%|███       | 15/50 [00:00<00:00, 41.97it/s][A
Iteration:  40%|████      | 20/50 [00:00<00:00, 42.29it/s][A
Iteration:  50%|█████    


######### Train Loss 700 #########
0 0.3219 


######### Eval 700 #########
auc: 0.5901
New best score: from 0.583015786608601 to 0.5900925421883506
Best checkpoint is updated at 700 with auc 0.5900925421883506.



Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  10%|█         | 5/50 [00:00<00:01, 38.86it/s][A
Iteration:  18%|█▊        | 9/50 [00:00<00:01, 36.68it/s][A
Iteration:  26%|██▌       | 13/50 [00:00<00:00, 37.49it/s][A
Iteration:  36%|███▌      | 18/50 [00:00<00:00, 39.46it/s][A
Iteration:  48%|████▊     | 24/50 [00:00<00:00, 43.90it/s][A
Iteration:  58%|█████▊    | 29/50 [00:00<00:00, 45.07it/s][A
Iteration:  68%|██████▊   | 34/50 [00:00<00:00, 46.11it/s][A
Iteration:  78%|███████▊  | 39/50 [00:00<00:00, 46.05it/s][A
Iteration:  88%|████████▊ | 44/50 [00:01<00:00, 44.25it/s][A
Iteration: 100%|██████████| 50/50 [00:01<00:00, 42.88it/s][A
Training Epoch:  75%|███████▌  | 15/20 [00:16<00:05,  1.14s/it]
Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  10%|█         | 5/50 [00:00<00:00, 45.49it/s][A
Iteration:  20%|██        | 10/50 [00:00<00:00, 45.56it/s][A
Iteration:  30%|███       | 15/50 [00:00<00:00, 41.87it/s][A
Iteration:  40%|████      


######### Train Loss 800 #########
0 0.2744 


######### Eval 800 #########
auc: 0.5727



Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:   8%|▊         | 4/50 [00:00<00:01, 36.58it/s][A
Iteration:  20%|██        | 10/50 [00:00<00:00, 49.10it/s][A
Iteration:  32%|███▏      | 16/50 [00:00<00:00, 50.40it/s][A
Iteration:  44%|████▍     | 22/50 [00:00<00:00, 50.24it/s][A
Iteration:  56%|█████▌    | 28/50 [00:00<00:00, 49.14it/s][A
Iteration:  68%|██████▊   | 34/50 [00:00<00:00, 50.62it/s][A
Iteration:  82%|████████▏ | 41/50 [00:00<00:00, 53.41it/s][A
Iteration: 100%|██████████| 50/50 [00:00<00:00, 52.26it/s][A
Training Epoch:  85%|████████▌ | 17/20 [00:18<00:03,  1.08s/it]
Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  12%|█▏        | 6/50 [00:00<00:00, 52.50it/s][A
Iteration:  24%|██▍       | 12/50 [00:00<00:00, 46.96it/s][A
Iteration:  34%|███▍      | 17/50 [00:00<00:00, 47.92it/s][A
Iteration:  46%|████▌     | 23/50 [00:00<00:00, 50.61it/s][A
Iteration:  58%|█████▊    | 29/50 [00:00<00:00, 50.02it/s][A
Iteration:  72%|███████▏ 


######### Train Loss 900 #########
0 0.2363 


######### Eval 900 #########
auc: 0.5718



Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  14%|█▍        | 7/50 [00:00<00:00, 58.20it/s][A
Iteration:  26%|██▌       | 13/50 [00:00<00:00, 53.95it/s][A
Iteration:  38%|███▊      | 19/50 [00:00<00:00, 55.50it/s][A
Iteration:  52%|█████▏    | 26/50 [00:00<00:00, 55.04it/s][A
Iteration:  64%|██████▍   | 32/50 [00:00<00:00, 56.07it/s][A
Iteration:  76%|███████▌  | 38/50 [00:00<00:00, 50.85it/s][A
Iteration:  88%|████████▊ | 44/50 [00:00<00:00, 49.96it/s][A
Iteration: 100%|██████████| 50/50 [00:00<00:00, 51.33it/s][A
Training Epoch:  95%|█████████▌| 19/20 [00:20<00:01,  1.05s/it]
Iteration:   0%|          | 0/50 [00:00<?, ?it/s][A
Iteration:  16%|█▌        | 8/50 [00:00<00:00, 65.20it/s][A
Iteration:  30%|███       | 15/50 [00:00<00:00, 61.45it/s][A
Iteration:  44%|████▍     | 22/50 [00:00<00:00, 58.94it/s][A
Iteration:  56%|█████▌    | 28/50 [00:00<00:00, 51.66it/s][A
Iteration:  68%|██████▊   | 34/50 [00:00<00:00, 50.02it/s][A
Iteration:  80%|████████ 


######### Train Loss 1000 #########
0 0.1937 


######### Eval 1000 #########
auc: 0.5676
Load best ckpt from `./checkpoints/best`.
Training completes.


In [76]:
predict = model.predict(seqdata)

print(predict[:20])

[[0.03759647]
 [0.3407742 ]
 [0.08466253]
 [0.04212925]
 [0.07062326]
 [0.61914057]
 [0.11943546]
 [0.28911486]
 [0.03431963]
 [0.1817149 ]
 [0.10272896]
 [0.4975288 ]
 [0.23787639]
 [0.53194153]
 [0.5683119 ]
 [0.42574254]
 [0.05658188]
 [0.10878944]
 [0.01620503]
 [0.54372704]]


In [43]:
model.save_model('./checkpoints/indiv_outcome.sequence/')