In [30]:
from typing import SupportsFloat, Any

import gymnasium as gym
from gymnasium.core import ActType, ObsType
from gymnasium.spaces import Box, Discrete
import pandas as pd
import numpy as np

In [31]:
data_path = '/Users/jk1/temp/treatment_effects/preprocessing/gsu_Extraction_20220815_prepro_25112023_213851/splits/train_features_split_0.csv'
continuous_outcomes_path = '/Users/jk1/temp/treatment_effects/preprocessing/gsu_Extraction_20220815_prepro_25112023_213851/splits/train_continuous_outcomes_split_0.csv'
outcomes_path = '/Users/jk1/temp/treatment_effects/preprocessing/gsu_Extraction_20220815_prepro_25112023_213851/splits/train_final_outcomes_split_0.csv'

In [32]:
features_df = pd.read_csv(data_path)
continuous_outcomes_df = pd.read_csv(continuous_outcomes_path)
outcomes_df = pd.read_csv(outcomes_path)

In [17]:
features_df.head()

Unnamed: 0.1,Unnamed: 0,relative_sample_date_hourly_cat,case_admission_id,sample_label,source,value,impute_missing_as,patient_id
0,6120,0,1002417_9090,ALAT,EHR_pop_imputed,-0.164766,,1002417
1,6121,1,1002417_9090,ALAT,EHR_pop_imputed_locf_imputed,-0.164766,,1002417
2,6122,2,1002417_9090,ALAT,EHR_pop_imputed_locf_imputed,-0.164766,,1002417
3,6123,3,1002417_9090,ALAT,EHR_pop_imputed_locf_imputed,-0.164766,,1002417
4,6124,4,1002417_9090,ALAT,EHR_pop_imputed_locf_imputed,-0.164766,,1002417


In [18]:
continuous_outcomes_df.head()

Unnamed: 0.1,Unnamed: 0,relative_sample_date_hourly_cat,case_admission_id,source,nihss,nihss_delta,nihss_delta_to_best_prior_state,nihss_delta_to_start_state,nihss_delta_at_next_ts,nihss_delta_to_best_prior_state_at_next_ts,nihss_delta_to_start_state_at_next_ts,patient_id
0,9936,0,1002417_9090,EHR,6.0,0.0,0.0,0.0,6.0,6.0,6.0,1002417
1,9937,1,1002417_9090,EHR,12.0,6.0,6.0,6.0,0.0,6.0,6.0,1002417
2,9938,2,1002417_9090,EHR_locf_imputed,12.0,0.0,6.0,6.0,0.0,6.0,6.0,1002417
3,9939,3,1002417_9090,EHR_locf_imputed,12.0,0.0,6.0,6.0,-1.0,5.0,5.0,1002417
4,9940,4,1002417_9090,EHR,11.0,-1.0,5.0,5.0,-1.0,4.0,4.0,1002417


In [33]:
outcomes_df.head()

Unnamed: 0.1,Unnamed: 0,case_admission_id,Symptomatic ICH,Symptomatic ICH date,Recurrent stroke,Recurrent stroke date,Orolingual angioedema,Death in hospital,Death at hospital date,Death at hospital time,...,3M ICH date,3M Death,3M Death date,3M Death cause,3M Epileptic seizure,3M Epileptic seizure date,3M delta mRS,3M mRS 0-1,3M mRS 0-2,patient_id
0,3,100503_0884,no,,no,,,0.0,,,...,,0.0,,,,,0.0,0.0,0.0,100503
1,6,1005798_9217,no,,no,,,0.0,,,...,,0.0,,,no,,1.0,1.0,1.0,1005798
2,13,1012915_7747,no,,no,,,0.0,,,...,,0.0,,,no,,0.0,1.0,1.0,1012915
3,15,1018635_4340,no,,no,,,0.0,,,...,,1.0,20201120.0,unknown,no,,2.0,0.0,0.0,1018635
4,16,1020314_6753,no,,no,,,0.0,,,...,,0.0,,,no,,3.0,0.0,0.0,1020314


