## Import Modules

In [2]:
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset
from torch import Tensor
from enum import Enum
from typing import List
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
import math
import json
import os
import h5py

## Dataset

In [3]:
class MIMICDataset(Dataset):
    def __init__(self, processed_dir: str, train: bool):
        self.processed_dir = processed_dir

        data_path = os.path.join(processed_dir, ('train/' if train else 'test/'))
        index_path = os.path.join(processed_dir, f"{'train' if train else 'test'}_idxs.npy")

        try:
            self.indexes = np.load(index_path)
            self.vocab = Vocab.from_json(os.path.join(processed_dir, 'vocab.json')) 
            self.notes_ts_path = os.path.join(data_path, 'notes_ts.h5')
            if not os.path.exists(self.notes_ts_path):
                raise FileNotFoundError()
        except FileNotFoundError as e:
            print("Make sure data has been processed: ", e)
            return
        
        with h5py.File(self.notes_ts_path, 'r') as f:
            self.nts_ids = set([int(k.split('_')[-1]) for k in list(f.keys())])
    
    def __len__(self):
        return len(self.indexes)

    def __getitem__(self, item_idx):
        pat_id = self.indexes[item_idx]

    def __getitem__(self, item_idx):
        pat_id = self.indexes[item_idx]

        if type(pat_id) != np.ndarray:
            return self._getpatient(pat_id)
        else:
            return self._getpatients(pat_id)

    def _getpatient(self, pat_id):
        nts, missing = None, None
        
        if pat_id not in self.nts_ids:
            nts = (np.empty(0), np.empty(0), np.empty(0))
            missing = True
        else:
            with h5py.File(self.notes_ts_path, 'r') as f:
                nts = self._format_notes_ts_group(f[f'pat_id_{pat_id}'])
                missing = False

        return nts, missing
    
    def _getpatients(self, pat_ids):
        nts, missing = [], []
        
        match_ids = [pat_id for pat_id in pat_ids if pat_id in self.nts_ids]

        with h5py.File(self.notes_ts_path, 'r') as f:
            for pat_id in pat_ids:
                if pat_id in match_ids:
                    nts.append(self._format_notes_ts_group(f[f'pat_id_{pat_id}']))
                    missing.append(True)
                else:
                    missing.append(False)

        return nts, missing

    @staticmethod
    def _format_notes_ts_group(nts_group):
        group_size = len(nts_group)
        times, cats, notes = [0]*group_size, [0]*group_size, [0]*group_size
        for d in nts_group.keys():
            _, gidx, _, time, _, cat = d.split('_')
            gidx, time, cat = int(gidx), int(time), np.array([int(c) for c in cat])
            times[gidx] = time
            cats[gidx] = cat
            notes[gidx] = nts_group[d][:]

        times, cats = np.array(times), np.array(cats)
        
        max_note_len = max([len(note) for note in notes])
        notes = np.array([np.pad(note, (0, max_note_len-len(note))) for note in notes])

        return times, cats, notes

In [7]:
train_ds = MIMICDataset('../data/processed/', True)

## Embedder

In [6]:
SEP_TOKEN = '[SEP]'

class Vocab(object):
    def __init__(self):
        self.tok2id = {}
        self.id2tok = {}
        self.tok2cnt = {}
        self.cnt = 1

    def add_token(self, token: str):
        if token not in self.tok2id:
            self.tok2id[token] = self.cnt
            self.id2tok[self.cnt] = token
            self.tok2cnt[token] = 1
            self.cnt += 1
        else:
            self.tok2cnt[token] += 1

    def top_tokens(self, top: int):
        return set([tok for _, tok in sorted(list({v:k for k,v in self.tok2cnt.items()}.items()), reverse=True)[:top]])

    def to_json(self, path: str):
        vocab_data = {
            'tok2id': self.tok2id,
            'id2tok': self.id2tok,
            'tok2cnt': self.tok2cnt,
            'cnt': self.cnt
        }

        with open(path, 'w', encoding='utf-8') as f:
            json.dump(vocab_data, f, indent=4)

    @classmethod
    def from_json(cls, path: str):
        with open(path, 'r') as f:
            vocab_data = json.load(f)

        v = Vocab()
        v.tok2id = {k: int(v) for k,v in vocab_data['tok2id'].items()}
        v.id2tok = {int(k): v for k,v in vocab_data['id2tok'].items()}
        v.tok2cnt = {k: int(v) for k,v in vocab_data['tok2cnt'].items()}
        v.cnt = vocab_data['cnt']

        return v

    def __len__(self):
        return self.cnt

