## Mortality Training
1. Load the MIMIC III Dataset
2. Normalize Data
3. Load into Pytorch Dataloader
4. Train
5. Evaluate

In [12]:
from comet_ml import Experiment
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from core.model import SAnD
from mimic3_benchmarks.mimic3benchmark.readers import PhenotypingReader as Reader
from mimic3_benchmarks.mimic3models.preprocessing import Discretizer, Normalizer
from utils.functions import get_weighted_sampler, get_weights
from utils.trainer import NeuralNetworkClassifier
from utils.batch_gen import BatchGen
from utils.phenotyping_utils import load_data


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
task = "phenotyping"

## Load Data
Using the [MimicIII Benchmark's](https://github.com/YerevaNN/mimic3-benchmarks) InHospitalMortalityReader and associated functions, load the data into memory.

If you are following along you will need to arrange the data following the directions on the link.


In [2]:
train_reader = Reader(dataset_dir=f"mimic3_benchmarks/data/{task}/train",
                              listfile=f"mimic3_benchmarks/data/{task}/train_listfile.csv")
val_reader = Reader(dataset_dir=f"mimic3_benchmarks/data/{task}/train",
                              listfile=f"mimic3_benchmarks/data/{task}/val_listfile.csv")
test_reader = Reader(dataset_dir=f"mimic3_benchmarks/data/{task}/test",
                              listfile=f"mimic3_benchmarks/data/{task}/test_listfile.csv")

### Example data

In [3]:
ex = pd.DataFrame.from_dict(train_reader.read_example(0)['X'])
ex.columns = train_reader.read_example(0)["header"]
print(train_reader.read_example(0)['y'])
ex

[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


Unnamed: 0,Hours,Capillary refill rate,Diastolic blood pressure,Fraction inspired oxygen,Glascow coma scale eye opening,Glascow coma scale motor response,Glascow coma scale total,Glascow coma scale verbal response,Glucose,Heart Rate,Height,Mean blood pressure,Oxygen saturation,Respiratory rate,Systolic blood pressure,Temperature,Weight,pH
0,0.3052777777777777,,86.0,,1 No Response,5 Localizes Pain,7.0,1.0 ET/Trach,,92.0,,104.0,100.0,14.0,141.0,,,
1,0.4719444444444444,,,,,,,,142.0,,,,100.0,14.0,,,,
2,0.6719444444444445,,,,,,,,,,,,,,,,,7.43
3,1.3052777777777778,,90.0,,,,,,,98.0,,109.0,100.0,14.0,144.0,38.16669845581055,,
4,2.305277777777777,,90.0,,,,,,,99.0,,109.0,100.0,14.0,147.0,,,
5,3.305277777777777,,93.0,,,,,,,106.0,,113.0,100.0,14.0,152.0,39.11111195882162,,
6,4.3052777777777775,,88.0,,,,,,,106.0,,106.0,100.0,14.0,145.0,,,
7,5.3052777777777775,,86.0,,3 To speech,5 Localizes Pain,9.0,1.0 ET/Trach,,103.0,,103.0,100.0,14.0,141.0,39.11111195882162,,
8,5.8052777777777775,,85.0,,,,,,,112.0,,102.0,100.0,16.0,135.0,,,
9,6.3052777777777775,,88.0,,,,,,,105.0,,110.0,100.0,9.0,151.0,,,


### Normalize and Discretize data
From the MIMIC-III Benchmark repo

In [4]:
discretizer = Discretizer(timestep=1.0,
                          store_masks=True,
                          impute_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

In [5]:
normalizer = Normalizer(fields=cont_channels)  # choose here which columns to standardize
normalizer_state = 'train/ph_ts1.0.input_str-previous.start_time-zero.normalizer'
normalizer.load_params(normalizer_state)

In [15]:
batch_size = 128
max_seq_len = 256
small_part = True

train_raw = load_data(train_reader, discretizer, normalizer, max_seq_len, small_part)
val_raw = load_data(val_reader, discretizer, normalizer, max_seq_len, small_part)
test_raw = load_data(test_reader, discretizer, normalizer, max_seq_len, small_part)

train_ds = TensorDataset(train_raw[0], train_raw[1])
val_ds = TensorDataset(val_raw[0], val_raw[1])
test_ds = TensorDataset(test_raw[0], test_raw[1])

train_loader = DataLoader(train_ds, batch_size=batch_size)#sampler=get_weighted_sampler(y_train))
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=batch_size)

N, seq_len, feature_count = train_raw[0].shape
print(N, seq_len, feature_count)

ValueError: not enough values to unpack (expected 3, got 2)

## Create model, Train, and Evaluate
Train the model, this will export results to Comet ML.
Creating is done in the same step such that a new experiment made each time
Evaluation is done in the same step such that the Comet ML experiment is also deleted.

In [13]:
n_heads = 8
factor = 120 # M
num_class = 25
num_layers = 2 # N
epochs = 30
betas = (0.9, 0.98)
lr = 0.0005
eps = 4e-09
weight_decay = 5e-4
dropout = 0.4

experiment = Experiment(
    api_key="eQ3INeSsFGUYKahSdEtjhry42",
    project_name="general",
    workspace="samdoud"
)

clf = NeuralNetworkClassifier(
    SAnD(feature_count, seq_len, n_heads, factor, num_class, num_layers, dropout_rate=dropout),
    nn.BCEWithLogitsLoss(),
    optim.Adam, optimizer_config={
        "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay},
    experiment=experiment
)

clf.fit(
    {
        "train": train_loader,
        "val": val_loader
    },
    validation=True,
    epochs=epochs,
    verbose=True,
)

clf.evaluate(test_loader)
experiment.end()

COMET INFO: Experiment is live on comet.com https://www.comet.com/samdoud/general/c6305004fc1649369bf5560d61f3697d

[36mTraining[0m - Epochs: 001/030: 100%|█████████▉| 29184/29250 [08:59<00:00, 2363.28it/s]

[32mTrain finished. [0mAccuracy: 0.5455 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 001/030: 100%|██████████| 29250/29250 [09:00<00:00, 54.07it/s]  


[31mValidation finished. [0mAccuracy: 0.6646 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 002/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2537.39it/s]

[32mTrain finished. [0mAccuracy: 0.5493 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 002/030: 100%|██████████| 29250/29250 [00:13<00:00, 2245.95it/s]


[31mValidation finished. [0mAccuracy: 0.6740 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 003/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2535.26it/s]

[32mTrain finished. [0mAccuracy: 0.5467 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 003/030: 100%|██████████| 29250/29250 [00:13<00:00, 2244.42it/s]


[31mValidation finished. [0mAccuracy: 0.6878 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 004/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2639.01it/s]

[32mTrain finished. [0mAccuracy: 0.5590 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 004/030: 100%|██████████| 29250/29250 [00:12<00:00, 2286.83it/s]


[31mValidation finished. [0mAccuracy: 0.6797 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 005/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2626.42it/s]

[32mTrain finished. [0mAccuracy: 0.5594 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 005/030: 100%|██████████| 29250/29250 [00:12<00:00, 2302.32it/s]


[31mValidation finished. [0mAccuracy: 0.6922 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 006/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2633.24it/s]

[32mTrain finished. [0mAccuracy: 0.5662 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 006/030: 100%|██████████| 29250/29250 [00:12<00:00, 2320.28it/s]


[31mValidation finished. [0mAccuracy: 0.6881 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 007/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2631.08it/s]

[32mTrain finished. [0mAccuracy: 0.5723 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 007/030: 100%|██████████| 29250/29250 [00:12<00:00, 2292.99it/s]


[31mValidation finished. [0mAccuracy: 0.7023 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 008/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2571.81it/s]

[32mTrain finished. [0mAccuracy: 0.5761 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 008/030: 100%|██████████| 29250/29250 [00:12<00:00, 2288.91it/s]


[31mValidation finished. [0mAccuracy: 0.7063 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 009/030: 100%|██████████| 29250/29250 [00:11<00:00, 2482.60it/s]

[32mTrain finished. [0mAccuracy: 0.5786 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 009/030: 100%|██████████| 29250/29250 [00:12<00:00, 2273.28it/s]


[31mValidation finished. [0mAccuracy: 0.7052 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 010/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2583.22it/s]

[32mTrain finished. [0mAccuracy: 0.5841 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 010/030: 100%|██████████| 29250/29250 [00:12<00:00, 2306.14it/s]


[31mValidation finished. [0mAccuracy: 0.7101 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 011/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2552.79it/s]

[32mTrain finished. [0mAccuracy: 0.5928 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 011/030: 100%|██████████| 29250/29250 [00:12<00:00, 2284.89it/s]


[31mValidation finished. [0mAccuracy: 0.7091 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 012/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2614.22it/s]

[32mTrain finished. [0mAccuracy: 0.6121 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 012/030: 100%|██████████| 29250/29250 [00:12<00:00, 2279.54it/s]


[31mValidation finished. [0mAccuracy: 0.7106 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 013/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2617.21it/s]

[32mTrain finished. [0mAccuracy: 0.6172 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 013/030: 100%|██████████| 29250/29250 [00:12<00:00, 2322.38it/s]


[31mValidation finished. [0mAccuracy: 0.7100 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 014/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2548.78it/s]

[32mTrain finished. [0mAccuracy: 0.6277 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 014/030: 100%|██████████| 29250/29250 [00:12<00:00, 2313.29it/s]


[31mValidation finished. [0mAccuracy: 0.7113 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 015/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2611.07it/s]

[32mTrain finished. [0mAccuracy: 0.6387 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 015/030: 100%|██████████| 29250/29250 [00:12<00:00, 2330.26it/s]


[31mValidation finished. [0mAccuracy: 0.7154 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 016/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2598.44it/s]

[32mTrain finished. [0mAccuracy: 0.6452 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 016/030: 100%|██████████| 29250/29250 [00:12<00:00, 2309.82it/s]


[31mValidation finished. [0mAccuracy: 0.7172 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 017/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2599.53it/s]

[32mTrain finished. [0mAccuracy: 0.6494 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 017/030: 100%|██████████| 29250/29250 [00:12<00:00, 2312.55it/s]


[31mValidation finished. [0mAccuracy: 0.7189 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 018/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2613.53it/s]

[32mTrain finished. [0mAccuracy: 0.6513 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 018/030: 100%|██████████| 29250/29250 [00:12<00:00, 2329.54it/s]


[31mValidation finished. [0mAccuracy: 0.7190 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 019/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2584.54it/s]

[32mTrain finished. [0mAccuracy: 0.6533 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 019/030: 100%|██████████| 29250/29250 [00:12<00:00, 2305.71it/s]


[31mValidation finished. [0mAccuracy: 0.7210 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 020/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2483.54it/s]

[32mTrain finished. [0mAccuracy: 0.6574 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 020/030: 100%|██████████| 29250/29250 [00:12<00:00, 2252.51it/s]


[31mValidation finished. [0mAccuracy: 0.7218 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 021/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2605.11it/s]

[32mTrain finished. [0mAccuracy: 0.6592 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 021/030: 100%|██████████| 29250/29250 [00:12<00:00, 2324.40it/s]


[31mValidation finished. [0mAccuracy: 0.7230 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 022/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2614.24it/s]

[32mTrain finished. [0mAccuracy: 0.6637 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 022/030: 100%|██████████| 29250/29250 [00:12<00:00, 2317.39it/s]


[31mValidation finished. [0mAccuracy: 0.7233 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 023/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2627.52it/s]

[32mTrain finished. [0mAccuracy: 0.6652 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 023/030: 100%|██████████| 29250/29250 [00:12<00:00, 2319.03it/s]


[31mValidation finished. [0mAccuracy: 0.7240 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 024/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2608.77it/s]

[32mTrain finished. [0mAccuracy: 0.6598 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 024/030: 100%|██████████| 29250/29250 [00:12<00:00, 2326.27it/s]


[31mValidation finished. [0mAccuracy: 0.7240 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 025/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2598.32it/s]

[32mTrain finished. [0mAccuracy: 0.6736 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 025/030: 100%|██████████| 29250/29250 [00:12<00:00, 2312.75it/s]


[31mValidation finished. [0mAccuracy: 0.7240 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 026/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2598.10it/s]

[32mTrain finished. [0mAccuracy: 0.6694 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 026/030: 100%|██████████| 29250/29250 [00:12<00:00, 2325.87it/s]


[31mValidation finished. [0mAccuracy: 0.7229 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 027/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2615.65it/s]

[32mTrain finished. [0mAccuracy: 0.6731 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 027/030: 100%|██████████| 29250/29250 [00:12<00:00, 2319.32it/s]


[31mValidation finished. [0mAccuracy: 0.7257 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 028/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2601.63it/s]

[32mTrain finished. [0mAccuracy: 0.6749 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 028/030: 100%|██████████| 29250/29250 [00:12<00:00, 2317.43it/s]


[31mValidation finished. [0mAccuracy: 0.7244 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 029/030: 100%|█████████▉| 29184/29250 [00:11<00:00, 2628.96it/s]

[32mTrain finished. [0mAccuracy: 0.6736 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 029/030: 100%|██████████| 29250/29250 [00:12<00:00, 2334.60it/s]


[31mValidation finished. [0mAccuracy: 0.7246 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 030/030: 100%|██████████| 29250/29250 [00:11<00:00, 2463.03it/s]

[32mTrain finished. [0mAccuracy: 0.6757 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 030/030: 100%|██████████| 29250/29250 [00:12<00:00, 2274.27it/s]


[31mValidation finished. [0mAccuracy: 0.7255 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[32mEvaluating[0m: 100%|██████████| 6281/6281 [00:01<00:00, 4611.65it/s]
COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.com/samdoud/general/c6305004fc1649369bf5560d61f3697d
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     test_loss [50]         : (1.1974750757217407, 62.925976037979126)
COMET INFO:     train_accuracy [30]    : (0.5454524357840066, 0.6757104408058814)
COMET INFO:     train_loss [7557]      : (0.3341692388057709, 280.1351013183594)
COMET INFO:     validate_accuracy [30] : (0.6645940886693542, 0.7257305206106658)
COMET INFO:     validate_loss [1500]   : (0.8510227799415588, 12.934517860412598)
COMET INFO:   Parameters:
COMET INFO:     batch_size    : 128
COMET INFO:     betas         : (0.9, 0.98)
COMET INFO:     epochs        : 30
COMET INFO:     eps           : 4e-09
CO

error on test
[35mEvaluation finished. [0mAccuracy: 0.7255 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


COMET INFO: Uploading 785 metrics, params and output messages


ModuleNotFoundError: No module named 'mimic3models'