In [1]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

In [32]:
from torch.utils.data import Dataset
from enum import Enum
import pandas as pd
import numpy as np
import os


class MIMICDataset(Dataset):
    def __init__(self, processed_dir: str, train: bool):
        self.processed_dir = processed_dir

        if train:
            self.data_path = os.path.join(self.processed_dir, 'train/')
            self.index_path = os.path.join(
                self.processed_dir, 'train_idxs.npy')
        else:
            self.data_path = os.path.join(self.processed_dir, 'test/')
            self.index_path = os.path.join(self.processed_dir, 'test_idxs.npy')

        try:
            self.idxs = np.load(self.index_path)
            self.demographics = pd.read_csv(
                os.path.join(self.data_path, 'demographic.csv'))
            self.vitals = pd.read_csv(
                os.path.join(self.data_path, 'vitals.csv'))
            self.interventions = pd.read_csv(
                os.path.join(self.data_path, 'interventions.csv'))
            self.notes_static = pd.read_csv(
                os.path.join(self.data_path, 'notes_static.csv'))
            self.notes_ts = pd.read_csv(
                os.path.join(self.data_path, 'notes_ts.csv'))
        except FileNotFoundError as e:
            print("Make sure data has been processed: ", e)

        self.demographics.set_index('pat_id', inplace=True)
        self.vitals.set_index(['pat_id', 'hours_in'], inplace=True)
        self.interventions.set_index(['pat_id', 'hours_in'], inplace=True)
        self.notes_static.set_index('pat_id', inplace=True)
        self.notes_ts.set_index(['pat_id', 'hours_in'], inplace=True)

        self.nst_ids = set(self.notes_static.index.values)
        self.nts_ids = set(self.notes_ts.index.get_level_values(0).values)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, idx):
        pat_id = self.idxs[idx]
        demographic = self.demographics.loc[pat_id]

        dem = self.demographics.loc[pat_id]
        vit = self.vitals.loc[pat_id]
        itv = self.interventions.loc[pat_id]

        nst = self.notes_static.loc[pat_id] if pat_id in self.nst_ids else None
        nts = self.notes_ts.loc[pat_id] if pat_id in self.nts_ids else None

        return dem, vit, itv, nst, nts


In [33]:
train_ds = MIMICDataset('../data/processed', True)
test_ds = MIMICDataset('../data/processed', False)

In [34]:
BATCH_SIZE = 32

In [36]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=True)