# PyHealth GAMENet Reproduction Study

## Imports

First, import python libraries, pyhealth, and related pyhealth libraries

In [1]:
import argparse
import sys
import pandas
import json

import pyhealth

from pyhealth.datasets import MIMIC4Dataset, MIMIC3Dataset
from pyhealth.tasks import drug_recommendation_mimic4_fn, drug_recommendation_mimic3_fn
# import dataloader related functions
from pyhealth.datasets.splitter import split_by_patient
from pyhealth.datasets import split_by_patient, get_dataloader
# import gamenet model
from pyhealth.models import RETAIN, GAMENet
# import trainer
from pyhealth.trainer import Trainer


  from tqdm.autonotebook import trange


Next, import our custom libraries we've designed for this study

In [2]:
from drug_rec_task import drug_recommendation_mimic4_no_hist, drug_recommendation_mimic4_no_proc
from model import ModelWrapper
from alt_gamenets import GAMENetNoHist, GAMENetNoProc
from mimic import MIMIC4, MIMICWrapper

# import our constants
from constants import (
    DEV,
    EPOCHS, LR, DECAY_WEIGHT,
    DRUG_REC_TN, ALL_TASKS,
    GN_KEY, RT_KEY,
    MODEL_TYPES_PER_TASK, RETAIN_FEATS_PER_TASK,
    GAMENET_EXP, RETAIN_EXP,
    SCORE_KEY, DPV_KEY, DDI_RATE_KEY,
    BASE_DDI_RATE
)

## Load Data

We use the `MIMIC4` data class from the `mimic` import.
Another option would be to import the `MIMIC3` data class and use that as the `dataset` below.
For this purpose, we just use `MIMIC4` and use that to load the data and prepare it with the appropriate tasks.

The data class (either `MIMIC4` or `MIMIC3` decides where the data root is.
By default, it reads in data from `./hiddendata/extracted/{mimic3/4}/`.
So, for MIMIC4 data the data would need to be in: `./hiddendata/extracted/mimic4/`.
This can be changed by either modifying the MIMIC4 class directly, or modifying the data root default in the `constants` file.

In [3]:
# whether to read in "dev" mode or not
DEV = False

In [4]:
# save data in the ./hiddendata/extracted/ directory
## this uses MIMIC4, so place data in ./hiddendata/extracted/mimic4/
## could use MIMIC3 with from mimic import MIMIC3 and placing data in ./hiddendata/extracted/mimic3/
dataset = MIMIC4
mimic = MIMICWrapper(datasource=dataset, tasks=dataset.all_tasks())
mimic_data = mimic.load_data(dev=DEV)
drug_task_data = mimic.drug_task_data()
dataloaders = mimic.create_dataloaders()

reading mimic4 data...
---DATA STATS FOR mimic4 DATA---
stat

Statistics of base dataset (dev=False):
	- Dataset: MIMIC4Dataset
	- Number of patients: 180733
	- Number of visits: 431231
	- Number of visits per patient: 2.3860
	- Number of events per visit in diagnoses_icd: 11.0296
	- Number of events per visit in procedures_icd: 1.5518
	- Number of events per visit in prescriptions: 54.2354

info

dataset.patients: patient_id -> <Patient>

<Patient>
    - visits: visit_id -> <Visit> 
    - other patient-level info
    
    <Visit>
        - event_list_dict: table_name -> List[Event]
        - other visit-level info
    
        <Event>
            - code: str
            - other event-level info

***run task: drug_recommendation


Generating samples for drug_recommendation_mimic4_fn: 100%|███████████| 180733/180733 [00:36<00:00, 4890.30it/s]


{'visit_id': '22595853', 'patient_id': '10000032', 'conditions': [['5723', '78959', '5715', '07070', '496', '29680', '30981', 'V1582']], 'procedures': [['5491']], 'drugs': ['B01A', 'J07B', 'A12B', 'C03D', 'C03C', 'N02B', 'J05A', 'R03A', 'N07B', 'R03B'], 'drugs_all': [['B01A', 'J07B', 'A12B', 'C03D', 'C03C', 'N02B', 'J05A', 'R03A', 'N07B', 'R03B']]}
***run task: no_hist


Generating samples for drug_recommendation_mimic4_no_hist: 100%|██████| 180733/180733 [00:39<00:00, 4550.12it/s]


{'visit_id': '22595853', 'patient_id': '10000032', 'conditions': [['5723', '78959', '5715', '07070', '496', '29680', '30981', 'V1582']], 'procedures': [['5491']], 'drugs': ['B01A', 'J07B', 'A12B', 'C03D', 'C03C', 'N02B', 'J05A', 'R03A', 'N07B', 'R03B'], 'drugs_all': [['B01A', 'J07B', 'A12B', 'C03D', 'C03C', 'N02B', 'J05A', 'R03A', 'N07B', 'R03B']]}
***run task: no_proc


Generating samples for drug_recommendation_mimic4_no_proc: 100%|██████| 180733/180733 [00:45<00:00, 3955.30it/s]


{'visit_id': '22595853', 'patient_id': '10000032', 'conditions': [['5723', '78959', '5715', '07070', '496', '29680', '30981', 'V1582']], 'drugs': ['B01A', 'J07B', 'A12B', 'C03D', 'C03C', 'N02B', 'J05A', 'R03A', 'N07B', 'R03B'], 'drugs_all': [['B01A', 'J07B', 'A12B', 'C03D', 'C03C', 'N02B', 'J05A', 'R03A', 'N07B', 'R03B']]}


In [5]:
drug_task_data

{'drug_recommendation': <pyhealth.datasets.sample_dataset.SampleDataset at 0x7f969fa65f10>,
 'no_hist': <pyhealth.datasets.sample_dataset.SampleDataset at 0x7f96698b7e80>,
 'no_proc': <pyhealth.datasets.sample_dataset.SampleDataset at 0x7f966875ce50>}

## Create DDI Matrices

In order to calculate DDI Rate, we need to create the DDI matrices.
GAMENet models have this built-in, but RETAIN does not.
So, we craft our DDI matrices ahead of time to help us calculate the rate later.

In [6]:
ddi_mats = {}

for taskname in mimic.get_task_names():
    model_type = MODEL_TYPES_PER_TASK[taskname][GN_KEY]
    ddi_mats[taskname] = model_type(drug_task_data[taskname]).generate_ddi_adj()

## Train the Models

In [7]:
retain = {}
gamenet = {}

In [8]:
# baseline
print("---RETAIN TRAINING---")
for taskname,dataloader in dataloaders.items():
    print("--training retain on {} data--".format(taskname))
    # create and train retain model
    retain[taskname] = ModelWrapper(
        drug_task_data[taskname],
        model=MODEL_TYPES_PER_TASK[taskname][RT_KEY],
        feature_keys=RETAIN_FEATS_PER_TASK[taskname],
        experiment="{}_task_{}".format(RETAIN_EXP, taskname)
    )
    retain[taskname].train_model(
        dataloader["train"], dataloader["val"],
        decay_weight=DECAY_WEIGHT,
        learning_rate=LR,
        epochs=EPOCHS
    )

# gamenet
print("---GAMENET TRAINING---")
for taskname,dataloader in dataloaders.items():
    print("--training gamenet on {} data--".format(taskname))
    # create and train gamenet model
    gamenet[taskname] = ModelWrapper(
        drug_task_data[taskname],
        model=MODEL_TYPES_PER_TASK[taskname][GN_KEY],
        experiment="{}_task_{}".format(GAMENET_EXP, taskname)
    )
    gamenet[taskname].train_model(
        dataloader["train"], dataloader["val"],
        decay_weight = DECAY_WEIGHT,
        learning_rate = LR,
        epochs=EPOCHS
    )

---RETAIN TRAINING---
--training retain on drug_recommendation data--
making retain model


RETAIN(
  (embeddings): ModuleDict(
    (conditions): Embedding(19186, 128, padding_idx=0)
    (procedures): Embedding(10605, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (retain): ModuleDict(
    (conditions): RETAINLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
    (procedures): RETAINLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (fc): Linear(in_features=256, out_features=200, bias=True)
)
Metrics: ['jaccard_samples', 'accuracy', 'hamming_loss', 'precision

Epoch 0 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-0, step-1842 ---
loss: 0.2492
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 56.13it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-0, step-1842 ---
jaccard_samples: 0.4020
accuracy: 0.0017
hamming_loss: 0.0865
precision_samples: 0.7345
recall_samples: 0.4885
pr_auc_samples: 0.7072
f1_samples: 0.5540
loss: 0.2183
New best accuracy score (0.0017) at epoch-0, step-1842



Epoch 1 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-1, step-3684 ---
loss: 0.2130
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 56.45it/s]
--- Eval epoch-1, step-3684 ---
jaccard_samples: 0.4207
accuracy: 0.0020
hamming_loss: 0.0824
precision_samples: 0.7548
recall_samples: 0.5013
pr_auc_samples: 0.7268
f1_samples: 0.5729
loss: 0.2046
New best accuracy score (0.0020) at epoch-1, step-3684



Epoch 2 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-2, step-5526 ---
loss: 0.2041
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 57.24it/s]
--- Eval epoch-2, step-5526 ---
jaccard_samples: 0.4341
accuracy: 0.0018
hamming_loss: 0.0807
precision_samples: 0.7574
recall_samples: 0.5188
pr_auc_samples: 0.7376
f1_samples: 0.5860
loss: 0.1987



Epoch 3 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-3, step-7368 ---
loss: 0.1992
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 55.96it/s]
--- Eval epoch-3, step-7368 ---
jaccard_samples: 0.4382
accuracy: 0.0017
hamming_loss: 0.0796
precision_samples: 0.7657
recall_samples: 0.5192
pr_auc_samples: 0.7434
f1_samples: 0.5903
loss: 0.1956



Epoch 4 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-4, step-9210 ---
loss: 0.1960
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 56.71it/s]
--- Eval epoch-4, step-9210 ---
jaccard_samples: 0.4439
accuracy: 0.0017
hamming_loss: 0.0792
precision_samples: 0.7621
recall_samples: 0.5293
pr_auc_samples: 0.7469
f1_samples: 0.5956
loss: 0.1936



Epoch 5 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-5, step-11052 ---
loss: 0.1938
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 55.85it/s]
--- Eval epoch-5, step-11052 ---
jaccard_samples: 0.4485
accuracy: 0.0017
hamming_loss: 0.0787
precision_samples: 0.7638
recall_samples: 0.5348
pr_auc_samples: 0.7507
f1_samples: 0.6003
loss: 0.1922