In [8]:
class WordEmbedder(nn.Module):
    def __init__(self, vocab, vocab_size, embed_dim, d_model):
        super(WordEmbedder, self).__init__()
        
        # Create idx -> embed_idx mapping
        top_words = vocab.top_tokens(vocab_size+1)
        top_words.remove(SEP_TOKEN)
        self.emb_idx_map = {vocab.tok2id[tok]: idx for idx, tok in enumerate(top_words, start=1)}
        self.sep_tok_id = vocab.tok2id[SEP_TOKEN]
        self.vocab_size = vocab_size

        # SEP[0] VOCAB[1,V] UNKNOWN[V+1] PADDING[V+2]  
        self.embedder = nn.Embedding(vocab_size+3, embed_dim)
        self.linear = nn.Linear(embed_dim, d_model)

    def forward(self, x: Tensor):
        x = torch.LongTensor([[self.embed_idx(idx) for idx in seq] for seq in x]).to(self.embedder.weight.device)
        x = self.embedder(x)
        return F.relu(self.linear(x))

    def embed_idx(self, idx):
        if idx == self.sep_tok_id:
            return 0
        elif idx in self.emb_idx_map:
            return self.emb_idx_map[idx]
        elif idx == 0:
            return self.vocab_size+2
        else:
            return self.vocab_size+1

## Encoder

In [9]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x + self.pe[:,:x.size(1)]
        return self.dropout(x)

## Time2Vec

In [10]:
class Time2Vec(nn.Module):
    def __init__(self, max_len=5000):
        super(Time2Vec, self).__init__()
        
        self.omega = nn.Parameter(torch.randn(max_len))
        self.phi = nn.Parameter(torch.randn(max_len))

    def forward(self, tau):
        seq_len = tau.size(0)
        zero_start = bool(tau[0] == 0)
        tau = (tau*self.omega[:seq_len]) + self.phi[:seq_len]
        tau[zero_start:] = torch.sin(tau[zero_start:])
        return tau

## Model

determine device

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

params

In [12]:
# General 
DROPRATE = 0.3

# Embedder
EMBD_DIM = 10

# Transformer
TRAN_DMODEL = 512
TRAN_NHEAD = 8
TRAN_DFF = 2048

model

In [13]:
word_embed = WordEmbedder(train_ds.vocab, 10000, EMBD_DIM, TRAN_DMODEL).to(device)

pos_encode = PositionalEncoding(TRAN_DMODEL, DROPRATE).to(device)

enc_layer = nn.TransformerEncoderLayer(TRAN_DMODEL, TRAN_NHEAD, TRAN_DFF, DROPRATE, batch_first=True).to(device)

time_2_vec = Time2Vec()

In [44]:
class IrregularTimeNLP(nn.Module):
    def __init__(self, vocab, embed_dim, model_dim, tran_heads, tran_dff, vocab_size=10000, dropout=0.3):
        super(IrregularTimeNLP, self).__init__()

        self.word_embed = WordEmbedder(vocab, vocab_size, embed_dim, model_dim)
        self.pos_encode = PositionalEncoding(model_dim, dropout)
        self.enc_layer = nn.TransformerEncoderLayer(model_dim, tran_heads, tran_dff, dropout, batch_first=True)

        self.time2vec = Time2Vec()

    def forward(self, times, cats, notes):
        notes = self.word_embed(notes)
        notes = self.pos_encode(notes)
        notes = self.enc_layer(notes)
        notes = notes.mean(dim=1)

        times = self.time2vec(times)

        return torch.cat([times.unsqueeze(1),cats,notes], dim=1)

In [45]:
model = IrregularTimeNLP(train_ds.vocab, 10, 512, 8, 2048).to(device)

In [46]:
sample = train_ds[0][0]
sample_time = torch.tensor(sample[0]).to(device)
sample_cat = torch.tensor(sample[1]).to(device)
sample_note = torch.tensor(sample[2]).to(device)

In [47]:
res = model(sample_time, sample_cat, sample_note)

In [50]:
sample

(array([  0,   3,   5,  13,  14,  15,  24,  26,  27,  36,  37,  38,  50,
         50,  50,  50,  51,  62,  64,  65,  74,  76,  86,  86,  86,  86,
         91,  97,  99, 110, 110, 110, 110, 110, 110, 110, 110, 110, 119,
        121, 136]),
 array([[0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0