In [1]:
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.offline.estimators import (
    ImportanceSampling,
    WeightedImportanceSampling,
    DirectMethod,
    DoublyRobust,
)
from tqdm import tqdm
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel
import numpy as np

## Demo data

In [2]:
config = (
    DQNConfig()
    .environment(env="CartPole-v1")
    .framework("torch")
    .offline_data(input_="/Users/jk1/temp/ope_tests/cartpole-out")
    .evaluation(
        evaluation_interval=1,
        evaluation_duration=10,
        evaluation_num_workers=1,
        evaluation_duration_unit="timesteps",
        evaluation_config={"input": "/Users/jk1/temp/ope_tests/cartpole-eval"},
        off_policy_estimation_methods={
            "is": {"type": ImportanceSampling},
            "wis": {"type": WeightedImportanceSampling},
            "dm_fqe": {
                "type": DirectMethod,
                "q_model_config": {"type": FQETorchModel, "polyak_coef": 0.05},
            },
            "dr_fqe": {
                "type": DoublyRobust,
                "q_model_config": {"type": FQETorchModel, "polyak_coef": 0.05},
            },
        },
    )
)


In [3]:
algo = config.build()

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2023-12-10 12:26:15,687	INFO worker.py:1673 -- Started a local Ray instance.
[33m(raylet)[0m [2023-12-10 12:26:25,681 E 12512 503067] (raylet) file_system_monitor.cc:111: /tmp/ray/session_2023-12-10_12-26-13_065290_12500 is over 95% full, available space: 1205727232; capacity: 499963174912. Object cr

In [4]:
for _ in range(2):
    algo.train()



In [5]:
from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.offline.estimators import DoublyRobust
from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel

estimator = DoublyRobust(
    policy=algo.get_policy(),
    gamma=0.99,
    q_model_config={"type": FQETorchModel, "n_iters": 160},
)

In [6]:
# Train estimator's Q-model; only required for DM and DR estimators
reader = JsonReader("/Users/jk1/temp/ope_tests/cartpole-out")
for _ in range(2):
    batch = reader.next()
    print(estimator.train(batch))
    # {'loss': ...}

[33m(raylet)[0m [2023-12-10 12:26:35,735 E 12512 503067] (raylet) file_system_monitor.cc:111: /tmp/ray/session_2023-12-10_12-26-13_065290_12500 is over 95% full, available space: 1204441088; capacity: 499963174912. Object creation will fail if spilling is required.


{'loss': 246.2779119031076}
{'loss': 47136.12594086613}


In [7]:
reader = JsonReader("/Users/jk1/temp/ope_tests/cartpole-eval")
# Compute off-policy estimates
for _ in range(2):
    batch = reader.next()
    print(estimator.estimate(batch))



{'v_behavior': 16.334573860702964, 'v_behavior_std': 7.473125277139228, 'v_target': 1981.9499566597967, 'v_target_std': 3193.4366097375023, 'v_gain': 121.3346594506447, 'v_delta': 1965.6153827990938}
{'v_behavior': 16.334573860702964, 'v_behavior_std': 7.473125277139228, 'v_target': 1981.9499566597967, 'v_target_std': 3193.4366097375023, 'v_gain': 121.3346594506447, 'v_delta': 1965.6153827990938}


In [8]:
from ray.rllib.policy.sample_batch import convert_ma_batch_to_sample_batch

batch = convert_ma_batch_to_sample_batch(batch)    


In [9]:
batch

SampleBatch(200: ['obs', 'new_obs', 'actions', 'prev_actions', 'rewards', 'prev_rewards', 'terminateds', 'truncateds', 'infos', 'eps_id', 'unroll_id', 'agent_index', 't', 'action_dist_inputs', 'action_prob', 'action_logp', 'advantages', 'value_targets'])

In [10]:
all_episodes = batch.split_by_episode()

In [11]:
all_episodes[-1]['truncateds']

array([False, False, False, False, False, False, False, False, False,
       False, False, False])

## Custom data

### Prepare data

In [2]:
import pandas as pd

In [3]:
data_path = '/Users/jk1/temp/treatment_effects/preprocessing/gsu_Extraction_20220815_prepro_25112023_213851/splits/train_features_split_0.csv'
continuous_outcomes_path = '/Users/jk1/temp/treatment_effects/preprocessing/gsu_Extraction_20220815_prepro_25112023_213851/splits/train_continuous_outcomes_split_0.csv'
outcomes_path = '/Users/jk1/temp/treatment_effects/preprocessing/gsu_Extraction_20220815_prepro_25112023_213851/splits/train_final_outcomes_split_0.csv'


In [4]:
features_df = pd.read_csv(data_path)
continuous_outcomes_df = pd.read_csv(continuous_outcomes_path)
outcomes_df = pd.read_csv(outcomes_path)

In [5]:
# Features data
features_df.drop(columns=['impute_missing_as'], inplace=True)

pivoted_features_df = features_df.pivot(index=['case_admission_id', 'relative_sample_date_hourly_cat'],
                                        columns='sample_label', values='value')

# get rid of multiindex
pivoted_features_df = pivoted_features_df.rename_axis(None, axis=1).reset_index()

# seperate out treatment features
treatment_df = pivoted_features_df[
    ['case_admission_id', 'relative_sample_date_hourly_cat', 'anti_hypertensive_strategy']]
pivoted_features_df.drop(columns=['anti_hypertensive_strategy'], inplace=True)

In [42]:
pivoted_features_df

Unnamed: 0,case_admission_id,relative_sample_date_hourly_cat,ALAT,ASAT,FIO2,Glasgow Coma Scale,INR,LDL cholesterol calcule,PTT,age,...,referral_other_hospital,referral_self_referral_or_gp,sex_male,sodium,temperature,thrombocytes,triglycerides,uree,wake_up_stroke_true,weight
0,1002417_9090,0,-0.164766,-0.165095,-0.28577,0.385502,-0.751824,-0.163287,-1.135252,-1.49435,...,0.0,0.0,1.0,-1.544155,-0.326528,0.056011,-0.084327,1.061508,1.0,0.743800
1,1002417_9090,1,-0.164766,-0.165095,-0.28577,0.385502,-0.751824,-0.163287,-1.135252,-1.49435,...,0.0,0.0,1.0,-1.544155,-0.094932,0.056011,-0.084327,1.061508,1.0,0.743800
2,1002417_9090,2,-0.164766,-0.165095,-0.28577,0.385502,-0.751824,-0.163287,-1.135252,-1.49435,...,0.0,0.0,1.0,-1.544155,-0.094932,0.056011,-0.084327,1.061508,1.0,0.743800
3,1002417_9090,3,-0.164766,-0.165095,-0.28577,0.385502,-0.751824,-0.163287,-1.135252,-1.49435,...,0.0,0.0,1.0,-1.544155,-0.094932,0.056011,-0.084327,1.061508,1.0,0.880285
4,1002417_9090,4,-0.164766,-0.165095,-0.28577,0.385502,-0.751824,-0.163287,-1.135252,-1.49435,...,0.0,0.0,1.0,-1.544155,-0.094932,0.056011,-0.084327,1.061508,1.0,0.880285
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
104467,996783_7797,67,-0.164766,-0.165095,-0.28577,0.385502,-0.452063,-1.115572,-1.033377,-0.64130,...,0.0,1.0,1.0,-0.124707,-0.558125,2.425328,1.810562,0.058614,0.0,-0.552805
104468,996783_7797,68,-0.164766,-0.165095,-0.28577,0.385502,-0.452063,-1.115572,-1.033377,-0.64130,...,0.0,1.0,1.0,-0.124707,-0.558125,2.425328,1.810562,0.058614,0.0,-0.552805
104469,996783_7797,69,-0.164766,-0.165095,-0.28577,0.385502,-0.452063,-1.115572,-1.033377,-0.64130,...,0.0,1.0,1.0,-0.124707,-0.558125,2.425328,1.810562,0.058614,0.0,-0.552805
104470,996783_7797,70,-0.164766,-0.165095,-0.28577,0.385502,-0.452063,-1.115572,-1.033377,-0.64130,...,0.0,1.0,1.0,-0.124707,-0.558125,2.425328,1.810562,0.058614,0.0,-0.552805


In [6]:
n_episodes = features_df.case_admission_id.nunique()
n_episodes

1451

In [7]:
target_outcome = '3M mRS 0-2'
verbose = True

In [13]:
# As preprocesor does not accept Text spaces, we need to convert them to two integers
cid0 = pivoted_features_df[pivoted_features_df.columns[0]].apply(lambda x: x.split('_')[0]).astype(int)
cid1 = pivoted_features_df[pivoted_features_df.columns[0]].apply(lambda x: x.split('_')[1]).astype(int)

In [28]:
features_with_index_columns_df = pd.concat([cid0, cid1, pivoted_features_df[pivoted_features_df.columns[1:]]], axis=1).astype(float)
features_with_index_columns_df.columns = ['cid0', 'cid1'] + list(pivoted_features_df.columns[1:])

In [27]:
features_with_index_columns_dffeatures_with_index_columns_df

Unnamed: 0,cid0,cid1,relative_sample_date_hourly_cat,ALAT,ASAT,FIO2,Glasgow Coma Scale,INR,LDL cholesterol calcule,PTT,...,referral_other_hospital,referral_self_referral_or_gp,sex_male,sodium,temperature,thrombocytes,triglycerides,uree,wake_up_stroke_true,weight
0,1002417.0,9090.0,0.0,-0.164766,-0.165095,-0.28577,0.385502,-0.751824,-0.163287,-1.135252,...,0.0,0.0,1.0,-1.544155,-0.326528,0.056011,-0.084327,1.061508,1.0,0.743800
1,1002417.0,9090.0,1.0,-0.164766,-0.165095,-0.28577,0.385502,-0.751824,-0.163287,-1.135252,...,0.0,0.0,1.0,-1.544155,-0.094932,0.056011,-0.084327,1.061508,1.0,0.743800
2,1002417.0,9090.0,2.0,-0.164766,-0.165095,-0.28577,0.385502,-0.751824,-0.163287,-1.135252,...,0.0,0.0,1.0,-1.544155,-0.094932,0.056011,-0.084327,1.061508,1.0,0.743800
3,1002417.0,9090.0,3.0,-0.164766,-0.165095,-0.28577,0.385502,-0.751824,-0.163287,-1.135252,...,0.0,0.0,1.0,-1.544155,-0.094932,0.056011,-0.084327,1.061508,1.0,0.880285
4,1002417.0,9090.0,4.0,-0.164766,-0.165095,-0.28577,0.385502,-0.751824,-0.163287,-1.135252,...,0.0,0.0,1.0,-1.544155,-0.094932,0.056011,-0.084327,1.061508,1.0,0.880285
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
104467,996783.0,7797.0,67.0,-0.164766,-0.165095,-0.28577,0.385502,-0.452063,-1.115572,-1.033377,...,0.0,1.0,1.0,-0.124707,-0.558125,2.425328,1.810562,0.058614,0.0,-0.552805
104468,996783.0,7797.0,68.0,-0.164766,-0.165095,-0.28577,0.385502,-0.452063,-1.115572,-1.033377,...,0.0,1.0,1.0,-0.124707,-0.558125,2.425328,1.810562,0.058614,0.0,-0.552805
104469,996783.0,7797.0,69.0,-0.164766,-0.165095,-0.28577,0.385502,-0.452063,-1.115572,-1.033377,...,0.0,1.0,1.0,-0.124707,-0.558125,2.425328,1.810562,0.058614,0.0,-0.552805
104470,996783.0,7797.0,70.0,-0.164766,-0.165095,-0.28577,0.385502,-0.452063,-1.115572,-1.033377,...,0.0,1.0,1.0,-0.124707,-0.558125,2.425328,1.810562,0.058614,0.0,-0.552805


In [41]:
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.offline import JsonWriter
from ray.rllib.evaluation import SampleBatchBuilder
from gymnasium.spaces import Box, Discrete, Dict, Text
import numpy as np
import os

batch_builder = SampleBatchBuilder()  # or MultiAgentSampleBatchBuilder
writer = JsonWriter(os.path.join('/Users/jk1/temp/ope_tests/', 'custom_data_out'))
n_features = len(pivoted_features_df.columns) - 3

# # As preprocesor does not accept Text spaces, we need to convert them to two integers
# cid0 = pivoted_features_df[pivoted_features_df.columns[0]].apply(lambda x: x.split('_')[0]).astype(int)
# cid1 = pivoted_features_df[pivoted_features_df.columns[0]].apply(lambda x: x.split('_')[1]).astype(int)
# 
# obs_space = Dict({
#     'cid': Box(low=min(cid0.min(), cid1.min()), high=max(cid0.max(), cid1.max()), shape=(2,), dtype=int),
#     'timestep': Box(low=pivoted_features_df[pivoted_features_df.columns[1]].min().min(), high=pivoted_features_df[pivoted_features_df.columns[1]].max().max(),
#                     shape=(1,), dtype=int),
#     'data': Box(low=pivoted_features_df[pivoted_features_df.columns[2:]].min().min(), high=pivoted_features_df[pivoted_features_df.columns[2:]].max().max(), shape=(n_features,), dtype=np.float32)
# })

obs_space = Box(low=features_with_index_columns_df.min().min(), high=features_with_index_columns_df.max().max(), shape=(len(features_with_index_columns_df.columns),), dtype=np.float32)

prep = get_preprocessor(obs_space)(obs_space)

if verbose:
    print("The preprocessor is", prep)


# for eps_id in tqdm(range(n_episodes)):
for eps_id in tqdm(range(2)):
    cid = pivoted_features_df.case_admission_id.unique()[eps_id]
    cid_data_df = features_with_index_columns_df[(features_with_index_columns_df.cid0.astype(int) == int(cid.split('_')[0]))
                                                    & (features_with_index_columns_df.cid1.astype(int) == int(cid.split('_')[1]))]
    # for ts in range(cid_data_df.relative_sample_date_hourly_cat.max() + 1):
    for ts in range(int(cid_data_df.relative_sample_date_hourly_cat.max())):
        obs_df = cid_data_df[cid_data_df.relative_sample_date_hourly_cat == ts]
        # obs_data = obs_df[obs_df.columns[2:]].values[0]
        obs_data = obs_df.values[0]
        
        # obs = {
        #     'cid': np.array(cid.split('_')).astype(int),
        #     'timestep': np.array([ts]),
        #     'data': obs_data
        # }
        
        obs = prep.transform(obs_data)
        
        new_obs_df = cid_data_df[cid_data_df.relative_sample_date_hourly_cat == ts+1]
        new_obs_data = new_obs_df.values[0]
        # new_obs_data = new_obs_df[new_obs_df.columns[2:]].values[0]
        # 
        # new_obs = {
        #     'cid': np.array(cid.split('_')).astype(int),
        #     'timestep': np.array([ts+1]),
        #     'data': new_obs_data
        # }
        
        new_obs = prep.transform(new_obs_data)
        
        action = int(treatment_df[(treatment_df.case_admission_id == cid) & (treatment_df.relative_sample_date_hourly_cat == ts)]['anti_hypertensive_strategy'].values[0])
        
        if ts == 0:
            prev_action = action
        else:
            prev_action = int(treatment_df[(treatment_df.case_admission_id == cid) & (treatment_df.relative_sample_date_hourly_cat == ts-1)]['anti_hypertensive_strategy'].values[0])
            
        # if ts == cid_data_df.relative_sample_date_hourly_cat.max():
        if ts == cid_data_df.relative_sample_date_hourly_cat.max() - 1:
            terminated = True
            reward = outcomes_df[outcomes_df.case_admission_id == cid][target_outcome].values[0]
        else:
            terminated = False
            reward = 0
            
        truncated = False
        prev_reward = 0
        info = {}
            
        if verbose:
            print(f'cid: {cid}, ts: {ts}, action: {action}, terminated: {terminated}, reward: {reward}')
            print(f'prev_action: {prev_action}')
            print('---')
 
        batch_builder.add_values(
                t=ts,
                eps_id=eps_id,
                agent_index=0,
                obs=obs,
                actions=action,
                action_prob=1.0,  # put the true action probability here
                action_logp=0.0,
                rewards=reward,
                prev_actions=prev_action,
                prev_rewards=prev_reward,
                terminateds=terminated,
                truncateds=truncated,
                infos=info,
                new_obs=new_obs,
            )
        
    writer.write(batch_builder.build_and_reset())

        

  prep = get_preprocessor(obs_space)(obs_space)


The preprocessor is <ray.rllib.models.preprocessors.NoPreprocessor object at 0x7ff6482bca00>


  0%|          | 0/2 [00:00<?, ?it/s]

cid: 1002417_9090, ts: 0, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 1002417_9090, ts: 1, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 1002417_9090, ts: 2, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 1002417_9090, ts: 3, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 1002417_9090, ts: 4, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 1002417_9090, ts: 5, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 1002417_9090, ts: 6, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 1002417_9090, ts: 7, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 1002417_9090, ts: 8, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 1002417_9090, ts: 9, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 1002417_9090, ts: 10, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 1002417_9090, ts: 11, action: 7, terminated: False, reward:

 50%|█████     | 1/2 [00:01<00:01,  1.35s/it]

cid: 1002417_9090, ts: 69, action: 4, terminated: False, reward: 0
prev_action: 4
---
cid: 1002417_9090, ts: 70, action: 4, terminated: True, reward: 0.0
prev_action: 4
---
cid: 100447_0097, ts: 0, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 100447_0097, ts: 1, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 100447_0097, ts: 2, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 100447_0097, ts: 3, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 100447_0097, ts: 4, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 100447_0097, ts: 5, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 100447_0097, ts: 6, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 100447_0097, ts: 7, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 100447_0097, ts: 8, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 100447_0097, ts: 9, action: 7, terminated: False, reward: 0
prev_a

100%|██████████| 2/2 [00:02<00:00,  1.36s/it]

cid: 100447_0097, ts: 65, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 100447_0097, ts: 66, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 100447_0097, ts: 67, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 100447_0097, ts: 68, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 100447_0097, ts: 69, action: 7, terminated: False, reward: 0
prev_action: 7
---
cid: 100447_0097, ts: 70, action: 7, terminated: True, reward: 1.0
prev_action: 7
---





In [40]:
int(cid_data_df.relative_sample_date_hourly_cat.max())

71

In [19]:
import gymnasium as gym
from gymnasium.spaces import Box, Discrete, Dict, Text
from ray.rllib.models.preprocessors import get_preprocessor


n_features = len(pivoted_features_df.columns) - 2


# space = Box(low=pivoted_features_df[pivoted_features_df.columns[2:]].min().min(), high=pivoted_features_df[pivoted_features_df.columns[2:]].max().max(), shape=(n_features,), dtype=np.float32)

prep = get_preprocessor(space)(space)
# isinstance(space, gym.spaces.Dict)

  prep = get_preprocessor(space)(space)