Epoch 6 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-6, step-12894 ---
loss: 0.1922
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 56.35it/s]
--- Eval epoch-6, step-12894 ---
jaccard_samples: 0.4457
accuracy: 0.0020
hamming_loss: 0.0784
precision_samples: 0.7710
recall_samples: 0.5266
pr_auc_samples: 0.7513
f1_samples: 0.5973
loss: 0.1915



Epoch 7 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-7, step-14736 ---
loss: 0.1910
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 56.41it/s]
--- Eval epoch-7, step-14736 ---
jaccard_samples: 0.4406
accuracy: 0.0020
hamming_loss: 0.0786
precision_samples: 0.7793
recall_samples: 0.5152
pr_auc_samples: 0.7518
f1_samples: 0.5924
loss: 0.1918
New best accuracy score (0.0020) at epoch-7, step-14736



Epoch 8 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-8, step-16578 ---
loss: 0.1901
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 56.48it/s]
--- Eval epoch-8, step-16578 ---
jaccard_samples: 0.4432
accuracy: 0.0020
hamming_loss: 0.0785
precision_samples: 0.7776
recall_samples: 0.5194
pr_auc_samples: 0.7529
f1_samples: 0.5953
loss: 0.1910



Epoch 9 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-9, step-18420 ---
loss: 0.1893
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 56.62it/s]
--- Eval epoch-9, step-18420 ---
jaccard_samples: 0.4476
accuracy: 0.0022
hamming_loss: 0.0784
precision_samples: 0.7707
recall_samples: 0.5297
pr_auc_samples: 0.7535
f1_samples: 0.5991
loss: 0.1907
New best accuracy score (0.0022) at epoch-9, step-18420



Epoch 10 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-10, step-20262 ---
loss: 0.1885
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 56.61it/s]
--- Eval epoch-10, step-20262 ---
jaccard_samples: 0.4528
accuracy: 0.0022
hamming_loss: 0.0782
precision_samples: 0.7628
recall_samples: 0.5409
pr_auc_samples: 0.7533
f1_samples: 0.6045
loss: 0.1905



Epoch 11 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-11, step-22104 ---
loss: 0.1879
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 56.89it/s]
--- Eval epoch-11, step-22104 ---
jaccard_samples: 0.4463
accuracy: 0.0022
hamming_loss: 0.0785
precision_samples: 0.7730
recall_samples: 0.5267
pr_auc_samples: 0.7536
f1_samples: 0.5978
loss: 0.1906



Epoch 12 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-12, step-23946 ---
loss: 0.1872
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 57.19it/s]
--- Eval epoch-12, step-23946 ---
jaccard_samples: 0.4527
accuracy: 0.0024
hamming_loss: 0.0786
precision_samples: 0.7611
recall_samples: 0.5427
pr_auc_samples: 0.7528
f1_samples: 0.6040
loss: 0.1906
New best accuracy score (0.0024) at epoch-12, step-23946



Epoch 13 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-13, step-25788 ---
loss: 0.1866
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 56.17it/s]
--- Eval epoch-13, step-25788 ---
jaccard_samples: 0.4487
accuracy: 0.0018
hamming_loss: 0.0787
precision_samples: 0.7686
recall_samples: 0.5333
pr_auc_samples: 0.7537
f1_samples: 0.6002
loss: 0.1908



Epoch 14 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-14, step-27630 ---
loss: 0.1861
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 56.81it/s]
--- Eval epoch-14, step-27630 ---
jaccard_samples: 0.4457
accuracy: 0.0018
hamming_loss: 0.0787
precision_samples: 0.7722
recall_samples: 0.5268
pr_auc_samples: 0.7534
f1_samples: 0.5975
loss: 0.1909



Epoch 15 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-15, step-29472 ---
loss: 0.1856
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 56.57it/s]
--- Eval epoch-15, step-29472 ---
jaccard_samples: 0.4443
accuracy: 0.0019
hamming_loss: 0.0787
precision_samples: 0.7760
recall_samples: 0.5218
pr_auc_samples: 0.7539
f1_samples: 0.5963
loss: 0.1912



Epoch 16 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-16, step-31314 ---
loss: 0.1851
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 56.38it/s]
--- Eval epoch-16, step-31314 ---
jaccard_samples: 0.4433
accuracy: 0.0017
hamming_loss: 0.0788
precision_samples: 0.7770
recall_samples: 0.5210
pr_auc_samples: 0.7538
f1_samples: 0.5947
loss: 0.1914



Epoch 17 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-17, step-33156 ---
loss: 0.1845
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 56.58it/s]
--- Eval epoch-17, step-33156 ---
jaccard_samples: 0.4485
accuracy: 0.0019
hamming_loss: 0.0789
precision_samples: 0.7670
recall_samples: 0.5339
pr_auc_samples: 0.7532
f1_samples: 0.6000
loss: 0.1911



Epoch 18 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-18, step-34998 ---
loss: 0.1843
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 56.13it/s]
--- Eval epoch-18, step-34998 ---
jaccard_samples: 0.4506
accuracy: 0.0022
hamming_loss: 0.0792
precision_samples: 0.7612
recall_samples: 0.5415
pr_auc_samples: 0.7524
f1_samples: 0.6017
loss: 0.1912



Epoch 19 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-19, step-36840 ---
loss: 0.1839
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 56.61it/s]
--- Eval epoch-19, step-36840 ---
jaccard_samples: 0.4508
accuracy: 0.0017
hamming_loss: 0.0789
precision_samples: 0.7637
recall_samples: 0.5390
pr_auc_samples: 0.7532
f1_samples: 0.6021
loss: 0.1908
Loaded best model


--training retain on no_hist data--
making retain model


RETAIN(
  (embeddings): ModuleDict(
    (conditions): Embedding(19186, 128, padding_idx=0)
    (procedures): Embedding(10605, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (retain): ModuleDict(
    (conditions): RETAINLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
    (procedures): RETAINLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (fc): Linear(in_features=256, out_features=200, bias=True)
)
Metrics: ['jaccard_samples', 'accuracy', 'hamming_loss', 'precision

Epoch 0 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-0, step-1838 ---
loss: 0.2469
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 111.71it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-0, step-1838 ---
jaccard_samples: 0.3952
accuracy: 0.0015
hamming_loss: 0.0857
precision_samples: 0.7390
recall_samples: 0.4734
pr_auc_samples: 0.7001
f1_samples: 0.5465
loss: 0.2218
New best accuracy score (0.0015) at epoch-0, step-1838



Epoch 1 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-1, step-3676 ---
loss: 0.2168
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 113.92it/s]
--- Eval epoch-1, step-3676 ---
jaccard_samples: 0.4125
accuracy: 0.0017
hamming_loss: 0.0831
precision_samples: 0.7499
recall_samples: 0.4909
pr_auc_samples: 0.7177
f1_samples: 0.5645
loss: 0.2097
New best accuracy score (0.0017) at epoch-1, step-3676



Epoch 2 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-2, step-5514 ---
loss: 0.2082
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 111.54it/s]
--- Eval epoch-2, step-5514 ---
jaccard_samples: 0.4284
accuracy: 0.0017
hamming_loss: 0.0809
precision_samples: 0.7523
recall_samples: 0.5101
pr_auc_samples: 0.7298
f1_samples: 0.5806
loss: 0.2015



Epoch 3 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-3, step-7352 ---
loss: 0.2026
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 113.25it/s]
--- Eval epoch-3, step-7352 ---
jaccard_samples: 0.4342
accuracy: 0.0017
hamming_loss: 0.0799
precision_samples: 0.7609
recall_samples: 0.5140
pr_auc_samples: 0.7377
f1_samples: 0.5863
loss: 0.1968



Epoch 4 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-4, step-9190 ---
loss: 0.1992
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 114.22it/s]
--- Eval epoch-4, step-9190 ---
jaccard_samples: 0.4369
accuracy: 0.0016
hamming_loss: 0.0793
precision_samples: 0.7657
recall_samples: 0.5149
pr_auc_samples: 0.7424
f1_samples: 0.5889
loss: 0.1947



