In [354]:
import json

SEP_TOKEN = '[SEP]'


class Vocab(object):
    def __init__(self):
        self.tok2id = {}
        self.id2tok = {}
        self.tok2cnt = {}
        self.cnt = 1

    def add_token(self, token: str):
        if token not in self.tok2id:
            self.tok2id[token] = self.cnt
            self.id2tok[self.cnt] = token
            self.tok2cnt[token] = 1
            self.cnt += 1
        else:
            self.tok2cnt[token] += 1

    def top_tokens(self, top: int):
        return set([tok for _, tok in sorted(list({v: k for k, v in self.tok2cnt.items()}.items()), reverse=True)[:top]])

    def to_json(self, path: str):
        vocab_data = {
            'tok2id': self.tok2id,
            'id2tok': self.id2tok,
            'tok2cnt': self.tok2cnt,
            'cnt': self.cnt
        }

        with open(path, 'w', encoding='utf-8') as f:
            json.dump(vocab_data, f, indent=4)

    @classmethod
    def from_json(cls, path: str):
        with open(path, 'r') as f:
            vocab_data = json.load(f)

        v = Vocab()
        v.tok2id = {k: int(v) for k, v in vocab_data['tok2id'].items()}
        v.id2tok = {int(k): v for k, v in vocab_data['id2tok'].items()}
        v.tok2cnt = {k: int(v) for k, v in vocab_data['tok2cnt'].items()}
        v.cnt = vocab_data['cnt']

        return v

    def __len__(self):
        return self.cnt


In [424]:
import os

processed_dir = '../data/processed/'
data_path = os.path.join(processed_dir, 'train')
index_path = index_path = os.path.join(processed_dir, "train_idxs.npy")

In [425]:
import numpy as np

indexes = np.load(index_path)

In [426]:
vocab = Vocab.from_json(
    os.path.join(processed_dir, 'vocab.json'))
notes_static_path = os.path.join(data_path, 'notes_static.h5')
notes_ts_path = os.path.join(data_path, 'notes_ts.h5')

In [427]:
import h5py

with h5py.File(notes_static_path, 'r') as f:
        nst_ids = set([int(k.split('_')[-1]) for k in list(f.keys())])
with h5py.File(notes_ts_path, 'r') as f:
        nts_ids = set([int(k.split('_')[-1]) for k in list(f.keys())])

### Sample Get (Single)

In [465]:
item_idx = np.arange(100)
# item_idx = 0

In [466]:
pat_id = indexes[item_idx]

In [213]:
TIMESERIES_DIM = 150
NOTES_TIME_DIM = 128

In [300]:
with h5py.File(notes_ts_path, 'r') as f:
    nts = f[f'pat_id_{pat_id}']
    nts = _format_notes_ts_group(nts)
    notes, mask = nts

In [414]:
notes.shape

(128, 330)

In [467]:
nst, nts, nst_msk, nts_msk = getpatients_notes(pat_id)

In [464]:
print(nst.shape,nts.shape,nst_msk.shape,nts_msk.shape)

(4, 1) (4, 128, 500) (4,) (4, 128)


In [468]:
print(nst.shape,nts.shape,nst_msk.shape,nts_msk.shape)

(100, 1299) (100, 128, 4799) (100,) (100, 128)


In [460]:
def getpatients_notes(pat_ids):
    match_nst = set([pat_id for pat_id in pat_ids if pat_id in nst_ids])
    nst, nst_msk = [], np.zeros(len(pat_ids))
    with h5py.File(notes_static_path, 'r') as f:
        for pidx, pid in enumerate(pat_ids):
            if pid in match_nst:
                nst.append(f[f'pat_id_{pid}'][:])
                nst_msk[pidx] = 1
                
    if len(match_nst):
        nst = padded_stack(nst)
        nst = pad_missing(nst, nst_msk)
    else:
        nst = np.zeros((len(pat_ids),1))

    match_nts = set([pat_id for pat_id in pat_ids if pat_id in nts_ids])
    nts, nts_msk = [], np.zeros((len(pat_ids), NOTES_TIME_DIM))
    with h5py.File(notes_ts_path, 'r') as f:
        for pidx, pid in enumerate(pat_ids):
            if pid in match_nts:
                gnotes, gmask = _format_notes_ts_group(f[f'pat_id_{pid}'])
                nts.append(gnotes)
                nts_msk[pidx] = gmask
                
    if len(match_nts):
        nts = padded_stack(nts)
        nts = pad_missing(nts, nts_msk.sum(axis=1)>0)
    else:
        nts = np.zeros((len(pat_ids),NOTES_TIME_DIM,1))
   
    return nst, nts, nst_msk, nts_msk

