In [3]:
import torch
import sklearn

In [None]:

class TIHMDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        root: str = "./",
        train=True,
        imputer=impute.SimpleImputer(),
        n_days: int = 1,
        normalise: typing.Union[str, None] = "global",
    ):
        

        self.train = train  # saving whether training or testing
        self._dataset = TIHM(root=root)  # the dataset

        # splitting the data by date to get train-test split
        train_data, test_data, train_target, test_target = self._train_test_split(
            data=self._dataset.data, target=self._dataset.target, test_start=TEST_START
        )

        ## getting arrays from data frames
        # train
        train_patient_id = train_data["patient_id"].values
        train_date = train_data["date"].dt.date.values
        train_data = train_data.drop(["patient_id", "date"], axis=1).values
        train_target = train_target.drop(["patient_id", "date"], axis=1).values
        # test
        test_patient_id = test_data["patient_id"].values
        test_date = test_data["date"].dt.date.values
        test_data = test_data.drop(["patient_id", "date"], axis=1).values
        test_target = test_target.drop(["patient_id", "date"], axis=1).values

        # impute the data with the given imputer
        train_data, test_data = self._impute(
            train_data=train_data, test_data=test_data, imputer=imputer
        )

        if not normalise is None:
            # scale the data with the sklearn StandardScaler
            train_data, test_data = self._normalise(
                train_data=train_data,
                test_data=test_data,
                train_patient_id=train_patient_id,
                test_patient_id=test_patient_id,
                normalise=normalise,
            )

        # reformatting the training and testing data to contain the n_days in each data point
        if n_days > 1:
            (
                train_data,
                train_target,
                train_patient_id,
                train_date,
            ) = self._reformat_n_days(
                data=train_data,
                target=train_target,
                patient_id=train_patient_id,
                date=train_date,
                n_days=n_days,
            )

            test_data, test_target, test_patient_id, test_date = self._reformat_n_days(
                data=test_data,
                target=test_target,
                patient_id=test_patient_id,
                date=test_date,
                n_days=n_days,
            )

        # saving data to class attributes
        self.train_data, self.test_data = train_data, test_data
        self.train_target, self.test_target = train_target, test_target
        self.train_patient_id, self.test_patient_id = train_patient_id, test_patient_id
        self.train_date, self.test_date = train_date, test_date

        return

    @property
    def feature_names(self) -> typing.List[str]:
        """
        The names of the features in the x data.
        """
        return list(self._dataset.data.drop(["patient_id", "date"], axis=1).columns)

    @property
    def target_names(self) -> typing.List[str]:
        """
        The names of the features in the y data.
        """
        return list(self._dataset.target.drop(["patient_id", "date"], axis=1).columns)

    def _train_test_split(
        self,
        data: pd.DataFrame,
        target: pd.DataFrame,
        test_start: str,
    ) -> typing.Tuple[pd.DataFrame]:

        # train
        train_data = data[data["date"] < pd.to_datetime(test_start)]
        train_target = target[target["date"] < pd.to_datetime(test_start)]

        # test
        test_data = data[data["date"] >= pd.to_datetime(test_start)]
        test_target = target[target["date"] >= pd.to_datetime(test_start)]

        return train_data, test_data, train_target, test_target

    def _impute(
        self, train_data: np.ndarray, test_data: np.ndarray, imputer
    ) -> typing.Tuple[np.ndarray]:

        try:
            train_data = imputer.fit_transform(
                train_data
            )  # fit and transform with the train data
            test_data = imputer.transform(test_data)  # transform with the test data
        except AttributeError:
            raise TypeError(
                "Please ensure that the imputer is a sklearn imputer, "
                + "or implements the fit_transform and transform methods."
            )
        return train_data, test_data

    def _normalise(
        self,
        train_data: np.ndarray,
        test_data: np.ndarray,
        train_patient_id: np.ndarray,
        test_patient_id: np.ndarray,
        normalise: str,
    ) -> typing.Tuple[np.ndarray]:

        if normalise == "global":
            scaler = preprocessing.StandardScaler()
            train_data = scaler.fit_transform(train_data)
            test_data = scaler.transform(test_data)

        elif normalise == "id":
            scaler = StandardGroupScaler()
            train_data = scaler.fit_transform(train_data, groups=train_patient_id)
            test_data = scaler.transform(test_data, groups=test_patient_id)

        else:
            raise ValueError(
                f"normalise must be None, 'global' or 'id', not {normalise}."
            )

        return train_data, test_data

    def _reformat_n_days(
        self,
        data: np.ndarray,
        target: np.ndarray,
        patient_id: np.ndarray,
        date: np.ndarray,
        n_days: int,
    ) -> typing.Tuple[np.ndarray]:

        # new arrays
        data_out = []
        target_out = []
        patient_id_out = []
        date_out = []

        # iterate over patient_ids
        for n_id, id_val in enumerate(np.unique(patient_id)):
            idx_id = np.arange(data.shape[0])[patient_id == id_val][
                np.argsort(date[patient_id == id_val])
            ]
            idx_split = np.where(
                date[idx_id][1:] - date[idx_id][:-1] > dt.timedelta(days=1)
            )[0]
            idx_split = np.split(np.arange(idx_id.shape[0]), idx_split + 1)

            for n_split, i_split in enumerate(idx_split):
                data_i = data[idx_id][i_split]
                target_i = target[idx_id][i_split]
                patient_id_i = patient_id[idx_id][i_split]
                date_i = date[idx_id][i_split]

                # if X_i is not long enough to build a sequence
                # then we skip it
                if data_i.shape[0] < n_days:
                    continue

                # roll the data
                data_i_rolled = make_input_roll(data_i, n_days)
                target_i_rolled = make_input_roll(target_i, n_days)
                patient_id_i_rolled = make_input_roll(patient_id_i, n_days)
                date_i_rolled = make_input_roll(date_i, n_days)

                # append to the outputs
                data_out.append(data_i_rolled)
                target_out.append(target_i_rolled)
                patient_id_out.append(patient_id_i_rolled)
                date_out.append(date_i_rolled)

        # make outputs arrays
        data_out = np.vstack(data_out)
        target_out = np.vstack(target_out)
        patient_id_out = np.vstack(patient_id_out)
        date_out = np.vstack(date_out)

        return data_out, target_out, patient_id_out, date_out

    def __getitem__(self, index: int):
        if self.train:
            x, y = self.train_data[index], self.train_target[index]
        else:
            x, y = self.test_data[index], self.test_target[index]
        return x, y

    def __len__(self):
        return len(self.train_data) if self.train else len(self.test_data)