In [9]:
import pandas as pd
import os
from ray.rllib.offline import JsonReader
from evaluation.off_policy_evaluation.rllib_policy_from_table import reconstitute_case_admission_id
from data_loaders.rllib_data_io import rllib_gsu_dataset_creation
from evaluation.off_policy_evaluation.rllib_policy_from_table import PolicyFromTable


In [2]:
gsu_features_path = '/Users/jk1/temp/treatment_effects/preprocessing/gsu_Extraction_20220815_prepro_25112023_213851/splits/val_features_split_0.csv'
gsu_final_outcomes_path = '/Users/jk1/temp/treatment_effects/preprocessing/gsu_Extraction_20220815_prepro_25112023_213851/splits/val_final_outcomes_split_0.csv'
output_path = '/Users/jk1/temp/ope_tests/custom_data_out/val_split0_temp'

## Testing data conversion to Rllib batch format

In [3]:
rllib_gsu_dataset_creation(gsu_features_path = gsu_features_path, gsu_final_outcomes_path=gsu_final_outcomes_path,
                               output_path=output_path,
                                save_index_columns=True, verbose=False)

100%|██████████| 356/356 [03:06<00:00,  1.91it/s]


In [4]:
# get file in output_path
available_files = [file for file in os.listdir(output_path) if file.endswith('.json')]
if len (available_files) != 1:
    raise ValueError('output_path should contain only one file')
else:
    saved_data_file = available_files[0]
saved_data_file

'output-2023-12-21_11-31-41_worker-0_0.json'

In [5]:
reader = JsonReader(os.path.join(output_path, saved_data_file))

In [6]:
features_df = pd.read_csv(gsu_features_path)
outcomes_df = pd.read_csv(gsu_final_outcomes_path)

In [7]:
num_batches = sum(1 for _ in reader.read_all_files())

n_not_found = 0
cids_from_saved_batchs = []
for _ in range(num_batches):
    batch = reader.next()
    cid = reconstitute_case_admission_id(batch['obs'][0][0], batch['obs'][0][1])
    cids_from_saved_batchs.append(cid)
    # check if cid is in features_df
    if cid not in features_df.case_admission_id.values:
        print(f'{cid} not found in feature_df')
        print(batch['obs'][0][0], batch['obs'][0][1])
        n_not_found += 1
        
print(n_not_found)

0


## Testing policy creation

In [8]:
# create treatment dataframe 
# Features data
features_df.drop(columns=['impute_missing_as'], inplace=True)

pivoted_features_df = features_df.pivot(index=['case_admission_id', 'relative_sample_date_hourly_cat'],
                                        columns='sample_label', values='value')

# get rid of multiindex
pivoted_features_df = pivoted_features_df.rename_axis(None, axis=1).reset_index()

# seperate out treatment features
treatment_df = pivoted_features_df[
    ['case_admission_id', 'relative_sample_date_hourly_cat', 'anti_hypertensive_strategy']]
treatment_df.rename(columns={'relative_sample_date_hourly_cat': 'timestep'}, inplace=True)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  treatment_df.rename(columns={'relative_sample_date_hourly_cat': 'timestep'}, inplace=True)


In [10]:
# create policy
treatment_policy = PolicyFromTable({}, {}, {}, lookup_table=treatment_df)

## test policy evaluation

In [12]:
from evaluation.off_policy_evaluation.weighted_importance_sampling import weighted_importance_sampling

results_df = weighted_importance_sampling(
    baseline_data_path=os.path.join(output_path, saved_data_file),
    target_treatment_df=treatment_df,
    verbose=True
)

Weighted Importance Sampling Estimation: 100%|██████████| 356/356 [01:12<00:00,  4.92it/s]


In [13]:
results_df.head()

Unnamed: 0,v_behavior,v_behavior_std,v_target,v_target_std,v_gain,v_delta,case_admission_id
0,0.0,0.0,0.0,0.0,0.0,0.0,10189_1690
0,0.494839,0.0,0.494839,0.0,1.0,0.0,1025279_1586
0,0.494839,0.0,0.494839,0.0,1.0,0.0,1025830_4031
0,0.0,0.0,0.0,0.0,0.0,0.0,10338_5096
0,0.494839,0.0,0.494839,0.0,1.0,0.0,1042770_4046


In [14]:
results_df.v_behavior.median(), results_df.v_target.median()

(0.49483865960020695, 0.49483865960020695)

In [15]:
results_df.v_delta.mean(), results_df.v_delta.median()

(0.0, 0.0)