Epoch 5 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-5, step-11028 ---
loss: 0.1967
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 112.91it/s]
--- Eval epoch-5, step-11028 ---
jaccard_samples: 0.4370
accuracy: 0.0017
hamming_loss: 0.0789
precision_samples: 0.7717
recall_samples: 0.5124
pr_auc_samples: 0.7445
f1_samples: 0.5888
loss: 0.1932



Epoch 6 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-6, step-12866 ---
loss: 0.1951
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 112.20it/s]
--- Eval epoch-6, step-12866 ---
jaccard_samples: 0.4445
accuracy: 0.0015
hamming_loss: 0.0785
precision_samples: 0.7633
recall_samples: 0.5277
pr_auc_samples: 0.7463
f1_samples: 0.5959
loss: 0.1919



Epoch 7 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-7, step-14704 ---
loss: 0.1936
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 113.10it/s]
--- Eval epoch-7, step-14704 ---
jaccard_samples: 0.4392
accuracy: 0.0019
hamming_loss: 0.0785
precision_samples: 0.7746
recall_samples: 0.5140
pr_auc_samples: 0.7482
f1_samples: 0.5908
loss: 0.1916
New best accuracy score (0.0019) at epoch-7, step-14704



Epoch 8 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-8, step-16542 ---
loss: 0.1927
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 113.79it/s]
--- Eval epoch-8, step-16542 ---
jaccard_samples: 0.4463
accuracy: 0.0017
hamming_loss: 0.0785
precision_samples: 0.7631
recall_samples: 0.5303
pr_auc_samples: 0.7482
f1_samples: 0.5976
loss: 0.1911



Epoch 9 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-9, step-18380 ---
loss: 0.1918
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 112.57it/s]
--- Eval epoch-9, step-18380 ---
jaccard_samples: 0.4439
accuracy: 0.0017
hamming_loss: 0.0784
precision_samples: 0.7704
recall_samples: 0.5239
pr_auc_samples: 0.7501
f1_samples: 0.5952
loss: 0.1907



Epoch 10 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-10, step-20218 ---
loss: 0.1911
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 114.66it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-10, step-20218 ---
jaccard_samples: 0.4420
accuracy: 0.0015
hamming_loss: 0.0784
precision_samples: 0.7718
recall_samples: 0.5209
pr_auc_samples: 0.7497
f1_samples: 0.5935
loss: 0.1908



Epoch 11 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-11, step-22056 ---
loss: 0.1904
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 114.40it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-11, step-22056 ---
jaccard_samples: 0.4416
accuracy: 0.0019
hamming_loss: 0.0782
precision_samples: 0.7766
recall_samples: 0.5170
pr_auc_samples: 0.7508
f1_samples: 0.5931
loss: 0.1906



Epoch 12 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-12, step-23894 ---
loss: 0.1897
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 111.72it/s]
--- Eval epoch-12, step-23894 ---
jaccard_samples: 0.4459
accuracy: 0.0017
hamming_loss: 0.0785
precision_samples: 0.7667
recall_samples: 0.5283
pr_auc_samples: 0.7499
f1_samples: 0.5975
loss: 0.1906



Epoch 13 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-13, step-25732 ---
loss: 0.1892
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 112.67it/s]
--- Eval epoch-13, step-25732 ---
jaccard_samples: 0.4444
accuracy: 0.0017
hamming_loss: 0.0784
precision_samples: 0.7714
recall_samples: 0.5235
pr_auc_samples: 0.7511
f1_samples: 0.5958
loss: 0.1904



Epoch 14 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-14, step-27570 ---
loss: 0.1886
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 113.21it/s]
--- Eval epoch-14, step-27570 ---
jaccard_samples: 0.4403
accuracy: 0.0017
hamming_loss: 0.0787
precision_samples: 0.7767
recall_samples: 0.5159
pr_auc_samples: 0.7508
f1_samples: 0.5916
loss: 0.1906



Epoch 15 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-15, step-29408 ---
loss: 0.1880
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 114.27it/s]
--- Eval epoch-15, step-29408 ---
jaccard_samples: 0.4421
accuracy: 0.0017
hamming_loss: 0.0789
precision_samples: 0.7718
recall_samples: 0.5222
pr_auc_samples: 0.7499
f1_samples: 0.5931
loss: 0.1911



Epoch 16 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-16, step-31246 ---
loss: 0.1875
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 112.94it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-16, step-31246 ---
jaccard_samples: 0.4419
accuracy: 0.0016
hamming_loss: 0.0787
precision_samples: 0.7727
recall_samples: 0.5203
pr_auc_samples: 0.7500
f1_samples: 0.5932
loss: 0.1909



Epoch 17 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-17, step-33084 ---
loss: 0.1869
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 111.73it/s]
--- Eval epoch-17, step-33084 ---
jaccard_samples: 0.4426
accuracy: 0.0013
hamming_loss: 0.0786
precision_samples: 0.7724
recall_samples: 0.5214
pr_auc_samples: 0.7505
f1_samples: 0.5940
loss: 0.1910



Epoch 18 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-18, step-34922 ---
loss: 0.1863
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 112.65it/s]
--- Eval epoch-18, step-34922 ---
jaccard_samples: 0.4334
accuracy: 0.0018
hamming_loss: 0.0791
precision_samples: 0.7841
recall_samples: 0.5033
pr_auc_samples: 0.7501
f1_samples: 0.5846
loss: 0.1920



Epoch 19 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-19, step-36760 ---
loss: 0.1858
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:02<00:00, 112.96it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-19, step-36760 ---
jaccard_samples: 0.4417
accuracy: 0.0016
hamming_loss: 0.0787
precision_samples: 0.7731
recall_samples: 0.5199
pr_auc_samples: 0.7500
f1_samples: 0.5928
loss: 0.1913
Loaded best model


--training retain on no_proc data--
making retain model


RETAIN(
  (embeddings): ModuleDict(
    (conditions): Embedding(22643, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (retain): ModuleDict(
    (conditions): RETAINLayer(
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (alpha_gru): GRU(128, 128, batch_first=True)
      (beta_gru): GRU(128, 128, batch_first=True)
      (alpha_li): Linear(in_features=128, out_features=1, bias=True)
      (beta_li): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (fc): Linear(in_features=128, out_features=201, bias=True)
)
Metrics: ['jaccard_samples', 'accuracy', 'hamming_loss', 'precision_samples', 'recall_samples', 'pr_auc_samples', 'f1_samples']
Device: cuda

Training:
Batch size: 64
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight decay: 1e-05
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7f9660473100>
Monitor: accuracy
Monitor criterion: max
Epochs: 20



Epoch 0 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-0, step-3437 ---
loss: 0.2331
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 84.98it/s]
--- Eval epoch-0, step-3437 ---
jaccard_samples: 0.3448
accuracy: 0.0006
hamming_loss: 0.0824
precision_samples: 0.7396
recall_samples: 0.4063
pr_auc_samples: 0.6674
f1_samples: 0.4957
loss: 0.2080
New best accuracy score (0.0006) at epoch-0, step-3437



Epoch 1 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-1, step-6874 ---
loss: 0.2024
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 84.46it/s]
--- Eval epoch-1, step-6874 ---
jaccard_samples: 0.3778
accuracy: 0.0006
hamming_loss: 0.0791
precision_samples: 0.7369
recall_samples: 0.4531
pr_auc_samples: 0.6900
f1_samples: 0.5310
loss: 0.1978



Epoch 2 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-2, step-10311 ---
loss: 0.1957
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 83.70it/s]
--- Eval epoch-2, step-10311 ---
jaccard_samples: 0.3871
accuracy: 0.0007
hamming_loss: 0.0777
precision_samples: 0.7441
recall_samples: 0.4632
pr_auc_samples: 0.7003
f1_samples: 0.5403
loss: 0.1929
New best accuracy score (0.0007) at epoch-2, step-10311



Epoch 3 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-3, step-13748 ---
loss: 0.1924
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 84.09it/s]
--- Eval epoch-3, step-13748 ---
jaccard_samples: 0.3907
accuracy: 0.0009
hamming_loss: 0.0772
precision_samples: 0.7458
recall_samples: 0.4676
pr_auc_samples: 0.7043
f1_samples: 0.5437
loss: 0.1909
New best accuracy score (0.0009) at epoch-3, step-13748



Epoch 4 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-4, step-17185 ---
loss: 0.1904
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 83.35it/s]
--- Eval epoch-4, step-17185 ---
jaccard_samples: 0.3906
accuracy: 0.0009
hamming_loss: 0.0769
precision_samples: 0.7545
recall_samples: 0.4626
pr_auc_samples: 0.7083
f1_samples: 0.5439
loss: 0.1900



