## Import Modules

In [29]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import torch.nn as nn
import math
import numpy as np
from sklearn.model_selection import KFold
import pandas as pd

In [30]:
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

## Dataset

In [58]:
from torch.utils.data import Dataset
from enum import Enum
import pandas as pd
import numpy as np
import os


class NotesDataset(Dataset):
    def __init__(self, processed_dir: str, train: bool):
        self.processed_dir = processed_dir

        if train:
            self.data_path = os.path.join(self.processed_dir, 'train/')
            self.index_path = os.path.join(
                self.processed_dir, 'train_idxs.npy')
        else:
            self.data_path = os.path.join(self.processed_dir, 'test/')
            self.index_path = os.path.join(self.processed_dir, 'test_idxs.npy')

        try:
            self.idxs = np.load(self.index_path)
            self.notes = pd.read_csv(
                os.path.join(self.data_path, 'notes_static.csv'))
            self.labels = pd.read_csv(os.path.join(self.data_path, 'labels.csv'))
        except FileNotFoundError as e:
            print("Make sure data has been processed: ", e)

        self.notes.set_index(['pat_id'], inplace=True)
        self.labels.set_index('pat_id', inplace=True)

        self.ids = set(self.notes.index.values)

        self.notes = self.notes.to_dict(orient='index')

        self.num_classes = self.labels.shape[1]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, idx):
        pat_id = self.idxs[idx]

        lbl = self.labels.loc[pat_id].values
        
        if type(pat_id) != np.ndarray:
            if pat_id not in self.ids:
                return '', lbl
            else:
                note = self.notes[pat_id]['TEXT']
                return note, lbl

        pat_ids = pat_id
        
        notes = list()
        for pat_id in pat_ids:
            if pat_id in self.ids:
                notes.append(self.notes[pat_id]['TEXT'])
            else:
                notes.append('')

        return notes, lbl

In [64]:
train_ds = NotesDataset('../data/processed', True)