# MIMIC-III M3Care Baseline

### Testing

In [1]:
from mimic.models import (MLP, TimeSeriesTransformer, TimeSeriesNLEmbedder, NLEmbedder)
from mimic.dataset import MIMICDataset
from mimic.vocab import Vocab
from general.m3care import M3Care
from util.kfold import KFoldDatasetLoader

from torch import nn

import numpy as np

import os
import torch

processed_dir = './mimic/data/processed'
test_ds = MIMICDataset(processed_dir, False)
vocab = Vocab.from_json(os.path.join(processed_dir, 'vocab.json'))

DEM_DIM = 18
EMB_DIM = 512
dem_mdl = MLP(in_dim=DEM_DIM, hidden_dim=[128,192,256], out_dim=EMB_DIM, bias=True, relu=True, norm=True)

DROPOUT = 0.3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
unimodal_models = nn.ModuleList([dem_mdl])
missing_modals = [False]
time_modals = [False]
timesteps_modals = []
mask_modals = [False]
output_dim = 2
keep_prob = 1 - DROPOUT

model = M3Care(unimodal_models, missing_modals, time_modals, timesteps_modals, mask_modals, EMB_DIM, output_dim, device, keep_prob).to(device)

sample = test_ds[:32]
dem, vit, itv, nst, vit_msk, itv_msk, nst_msk, lbl = sample
dem_ten = torch.from_numpy(dem).float().to(device)
lbl_ten = torch.from_numpy(lbl).float().to(device)
res = model(dem_ten, 32)



### Load modules and define constants.

In [1]:
from mimic.models import (MLP, TimeSeriesTransformer, TimeSeriesNLEmbedder, NLEmbedder)
from mimic.dataset import MIMICDataset
from mimic.vocab import Vocab
from general.m3care import M3Care
from util.kfold import KFoldDatasetLoader

from sklearn.metrics import recall_score, precision_score, f1_score, accuracy_score
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

from torch import nn

import numpy as np
import math

import os
import torch

In [2]:
processed_dir = './mimic/data/processed'

### Load Datasets

In [3]:
# train_ds = MIMICDataset(processed_dir, True)
test_ds = MIMICDataset(processed_dir, False)

vocab = Vocab.from_json(os.path.join(processed_dir, 'vocab.json'))

In [4]:
DEM_DIM = 18
VIT_DIM = 104
ITV_DIM = 14
EMB_DIM = 512
DROPOUT = 0.3

VIT_TIMESTEPS = 150
ITV_TIMESTEPS = 150
NTS_TIMESTEPS = 128

NST_WORD_LIMIT = 10000
NTS_WORD_LIMIT = 5000

### Instantiate Unimodal Extraction Models

In [5]:
dem_mdl = MLP(in_dim=DEM_DIM, hidden_dim=[128,192,256], out_dim=EMB_DIM, bias=True, relu=True, norm=True)
vit_mdl = TimeSeriesTransformer(VIT_DIM, EMB_DIM, max_len=VIT_TIMESTEPS, dropout=DROPOUT)
itv_mdl = TimeSeriesTransformer(ITV_DIM, EMB_DIM, max_len=ITV_TIMESTEPS, dropout=DROPOUT)
nst_mdl = NLEmbedder(vocab, 16, EMB_DIM, 8, 2048, dropout=DROPOUT)

### Instantiate M3Care Model

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Unimodal feature extractor
# unimodal_models = nn.ModuleList([dem_mdl, vit_mdl, itv_mdl, nst_mdl])
unimodal_models = nn.ModuleList([dem_mdl])
# Which models have missing values
# missing_modals = [False, True, True, True]
missing_modals = [False]
# Which models are time-based
# time_modals = [False, True, True, False]
time_modals = [False]
# For time based models what's the max sequence length
# timesteps_modals = [150, 150]
timesteps_modals = []
# Which modalities have a mask
# mask_modals = [False, True, True, True]
mask_modals = [False]
# Output dim that's put into out_mdl
output_dim = 2
# Dropout inverse
keep_prob = 1 - DROPOUT

In [7]:
model = M3Care(unimodal_models, missing_modals, time_modals, timesteps_modals, mask_modals, EMB_DIM, output_dim, device, keep_prob).to(device)



### Training Loop

In [8]:
EPOCHS = 3
BATCH_SIZE = 32
FOLDS = 8
LEARNING_RATE = 1e-3

In [9]:
loss_bce = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loader = KFoldDatasetLoader(test_ds, FOLDS, BATCH_SIZE)

In [10]:
BATCHES_PER_LOG = 5
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
writer = SummaryWriter(f'runs/m3care_mimic_{timestamp}')