Epoch 5 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-5, step-20622 ---
loss: 0.1893
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 84.12it/s]
--- Eval epoch-5, step-20622 ---
jaccard_samples: 0.3918
accuracy: 0.0010
hamming_loss: 0.0766
precision_samples: 0.7575
recall_samples: 0.4632
pr_auc_samples: 0.7098
f1_samples: 0.5448
loss: 0.1890
New best accuracy score (0.0010) at epoch-5, step-20622



Epoch 6 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-6, step-24059 ---
loss: 0.1886
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 84.93it/s]
--- Eval epoch-6, step-24059 ---
jaccard_samples: 0.3981
accuracy: 0.0009
hamming_loss: 0.0763
precision_samples: 0.7491
recall_samples: 0.4755
pr_auc_samples: 0.7103
f1_samples: 0.5514
loss: 0.1883



Epoch 7 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-7, step-27496 ---
loss: 0.1881
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 84.28it/s]
--- Eval epoch-7, step-27496 ---
jaccard_samples: 0.3980
accuracy: 0.0011
hamming_loss: 0.0764
precision_samples: 0.7509
recall_samples: 0.4750
pr_auc_samples: 0.7111
f1_samples: 0.5512
loss: 0.1884
New best accuracy score (0.0011) at epoch-7, step-27496



Epoch 8 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-8, step-30933 ---
loss: 0.1876
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 84.00it/s]
--- Eval epoch-8, step-30933 ---
jaccard_samples: 0.3940
accuracy: 0.0011
hamming_loss: 0.0763
precision_samples: 0.7579
recall_samples: 0.4663
pr_auc_samples: 0.7119
f1_samples: 0.5470
loss: 0.1882



Epoch 9 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-9, step-34370 ---
loss: 0.1873
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 77.44it/s]
--- Eval epoch-9, step-34370 ---
jaccard_samples: 0.3942
accuracy: 0.0012
hamming_loss: 0.0764
precision_samples: 0.7566
recall_samples: 0.4666
pr_auc_samples: 0.7115
f1_samples: 0.5472
loss: 0.1880
New best accuracy score (0.0012) at epoch-9, step-34370



Epoch 10 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-10, step-37807 ---
loss: 0.1871
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 78.98it/s]
--- Eval epoch-10, step-37807 ---
jaccard_samples: 0.3897
accuracy: 0.0010
hamming_loss: 0.0766
precision_samples: 0.7641
recall_samples: 0.4568
pr_auc_samples: 0.7120
f1_samples: 0.5426
loss: 0.1885



Epoch 11 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-11, step-41244 ---
loss: 0.1869
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 84.52it/s]
--- Eval epoch-11, step-41244 ---
jaccard_samples: 0.3974
accuracy: 0.0012
hamming_loss: 0.0762
precision_samples: 0.7533
recall_samples: 0.4732
pr_auc_samples: 0.7123
f1_samples: 0.5505
loss: 0.1875



Epoch 12 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-12, step-44681 ---
loss: 0.1868
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 80.04it/s]
--- Eval epoch-12, step-44681 ---
jaccard_samples: 0.3930
accuracy: 0.0014
hamming_loss: 0.0764
precision_samples: 0.7629
recall_samples: 0.4620
pr_auc_samples: 0.7128
f1_samples: 0.5459
loss: 0.1882
New best accuracy score (0.0014) at epoch-12, step-44681



Epoch 13 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-13, step-48118 ---
loss: 0.1866
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 85.23it/s]
--- Eval epoch-13, step-48118 ---
jaccard_samples: 0.3932
accuracy: 0.0011
hamming_loss: 0.0763
precision_samples: 0.7615
recall_samples: 0.4629
pr_auc_samples: 0.7128
f1_samples: 0.5462
loss: 0.1880



Epoch 14 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-14, step-51555 ---
loss: 0.1866
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 84.67it/s]
--- Eval epoch-14, step-51555 ---
jaccard_samples: 0.3978
accuracy: 0.0011
hamming_loss: 0.0762
precision_samples: 0.7533
recall_samples: 0.4737
pr_auc_samples: 0.7125
f1_samples: 0.5507
loss: 0.1872



Epoch 15 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-15, step-54992 ---
loss: 0.1864
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 84.48it/s]
--- Eval epoch-15, step-54992 ---
jaccard_samples: 0.3954
accuracy: 0.0010
hamming_loss: 0.0763
precision_samples: 0.7580
recall_samples: 0.4675
pr_auc_samples: 0.7130
f1_samples: 0.5484
loss: 0.1880



Epoch 16 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-16, step-58429 ---
loss: 0.1864
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 79.71it/s]
--- Eval epoch-16, step-58429 ---
jaccard_samples: 0.3962
accuracy: 0.0009
hamming_loss: 0.0762
precision_samples: 0.7581
recall_samples: 0.4699
pr_auc_samples: 0.7132
f1_samples: 0.5491
loss: 0.1878



Epoch 17 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-17, step-61866 ---
loss: 0.1863
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 78.47it/s]
--- Eval epoch-17, step-61866 ---
jaccard_samples: 0.3999
accuracy: 0.0011
hamming_loss: 0.0762
precision_samples: 0.7512
recall_samples: 0.4770
pr_auc_samples: 0.7128
f1_samples: 0.5533
loss: 0.1876



Epoch 18 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-18, step-65303 ---
loss: 0.1863
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 81.30it/s]
--- Eval epoch-18, step-65303 ---
jaccard_samples: 0.3945
accuracy: 0.0010
hamming_loss: 0.0761
precision_samples: 0.7600
recall_samples: 0.4662
pr_auc_samples: 0.7126
f1_samples: 0.5474
loss: 0.1878



Epoch 19 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-19, step-68740 ---
loss: 0.1861
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:05<00:00, 76.24it/s]
--- Eval epoch-19, step-68740 ---
jaccard_samples: 0.4026
accuracy: 0.0010
hamming_loss: 0.0761
precision_samples: 0.7471
recall_samples: 0.4837
pr_auc_samples: 0.7128
f1_samples: 0.5557
loss: 0.1873
Loaded best model


---GAMENET TRAINING---
--training gamenet on drug_recommendation data--
making gamenet model


GAMENet(
  (embeddings): ModuleDict(
    (conditions): Embedding(19186, 128, padding_idx=0)
    (procedures): Embedding(10605, 128, padding_idx=0)
  )
  (cond_rnn): GRU(128, 128, batch_first=True)
  (proc_rnn): GRU(128, 128, batch_first=True)
  (query): Sequential(
    (0): ReLU()
    (1): Linear(in_features=256, out_features=128, bias=True)
  )
  (gamenet): GAMENetLayer(
    (ehr_gcn): GCN(
      (gcn1): GCNLayer()
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (gcn2): GCNLayer()
    )
    (ddi_gcn): GCN(
      (gcn1): GCNLayer()
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (gcn2): GCNLayer()
    )
    (fc): Linear(in_features=384, out_features=200, bias=True)
    (bce_loss_fn): BCEWithLogitsLoss()
  )
)
Metrics: ['jaccard_samples', 'accuracy', 'hamming_loss', 'precision_samples', 'recall_samples', 'pr_auc_samples', 'f1_samples']
Device: cuda

Training:
Batch size: 64
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight decay: 1e-

Epoch 0 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-0, step-1842 ---
loss: 0.2342
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 54.68it/s]
--- Eval epoch-0, step-1842 ---
jaccard_samples: 0.3775
accuracy: 0.0003
hamming_loss: 0.0902
precision_samples: 0.7069
recall_samples: 0.4654
pr_auc_samples: 0.6745
f1_samples: 0.5304
loss: 0.2182
New best accuracy score (0.0003) at epoch-0, step-1842



Epoch 1 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-1, step-3684 ---
loss: 0.2092
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 55.17it/s]
--- Eval epoch-1, step-3684 ---
jaccard_samples: 0.4046
accuracy: 0.0005
hamming_loss: 0.0856
precision_samples: 0.7243
recall_samples: 0.4936
pr_auc_samples: 0.7009
f1_samples: 0.5583
loss: 0.2073
New best accuracy score (0.0005) at epoch-1, step-3684



Epoch 2 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-2, step-5526 ---
loss: 0.1999
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 55.62it/s]
--- Eval epoch-2, step-5526 ---
jaccard_samples: 0.4287
accuracy: 0.0015
hamming_loss: 0.0834
precision_samples: 0.7240
recall_samples: 0.5310
pr_auc_samples: 0.7212
f1_samples: 0.5813
loss: 0.2017
New best accuracy score (0.0015) at epoch-2, step-5526



Epoch 3 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-3, step-7368 ---
loss: 0.1939
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 55.27it/s]
--- Eval epoch-3, step-7368 ---
jaccard_samples: 0.4364
accuracy: 0.0008
hamming_loss: 0.0816
precision_samples: 0.7362
recall_samples: 0.5338
pr_auc_samples: 0.7311
f1_samples: 0.5893
loss: 0.1976



