## Mortality Training
1. Load the MIMIC III Dataset
2. Normalize Data
3. Load into Pytorch Dataloader
4. Train
5. Evaluate

In [3]:
from comet_ml import Experiment
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from core.model import SAnD
from mimic3_benchmarks.mimic3benchmark.readers import InHospitalMortalityReader as Reader
from mimic3_benchmarks.mimic3models.preprocessing import Discretizer, Normalizer
from utils.ihm_utils import load_data
from utils.functions import get_weighted_sampler, get_weights
from utils.trainer import NeuralNetworkClassifier

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
task = "in-hospital-mortality"

## Load Data
Using the [MimicIII Benchmark's](https://github.com/YerevaNN/mimic3-benchmarks) InHospitalMortalityReader and associated functions, load the data into memory.

If you are following along you will need to arrange the data following the directions on the link.


In [4]:
train_reader = Reader(dataset_dir=f"mimic3_benchmarks/data/{task}/train",
                              listfile=f"mimic3_benchmarks/data/{task}/train_listfile.csv")
val_reader = Reader(dataset_dir=f"mimic3_benchmarks/data/{task}/train",
                              listfile=f"mimic3_benchmarks/data/{task}/val_listfile.csv")
test_reader = Reader(dataset_dir=f"mimic3_benchmarks/data/{task}/test",
                              listfile=f"mimic3_benchmarks/data/{task}/test_listfile.csv")

### Example data

In [13]:
ex = pd.DataFrame.from_dict(train_reader.read_example(0)['X'])
ex.columns = train_reader.read_example(0)["header"]
ex

Unnamed: 0,Hours,Capillary refill rate,Diastolic blood pressure,Fraction inspired oxygen,Glascow coma scale eye opening,Glascow coma scale motor response,Glascow coma scale total,Glascow coma scale verbal response,Glucose,Heart Rate,Height,Mean blood pressure,Oxygen saturation,Respiratory rate,Systolic blood pressure,Temperature,Weight,pH
0,0.2138888888888889,,71.0,,,,,,,94,,95,92.0,16,141.0,,,
1,0.46388888888888885,,71.0,,,,,,,91,,95,93.0,16,141.0,,,
2,0.7138888888888889,,,,,,,,,89,,,93.0,19,,,,
3,0.9638888888888889,,,,,,,,,89,,,94.0,13,,,,
4,1.1305555555555555,,,,,,,,,,,,,,,,,7.51
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
66,44.21388888888889,,58.0,,,,,,,103,,81,99.0,21,113.0,,,
67,45.21388888888889,,59.0,,4 Spontaneously,6 Obeys Commands,15,5 Oriented,,107,,81,97.0,22,116.0,,,
68,46.21388888888889,,61.0,,,,,,,100,,85,97.0,20,120.0,,,
69,47.21388888888889,,58.0,,,,,,,90,,83,100.0,18,123.0,37.0,,


### Normalize and Discretize data
From the MIMIC-III Benchmark repo

In [5]:
discretizer = Discretizer(timestep=1.0,
                          store_masks=True,
                          impute_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

In [6]:
normalizer = Normalizer(fields=cont_channels)  # choose here which columns to standardize
normalizer_state = 'train/ihm_ts1.0.input_str-previous.start_time-zero.normalizer'
normalizer.load_params(normalizer_state)

In [None]:
train_raw = load_data(train_reader, discretizer, normalizer)
val_raw = load_data(val_reader, discretizer, normalizer)
test_raw = load_data(val_reader, discretizer, normalizer)

N, seq_len, feature_count = train_raw[0].shape

In [7]:
batch_size = 256

train_ds = TensorDataset(train_raw[0], train_raw[1])
val_ds = TensorDataset(val_raw[0], val_raw[1])
test_ds = TensorDataset(test_raw[0], test_raw[1])

train_loader = DataLoader(train_ds, batch_size=batch_size)#sampler=get_weighted_sampler(y_train))
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=batch_size)

## Create model, Train, and Evaluate
Train the model, this will export results to Comet ML.
Creating is done in the same step such that a new experiment made each time
Evaluation is done in the same step such that the Comet ML experiment is also deleted.

In [8]:
n_heads = 8
factor = 12 # M
num_class = 2
num_layers = 2 # N
epochs = 30
betas = (0.9, 0.98)
lr = 0.0005
eps = 4e-09
weight_decay = 5e-4

experiment = Experiment(
    api_key="eQ3INeSsFGUYKahSdEtjhry42",
    project_name="general",
    workspace="samdoud"
)

clf = NeuralNetworkClassifier(
    SAnD(feature_count, seq_len, n_heads, factor, num_class, num_layers, dropout_rate=0.3),
    nn.CrossEntropyLoss(weight=torch.tensor(get_weights(train_raw[1], level=1), dtype=torch.float32).to(device=device)),
    optim.Adam, optimizer_config={
        "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay},
    experiment=experiment
)

clf.fit(
    {
        "train": train_loader,
        "val": val_loader
    },
    validation=True,
    epochs=epochs,
    verbose=True,
)

clf.evaluate(test_loader)
experiment.end()

COMET INFO: Experiment is live on comet.com https://www.comet.com/samdoud/general/975b737db72048e9a0730bae7fced099

[36mTraining[0m - Epochs: 001/030:  99%|█████████▉| 14592/14681 [00:04<00:00, 5969.89it/s]

[32mTrain finished. [0mAccuracy: 0.6446 MSE: 0.3554 AUROC: 0.6383 AUPRC: 0.4490


[36mTraining[0m - Epochs: 001/030: 100%|██████████| 14681/14681 [00:04<00:00, 3252.31it/s]


[31mValidation finished. [0mAccuracy: 0.4789 MSE: 0.1144 AUROC: 0.6697 AUPRC: 0.5690


[36mTraining[0m - Epochs: 002/030:  99%|█████████▉| 14592/14681 [00:02<00:00, 5918.43it/s]

[32mTrain finished. [0mAccuracy: 0.6971 MSE: 0.3029 AUROC: 0.6928 AUPRC: 0.4962


[36mTraining[0m - Epochs: 002/030: 100%|██████████| 14681/14681 [00:02<00:00, 6074.85it/s]


[31mValidation finished. [0mAccuracy: 0.4053 MSE: 0.1305 AUROC: 0.6358 AUPRC: 0.5690


[36mTraining[0m - Epochs: 003/030:  99%|█████████▉| 14592/14681 [00:02<00:00, 5646.81it/s]

[32mTrain finished. [0mAccuracy: 0.6934 MSE: 0.3066 AUROC: 0.6935 AUPRC: 0.4982


[36mTraining[0m - Epochs: 003/030: 100%|██████████| 14681/14681 [00:02<00:00, 5761.14it/s]


[31mValidation finished. [0mAccuracy: 0.5869 MSE: 0.0907 AUROC: 0.7089 AUPRC: 0.5616


[36mTraining[0m - Epochs: 004/030: 100%|██████████| 14681/14681 [00:02<00:00, 5267.14it/s]

[32mTrain finished. [0mAccuracy: 0.7049 MSE: 0.2951 AUROC: 0.7073 AUPRC: 0.5114


[36mTraining[0m - Epochs: 004/030: 100%|██████████| 14681/14681 [00:02<00:00, 5648.72it/s]


[31mValidation finished. [0mAccuracy: 0.8358 MSE: 0.0360 AUROC: 0.7164 AUPRC: 0.5162


[36mTraining[0m - Epochs: 005/030: 100%|██████████| 14681/14681 [00:02<00:00, 5289.62it/s]

[32mTrain finished. [0mAccuracy: 0.7143 MSE: 0.2857 AUROC: 0.7096 AUPRC: 0.5113


[36mTraining[0m - Epochs: 005/030: 100%|██████████| 14681/14681 [00:02<00:00, 5806.17it/s]


[31mValidation finished. [0mAccuracy: 0.6148 MSE: 0.0845 AUROC: 0.7241 AUPRC: 0.5670


[36mTraining[0m - Epochs: 006/030: 100%|██████████| 14681/14681 [00:02<00:00, 5152.20it/s]

[32mTrain finished. [0mAccuracy: 0.7372 MSE: 0.2628 AUROC: 0.7336 AUPRC: 0.5346


[36mTraining[0m - Epochs: 006/030: 100%|██████████| 14681/14681 [00:02<00:00, 5686.63it/s]


[31mValidation finished. [0mAccuracy: 0.7781 MSE: 0.0487 AUROC: 0.7392 AUPRC: 0.5346


[36mTraining[0m - Epochs: 007/030:  99%|█████████▉| 14592/14681 [00:02<00:00, 5920.88it/s]

[32mTrain finished. [0mAccuracy: 0.7390 MSE: 0.2610 AUROC: 0.7366 AUPRC: 0.5378


[36mTraining[0m - Epochs: 007/030: 100%|██████████| 14681/14681 [00:02<00:00, 5875.33it/s]


[31mValidation finished. [0mAccuracy: 0.7120 MSE: 0.0632 AUROC: 0.7532 AUPRC: 0.5650


[36mTraining[0m - Epochs: 008/030: 100%|██████████| 14681/14681 [00:02<00:00, 5478.01it/s]

[32mTrain finished. [0mAccuracy: 0.7491 MSE: 0.2509 AUROC: 0.7492 AUPRC: 0.5509


[36mTraining[0m - Epochs: 008/030: 100%|██████████| 14681/14681 [00:02<00:00, 5749.51it/s]


[31mValidation finished. [0mAccuracy: 0.7300 MSE: 0.0593 AUROC: 0.7568 AUPRC: 0.5644


[36mTraining[0m - Epochs: 009/030: 100%|██████████| 14681/14681 [00:02<00:00, 5182.58it/s]

[32mTrain finished. [0mAccuracy: 0.7523 MSE: 0.2477 AUROC: 0.7473 AUPRC: 0.5479


[36mTraining[0m - Epochs: 009/030: 100%|██████████| 14681/14681 [00:02<00:00, 5638.32it/s]


[31mValidation finished. [0mAccuracy: 0.4516 MSE: 0.1204 AUROC: 0.6635 AUPRC: 0.5763


[36mTraining[0m - Epochs: 010/030:  99%|█████████▉| 14592/14681 [00:02<00:00, 5410.01it/s]

[32mTrain finished. [0mAccuracy: 0.7740 MSE: 0.2260 AUROC: 0.7746 AUPRC: 0.5776


[36mTraining[0m - Epochs: 010/030: 100%|██████████| 14681/14681 [00:02<00:00, 5502.25it/s]


[31mValidation finished. [0mAccuracy: 0.4640 MSE: 0.1176 AUROC: 0.6668 AUPRC: 0.5736


[36mTraining[0m - Epochs: 011/030:  99%|█████████▉| 14592/14681 [00:02<00:00, 5630.19it/s]

[32mTrain finished. [0mAccuracy: 0.7836 MSE: 0.2164 AUROC: 0.7849 AUPRC: 0.5888


[36mTraining[0m - Epochs: 011/030: 100%|██████████| 14681/14681 [00:02<00:00, 5603.52it/s]


[31mValidation finished. [0mAccuracy: 0.4140 MSE: 0.1286 AUROC: 0.6350 AUPRC: 0.5633


[36mTraining[0m - Epochs: 012/030:  99%|█████████▉| 14592/14681 [00:02<00:00, 5841.47it/s]

[32mTrain finished. [0mAccuracy: 0.7918 MSE: 0.2082 AUROC: 0.7958 AUPRC: 0.6012


[36mTraining[0m - Epochs: 012/030: 100%|██████████| 14681/14681 [00:02<00:00, 5905.43it/s]


[31mValidation finished. [0mAccuracy: 0.5034 MSE: 0.1090 AUROC: 0.6674 AUPRC: 0.5535


[36mTraining[0m - Epochs: 013/030: 100%|██████████| 14681/14681 [00:02<00:00, 5329.86it/s]

[32mTrain finished. [0mAccuracy: 0.8090 MSE: 0.1910 AUROC: 0.8148 AUPRC: 0.6234


[36mTraining[0m - Epochs: 013/030: 100%|██████████| 14681/14681 [00:02<00:00, 5828.10it/s]


[31mValidation finished. [0mAccuracy: 0.5475 MSE: 0.0993 AUROC: 0.6793 AUPRC: 0.5453


[36mTraining[0m - Epochs: 014/030:  99%|█████████▉| 14592/14681 [00:02<00:00, 5711.42it/s]

[32mTrain finished. [0mAccuracy: 0.8177 MSE: 0.1823 AUROC: 0.8226 AUPRC: 0.6330


[36mTraining[0m - Epochs: 014/030: 100%|██████████| 14681/14681 [00:02<00:00, 5673.47it/s]


[31mValidation finished. [0mAccuracy: 0.4513 MSE: 0.1204 AUROC: 0.6469 AUPRC: 0.5570


[36mTraining[0m - Epochs: 015/030: 100%|██████████| 14681/14681 [00:02<00:00, 5442.94it/s]

[32mTrain finished. [0mAccuracy: 0.7887 MSE: 0.2113 AUROC: 0.7897 AUPRC: 0.5943


[36mTraining[0m - Epochs: 015/030: 100%|██████████| 14681/14681 [00:02<00:00, 5837.66it/s]


[31mValidation finished. [0mAccuracy: 0.6893 MSE: 0.0682 AUROC: 0.7101 AUPRC: 0.5201


[36mTraining[0m - Epochs: 016/030: 100%|██████████| 14681/14681 [00:02<00:00, 5338.46it/s]

[32mTrain finished. [0mAccuracy: 0.8149 MSE: 0.1851 AUROC: 0.8238 AUPRC: 0.6340


[36mTraining[0m - Epochs: 016/030: 100%|██████████| 14681/14681 [00:02<00:00, 5733.34it/s]


[31mValidation finished. [0mAccuracy: 0.7328 MSE: 0.0586 AUROC: 0.7013 AUPRC: 0.4959


[36mTraining[0m - Epochs: 017/030: 100%|██████████| 14681/14681 [00:02<00:00, 5274.19it/s]

[32mTrain finished. [0mAccuracy: 0.8329 MSE: 0.1671 AUROC: 0.8380 AUPRC: 0.6525


[36mTraining[0m - Epochs: 017/030: 100%|██████████| 14681/14681 [00:02<00:00, 5799.53it/s]


[31mValidation finished. [0mAccuracy: 0.7629 MSE: 0.0520 AUROC: 0.6917 AUPRC: 0.4776


[36mTraining[0m - Epochs: 018/030:  99%|█████████▉| 14592/14681 [00:02<00:00, 5892.91it/s]

[32mTrain finished. [0mAccuracy: 0.8434 MSE: 0.1566 AUROC: 0.8515 AUPRC: 0.6697


[36mTraining[0m - Epochs: 018/030: 100%|██████████| 14681/14681 [00:02<00:00, 5905.11it/s]


[31mValidation finished. [0mAccuracy: 0.7322 MSE: 0.0588 AUROC: 0.7039 AUPRC: 0.4992


[36mTraining[0m - Epochs: 019/030:  99%|█████████▉| 14592/14681 [00:02<00:00, 5898.01it/s]

[32mTrain finished. [0mAccuracy: 0.8563 MSE: 0.1437 AUROC: 0.8638 AUPRC: 0.6871


[36mTraining[0m - Epochs: 019/030: 100%|██████████| 14681/14681 [00:02<00:00, 5850.91it/s]


[31mValidation finished. [0mAccuracy: 0.7564 MSE: 0.0535 AUROC: 0.6821 AUPRC: 0.4665


[36mTraining[0m - Epochs: 020/030:  99%|█████████▉| 14592/14681 [00:02<00:00, 5608.84it/s]

[32mTrain finished. [0mAccuracy: 0.8577 MSE: 0.1423 AUROC: 0.8647 AUPRC: 0.6885


[36mTraining[0m - Epochs: 020/030: 100%|██████████| 14681/14681 [00:02<00:00, 5606.14it/s]


[31mValidation finished. [0mAccuracy: 0.8163 MSE: 0.0403 AUROC: 0.6626 AUPRC: 0.4421


[36mTraining[0m - Epochs: 021/030:  99%|█████████▉| 14592/14681 [00:02<00:00, 5898.98it/s]

[32mTrain finished. [0mAccuracy: 0.8500 MSE: 0.1500 AUROC: 0.8570 AUPRC: 0.6777


[36mTraining[0m - Epochs: 021/030: 100%|██████████| 14681/14681 [00:02<00:00, 5883.74it/s]


[31mValidation finished. [0mAccuracy: 0.8243 MSE: 0.0386 AUROC: 0.6556 AUPRC: 0.4361


[36mTraining[0m - Epochs: 022/030:  99%|█████████▉| 14592/14681 [00:02<00:00, 5918.27it/s]

[32mTrain finished. [0mAccuracy: 0.8544 MSE: 0.1456 AUROC: 0.8587 AUPRC: 0.6809


[36mTraining[0m - Epochs: 022/030: 100%|██████████| 14681/14681 [00:02<00:00, 5781.50it/s]


[31mValidation finished. [0mAccuracy: 0.7905 MSE: 0.0460 AUROC: 0.6719 AUPRC: 0.4503


[36mTraining[0m - Epochs: 023/030:  99%|█████████▉| 14592/14681 [00:02<00:00, 5737.06it/s]

[32mTrain finished. [0mAccuracy: 0.8405 MSE: 0.1595 AUROC: 0.8424 AUPRC: 0.6591


[36mTraining[0m - Epochs: 023/030: 100%|██████████| 14681/14681 [00:02<00:00, 5796.43it/s]


[31mValidation finished. [0mAccuracy: 0.8380 MSE: 0.0356 AUROC: 0.6500 AUPRC: 0.4370


[36mTraining[0m - Epochs: 024/030: 100%|██████████| 14681/14681 [00:02<00:00, 5287.82it/s]

[32mTrain finished. [0mAccuracy: 0.8525 MSE: 0.1475 AUROC: 0.8557 AUPRC: 0.6771


[36mTraining[0m - Epochs: 024/030: 100%|██████████| 14681/14681 [00:02<00:00, 5740.77it/s]


[31mValidation finished. [0mAccuracy: 0.8442 MSE: 0.0342 AUROC: 0.6294 AUPRC: 0.4163


[36mTraining[0m - Epochs: 025/030:  99%|█████████▉| 14592/14681 [00:02<00:00, 5887.52it/s]

[32mTrain finished. [0mAccuracy: 0.8813 MSE: 0.1187 AUROC: 0.8842 AUPRC: 0.7203


[36mTraining[0m - Epochs: 025/030: 100%|██████████| 14681/14681 [00:02<00:00, 5910.99it/s]


[31mValidation finished. [0mAccuracy: 0.8184 MSE: 0.0398 AUROC: 0.6532 AUPRC: 0.4306


[36mTraining[0m - Epochs: 026/030: 100%|██████████| 14681/14681 [00:02<00:00, 5470.85it/s]

[32mTrain finished. [0mAccuracy: 0.9012 MSE: 0.0988 AUROC: 0.9070 AUPRC: 0.7565


[36mTraining[0m - Epochs: 026/030: 100%|██████████| 14681/14681 [00:02<00:00, 5836.42it/s]


[31mValidation finished. [0mAccuracy: 0.7840 MSE: 0.0474 AUROC: 0.6797 AUPRC: 0.4604


[36mTraining[0m - Epochs: 027/030: 100%|██████████| 14681/14681 [00:02<00:00, 5445.62it/s]

[32mTrain finished. [0mAccuracy: 0.8999 MSE: 0.1001 AUROC: 0.9081 AUPRC: 0.7564


[36mTraining[0m - Epochs: 027/030: 100%|██████████| 14681/14681 [00:02<00:00, 5907.00it/s]


[31mValidation finished. [0mAccuracy: 0.7921 MSE: 0.0456 AUROC: 0.6718 AUPRC: 0.4503


[36mTraining[0m - Epochs: 028/030:  99%|█████████▉| 14592/14681 [00:02<00:00, 5778.78it/s]

[32mTrain finished. [0mAccuracy: 0.9061 MSE: 0.0939 AUROC: 0.9111 AUPRC: 0.7646


[36mTraining[0m - Epochs: 028/030: 100%|██████████| 14681/14681 [00:02<00:00, 5919.60it/s]


[31mValidation finished. [0mAccuracy: 0.7806 MSE: 0.0482 AUROC: 0.6690 AUPRC: 0.4468


[36mTraining[0m - Epochs: 029/030:  99%|█████████▉| 14592/14681 [00:02<00:00, 5996.23it/s]

[32mTrain finished. [0mAccuracy: 0.9074 MSE: 0.0926 AUROC: 0.9127 AUPRC: 0.7672


[36mTraining[0m - Epochs: 029/030: 100%|██████████| 14681/14681 [00:02<00:00, 5895.75it/s]


[31mValidation finished. [0mAccuracy: 0.8383 MSE: 0.0355 AUROC: 0.6473 AUPRC: 0.4338


[36mTraining[0m - Epochs: 030/030:  99%|█████████▉| 14592/14681 [00:02<00:00, 5826.17it/s]

[32mTrain finished. [0mAccuracy: 0.8976 MSE: 0.1024 AUROC: 0.8992 AUPRC: 0.7463


[36mTraining[0m - Epochs: 030/030: 100%|██████████| 14681/14681 [00:02<00:00, 5750.56it/s]


[31mValidation finished. [0mAccuracy: 0.7831 MSE: 0.0476 AUROC: 0.6898 AUPRC: 0.4733


[32mEvaluating[0m: 100%|██████████| 3222/3222 [00:00<00:00, 11385.22it/s]
COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.com/samdoud/general/975b737db72048e9a0730bae7fced099
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     test_AUPRC [13]         : (0.4699000046633014, 0.523934503496798)
COMET INFO:     test_AUROC [13]         : (0.68517571313601, 0.705429292929293)
COMET INFO:     test_MSE [13]           : (0.216796875, 0.234375)
COMET INFO:     test_accuracy [13]      : (0.765625, 0.783203125)
COMET INFO:     test_loss [13]          : (1.9413737058639526, 26.193230867385864)
COMET INFO:     train_AUPRC [1740]      : (0.2807586733367984, 0.8437258126934984)
COMET INFO:     train_AUROC [1740]      : (0.46840676292731087, 0.9425013248542661)
COMET INFO:     train_MSE [1740]        : (0.

[35mEvaluation finished. [0mAccuracy: 0.7831 MSE: 0.2169 AUROC: 0.6898 AUPRC: 0.4733


COMET INFO: Uploading metrics, params, and assets to Comet before program termination (may take several seconds)
COMET INFO: The Python SDK has 3600 seconds to finish before aborting...