In [11]:
def load_batch(loader):
    dem, vit, itv, nst, vit_msk, itv_msk, nst_msk, lbl = loader.next()
    dem_ten = torch.from_numpy(dem).float().to(device)
    # vit_ten = torch.from_numpy(vit).float().to(device)
    # itv_ten = torch.from_numpy(itv).float().to(device)
    # nst_ten = torch.from_numpy(nst).float().to(device)
    # vit_msk_ten = torch.from_numpy(vit_msk).bool().to(device)
    # itv_msk_ten = torch.from_numpy(itv_msk).bool().to(device)
    # nst_msk_ten = torch.from_numpy(nst_msk).bool().to(device)
    lbl_ten = torch.from_numpy(lbl).float().to(device)
    # return (dem_ten, vit_ten, itv_ten, nst_ten, vit_msk_ten, itv_msk_ten, nst_msk_ten), lbl_ten
    return (dem_ten), lbl_ten

In [12]:
losses = []
accuracy, precision, recall, f1 = [], [], [], []
total_train_batches = math.floor(len(test_ds)/BATCH_SIZE)

running_loss = 0.0
y_truths, y_preds = [], []

# Each epoch ones only one fold
for epoch_idx in range(EPOCHS):
    loader.train()

    batch_idx = 1

    # Iterate through each training batch
    while not loader.end():
        # Load batch from the loader
        X, y = load_batch(loader)
        y = torch.stack([~y.bool(),y.bool()], axis=1).float()

        optimizer.zero_grad()
        y_pred, lstab = model(X, BATCH_SIZE)

        loss = loss_bce(y_pred, y)

        # Adjust model accordingly.
        loss.backward()
        optimizer.step()

        # Save results for single iteration
        running_loss += loss.detach().item()
        y_truths.append(y[:,-1].detach().cpu().numpy())
        y_preds.append(y_pred.argmax(axis=1).detach().cpu().numpy())

        # Save results for single log iteration
        if batch_idx % BATCHES_PER_LOG == 0:
            last_loss = running_loss / BATCHES_PER_LOG * BATCH_SIZE
            losses.append(last_loss)
            running_loss = 0.0

            y_truths_c = np.concatenate(y_truths)
            y_preds_c = np.concatenate(y_preds)
            accuracy.append(accuracy_score(y_truths_c, y_preds_c))
            precision.append(precision_score(y_truths_c, y_preds_c))
            recall.append(recall_score(y_truths_c, y_preds_c))
            f1.append(f1_score(y_truths_c, y_preds_c))
            y_truths.clear()
            y_preds.clear()

            # Display Results
            print(f"\tbatch {batch_idx} loss: {last_loss}")
            print(f"\tbest loss: {min(losses)}")

            # Write to tensorboard
            tb_idx = (epoch_idx * total_train_batches + batch_idx) * BATCH_SIZE
            writer.add_scalar('Loss/train', last_loss, tb_idx)

        batch_idx += 1

        if batch_idx == 10:
            break

    break

    loader.val()

    # Iterate through each validation batch
    while not loader.end():
        x = loader.next()

    loader.next_fold()
    if loader.end_fold():
        loader.reset()

	batch 5 loss: 17.238424015045165
	best loss: 17.238424015045165


In [14]:
pred_y = torch.tensor([True, False, True, True])
y = torch.tensor([True, False, False, True])

pred_y = torch.stack([~pred_y, pred_y], axis=1)
y = torch.stack([~y, y], axis=1)

In [20]:
y_truths_c = np.concatenate(y_truths)
y_preds_c = np.concatenate(y_preds)

In [22]:
recall_score(y_truths_c, y_preds_c)

0.07407407407407407

In [42]:
y_pred

tensor([[ 2.4446, -2.1665],
        [ 2.6084, -2.1263],
        [ 2.1012, -1.8023],
        [ 2.9684, -2.2325],
        [ 1.9413, -1.5935],
        [ 2.4603, -2.0854],
        [ 2.8656, -2.1502],
        [ 2.1422, -1.4622],
        [ 2.7696, -1.7585],
        [ 2.4717, -2.3189],
        [ 2.4094, -2.0613],
        [ 3.2527, -2.3141],
        [ 2.9413, -2.4611],
        [ 1.6488, -1.5808],
        [ 1.7613, -1.6040],
        [ 1.9012, -1.5147],
        [ 2.5354, -1.9792],
        [ 2.3909, -2.1359],
        [ 2.7984, -2.2884],
        [ 2.0891, -1.4728],
        [ 2.8681, -2.1382],
        [ 2.0934, -1.5964],
        [ 3.2246, -2.3791],
        [ 2.7333, -2.3112],
        [ 3.4254, -3.0554],
        [ 2.0117, -1.7837],
        [ 3.5364, -2.7268],
        [ 2.4244, -1.8148],
        [ 2.2837, -1.8553],
        [ 2.5749, -2.1558],
        [ 2.5648, -2.3115],
        [ 1.9434, -1.4515]], device='cuda:0', grad_fn=<SqueezeBackward1>)