Epoch 4 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-4, step-9210 ---
loss: 0.1901
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 54.78it/s]
--- Eval epoch-4, step-9210 ---
jaccard_samples: 0.4331
accuracy: 0.0014
hamming_loss: 0.0807
precision_samples: 0.7579
recall_samples: 0.5150
pr_auc_samples: 0.7366
f1_samples: 0.5856
loss: 0.1964



Epoch 5 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-5, step-11052 ---
loss: 0.1870
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 54.93it/s]
--- Eval epoch-5, step-11052 ---
jaccard_samples: 0.4496
accuracy: 0.0021
hamming_loss: 0.0821
precision_samples: 0.7216
recall_samples: 0.5629
pr_auc_samples: 0.7372
f1_samples: 0.6013
loss: 0.1969
New best accuracy score (0.0021) at epoch-5, step-11052



Epoch 6 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-6, step-12894 ---
loss: 0.1841
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 55.22it/s]
--- Eval epoch-6, step-12894 ---
jaccard_samples: 0.4501
accuracy: 0.0015
hamming_loss: 0.0820
precision_samples: 0.7292
recall_samples: 0.5602
pr_auc_samples: 0.7389
f1_samples: 0.6016
loss: 0.1963



Epoch 7 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-7, step-14736 ---
loss: 0.1811
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 54.70it/s]
--- Eval epoch-7, step-14736 ---
jaccard_samples: 0.4466
accuracy: 0.0020
hamming_loss: 0.0813
precision_samples: 0.7416
recall_samples: 0.5461
pr_auc_samples: 0.7408
f1_samples: 0.5985
loss: 0.1961



Epoch 8 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-8, step-16578 ---
loss: 0.1778
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 51.96it/s]
--- Eval epoch-8, step-16578 ---
jaccard_samples: 0.4434
accuracy: 0.0013
hamming_loss: 0.0814
precision_samples: 0.7471
recall_samples: 0.5393
pr_auc_samples: 0.7410
f1_samples: 0.5948
loss: 0.1972



Epoch 9 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-9, step-18420 ---
loss: 0.1742
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 55.49it/s]
--- Eval epoch-9, step-18420 ---
jaccard_samples: 0.4441
accuracy: 0.0017
hamming_loss: 0.0815
precision_samples: 0.7467
recall_samples: 0.5411
pr_auc_samples: 0.7422
f1_samples: 0.5955
loss: 0.1982



Epoch 10 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-10, step-20262 ---
loss: 0.1702
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 55.19it/s]
--- Eval epoch-10, step-20262 ---
jaccard_samples: 0.4518
accuracy: 0.0020
hamming_loss: 0.0824
precision_samples: 0.7292
recall_samples: 0.5633
pr_auc_samples: 0.7406
f1_samples: 0.6033
loss: 0.2014



Epoch 11 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-11, step-22104 ---
loss: 0.1663
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 55.00it/s]
--- Eval epoch-11, step-22104 ---
jaccard_samples: 0.4411
accuracy: 0.0014
hamming_loss: 0.0834
precision_samples: 0.7437
recall_samples: 0.5404
pr_auc_samples: 0.7378
f1_samples: 0.5914
loss: 0.2041



Epoch 12 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-12, step-23946 ---
loss: 0.1628
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 55.60it/s]
--- Eval epoch-12, step-23946 ---
jaccard_samples: 0.4445
accuracy: 0.0017
hamming_loss: 0.0845
precision_samples: 0.7306
recall_samples: 0.5543
pr_auc_samples: 0.7365
f1_samples: 0.5948
loss: 0.2087



Epoch 13 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-13, step-25788 ---
loss: 0.1598
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 55.31it/s]
--- Eval epoch-13, step-25788 ---
jaccard_samples: 0.4423
accuracy: 0.0009
hamming_loss: 0.0857
precision_samples: 0.7311
recall_samples: 0.5525
pr_auc_samples: 0.7341
f1_samples: 0.5921
loss: 0.2126



Epoch 14 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-14, step-27630 ---
loss: 0.1572
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 55.19it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-14, step-27630 ---
jaccard_samples: 0.4443
accuracy: 0.0015
hamming_loss: 0.0843
precision_samples: 0.7382
recall_samples: 0.5490
pr_auc_samples: 0.7361
f1_samples: 0.5945
loss: 0.2144



Epoch 15 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-15, step-29472 ---
loss: 0.1551
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 54.43it/s]
--- Eval epoch-15, step-29472 ---
jaccard_samples: 0.4475
accuracy: 0.0010
hamming_loss: 0.0870
precision_samples: 0.7188
recall_samples: 0.5699
pr_auc_samples: 0.7312
f1_samples: 0.5973
loss: 0.2198



Epoch 16 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-16, step-31314 ---
loss: 0.1532
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 55.55it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-16, step-31314 ---
jaccard_samples: 0.4376
accuracy: 0.0010
hamming_loss: 0.0849
precision_samples: 0.7482
recall_samples: 0.5352
pr_auc_samples: 0.7336
f1_samples: 0.5874
loss: 0.2205



Epoch 17 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-17, step-33156 ---
loss: 0.1519
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 55.16it/s]
--- Eval epoch-17, step-33156 ---
jaccard_samples: 0.4429
accuracy: 0.0013
hamming_loss: 0.0855
precision_samples: 0.7331
recall_samples: 0.5511
pr_auc_samples: 0.7339
f1_samples: 0.5933
loss: 0.2177



Epoch 18 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-18, step-34998 ---
loss: 0.1505
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 55.99it/s]
--- Eval epoch-18, step-34998 ---
jaccard_samples: 0.4416
accuracy: 0.0015
hamming_loss: 0.0854
precision_samples: 0.7392
recall_samples: 0.5460
pr_auc_samples: 0.7316
f1_samples: 0.5912
loss: 0.2247



Epoch 19 / 20:   0%|          | 0/1842 [00:00<?, ?it/s]

--- Train epoch-19, step-36840 ---
loss: 0.1497
Evaluation: 100%|█████████████████████████████████████████████████████████████| 233/233 [00:04<00:00, 55.02it/s]
--- Eval epoch-19, step-36840 ---
jaccard_samples: 0.4406
accuracy: 0.0011
hamming_loss: 0.0875
precision_samples: 0.7292
recall_samples: 0.5532
pr_auc_samples: 0.7305
f1_samples: 0.5905
loss: 0.2253
Loaded best model


--training gamenet on no_hist data--
making gamenet model without hist...


GAMENetNoHist(
  (embeddings): ModuleDict(
    (conditions): Embedding(19186, 128, padding_idx=0)
    (procedures): Embedding(10605, 128, padding_idx=0)
  )
  (cond_rnn): GRU(128, 128, batch_first=True)
  (proc_rnn): GRU(128, 128, batch_first=True)
  (query): Sequential(
    (0): ReLU()
    (1): Linear(in_features=256, out_features=128, bias=True)
  )
  (gamenet): GAMENetLayerNoDM(
    (ehr_gcn): GCN(
      (gcn1): GCNLayer()
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (gcn2): GCNLayer()
    )
    (ddi_gcn): GCN(
      (gcn1): GCNLayer()
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (gcn2): GCNLayer()
    )
    (fc): Linear(in_features=256, out_features=200, bias=True)
    (bce_loss_fn): BCEWithLogitsLoss()
  )
)
Metrics: ['jaccard_samples', 'accuracy', 'hamming_loss', 'precision_samples', 'recall_samples', 'pr_auc_samples', 'f1_samples']
Device: cuda

Training:
Batch size: 64
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight 

Epoch 0 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-0, step-1838 ---
loss: 0.2189
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:01<00:00, 232.25it/s]
--- Eval epoch-0, step-1838 ---
jaccard_samples: 0.4149
accuracy: 0.0013
hamming_loss: 0.0833
precision_samples: 0.7417
recall_samples: 0.4980
pr_auc_samples: 0.7147
f1_samples: 0.5675
loss: 0.2023
New best accuracy score (0.0013) at epoch-0, step-1838



Epoch 1 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-1, step-3676 ---
loss: 0.1961
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 254.04it/s]
--- Eval epoch-1, step-3676 ---
jaccard_samples: 0.4349
accuracy: 0.0017
hamming_loss: 0.0800
precision_samples: 0.7604
recall_samples: 0.5163
pr_auc_samples: 0.7377
f1_samples: 0.5868
loss: 0.1941
New best accuracy score (0.0017) at epoch-1, step-3676



Epoch 2 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-2, step-5514 ---
loss: 0.1900
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 244.41it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-2, step-5514 ---
jaccard_samples: 0.4457
accuracy: 0.0022
hamming_loss: 0.0787
precision_samples: 0.7621
recall_samples: 0.5297
pr_auc_samples: 0.7461
f1_samples: 0.5974
loss: 0.1913
New best accuracy score (0.0022) at epoch-2, step-5514



