## Mortality Training
1. Load the MIMIC III Dataset
2. Normalize Data
3. Load into Pytorch Dataloader
4. Train
5. Evaluate

In [1]:
from comet_ml import Experiment
import numpy as np
import pandas as pd
from sklearn.discriminant_analysis import StandardScaler
from sklearn.impute import SimpleImputer
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from core.model import SAnD
from mimic3_benchmarks.mimic3benchmark.readers import PhenotypingReader as Reader
from mimic3_benchmarks.mimic3models.preprocessing import Discretizer, Normalizer
from utils.functions import get_weighted_sampler, get_weights
from utils.trainer import NeuralNetworkClassifier
from utils.batch_gen import BatchGen
from utils.phenotyping_utils import load_data


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
task = "phenotyping"

## Load Data
Using the [MimicIII Benchmark's](https://github.com/YerevaNN/mimic3-benchmarks) InHospitalMortalityReader and associated functions, load the data into memory.

If you are following along you will need to arrange the data following the directions on the link.


In [2]:
train_reader = Reader(dataset_dir=f"mimic3_benchmarks/data/{task}/train",
                              listfile=f"mimic3_benchmarks/data/{task}/train_listfile.csv")
val_reader = Reader(dataset_dir=f"mimic3_benchmarks/data/{task}/train",
                              listfile=f"mimic3_benchmarks/data/{task}/val_listfile.csv")
test_reader = Reader(dataset_dir=f"mimic3_benchmarks/data/{task}/test",
                              listfile=f"mimic3_benchmarks/data/{task}/test_listfile.csv")

### Example data

In [3]:
ex = pd.DataFrame.from_dict(train_reader.read_example(0)['X'])
ex.columns = train_reader.read_example(0)["header"]
print(train_reader.read_example(0)['y'])
ex

[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


Unnamed: 0,Hours,Capillary refill rate,Diastolic blood pressure,Fraction inspired oxygen,Glascow coma scale eye opening,Glascow coma scale motor response,Glascow coma scale total,Glascow coma scale verbal response,Glucose,Heart Rate,Height,Mean blood pressure,Oxygen saturation,Respiratory rate,Systolic blood pressure,Temperature,Weight,pH
0,0.3052777777777777,,86.0,,1 No Response,5 Localizes Pain,7.0,1.0 ET/Trach,,92.0,,104.0,100.0,14.0,141.0,,,
1,0.4719444444444444,,,,,,,,142.0,,,,100.0,14.0,,,,
2,0.6719444444444445,,,,,,,,,,,,,,,,,7.43
3,1.3052777777777778,,90.0,,,,,,,98.0,,109.0,100.0,14.0,144.0,38.16669845581055,,
4,2.305277777777777,,90.0,,,,,,,99.0,,109.0,100.0,14.0,147.0,,,
5,3.305277777777777,,93.0,,,,,,,106.0,,113.0,100.0,14.0,152.0,39.11111195882162,,
6,4.3052777777777775,,88.0,,,,,,,106.0,,106.0,100.0,14.0,145.0,,,
7,5.3052777777777775,,86.0,,3 To speech,5 Localizes Pain,9.0,1.0 ET/Trach,,103.0,,103.0,100.0,14.0,141.0,39.11111195882162,,
8,5.8052777777777775,,85.0,,,,,,,112.0,,102.0,100.0,16.0,135.0,,,
9,6.3052777777777775,,88.0,,,,,,,105.0,,110.0,100.0,9.0,151.0,,,


### Normalize and Discretize data
From the MIMIC-III Benchmark repo

In [4]:
discretizer = Discretizer(timestep=1.0,
                          store_masks=True,
                          impute_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

In [5]:
normalizer = Normalizer(fields=cont_channels)  # choose here which columns to standardize
normalizer_state = 'train/ph_ts1.0.input_str-previous.start_time-zero.normalizer'
normalizer.load_params(normalizer_state)

In [6]:
batch_size = 128
max_seq_len = 256
small_part = False

train_raw = load_data(train_reader, discretizer, normalizer, max_seq_len, small_part)
val_raw = load_data(val_reader, discretizer, normalizer, max_seq_len, small_part)
test_raw = load_data(test_reader, discretizer, normalizer, max_seq_len, small_part)

train_ds = TensorDataset(train_raw[0], train_raw[1])
val_ds = TensorDataset(val_raw[0], val_raw[1])
test_ds = TensorDataset(test_raw[0], test_raw[1])

train_loader = DataLoader(train_ds, batch_size=batch_size)#sampler=get_weighted_sampler(y_train))
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=batch_size)

N, seq_len, feature_count = train_raw[0].shape
print(N, seq_len, feature_count)

## Create model, Train, and Evaluate
Train the model, this will export results to Comet ML.
Creating is done in the same step such that a new experiment made each time
Evaluation is done in the same step such that the Comet ML experiment is also deleted.

In [None]:
n_heads = 8
factor = 12 # M
num_class = 25
num_layers = 2 # N
epochs = 30
betas = (0.9, 0.98)
lr = 0.0005
eps = 4e-09
weight_decay = 5e-4

experiment = Experiment(
    api_key="eQ3INeSsFGUYKahSdEtjhry42",
    project_name="general",
    workspace="samdoud"
)

clf = NeuralNetworkClassifier(
    SAnD(feature_count, seq_len, n_heads, factor, num_class, num_layers, dropout_rate=0.3),
    nn.BCEWithLogitsLoss(),
    optim.Adam, optimizer_config={
        "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay},
    experiment=experiment
)

clf.fit(
    {
        "train": train_loader,
        "val": val_loader
    },
    validation=True,
    epochs=epochs,
    verbose=True,
)

clf.evaluate(test_loader)
experiment.end()

COMET INFO: Experiment is live on comet.com https://www.comet.com/samdoud/general/90e532db41844621b578580c3e220a22

[36mTraining[0m - Epochs: 001/030: 100%|██████████| 1000/1000 [00:02<00:00, 645.53it/s]

[32mTrain finished. [0mAccuracy: 0.5275 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 001/030: 100%|██████████| 1000/1000 [00:02<00:00, 368.77it/s]


[31mValidation finished. [0mAccuracy: 0.5832 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 002/030:  90%|████████▉ | 896/1000 [00:00<00:00, 2612.94it/s]

[32mTrain finished. [0mAccuracy: 0.5424 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 002/030: 100%|██████████| 1000/1000 [00:00<00:00, 1555.29it/s]


[31mValidation finished. [0mAccuracy: 0.6001 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 003/030:  90%|████████▉ | 896/1000 [00:00<00:00, 2661.85it/s]

[32mTrain finished. [0mAccuracy: 0.5792 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 003/030: 100%|██████████| 1000/1000 [00:00<00:00, 1548.37it/s]


[31mValidation finished. [0mAccuracy: 0.5910 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 004/030:  90%|████████▉ | 896/1000 [00:00<00:00, 2679.53it/s]

[32mTrain finished. [0mAccuracy: 0.5763 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 004/030: 100%|██████████| 1000/1000 [00:00<00:00, 1566.76it/s]


[31mValidation finished. [0mAccuracy: 0.5764 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 005/030:  90%|████████▉ | 896/1000 [00:00<00:00, 2705.46it/s]

[32mTrain finished. [0mAccuracy: 0.5690 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 005/030: 100%|██████████| 1000/1000 [00:00<00:00, 1579.67it/s]


[31mValidation finished. [0mAccuracy: 0.6174 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 006/030:  90%|████████▉ | 896/1000 [00:00<00:00, 2694.57it/s]

[32mTrain finished. [0mAccuracy: 0.5903 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 006/030: 100%|██████████| 1000/1000 [00:00<00:00, 1571.49it/s]


[31mValidation finished. [0mAccuracy: 0.6148 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 007/030: 100%|██████████| 1000/1000 [00:00<00:00, 1579.08it/s]

[32mTrain finished. [0mAccuracy: 0.5997 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000
[31mValidation finished. [0mAccuracy: 0.5917 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000



[36mTraining[0m - Epochs: 008/030: 100%|██████████| 1000/1000 [00:00<00:00, 1558.72it/s]

[32mTrain finished. [0mAccuracy: 0.5924 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000
[31mValidation finished. [0mAccuracy: 0.6179 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000



[36mTraining[0m - Epochs: 009/030:  90%|████████▉ | 896/1000 [00:00<00:00, 2712.27it/s]

[32mTrain finished. [0mAccuracy: 0.6218 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 009/030: 100%|██████████| 1000/1000 [00:00<00:00, 1552.83it/s]


[31mValidation finished. [0mAccuracy: 0.6283 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 010/030:  90%|████████▉ | 896/1000 [00:00<00:00, 2646.94it/s]

[32mTrain finished. [0mAccuracy: 0.6305 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 010/030: 100%|██████████| 1000/1000 [00:00<00:00, 1559.30it/s]


[31mValidation finished. [0mAccuracy: 0.6220 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 011/030: 100%|██████████| 1000/1000 [00:00<00:00, 1576.03it/s]

[32mTrain finished. [0mAccuracy: 0.6079 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000
[31mValidation finished. [0mAccuracy: 0.6268 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000



[36mTraining[0m - Epochs: 012/030:  90%|████████▉ | 896/1000 [00:00<00:00, 2651.12it/s]

[32mTrain finished. [0mAccuracy: 0.6341 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 012/030: 100%|██████████| 1000/1000 [00:00<00:00, 1561.48it/s]


[31mValidation finished. [0mAccuracy: 0.6318 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 013/030: 100%|██████████| 1000/1000 [00:00<00:00, 1574.93it/s]

[32mTrain finished. [0mAccuracy: 0.6284 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000
[31mValidation finished. [0mAccuracy: 0.6186 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000



[36mTraining[0m - Epochs: 014/030:  90%|████████▉ | 896/1000 [00:00<00:00, 2730.00it/s]

[32mTrain finished. [0mAccuracy: 0.6109 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 014/030: 100%|██████████| 1000/1000 [00:00<00:00, 1595.57it/s]


[31mValidation finished. [0mAccuracy: 0.6291 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 015/030:  90%|████████▉ | 896/1000 [00:00<00:00, 2705.83it/s]

[32mTrain finished. [0mAccuracy: 0.6405 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 015/030: 100%|██████████| 1000/1000 [00:00<00:00, 1584.04it/s]


[31mValidation finished. [0mAccuracy: 0.6299 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 016/030:  90%|████████▉ | 896/1000 [00:00<00:00, 2694.78it/s]

[32mTrain finished. [0mAccuracy: 0.6432 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 016/030: 100%|██████████| 1000/1000 [00:00<00:00, 1587.57it/s]


[31mValidation finished. [0mAccuracy: 0.6406 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 017/030:  90%|████████▉ | 896/1000 [00:00<00:00, 2689.63it/s]

[32mTrain finished. [0mAccuracy: 0.6493 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 017/030: 100%|██████████| 1000/1000 [00:00<00:00, 1569.29it/s]


[31mValidation finished. [0mAccuracy: 0.6334 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 018/030:  90%|████████▉ | 896/1000 [00:00<00:00, 2733.10it/s]

[32mTrain finished. [0mAccuracy: 0.6389 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 018/030: 100%|██████████| 1000/1000 [00:00<00:00, 1605.61it/s]


[31mValidation finished. [0mAccuracy: 0.6472 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 019/030:  90%|████████▉ | 896/1000 [00:00<00:00, 2722.28it/s]

[32mTrain finished. [0mAccuracy: 0.6513 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 019/030: 100%|██████████| 1000/1000 [00:00<00:00, 1587.38it/s]


[31mValidation finished. [0mAccuracy: 0.6554 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 020/030:  90%|████████▉ | 896/1000 [00:00<00:00, 2657.38it/s]

[32mTrain finished. [0mAccuracy: 0.6465 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 020/030: 100%|██████████| 1000/1000 [00:00<00:00, 1545.37it/s]


[31mValidation finished. [0mAccuracy: 0.6496 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 021/030: 100%|██████████| 1000/1000 [00:00<00:00, 2223.46it/s]

[32mTrain finished. [0mAccuracy: 0.6468 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 021/030: 100%|██████████| 1000/1000 [00:00<00:00, 1416.22it/s]


[31mValidation finished. [0mAccuracy: 0.6472 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 022/030:  90%|████████▉ | 896/1000 [00:00<00:00, 2683.81it/s]

[32mTrain finished. [0mAccuracy: 0.6705 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 022/030: 100%|██████████| 1000/1000 [00:00<00:00, 1583.86it/s]


[31mValidation finished. [0mAccuracy: 0.6397 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 023/030:  90%|████████▉ | 896/1000 [00:00<00:00, 2689.68it/s]

[32mTrain finished. [0mAccuracy: 0.6829 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 023/030: 100%|██████████| 1000/1000 [00:00<00:00, 1571.24it/s]


[31mValidation finished. [0mAccuracy: 0.6458 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 024/030: 100%|██████████| 1000/1000 [00:00<00:00, 1588.92it/s]

[32mTrain finished. [0mAccuracy: 0.6862 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000
[31mValidation finished. [0mAccuracy: 0.6441 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000



[36mTraining[0m - Epochs: 025/030: 100%|██████████| 1000/1000 [00:00<00:00, 1597.57it/s]

[32mTrain finished. [0mAccuracy: 0.6776 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000
[31mValidation finished. [0mAccuracy: 0.6541 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000



[36mTraining[0m - Epochs: 026/030: 100%|██████████| 1000/1000 [00:00<00:00, 1588.53it/s]

[32mTrain finished. [0mAccuracy: 0.6910 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000
[31mValidation finished. [0mAccuracy: 0.6384 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000



[36mTraining[0m - Epochs: 027/030: 100%|██████████| 1000/1000 [00:00<00:00, 1592.55it/s]

[32mTrain finished. [0mAccuracy: 0.7055 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000
[31mValidation finished. [0mAccuracy: 0.6394 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000



[36mTraining[0m - Epochs: 028/030: 100%|██████████| 1000/1000 [00:00<00:00, 1563.07it/s]

[32mTrain finished. [0mAccuracy: 0.6967 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000
[31mValidation finished. [0mAccuracy: 0.6380 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000



[36mTraining[0m - Epochs: 029/030:  90%|████████▉ | 896/1000 [00:00<00:00, 2740.68it/s]

[32mTrain finished. [0mAccuracy: 0.7082 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000
[31mValidation finished. [0mAccuracy: 0.6460 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 029/030: 100%|██████████| 1000/1000 [00:00<00:00, 1602.02it/s]
[36mTraining[0m - Epochs: 030/030: 100%|██████████| 1000/1000 [00:00<00:00, 1545.18it/s]

[32mTrain finished. [0mAccuracy: 0.6803 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000
[31mValidation finished. [0mAccuracy: 0.6445 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000



[32mEvaluating[0m: 100%|██████████| 1000/1000 [00:00<00:00, 5074.37it/s]
COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.com/samdoud/general/90e532db41844621b578580c3e220a22
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     test_loss [8]          : (3.0554118156433105, 24.504605054855347)
COMET INFO:     train_accuracy [30]    : (0.5274913398166459, 0.7082459128123378)
COMET INFO:     train_loss [264]       : (1.197020411491394, 44.14838790893555)
COMET INFO:     validate_accuracy [30] : (0.5763724518267537, 0.6554216928311779)
COMET INFO:     validate_loss [240]    : (1.5009020566940308, 24.989328384399414)
COMET INFO:   Parameters:
COMET INFO:     batch_size    : 128
COMET INFO:     betas         : (0.9, 0.98)
COMET INFO:     epochs        : 30
COMET INFO:     eps           : 4e-09
CO

error on test
[35mEvaluation finished. [0mAccuracy: 0.6445 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


COMET INFO: Uploading metrics, params, and assets to Comet before program termination (may take several seconds)
COMET INFO: The Python SDK has 3600 seconds to finish before aborting...
