# MIMIC-III M3Care Baseline

### Load modules and define constants.

In [1]:
from mimic.models import (MLP, TimeSeriesTransformer, TimeSeriesNLEmbedder, NLEmbedder)
from mimic.dataset import MIMICDataset
from mimic.vocab import Vocab
from general.m3care import M3Care

from torch import nn

import os
import torch

In [2]:
processed_dir = './mimic/data/processed'

### Load Datasets

In [3]:
# train_ds = MIMICDataset(processed_dir, True)
test_ds = MIMICDataset(processed_dir, False)

vocab = Vocab.from_json(os.path.join(processed_dir, 'vocab.json'))

In [4]:
DEM_DIM = 18
VIT_DIM = 104
ITV_DIM = 14
EMB_DIM = 512
DROPOUT = 0.3

VIT_TIMESTEPS = 150
ITV_TIMESTEPS = 150
NTS_TIMESTEPS = 128

NST_WORD_LIMIT = 10000
NTS_WORD_LIMIT = 5000

### Instantiate Unimodal Extraction Models

In [5]:
dem_mdl = MLP(in_dim=DEM_DIM, hidden_dim=[128,192,256], out_dim=EMB_DIM, bias=True, relu=True, norm=True)
vit_mdl = TimeSeriesTransformer(VIT_DIM, EMB_DIM, max_len=VIT_TIMESTEPS, dropout=DROPOUT)
itv_mdl = TimeSeriesTransformer(ITV_DIM, EMB_DIM, max_len=ITV_TIMESTEPS, dropout=DROPOUT)
nst_mdl = NLEmbedder(vocab, 16, EMB_DIM, 8, 2048, dropout=DROPOUT)
nts_mdl = TimeSeriesNLEmbedder(vocab, 16, EMB_DIM, 8, 2048, dropout=DROPOUT)

### Instantiate M3Care Model

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# unimodal_models = nn.ModuleList([dem_mdl, vit_mdl, itv_mdl, nst_mdl, nts_mdl])
unimodal_models = nn.ModuleList([dem_mdl, vit_mdl, itv_mdl, nst_mdl])
missing_modals = [False, True, True, True]
time_modals = [False, True, True, False]
timesteps_modals = [150, 150]
mask_modals = [False, True, True, True]
output_dim = 2
keep_prob = 1 - DROPOUT

In [7]:
model = M3Care(unimodal_models, missing_modals, time_modals, timesteps_modals, mask_modals, EMB_DIM, output_dim, device, keep_prob).to(device)



### Sample Run

In [8]:
sample = test_ds[:4]
dem, vit, itv, nst, nts, vit_msk, itv_msk, nst_msk, nts_msk, lbl = sample

In [9]:
dem_ten = torch.from_numpy(dem).float().to(device)
vit_ten = torch.from_numpy(vit).float().to(device)
itv_ten = torch.from_numpy(itv).float().to(device)
nst_ten = torch.from_numpy(nst).float().to(device)
# nts_ten = torch.from_numpy(nts).float().to(device)
vit_msk_ten = torch.from_numpy(vit_msk).bool().to(device)
itv_msk_ten = torch.from_numpy(itv_msk).bool().to(device)
nst_msk_ten = torch.from_numpy(nst_msk).bool().to(device)
# nts_msk_ten = torch.from_numpy(nts_msk).float().to(device)
lbl_ten = torch.from_numpy(lbl).bool().to(device)

In [10]:
def nanCnt(x):
    return x.isnan().sum()

In [11]:
orig, emb = model(dem_ten, vit_ten, itv_ten, nst_ten, vit_msk_ten, itv_msk_ten, nst_msk_ten, 4)

In [None]:
for o, e in zip(orig, emb):
    print(nanCnt(o), o.shape)
    print(nanCnt(e), e.shape)

In [17]:
print(nst_ten)
print(nst_msk_ten)

tensor([[0.],
        [0.],
        [0.],
        [0.]], device='cuda:0')
tensor([[False],
        [False],
        [False],
        [False]], device='cuda:0')


In [68]:
x = nst_ten
mask = nst_msk_ten

In [69]:
pad_mask = ~mask

In [70]:
pad_mask[3,0] = False

In [71]:
pad_mask

tensor([[ True],
        [ True],
        [ True],
        [False]], device='cuda:0')

In [72]:
x = nst_mdl.word_embed(x)

In [73]:
x = nst_mdl.pos_encode(x)

In [74]:
v = nst_mdl.enc_layer(x, src_key_padding_mask=pad_mask)

In [76]:
x.shape

torch.Size([4, 1, 512])

In [31]:
pad_mask

tensor([[True],
        [True],
        [True],
        [True]], device='cuda:0')

In [79]:
v.nan_to_num()

tensor([[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.4918,  1.2983, -0.5192,  ..., -0.6496,  0.4102,  1.9774]]],
       device='cuda:0', grad_fn=<NanToNumBackward0>)

In [80]:
v

tensor([[[    nan,     nan,     nan,  ...,     nan,     nan,     nan]],

        [[    nan,     nan,     nan,  ...,     nan,     nan,     nan]],

        [[    nan,     nan,     nan,  ...,     nan,     nan,     nan]],

        [[ 0.4918,  1.2983, -0.5192,  ..., -0.6496,  0.4102,  1.9774]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>)