Epoch 3 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-3, step-7352 ---
loss: 0.1865
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 253.52it/s]
--- Eval epoch-3, step-7352 ---
jaccard_samples: 0.4578
accuracy: 0.0023
hamming_loss: 0.0779
precision_samples: 0.7534
recall_samples: 0.5518
pr_auc_samples: 0.7514
f1_samples: 0.6097
loss: 0.1894
New best accuracy score (0.0023) at epoch-3, step-7352



Epoch 4 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-4, step-9190 ---
loss: 0.1839
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 248.80it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-4, step-9190 ---
jaccard_samples: 0.4571
accuracy: 0.0019
hamming_loss: 0.0774
precision_samples: 0.7640
recall_samples: 0.5441
pr_auc_samples: 0.7547
f1_samples: 0.6084
loss: 0.1887



Epoch 5 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-5, step-11028 ---
loss: 0.1816
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 252.21it/s]
--- Eval epoch-5, step-11028 ---
jaccard_samples: 0.4589
accuracy: 0.0022
hamming_loss: 0.0777
precision_samples: 0.7604
recall_samples: 0.5501
pr_auc_samples: 0.7554
f1_samples: 0.6099
loss: 0.1885



Epoch 6 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-6, step-12866 ---
loss: 0.1788
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 257.72it/s]
--- Eval epoch-6, step-12866 ---
jaccard_samples: 0.4539
accuracy: 0.0015
hamming_loss: 0.0779
precision_samples: 0.7718
recall_samples: 0.5362
pr_auc_samples: 0.7570
f1_samples: 0.6052
loss: 0.1894



Epoch 7 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-7, step-14704 ---
loss: 0.1755
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 252.26it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-7, step-14704 ---
jaccard_samples: 0.4607
accuracy: 0.0024
hamming_loss: 0.0786
precision_samples: 0.7525
recall_samples: 0.5602
pr_auc_samples: 0.7552
f1_samples: 0.6107
loss: 0.1905
New best accuracy score (0.0024) at epoch-7, step-14704



Epoch 8 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-8, step-16542 ---
loss: 0.1708
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 250.70it/s]
--- Eval epoch-8, step-16542 ---
jaccard_samples: 0.4595
accuracy: 0.0017
hamming_loss: 0.0814
precision_samples: 0.7399
recall_samples: 0.5693
pr_auc_samples: 0.7511
f1_samples: 0.6090
loss: 0.1953



Epoch 9 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-9, step-18380 ---
loss: 0.1655
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 253.31it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-9, step-18380 ---
jaccard_samples: 0.4576
accuracy: 0.0024
hamming_loss: 0.0803
precision_samples: 0.7515
recall_samples: 0.5572
pr_auc_samples: 0.7514
f1_samples: 0.6078
loss: 0.1981



Epoch 10 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-10, step-20218 ---
loss: 0.1602
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 250.04it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-10, step-20218 ---
jaccard_samples: 0.4608
accuracy: 0.0021
hamming_loss: 0.0808
precision_samples: 0.7417
recall_samples: 0.5687
pr_auc_samples: 0.7494
f1_samples: 0.6109
loss: 0.2020



Epoch 11 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-11, step-22056 ---
loss: 0.1560
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 246.51it/s]
--- Eval epoch-11, step-22056 ---
jaccard_samples: 0.4582
accuracy: 0.0019
hamming_loss: 0.0826
precision_samples: 0.7383
recall_samples: 0.5687
pr_auc_samples: 0.7455
f1_samples: 0.6079
loss: 0.2085



Epoch 12 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-12, step-23894 ---
loss: 0.1524
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 253.45it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-12, step-23894 ---
jaccard_samples: 0.4502
accuracy: 0.0022
hamming_loss: 0.0811
precision_samples: 0.7584
recall_samples: 0.5443
pr_auc_samples: 0.7464
f1_samples: 0.5998
loss: 0.2137



Epoch 13 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-13, step-25732 ---
loss: 0.1500
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 253.49it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-13, step-25732 ---
jaccard_samples: 0.4529
accuracy: 0.0023
hamming_loss: 0.0820
precision_samples: 0.7474
recall_samples: 0.5557
pr_auc_samples: 0.7444
f1_samples: 0.6026
loss: 0.2156



Epoch 14 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-14, step-27570 ---
loss: 0.1484
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 253.30it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-14, step-27570 ---
jaccard_samples: 0.4567
accuracy: 0.0023
hamming_loss: 0.0826
precision_samples: 0.7380
recall_samples: 0.5670
pr_auc_samples: 0.7433
f1_samples: 0.6066
loss: 0.2198



Epoch 15 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-15, step-29408 ---
loss: 0.1466
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 250.98it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-15, step-29408 ---
jaccard_samples: 0.4569
accuracy: 0.0021
hamming_loss: 0.0822
precision_samples: 0.7406
recall_samples: 0.5655
pr_auc_samples: 0.7433
f1_samples: 0.6068
loss: 0.2210



Epoch 16 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-16, step-31246 ---
loss: 0.1455
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 251.88it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-16, step-31246 ---
jaccard_samples: 0.4549
accuracy: 0.0023
hamming_loss: 0.0814
precision_samples: 0.7480
recall_samples: 0.5572
pr_auc_samples: 0.7440
f1_samples: 0.6050
loss: 0.2230



Epoch 17 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-17, step-33084 ---
loss: 0.1448
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 249.50it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-17, step-33084 ---
jaccard_samples: 0.4523
accuracy: 0.0025
hamming_loss: 0.0819
precision_samples: 0.7523
recall_samples: 0.5523
pr_auc_samples: 0.7426
f1_samples: 0.6020
loss: 0.2236
New best accuracy score (0.0025) at epoch-17, step-33084



Epoch 18 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-18, step-34922 ---
loss: 0.1440
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 256.04it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-18, step-34922 ---
jaccard_samples: 0.4559
accuracy: 0.0024
hamming_loss: 0.0847
precision_samples: 0.7281
recall_samples: 0.5747
pr_auc_samples: 0.7388
f1_samples: 0.6053
loss: 0.2251



Epoch 19 / 20:   0%|          | 0/1838 [00:00<?, ?it/s]

--- Train epoch-19, step-36760 ---
loss: 0.1431
Evaluation: 100%|████████████████████████████████████████████████████████████| 234/234 [00:00<00:00, 252.19it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-19, step-36760 ---
jaccard_samples: 0.4559
accuracy: 0.0021
hamming_loss: 0.0838
precision_samples: 0.7351
recall_samples: 0.5692
pr_auc_samples: 0.7400
f1_samples: 0.6052
loss: 0.2247
Loaded best model


--training gamenet on no_proc data--
making gamenet model without procedures...


GAMENetNoProc(
  (embeddings): ModuleDict(
    (conditions): Embedding(22643, 128, padding_idx=0)
  )
  (cond_rnn): GRU(128, 128, batch_first=True)
  (query): Sequential(
    (0): ReLU()
    (1): Linear(in_features=128, out_features=128, bias=True)
  )
  (gamenet): GAMENetLayer(
    (ehr_gcn): GCN(
      (gcn1): GCNLayer()
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (gcn2): GCNLayer()
    )
    (ddi_gcn): GCN(
      (gcn1): GCNLayer()
      (dropout_layer): Dropout(p=0.5, inplace=False)
      (gcn2): GCNLayer()
    )
    (fc): Linear(in_features=384, out_features=201, bias=True)
    (bce_loss_fn): BCEWithLogitsLoss()
  )
)
Metrics: ['jaccard_samples', 'accuracy', 'hamming_loss', 'precision_samples', 'recall_samples', 'pr_auc_samples', 'f1_samples']
Device: cuda

Training:
Batch size: 64
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight decay: 1e-05
Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7f96

Epoch 0 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-0, step-3437 ---
loss: 0.2143
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:07<00:00, 60.90it/s]
--- Eval epoch-0, step-3437 ---
jaccard_samples: 0.3791
accuracy: 0.0005
hamming_loss: 0.0811
precision_samples: 0.6974
recall_samples: 0.4749
pr_auc_samples: 0.6706
f1_samples: 0.5333
loss: 0.2006
New best accuracy score (0.0005) at epoch-0, step-3437



Epoch 1 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-1, step-6874 ---
loss: 0.1919
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:07<00:00, 60.40it/s]
--- Eval epoch-1, step-6874 ---
jaccard_samples: 0.3977
accuracy: 0.0005
hamming_loss: 0.0780
precision_samples: 0.7236
recall_samples: 0.4890
pr_auc_samples: 0.6969
f1_samples: 0.5515
loss: 0.1922
New best accuracy score (0.0005) at epoch-1, step-6874



Epoch 2 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-2, step-10311 ---
loss: 0.1862
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:07<00:00, 62.10it/s]
--- Eval epoch-2, step-10311 ---
jaccard_samples: 0.4051
accuracy: 0.0011
hamming_loss: 0.0768
precision_samples: 0.7309
recall_samples: 0.4949
pr_auc_samples: 0.7057
f1_samples: 0.5587
loss: 0.1893
New best accuracy score (0.0011) at epoch-2, step-10311



