In [2]:
import pandas as pd
from src.data.mimic_iii.real_dataset import MIMIC3RealDataset

2024-01-06 14:11:28.545214: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
gsu_features_path = '/Users/jk1/temp/treatment_effects/preprocessing/gsu_Extraction_20220815_prepro_25112023_213851/preprocessed_features_25112023_213851.csv'
gsu_continuous_outcomes_path = '/Users/jk1/temp/treatment_effects/preprocessing/gsu_Extraction_20220815_prepro_25112023_213851/preprocessed_outcomes_continuous_25112023_213851.csv'

In [4]:
features_df = pd.read_csv(gsu_features_path)
outcomes_df = pd.read_csv(gsu_continuous_outcomes_path)

# Preprocessing

In [5]:
last_timestep = None

In [6]:
# Features data
features_df.drop(columns=['impute_missing_as'], inplace=True)
pivoted_features_df = features_df.pivot(index=['case_admission_id', 'relative_sample_date_hourly_cat'],
                                        columns='sample_label', values='value')

# get rid of multiindex
pivoted_features_df = pivoted_features_df.rename_axis(None, axis=1).reset_index()

if last_timestep is not None:
    pivoted_features_df = pivoted_features_df[
        pivoted_features_df.relative_sample_date_hourly_cat < last_timestep + 1]

# seperate out treatment features
treatment_df = pivoted_features_df[
    ['case_admission_id', 'relative_sample_date_hourly_cat', 'anti_hypertensive_strategy']]
pivoted_features_df.drop(columns=['anti_hypertensive_strategy'], inplace=True)

# Set the 2-level index:
treatment_df.set_index(keys=["case_admission_id", "relative_sample_date_hourly_cat"], drop=True, inplace=True)
pivoted_features_df.set_index(keys=["case_admission_id", "relative_sample_date_hourly_cat"], drop=True,
                              inplace=True)

# Outcome data
if last_timestep is not None:
    outcomes_df = outcomes_df[outcomes_df.relative_sample_date_hourly_cat < last_timestep + 1]
outcomes_df.set_index(keys=["case_admission_id", "relative_sample_date_hourly_cat"], drop=True, inplace=True)
reformatted_outcomes_df = outcomes_df[['nihss_delta_at_next_ts']]

# Load into Dataset

In [7]:
# rename index 'case_admission_id' to 'subject_id'
treatment_df.index = treatment_df.index.rename(['subject_id', 'relative_sample_date_hourly_cat'])
pivoted_features_df.index = pivoted_features_df.index.rename(['subject_id', 'relative_sample_date_hourly_cat'])
reformatted_outcomes_df.index = reformatted_outcomes_df.index.rename(['subject_id', 'relative_sample_date_hourly_cat'])

In [9]:
static_features = pivoted_features_df.copy()
static_features = static_features.drop(columns=pivoted_features_df.columns)

In [10]:
static_features

subject_id,relative_sample_date_hourly_cat
100023_4784,0
100023_4784,1
100023_4784,2
100023_4784,3
100023_4784,4
...,...
9996_3256,67
9996_3256,68
9996_3256,69
9996_3256,70


In [7]:
scaling_params = {}
scaling_params['output_stds'] = reformatted_outcomes_df.std().values
scaling_params['output_means'] = reformatted_outcomes_df.mean().values

In [8]:
ds = MIMIC3RealDataset(
    treatments=treatment_df,
    outcomes=reformatted_outcomes_df,
     vitals=pivoted_features_df,
     static_features=pd.DataFrame(None),
     outcomes_unscaled=reformatted_outcomes_df,
     scaling_params=scaling_params,
     subset_name='train'
)

In [9]:
ds.data

{'sequence_lengths': array([71, 71, 71, ..., 71, 71, 71]),
 'prev_treatments': array([[[4.],
         [4.],
         [4.],
         ...,
         [0.],
         [0.],
         [0.]],
 
        [[7.],
         [7.],
         [7.],
         ...,
         [4.],
         [4.],
         [4.]],
 
        [[7.],
         [7.],
         [7.],
         ...,
         [7.],
         [7.],
         [7.]],
 
        ...,
 
        [[7.],
         [7.],
         [7.],
         ...,
         [3.],
         [3.],
         [3.]],
 
        [[7.],
         [7.],
         [7.],
         ...,
         [7.],
         [7.],
         [7.]],
 
        [[7.],
         [7.],
         [7.],
         ...,
         [7.],
         [7.],
         [7.]]]),
 'vitals': array([[[-0.30182825, -0.64252919, -0.2857695 , ...,  0.01302788,
           1.        ,  1.15325429],
         [-0.30182825, -0.64252919, -0.2857695 , ...,  0.01302788,
           1.        ,  1.15325429],
         [-0.30182825, -0.64252919, -0.2857695 

In [10]:
from src.data import RealDatasetCollection


class MIMIC3RealDatasetCollection(RealDatasetCollection):
    """
    Dataset collection (train_f, val_f, test_f)
    """
    def __init__(self,
                 ds: MIMIC3RealDataset,
                 projection_horizon: int = 5,
                 autoregressive=True,
                 **kwargs):
        """
        Args:
            path: Path with MIMIC-3 dataset (HDFStore)
            min_seq_length: Min sequence lenght in cohort
            max_seq_length: Max sequence lenght in cohort
            seed: Seed for random cohort patient selection
            max_number: Maximum number of patients in cohort
            split: Ratio of train / val / test split
            projection_horizon: Range of tau-step-ahead prediction (tau = projection_horizon + 1)
            autoregressive:
        """
        super(MIMIC3RealDatasetCollection, self).__init__()
        self.train_f = ds
        self.val_f = ds
        self.test_f = ds

        self.projection_horizon = projection_horizon
        self.has_vitals = True
        self.autoregressive = autoregressive
        self.processed_data_encoder = True

In [11]:
ds_collection = MIMIC3RealDatasetCollection(ds)

In [12]:
ds_collection.train_f.data

{'sequence_lengths': array([71, 71, 71, ..., 71, 71, 71]),
 'prev_treatments': array([[[4.],
         [4.],
         [4.],
         ...,
         [0.],
         [0.],
         [0.]],
 
        [[7.],
         [7.],
         [7.],
         ...,
         [4.],
         [4.],
         [4.]],
 
        [[7.],
         [7.],
         [7.],
         ...,
         [7.],
         [7.],
         [7.]],
 
        ...,
 
        [[7.],
         [7.],
         [7.],
         ...,
         [3.],
         [3.],
         [3.]],
 
        [[7.],
         [7.],
         [7.],
         ...,
         [7.],
         [7.],
         [7.]],
 
        [[7.],
         [7.],
         [7.],
         ...,
         [7.],
         [7.],
         [7.]]]),
 'vitals': array([[[-0.30182825, -0.64252919, -0.2857695 , ...,  0.01302788,
           1.        ,  1.15325429],
         [-0.30182825, -0.64252919, -0.2857695 , ...,  0.01302788,
           1.        ,  1.15325429],
         [-0.30182825, -0.64252919, -0.2857695 