Get the dataset and indices

In [11]:
import os
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), "../../.."))
sys.path.append(project_root)

Load the train config

In [12]:
import yaml

with open('train_config.yaml', 'r') as file:
    train_config = yaml.safe_load(file)

use_LR = train_config['train']['training_method'] == 'LR'
data_path = train_config['data']['data_dir']

if use_LR:
    path = data_path + "LR_data/"
else:
    path = data_path + "GRUD_data/"

dataset_path = os.path.join(path, "dataset.pkl")

This function processes ICU patient data by filtering out stays with insufficient recorded hours and selecting only the first `WINDOW_SIZE` hours. It creates a binary target variable, `los_3`, indicating whether a patient stayed in the ICU for more than three days. The function returns the filtered time-series data and the corresponding target values.  


In [13]:
def preprocess_data(statics, data):
    """
    Notes:
        - Only ICU stays longer than `WINDOW_SIZE + GAP_TIME` hours are considered.
        - `WINDOW_SIZE` defines how many initial hours of ICU stay are kept.
        - `GAP_TIME` accounts for a buffer period before prediction.
    """
    GAP_TIME = 6  # In hours
    WINDOW_SIZE = 24  # In hours

    # Define target labels
    y = statics[statics.max_hours > WINDOW_SIZE + GAP_TIME][["los_icu"]].copy()
    y["los_3"] = (y["los_icu"] > 3).astype(float)
    y.drop(columns=["los_icu"], inplace=True)

    # Filter data: keep only ICU stays present in y and within the first WINDOW_SIZE hours
    data = data[
        (data.index.get_level_values("icustay_id").isin(y.index.get_level_values("icustay_id"))) &
        (data.index.get_level_values("hours_in") < WINDOW_SIZE)
    ]

    return data, y

In [14]:
import pandas as pd
from torch import from_numpy
from tqdm import tqdm
import numpy as np
from torch import cat, tensor
import pickle
from mimic_handler import MIMICInputHandler
        
if os.path.exists(dataset_path): 
        print("Loading dataset...")
        with open(dataset_path, "rb") as f:
            dataset = pickle.load(f)  # Load the dataset
        print(f"Loaded dataset from {dataset_path}")
else:
    print("Creating dataset...")
    data_file_path = os.path.join(data_path, "all_hourly_data.h5")
    if os.path.exists(data_file_path):
        print("Loading data...")
        data = pd.read_hdf(data_file_path, "vitals_labs")
        statics = pd.read_hdf(data_file_path, "patients")

        ID_COLS = ["subject_id", "hadm_id", "icustay_id"]
        data, y = preprocess_data(statics, data)

        if use_LR:
            print("Flattening data for LR...")
            flat_data = data.pivot_table(index=ID_COLS, columns=["hours_in"])

            print("Flattening data...")
            data, y = [
                df.reset_index(drop=True)
                for df in tqdm((flat_data, y), desc="Flattening Index") ]


        assert np.issubdtype(data.values.dtype, np.number), "Non-numeric data found in features."
        assert np.issubdtype(y.values.dtype, np.number), "Non-numeric data found in labels."

        print("Creating dataset...")
        y_tensor = from_numpy(y.values).float()
        
        if use_LR:
            data_tensor = from_numpy(data.values).float()
            dataset = MIMICInputHandler.UserDataset(data_tensor, y_tensor)
        else:
            data_x = MIMICInputHandler.to_3D_tensor(data)
            dataset = MIMICInputHandler.UserDataset(data_x, y_tensor)

        os.makedirs(os.path.dirname(dataset_path), exist_ok=True)
        # Save the dataset to dataset.pkl
        print("Saving dataset and indices...")
        with open(dataset_path, "wb") as file:
            pickle.dump(dataset, file)
            print(f"Saved dataset to {dataset_path}")
    else:
        msg = "Please download the MIMIC-III dataset from https://physionet.org/content/mimiciii/1.4/ and save it in the specified path."
        raise FileNotFoundError(msg)



Creating dataset...
Loading data...
Flattening data for LR...
Flattening data...


Flattening Index: 100%|██████████| 2/2 [00:00<00:00,  4.42it/s]


Creating dataset...
Saving dataset and indices...
Saved dataset to ./data/LR_data/dataset.pkl