Epoch 3 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-3, step-13748 ---
loss: 0.1835
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:06<00:00, 62.69it/s]
--- Eval epoch-3, step-13748 ---
jaccard_samples: 0.4131
accuracy: 0.0011
hamming_loss: 0.0764
precision_samples: 0.7263
recall_samples: 0.5094
pr_auc_samples: 0.7110
f1_samples: 0.5672
loss: 0.1878
New best accuracy score (0.0011) at epoch-3, step-13748



Epoch 4 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-4, step-17185 ---
loss: 0.1817
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:07<00:00, 61.57it/s]
--- Eval epoch-4, step-17185 ---
jaccard_samples: 0.4143
accuracy: 0.0009
hamming_loss: 0.0759
precision_samples: 0.7332
recall_samples: 0.5074
pr_auc_samples: 0.7147
f1_samples: 0.5678
loss: 0.1867



Epoch 5 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-5, step-20622 ---
loss: 0.1803
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:07<00:00, 61.54it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-5, step-20622 ---
jaccard_samples: 0.4162
accuracy: 0.0010
hamming_loss: 0.0757
precision_samples: 0.7338
recall_samples: 0.5094
pr_auc_samples: 0.7153
f1_samples: 0.5696
loss: 0.1865



Epoch 6 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-6, step-24059 ---
loss: 0.1791
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:07<00:00, 61.66it/s]
--- Eval epoch-6, step-24059 ---
jaccard_samples: 0.4144
accuracy: 0.0011
hamming_loss: 0.0755
precision_samples: 0.7406
recall_samples: 0.5036
pr_auc_samples: 0.7167
f1_samples: 0.5677
loss: 0.1859



Epoch 7 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-7, step-27496 ---
loss: 0.1779
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:07<00:00, 62.06it/s]
--- Eval epoch-7, step-27496 ---
jaccard_samples: 0.4066
accuracy: 0.0010
hamming_loss: 0.0755
precision_samples: 0.7554
recall_samples: 0.4848
pr_auc_samples: 0.7189
f1_samples: 0.5591
loss: 0.1858



Epoch 8 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-8, step-30933 ---
loss: 0.1768
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:07<00:00, 61.58it/s]
--- Eval epoch-8, step-30933 ---
jaccard_samples: 0.4200
accuracy: 0.0013
hamming_loss: 0.0758
precision_samples: 0.7323
recall_samples: 0.5171
pr_auc_samples: 0.7197
f1_samples: 0.5728
loss: 0.1853
New best accuracy score (0.0013) at epoch-8, step-30933



Epoch 9 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-9, step-34370 ---
loss: 0.1757
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:07<00:00, 61.75it/s]
--- Eval epoch-9, step-34370 ---
jaccard_samples: 0.4204
accuracy: 0.0014
hamming_loss: 0.0752
precision_samples: 0.7391
recall_samples: 0.5123
pr_auc_samples: 0.7212
f1_samples: 0.5738
loss: 0.1850
New best accuracy score (0.0014) at epoch-9, step-34370



Epoch 10 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-10, step-37807 ---
loss: 0.1746
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:07<00:00, 62.08it/s]
--- Eval epoch-10, step-37807 ---
jaccard_samples: 0.4165
accuracy: 0.0014
hamming_loss: 0.0750
precision_samples: 0.7490
recall_samples: 0.5019
pr_auc_samples: 0.7226
f1_samples: 0.5696
loss: 0.1853
New best accuracy score (0.0014) at epoch-10, step-37807



Epoch 11 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-11, step-41244 ---
loss: 0.1734
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:07<00:00, 61.76it/s]
--- Eval epoch-11, step-41244 ---
jaccard_samples: 0.4156
accuracy: 0.0014
hamming_loss: 0.0752
precision_samples: 0.7465
recall_samples: 0.5008
pr_auc_samples: 0.7218
f1_samples: 0.5685
loss: 0.1874



Epoch 12 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-12, step-44681 ---
loss: 0.1721
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:07<00:00, 61.32it/s]
--- Eval epoch-12, step-44681 ---
jaccard_samples: 0.4217
accuracy: 0.0017
hamming_loss: 0.0755
precision_samples: 0.7355
recall_samples: 0.5164
pr_auc_samples: 0.7218
f1_samples: 0.5744
loss: 0.1865
New best accuracy score (0.0017) at epoch-12, step-44681



Epoch 13 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-13, step-48118 ---
loss: 0.1709
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:07<00:00, 61.46it/s]
--- Eval epoch-13, step-48118 ---
jaccard_samples: 0.4187
accuracy: 0.0011
hamming_loss: 0.0754
precision_samples: 0.7426
recall_samples: 0.5077
pr_auc_samples: 0.7219
f1_samples: 0.5722
loss: 0.1867



Epoch 14 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-14, step-51555 ---
loss: 0.1695
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:07<00:00, 61.80it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-14, step-51555 ---
jaccard_samples: 0.4216
accuracy: 0.0011
hamming_loss: 0.0754
precision_samples: 0.7392
recall_samples: 0.5136
pr_auc_samples: 0.7219
f1_samples: 0.5748
loss: 0.1882



Epoch 15 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-15, step-54992 ---
loss: 0.1679
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:07<00:00, 62.07it/s]
--- Eval epoch-15, step-54992 ---
jaccard_samples: 0.4236
accuracy: 0.0008
hamming_loss: 0.0766
precision_samples: 0.7267
recall_samples: 0.5255
pr_auc_samples: 0.7197
f1_samples: 0.5764
loss: 0.1886



Epoch 16 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-16, step-58429 ---
loss: 0.1663
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:07<00:00, 61.98it/s]
--- Eval epoch-16, step-58429 ---
jaccard_samples: 0.4231
accuracy: 0.0016
hamming_loss: 0.0765
precision_samples: 0.7297
recall_samples: 0.5230
pr_auc_samples: 0.7203
f1_samples: 0.5754
loss: 0.1897



Epoch 17 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-17, step-61866 ---
loss: 0.1647
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:07<00:00, 61.61it/s]
--- Eval epoch-17, step-61866 ---
jaccard_samples: 0.4173
accuracy: 0.0011
hamming_loss: 0.0774
precision_samples: 0.7345
recall_samples: 0.5137
pr_auc_samples: 0.7184
f1_samples: 0.5691
loss: 0.1912



Epoch 18 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-18, step-65303 ---
loss: 0.1634
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:06<00:00, 62.66it/s]
  _warn_prf(average, modifier, msg_start, len(result))
--- Eval epoch-18, step-65303 ---
jaccard_samples: 0.4116
accuracy: 0.0008
hamming_loss: 0.0768
precision_samples: 0.7472
recall_samples: 0.4981
pr_auc_samples: 0.7196
f1_samples: 0.5634
loss: 0.1918



Epoch 19 / 20:   0%|          | 0/3437 [00:00<?, ?it/s]

--- Train epoch-19, step-68740 ---
loss: 0.1623
Evaluation: 100%|█████████████████████████████████████████████████████████████| 435/435 [00:07<00:00, 61.64it/s]
--- Eval epoch-19, step-68740 ---
jaccard_samples: 0.4150
accuracy: 0.0011
hamming_loss: 0.0779
precision_samples: 0.7374
recall_samples: 0.5089
pr_auc_samples: 0.7167
f1_samples: 0.5664
loss: 0.1935
Loaded best model


## Evaluate the Models

In [9]:
baseline_result = {}
gamenet_result = {}

In [10]:
# baseline
print("---RETAIN EVALUATION---")
for taskname in mimic.get_task_names():
    print("--eval retain on {} data--".format(taskname))
    test_loader = dataloaders[taskname]["test"]
    baseline_result[taskname] = {}
    baseline_result[taskname][SCORE_KEY] = retain[taskname].evaluate_model(test_loader)
    baseline_result[taskname][SCORE_KEY][DPV_KEY] = retain[taskname].calc_avg_drugs_per_visit(test_loader)
    baseline_result[taskname][SCORE_KEY][DDI_RATE_KEY] = retain[taskname].calc_ddi_rate(
        test_loader, ddi_mats[taskname]
    )
    
# gamenet
print("---GAMENET EVALUATION---")
for taskname in mimic.get_task_names():
    print("--eval gamenet on {} data--".format(taskname))
    test_loader = dataloaders[taskname]["test"]
    gamenet_result[taskname] = {}
    gamenet_result[taskname][SCORE_KEY] = gamenet[taskname].evaluate_model(test_loader)
    gamenet_result[taskname][SCORE_KEY][DPV_KEY] = gamenet[taskname].calc_avg_drugs_per_visit(test_loader)
    gamenet_result[taskname][SCORE_KEY][DDI_RATE_KEY] = gamenet[taskname].calc_ddi_rate(
        test_loader, ddi_mats[taskname]
    )

---RETAIN EVALUATION---
--eval retain on drug_recommendation data--


Evaluation: 100%|█████████████████████████████████████████████████████████████| 230/230 [00:04<00:00, 52.35it/s]


{'jaccard_samples': 0.4540459546771232, 'accuracy': 0.0020420665713702267, 'hamming_loss': 0.07847287454904363, 'precision_samples': 0.7578135655729245, 'recall_samples': 0.5464850884204991, 'pr_auc_samples': 0.7534060235322836, 'f1_samples': 0.6047574833671272, 'loss': 0.18982936526122302}