In [418]:
def _getpatients_notes(self, pat_ids):
        nst, nts = [], []
        missing = []

        missing_st = []
        match_ids = set(
            [pat_id for pat_id in pat_ids if pat_id in self.nst_ids])
        with h5py.File(self.notes_static_path, 'r') as f:
            for pat_id in pat_ids:
                if pat_id in match_ids:
                    nst.append(f[f'row_{pat_id}'][:])
                    missing_st.append(False)
                else:
                    missing_st.append(True)

        missing_ts = []
        match_ids = set(
            [pat_id for pat_id in pat_ids if pat_id in self.nts_ids])
        with h5py.File(self.notes_ts_path, 'r') as f:
            for pat_id in pat_ids:
                if pat_id in match_ids:
                    nts.append(self._format_notes_ts_group(
                        f[f'pat_id_{pat_id}']))
                    missing_ts.append(False)
                else:
                    missing_ts.append(True)

        nst = self._format_notes_static(nst)
        missing = np.array(list(zip(missing_st, missing_ts)))

        return nst, nts, missing

In [415]:
def _format_notes_ts_group(nts_group):
        group_size = len(nts_group)
        times, notes = [0]*group_size, [0]*group_size
        for d in nts_group.keys():
            _, gidx, _, time = d.split('_')
            gidx, time = int(gidx), int(time)
            times[gidx] = time
            notes[gidx] = nts_group[d][:]

        mask = np.zeros(NOTES_TIME_DIM)
        for idx in times:
            mask[idx] = 1

        notes = padded_stack(notes)
        notes = pad_missing(notes, mask)

        return notes, mask


In [432]:
def getpatient_notes(pat_id):
        nst, nts = np.empty(0), (np.empty(0), np.empty(0), np.empty(0))
        nst_msk, nts_msk = np.zeros(1), np.zeros(NOTES_TIME_DIM)

        if pat_id in nst_ids:
            with h5py.File(notes_static_path, 'r') as f:
                nst = f[f'pat_id_{pat_id}'][:]
                nst_msk[0] = 1

        if pat_id in nts_ids:
            with h5py.File(notes_ts_path, 'r') as f:
                nts, nts_msk = _format_notes_ts_group(f[f'pat_id_{pat_id}'])

        return nst, nts, nst_msk, nts_msk

In [411]:
def pad_axis(arr: np.array, fill_to: int, axis: int):
    assert axis < arr.ndim

    curr_dim = arr.shape[axis]
    padding = [(0, 0)] * arr.ndim
    padding[axis] = (0, fill_to-curr_dim)
    return np.pad(arr, padding)


def padded_stack(mat: List[np.array], fill_dims=None):
    ndim = mat[0].ndim

    if fill_dims is None:
        fill_dims = np.max([sub.shape for sub in mat], axis=0)
    elif ndim == 1 and type(fill_dims) == int:
        fill_dims = [fill_dims]
    else:
        max_dims = np.max([sub.shape for sub in mat], axis=0)
        fill_dims = [mdim if dim==-1 else dim for mdim, dim in zip(max_dims, fill_dims)]

    padded_mats = []
    for submat in mat:
        padding = [(0, fill_to-dim)
                   for dim, fill_to in zip(submat.shape, fill_dims)]
        padded_mats.append(np.pad(submat, padding))

    return np.array(padded_mats)


def pad_missing(mat: np.array, mask: np.array):
    new_shape = mask.shape[0:1] + mat.shape[1:]
    full = np.zeros(new_shape)

    mask_idx = 0
    for full_idx, exists in enumerate(mask):
        if exists:
            full[full_idx] = mat[mask_idx]
            mask_idx += 1

    return full

### Sample Get (Multiple)

In [221]:
item_idx = [0,1,2,3]

In [222]:
pat_ids = indexes[item_idx]
vit = vitals.loc[pat_ids]

In [247]:
vitp, vit_msk = _format_ts_batch(vit)
print(vitp.shape)
print(vit_msk.shape)

(4, 150, 104)
(4, 150)


In [243]:
vitp = _format_ts_batch(vit)

In [239]:
[pad_axis(row, TIMESERIES_DIM, 0).shape for row in vitp]

[(150, 150, 104), (150, 150)]

In [241]:
vitp

2