## Mortality Training
1. Load the MIMIC III Dataset
2. Normalize Data
3. Load into Pytorch Dataloader
4. Train
5. Evaluate

In [1]:
from comet_ml import Experiment
import numpy as np
import pandas as pd
from sklearn.discriminant_analysis import StandardScaler
from sklearn.impute import SimpleImputer
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from core.model import SAnD
from mimic3_benchmarks.mimic3benchmark.readers import PhenotypingReader as Reader
from mimic3_benchmarks.mimic3models.preprocessing import Discretizer, Normalizer
from utils.functions import get_weighted_sampler, get_weights
from utils.trainer import NeuralNetworkClassifier
from utils.batch_gen import BatchGen
from utils.phenotyping_utils import load_data


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
task = "phenotyping"

## Load Data
Using the [MimicIII Benchmark's](https://github.com/YerevaNN/mimic3-benchmarks) InHospitalMortalityReader and associated functions, load the data into memory.

If you are following along you will need to arrange the data following the directions on the link.


In [2]:
train_reader = Reader(dataset_dir=f"mimic3_benchmarks/data/{task}/train",
                              listfile=f"mimic3_benchmarks/data/{task}/train_listfile.csv")
val_reader = Reader(dataset_dir=f"mimic3_benchmarks/data/{task}/train",
                              listfile=f"mimic3_benchmarks/data/{task}/val_listfile.csv")
test_reader = Reader(dataset_dir=f"mimic3_benchmarks/data/{task}/test",
                              listfile=f"mimic3_benchmarks/data/{task}/test_listfile.csv")

### Example data

In [3]:
ex = pd.DataFrame.from_dict(train_reader.read_example(0)['X'])
ex.columns = train_reader.read_example(0)["header"]
print(train_reader.read_example(0)['y'])
ex

[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


Unnamed: 0,Hours,Capillary refill rate,Diastolic blood pressure,Fraction inspired oxygen,Glascow coma scale eye opening,Glascow coma scale motor response,Glascow coma scale total,Glascow coma scale verbal response,Glucose,Heart Rate,Height,Mean blood pressure,Oxygen saturation,Respiratory rate,Systolic blood pressure,Temperature,Weight,pH
0,0.3052777777777777,,86.0,,1 No Response,5 Localizes Pain,7.0,1.0 ET/Trach,,92.0,,104.0,100.0,14.0,141.0,,,
1,0.4719444444444444,,,,,,,,142.0,,,,100.0,14.0,,,,
2,0.6719444444444445,,,,,,,,,,,,,,,,,7.43
3,1.3052777777777778,,90.0,,,,,,,98.0,,109.0,100.0,14.0,144.0,38.16669845581055,,
4,2.305277777777777,,90.0,,,,,,,99.0,,109.0,100.0,14.0,147.0,,,
5,3.305277777777777,,93.0,,,,,,,106.0,,113.0,100.0,14.0,152.0,39.11111195882162,,
6,4.3052777777777775,,88.0,,,,,,,106.0,,106.0,100.0,14.0,145.0,,,
7,5.3052777777777775,,86.0,,3 To speech,5 Localizes Pain,9.0,1.0 ET/Trach,,103.0,,103.0,100.0,14.0,141.0,39.11111195882162,,
8,5.8052777777777775,,85.0,,,,,,,112.0,,102.0,100.0,16.0,135.0,,,
9,6.3052777777777775,,88.0,,,,,,,105.0,,110.0,100.0,9.0,151.0,,,


### Normalize and Discretize data
From the MIMIC-III Benchmark repo

In [4]:
discretizer = Discretizer(timestep=1.0,
                          store_masks=True,
                          impute_strategy='previous',
                          start_time='zero')

discretizer_header = discretizer.transform(train_reader.read_example(0)["X"])[1].split(',')
cont_channels = [i for (i, x) in enumerate(discretizer_header) if x.find("->") == -1]

In [5]:
normalizer = Normalizer(fields=cont_channels)  # choose here which columns to standardize
normalizer_state = 'train/ph_ts1.0.input_str-previous.start_time-zero.normalizer'
normalizer.load_params(normalizer_state)

In [6]:
batch_size = 128
max_seq_len = 256
small_part = False

train_raw = load_data(train_reader, discretizer, normalizer, max_seq_len, small_part)
val_raw = load_data(val_reader, discretizer, normalizer, max_seq_len, small_part)
test_raw = load_data(test_reader, discretizer, normalizer, max_seq_len, small_part)

train_ds = TensorDataset(train_raw[0], train_raw[1])
val_ds = TensorDataset(val_raw[0], val_raw[1])
test_ds = TensorDataset(test_raw[0], test_raw[1])

train_loader = DataLoader(train_ds, batch_size=batch_size)#sampler=get_weighted_sampler(y_train))
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=batch_size)

N, seq_len, feature_count = train_raw[0].shape
print(N, seq_len, feature_count)

29250 256 76


## Create model, Train, and Evaluate
Train the model, this will export results to Comet ML.
Creating is done in the same step such that a new experiment made each time
Evaluation is done in the same step such that the Comet ML experiment is also deleted.

In [8]:
n_heads = 8
factor = 120 # M
num_class = 25
num_layers = 2 # N
epochs = 30
betas = (0.9, 0.98)
lr = 0.0005
eps = 4e-09
weight_decay = 5e-4
dropout = 0.4

experiment = Experiment(
    api_key="eQ3INeSsFGUYKahSdEtjhry42",
    project_name="general",
    workspace="samdoud"
)

clf = NeuralNetworkClassifier(
    SAnD(feature_count, seq_len, n_heads, factor, num_class, num_layers, dropout_rate=dropout),
    nn.BCEWithLogitsLoss(),
    optim.Adam, optimizer_config={
        "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay},
    experiment=experiment
)

clf.fit(
    {
        "train": train_loader,
        "val": val_loader
    },
    validation=True,
    epochs=epochs,
    verbose=True,
)

clf.evaluate(test_loader)
experiment.end()

COMET INFO: Experiment is live on comet.com https://www.comet.com/samdoud/general/836ca91398974768a1daee8a8fa3b772

[36mTraining[0m - Epochs: 001/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2308.47it/s]

[32mTrain finished. [0mAccuracy: 0.5506 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 001/030: 100%|██████████| 29250/29250 [00:14<00:00, 2058.08it/s]


[31mValidation finished. [0mAccuracy: 0.6474 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 002/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2328.46it/s]

[32mTrain finished. [0mAccuracy: 0.5504 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 002/030: 100%|██████████| 29250/29250 [00:13<00:00, 2121.58it/s]


[31mValidation finished. [0mAccuracy: 0.6652 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 003/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2362.62it/s]

[32mTrain finished. [0mAccuracy: 0.5486 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 003/030: 100%|██████████| 29250/29250 [00:13<00:00, 2110.93it/s]


[31mValidation finished. [0mAccuracy: 0.6751 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 004/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2383.89it/s]

[32mTrain finished. [0mAccuracy: 0.5512 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 004/030: 100%|██████████| 29250/29250 [00:13<00:00, 2116.99it/s]


[31mValidation finished. [0mAccuracy: 0.6856 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 005/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2390.11it/s]

[32mTrain finished. [0mAccuracy: 0.5541 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 005/030: 100%|██████████| 29250/29250 [00:13<00:00, 2113.67it/s]


[31mValidation finished. [0mAccuracy: 0.6869 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 006/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2403.14it/s]

[32mTrain finished. [0mAccuracy: 0.5556 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 006/030: 100%|██████████| 29250/29250 [00:13<00:00, 2130.59it/s]


[31mValidation finished. [0mAccuracy: 0.6801 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 007/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2393.52it/s]

[32mTrain finished. [0mAccuracy: 0.5653 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 007/030: 100%|██████████| 29250/29250 [00:13<00:00, 2114.02it/s]


[31mValidation finished. [0mAccuracy: 0.6880 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 008/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2392.65it/s]

[32mTrain finished. [0mAccuracy: 0.5649 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 008/030: 100%|██████████| 29250/29250 [00:13<00:00, 2114.66it/s]


[31mValidation finished. [0mAccuracy: 0.7006 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 009/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2392.07it/s]

[32mTrain finished. [0mAccuracy: 0.5728 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 009/030: 100%|██████████| 29250/29250 [00:13<00:00, 2118.52it/s]


[31mValidation finished. [0mAccuracy: 0.7048 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 010/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2343.18it/s]

[32mTrain finished. [0mAccuracy: 0.5703 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 010/030: 100%|██████████| 29250/29250 [00:14<00:00, 2074.41it/s]


[31mValidation finished. [0mAccuracy: 0.6950 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 011/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2386.51it/s]

[32mTrain finished. [0mAccuracy: 0.5772 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 011/030: 100%|██████████| 29250/29250 [00:13<00:00, 2108.72it/s]


[31mValidation finished. [0mAccuracy: 0.7093 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 012/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 1954.85it/s]

[32mTrain finished. [0mAccuracy: 0.5790 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 012/030: 100%|██████████| 29250/29250 [00:13<00:00, 2104.44it/s]


[31mValidation finished. [0mAccuracy: 0.7105 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 013/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2382.73it/s]

[32mTrain finished. [0mAccuracy: 0.5850 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 013/030: 100%|██████████| 29250/29250 [00:13<00:00, 2121.97it/s]


[31mValidation finished. [0mAccuracy: 0.7122 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 014/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2383.13it/s]

[32mTrain finished. [0mAccuracy: 0.5871 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 014/030: 100%|██████████| 29250/29250 [00:13<00:00, 2116.07it/s]


[31mValidation finished. [0mAccuracy: 0.7129 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 015/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2387.86it/s]

[32mTrain finished. [0mAccuracy: 0.5924 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 015/030: 100%|██████████| 29250/29250 [00:13<00:00, 2111.44it/s]


[31mValidation finished. [0mAccuracy: 0.7157 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 016/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2393.51it/s]

[32mTrain finished. [0mAccuracy: 0.5978 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 016/030: 100%|██████████| 29250/29250 [00:13<00:00, 2119.91it/s]


[31mValidation finished. [0mAccuracy: 0.7174 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 017/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2388.64it/s]

[32mTrain finished. [0mAccuracy: 0.6037 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 017/030: 100%|██████████| 29250/29250 [00:13<00:00, 2110.38it/s]


[31mValidation finished. [0mAccuracy: 0.7187 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 018/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2360.09it/s]

[32mTrain finished. [0mAccuracy: 0.6116 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 018/030: 100%|██████████| 29250/29250 [00:13<00:00, 2110.71it/s]


[31mValidation finished. [0mAccuracy: 0.7210 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 019/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2384.40it/s]

[32mTrain finished. [0mAccuracy: 0.6311 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 019/030: 100%|██████████| 29250/29250 [00:13<00:00, 2115.51it/s]


[31mValidation finished. [0mAccuracy: 0.7134 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 020/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2357.01it/s]

[32mTrain finished. [0mAccuracy: 0.6513 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 020/030: 100%|██████████| 29250/29250 [00:14<00:00, 2087.76it/s]


[31mValidation finished. [0mAccuracy: 0.7183 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 021/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2418.39it/s]

[32mTrain finished. [0mAccuracy: 0.6622 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 021/030: 100%|██████████| 29250/29250 [00:13<00:00, 2118.13it/s]


[31mValidation finished. [0mAccuracy: 0.7209 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 022/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2406.74it/s]

[32mTrain finished. [0mAccuracy: 0.6687 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 022/030: 100%|██████████| 29250/29250 [00:13<00:00, 2109.54it/s]


[31mValidation finished. [0mAccuracy: 0.7222 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 023/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2390.51it/s]

[32mTrain finished. [0mAccuracy: 0.6731 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 023/030: 100%|██████████| 29250/29250 [00:13<00:00, 2119.18it/s]


[31mValidation finished. [0mAccuracy: 0.7229 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 024/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2351.58it/s]

[32mTrain finished. [0mAccuracy: 0.6773 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 024/030: 100%|██████████| 29250/29250 [00:13<00:00, 2109.20it/s]


[31mValidation finished. [0mAccuracy: 0.7215 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 025/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2388.54it/s]

[32mTrain finished. [0mAccuracy: 0.6772 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 025/030: 100%|██████████| 29250/29250 [00:13<00:00, 2112.24it/s]


[31mValidation finished. [0mAccuracy: 0.7254 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 026/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2364.48it/s]

[32mTrain finished. [0mAccuracy: 0.6807 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 026/030: 100%|██████████| 29250/29250 [00:13<00:00, 2109.60it/s]


[31mValidation finished. [0mAccuracy: 0.7238 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 027/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2408.36it/s]

[32mTrain finished. [0mAccuracy: 0.6824 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 027/030: 100%|██████████| 29250/29250 [00:13<00:00, 2127.73it/s]


[31mValidation finished. [0mAccuracy: 0.7247 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 028/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2413.88it/s]

[32mTrain finished. [0mAccuracy: 0.6847 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 028/030: 100%|██████████| 29250/29250 [00:13<00:00, 2119.90it/s]


[31mValidation finished. [0mAccuracy: 0.7250 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 029/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2367.66it/s]

[32mTrain finished. [0mAccuracy: 0.6866 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 029/030: 100%|██████████| 29250/29250 [00:13<00:00, 2108.40it/s]


[31mValidation finished. [0mAccuracy: 0.7246 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 030/030: 100%|█████████▉| 29184/29250 [00:12<00:00, 2314.58it/s]

[32mTrain finished. [0mAccuracy: 0.6886 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[36mTraining[0m - Epochs: 030/030: 100%|██████████| 29250/29250 [00:14<00:00, 2078.87it/s]


[31mValidation finished. [0mAccuracy: 0.7261 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


[32mEvaluating[0m: 100%|██████████| 6281/6281 [00:01<00:00, 4612.78it/s]
COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.com/samdoud/general/836ca91398974768a1daee8a8fa3b772
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     test_loss [50]         : (0.5519314408302307, 29.14244955778122)
COMET INFO:     train_accuracy [30]    : (0.5486425066445836, 0.6885715084333469)
COMET INFO:     train_loss [7557]      : (0.32722023129463196, 261.7335510253906)
COMET INFO:     validate_accuracy [30] : (0.6473843322424915, 0.7261181287295638)
COMET INFO:     validate_loss [1500]   : (0.5097749829292297, 22.033466339111328)
COMET INFO:   Parameters:
COMET INFO:     batch_size    : 128
COMET INFO:     betas         : (0.9, 0.98)
COMET INFO:     epochs        : 30
COMET INFO:     eps           : 4e-09
CO

error on test
[35mEvaluation finished. [0mAccuracy: 0.7261 MSE: 0.0000 AUROC: 0.0000 AUPRC: 0.0000


COMET INFO: Uploading 1 metrics, params and output messages
