In [3]:
import torch
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import tqdm
from pathlib import Path
from sklearn.model_selection import train_test_split
import numpy as np

In [98]:
WINSIZE = 101   # for this dataset

class NursingDatasetV1(Dataset):
    """
        Dataset class to handle the nursingv1_dataset
    """

    def __init__(self, dir: Path, session_idxs: list) -> None:
        super().__init__()
        self._dir = dir
        self._session_idxs = session_idxs
        
        ## Get info from session sizes
        # save length of each session in dataset
        self._lengths = []  # TODO might be able to replace this with just a sum - dont need to save lengths and rn its useless
        # Save mapping from each possible index to the session that window is in
        self._idx_to_session = []

        for session_idx in self._session_idxs:
            # Get shape of session from dataset
            session_shape = torch.load(dir / f'{session_idx}' / 'Xshape.pt')
            
            # Save number of windows, which is session length - winsize + 1
            self._lengths.append(session_shape[1] - WINSIZE + 1)

            # Save which indices should map to this session as tuple (session, window idx in that session)
            self._idx_to_session += zip([session_idx]*self._lengths[-1], list(range(self._lengths[-1])))
            # print(session_idx, ':', self._idx_to_session[-1], '---', self._lengths[-1])


        # Save random mapping of internal window indices to external indices
        self._idxs = list(range(sum(self._lengths)))
        np.random.shuffle(self._idxs)
        

    def __getitem__(self, index: int) -> torch.Tensor:
        # Return one single window from one of the sessions and its label
        # return data in shape for convolution rather than linear input for now

        # For now, only support postive integer indices
        if not isinstance(index, int) or index < 0:
            print("Error: Unsupported index type")
            return None
        

        ## Get session to choose window from based on index
        # Use random mapping to choose random index
        idx = self._idxs[index]     # Will catch index out of bounds
        x,y = self._get_one_window_and_label(idx)
        return (x,y)

    def _get_one_window_and_label(self, idx: int) -> tuple[torch.Tensor]:
        
        # Get the session that this idx is in and the idx within that session
        session_idx, window_idx = self._idx_to_session[idx]

        # Read whole session and label files
        X = torch.load(self._dir / f'{session_idx}' / 'X.pt')
        y = torch.load(self._dir / f'{session_idx}' / 'y.pt')
        # print(session_idx, window_idx, X.shape[1] - WINSIZE +1)

        # Window session starting at window_idx
        window = X[:, window_idx:window_idx+WINSIZE]
        label = y[window_idx]

        return (window, label)

    def __len__(self) -> int:
        # Total number of windows in every session is length of dataset
        return sum(self._lengths)

    def get_one_session(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
        # Get one unwindowed session from session_idxs and its labels
        pass

    def get_all_sessions(self) -> list[tuple[torch.Tensor, torch.Tensor]]:
        # return list of all unwindowed sessions and their labels
        pass

    def get_one_windowed_session(self, idx) -> TensorDataset:
        # Return one windowed session and its labels as tensor dataset
        pass

    def get_all_windowed_sessions(self) -> list[TensorDataset]:
        # Return all windowed sessions and their labels as list of tensor datasets
        pass
    
    @staticmethod
    def nursingv1_train_dev_test_split(
        dir: Path, 
        train_size: float, 
        dev_size: float, 
        test_size: float,
        shuffle: bool = False,
    ) -> tuple:
        """
            Creates and returns three NursingDatasetV1 objects for train,
                dev, and test purposes. Each of the three objects are given
                a subset of the total sessions in the dataset. The number
                of sessions given to each dataset is set with train, dev,
                and test size parameters, which each represent a percentage 
                of the total number of sessions.
        Args:
            dir (Path): filepath to nursingv1 dataset in filesystem
            train_size (float): percent of sessions for train dataset
            dev_size (float): percent of sessions for dev dataset
            test_size (float): percent of sessions for test dataset
            shuffle (bool, optional): shuffle dataset before split. Defaults to False.

        Returns:
            tuple: Three NursingDatasetV1 objects (train, dev, test)
        """

        ## Check parameters:
        if not dir.is_dir():
            print("Error: directory does not exist")
            return None
        
        if sum([train_size, dev_size, test_size]) != 1:
            print("Error: train_size + dev_size + test_size != 1")
            return None

        ## Get list of all session idxs in dataset
        sessions = []
        for session in dir.iterdir():
            sessions.append(session.name)
        
        ## Split sessions into train, dev, and test
        # Shuffle first if desired
        if shuffle:
            np.random.shuffle(sessions)

        # Get size of partitions
        n_train_idxs = round(train_size * len(sessions))
        n_dev_idxs = round(dev_size * len(sessions))

        # Split sessions into three parts
        train_idxs, dev_idxs, test_idxs = np.split(
            sessions,
            [n_train_idxs, n_train_idxs + n_dev_idxs]
        )

        return (
            NursingDatasetV1(dir, train_idxs),
            NursingDatasetV1(dir, dev_idxs),
            NursingDatasetV1(dir, test_idxs)
        )

In [99]:
nursingv1_dir = Path('../data/nursingv1_dataset')
np.random.seed(0)
train_dataset, dev_dataset, test_dataset = NursingDatasetV1.nursingv1_train_dev_test_split(nursingv1_dir, 0.5, 0.2, 0.3)

for X,y in tqdm(train_dataset):
    x = X.shape

# this version takes 5 minutes to iterate through entire test dataset

100%|██████████| 894100/894100 [05:04<00:00, 2936.93it/s]