# Prepare data

In [19]:
# Features data
features_df.drop(columns=['impute_missing_as'], inplace=True)

pivoted_features_df = features_df.pivot(index=['case_admission_id', 'relative_sample_date_hourly_cat'],
                                        columns='sample_label', values='value')

# get rid of multiindex
pivoted_features_df = pivoted_features_df.rename_axis(None, axis=1).reset_index()

# seperate out treatment features
treatment_df = pivoted_features_df[
    ['case_admission_id', 'relative_sample_date_hourly_cat', 'anti_hypertensive_strategy']]
pivoted_features_df.drop(columns=['anti_hypertensive_strategy'], inplace=True)

In [24]:
n_features = len(pivoted_features_df.columns) - 2 
n_features

84

In [26]:
n_treatment_actions = treatment_df.anti_hypertensive_strategy.nunique()
n_treatment_actions

8

In [29]:
# max and min of all features
pivoted_features_df[pivoted_features_df.columns[2:]].max().max(), pivoted_features_df[pivoted_features_df.columns[2:]].min().min()

(20.57745506446808, -7.048474126005384)

In [None]:
class PatientEnv(gym.Env):
    def __init__(self, features_df, outcomes_df, n_treatment_actions, target_outcome):
        self.features_df = features_df
        self.outcomes_df = outcomes_df
        self.target_outcome = target_outcome
        # Action space is the number of unique values for the anti_hypertensive_strategy 
        self.action_space = Discrete(n_treatment_actions)
        
        n_features = len(pivoted_features_df.columns) - 2 
        self.observation_space = Box(low=pivoted_features_df[pivoted_features_df.columns[2:]].min().min(), high=pivoted_features_df[pivoted_features_df.columns[2:]].max().max(), shape=(n_features,), dtype=np.float32)
        
    def _get_obs(self, case_admission_id, relative_sample_date_hourly_cat):
        obs = self.features_df[(self.features_df.case_admission_id == case_admission_id) & (self.features_df.relative_sample_date_hourly_cat == relative_sample_date_hourly_cat)]
        obs = obs[obs.columns[2:]].values
        return obs
    
    def _take_action(self, action):
        pass
    
    def _get_reward(self, case_admission_id, relative_sample_date_hourly_cat):
        # check if last sample for patient
        if relative_sample_date_hourly_cat == self.features_df[self.features_df.case_admission_id == case_admission_id].relative_sample_date_hourly_cat.max():
            reward = self.outcomes_df[self.outcomes_df.case_admission_id == case_admission_id][self.target_outcome].values[0]
        else:
            reward = 0
        return reward
    
    def _is_done(self, case_admission_id, relative_sample_date_hourly_cat):
        # check if last sample for patient
        if relative_sample_date_hourly_cat == self.features_df[self.features_df.case_admission_id == case_admission_id].relative_sample_date_hourly_cat.max():
            done = True
        else:
            done = False
        return done
    
    def reset(self):
        self.current_case_admission_id = self.features_df.case_admission_id.sample(1).values[0]
        self.current_relative_sample_date_hourly_cat = self.features_df[self.features_df.case_admission_id == self.current_case_admission_id].relative_sample_date_hourly_cat.min()
        obs = self._get_obs(self.current_case_admission_id, self.current_relative_sample_date_hourly_cat)
        return obs
    
    def step(
        self, action: ActType
    ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
        self._take_action(action)
        obs = self._get_obs(self.current_case_admission_id, self.current_relative_sample_date_hourly_cat)
        reward = self._get_reward(self.current_case_admission_id, self.current_relative_sample_date_hourly_cat)
        done = self._is_done(self.current_case_admission_id, self.current_relative_sample_date_hourly_cat)  
        
        if not done:
            self.current_relative_sample_date_hourly_cat += 1
        
        return obs, reward, done, {}, {}