Evaluation: 100%|█████████████████████████████████████████████████████████████| 230/230 [00:03<00:00, 57.90it/s]
Evaluation: 100%|█████████████████████████████████████████████████████████████| 230/230 [00:03<00:00, 57.66it/s]


--eval retain on no_hist data--


Evaluation: 100%|████████████████████████████████████████████████████████████| 233/233 [00:02<00:00, 103.44it/s]


{'jaccard_samples': 0.4382674911377477, 'accuracy': 0.002217443891950007, 'hamming_loss': 0.07837320252654213, 'precision_samples': 0.7756755937000301, 'recall_samples': 0.5130361715564632, 'pr_auc_samples': 0.7474079338817261, 'f1_samples': 0.5902018398301893, 'loss': 0.1914405864963204}


Evaluation: 100%|████████████████████████████████████████████████████████████| 233/233 [00:02<00:00, 107.89it/s]
Evaluation: 100%|████████████████████████████████████████████████████████████| 233/233 [00:02<00:00, 111.94it/s]


--eval retain on no_proc data--


Evaluation: 100%|█████████████████████████████████████████████████████████████| 432/432 [00:05<00:00, 78.14it/s]


{'jaccard_samples': 0.3945862331790098, 'accuracy': 0.0008335447396078716, 'hamming_loss': 0.07415591199224406, 'precision_samples': 0.7598288153973151, 'recall_samples': 0.4644770862305861, 'pr_auc_samples': 0.7123848065116613, 'f1_samples': 0.547580627090208, 'loss': 0.18402852725099633}


Evaluation: 100%|█████████████████████████████████████████████████████████████| 432/432 [00:05<00:00, 78.11it/s]
Evaluation: 100%|█████████████████████████████████████████████████████████████| 432/432 [00:05<00:00, 84.34it/s]


---GAMENET EVALUATION---
--eval gamenet on drug_recommendation data--


Evaluation: 100%|█████████████████████████████████████████████████████████████| 230/230 [00:04<00:00, 54.63it/s]


{'jaccard_samples': 0.4521940582122721, 'accuracy': 0.002246273228507249, 'hamming_loss': 0.08143319038867333, 'precision_samples': 0.720762798456624, 'recall_samples': 0.5670332694952991, 'pr_auc_samples': 0.737870056857211, 'f1_samples': 0.6032621188112071, 'loss': 0.19578040170928707}


Evaluation: 100%|█████████████████████████████████████████████████████████████| 230/230 [00:04<00:00, 53.87it/s]
Evaluation: 100%|█████████████████████████████████████████████████████████████| 230/230 [00:04<00:00, 54.14it/s]


--eval gamenet on no_hist data--


Evaluation: 100%|████████████████████████████████████████████████████████████| 233/233 [00:01<00:00, 226.37it/s]
  _warn_prf(average, modifier, msg_start, len(result))


{'jaccard_samples': 0.453919083037768, 'accuracy': 0.002217443891950007, 'hamming_loss': 0.08120279532320925, 'precision_samples': 0.7543395806053412, 'recall_samples': 0.5536175023860068, 'pr_auc_samples': 0.7441593145306628, 'f1_samples': 0.6034781299268797, 'loss': 0.22013351334011094}


Evaluation: 100%|████████████████████████████████████████████████████████████| 233/233 [00:01<00:00, 219.46it/s]
Evaluation: 100%|████████████████████████████████████████████████████████████| 233/233 [00:01<00:00, 230.37it/s]


--eval gamenet on no_proc data--


Evaluation: 100%|█████████████████████████████████████████████████████████████| 432/432 [00:06<00:00, 62.53it/s]


{'jaccard_samples': 0.4232919544592121, 'accuracy': 0.0013046787228644947, 'hamming_loss': 0.07339358727689425, 'precision_samples': 0.7338074555704448, 'recall_samples': 0.5193401782681243, 'pr_auc_samples': 0.7212806490382262, 'f1_samples': 0.5761553823846068, 'loss': 0.1825957858797025}


Evaluation: 100%|█████████████████████████████████████████████████████████████| 432/432 [00:06<00:00, 62.90it/s]
Evaluation: 100%|█████████████████████████████████████████████████████████████| 432/432 [00:06<00:00, 62.44it/s]


## Display Scores

In [11]:
#results = {RT_KEY: baseline_result, GN_KEY: gamenet_result}
#rtoutfile = open("./results_gamenet.json", 'w')
#gnoutfile = open("./results_retain.json", 'w')
#gnoutfile.write(json.dumps(gamenet_result))
#rtoutfile.write(json.dumps(baseline_result))
#gnoutfile.close()
#rtoutfile.close()

In [12]:
#gn_pd_object = pandas.read_json('./results_gamenet.json', typ='series')
#gn_df = pandas.DataFrame.from_records(gn_pd_object)
#display(gn_df)

Unnamed: 0,scores,avg_dpv,ddi_rate
0,"{'jaccard_samples': 0.127961854468781, 'accura...",70.103448,0.04354
1,"{'jaccard_samples': 0.133885752813687, 'accura...",70.791667,0.045327
2,"{'jaccard_samples': 0.11004927634283, 'accurac...",69.697368,0.037088


In [16]:
metrics_columns = [
    "accuracy", "precision_samples", "recall_samples",
    "pr_auc_samples", "f1_samples",
    "avg_dpv", "ddi_rate"
]


In [17]:
retain_res_df = pandas.DataFrame.from_dict({(t,v): baseline_result[t][v] 
                            for t in baseline_result.keys()
                            for v in baseline_result[t].keys()},
                           orient='index')
retain_metrics = retain_res_df[metrics_columns]


#task_ids = []
#frames =[]
#
#for taskname, res in gamenet_result.items():
#    print(taskname)
#    print(res)
#    task_ids.append(taskname)
#    frames.append(pandas.DataFrame.from_dict(res, orient="index"))
#    
#pandas.concat(frames, keys=task_ids)

In [18]:
gamenet_res_df = pandas.DataFrame.from_dict({(t,v): gamenet_result[t][v] 
                            for t in gamenet_result.keys()
                            for v in gamenet_result[t].keys()},
                           orient='index')
gamenet_metrics = gamenet_res_df[metrics_columns]

In [19]:
display(retain_metrics)

Unnamed: 0,Unnamed: 1,accuracy,precision_samples,recall_samples,pr_auc_samples,f1_samples,avg_dpv,ddi_rate
drug_recommendation,scores,0.002042,0.757814,0.546485,0.753406,0.604757,16.496835,0.062308
no_hist,scores,0.002217,0.775676,0.513036,0.747408,0.590202,14.570555,0.066517
no_proc,scores,0.000834,0.759829,0.464477,0.712385,0.547581,11.44051,0.08561


In [20]:
display(gamenet_metrics)

Unnamed: 0,Unnamed: 1,accuracy,precision_samples,recall_samples,pr_auc_samples,f1_samples,avg_dpv,ddi_rate
drug_recommendation,scores,0.002246,0.720763,0.567033,0.73787,0.603262,18.253829,0.060669
no_hist,scores,0.002217,0.75434,0.553618,0.744159,0.603478,17.461967,0.06069
no_proc,scores,0.001305,0.733807,0.51934,0.721281,0.576155,13.977458,0.072855


In [21]:
retain_metrics.to_latex()

'\\begin{tabular}{llrrrrrrr}\n\\toprule\n &  & accuracy & precision_samples & recall_samples & pr_auc_samples & f1_samples & avg_dpv & ddi_rate \\\\\n\\midrule\ndrug_recommendation & scores & 0.002042 & 0.757814 & 0.546485 & 0.753406 & 0.604757 & 16.496835 & 0.062308 \\\\\n\\cline{1-9}\nno_hist & scores & 0.002217 & 0.775676 & 0.513036 & 0.747408 & 0.590202 & 14.570555 & 0.066517 \\\\\n\\cline{1-9}\nno_proc & scores & 0.000834 & 0.759829 & 0.464477 & 0.712385 & 0.547581 & 11.440510 & 0.085610 \\\\\n\\cline{1-9}\n\\bottomrule\n\\end{tabular}\n'

In [22]:
gamenet_metrics.to_latex()

'\\begin{tabular}{llrrrrrrr}\n\\toprule\n &  & accuracy & precision_samples & recall_samples & pr_auc_samples & f1_samples & avg_dpv & ddi_rate \\\\\n\\midrule\ndrug_recommendation & scores & 0.002246 & 0.720763 & 0.567033 & 0.737870 & 0.603262 & 18.253829 & 0.060669 \\\\\n\\cline{1-9}\nno_hist & scores & 0.002217 & 0.754340 & 0.553618 & 0.744159 & 0.603478 & 17.461967 & 0.060690 \\\\\n\\cline{1-9}\nno_proc & scores & 0.001305 & 0.733807 & 0.519340 & 0.721281 & 0.576155 & 13.977458 & 0.072855 \\\\\n\\cline{1-9}\n\\bottomrule\n\\end{tabular}\n'