In [1]:
%load_ext autoreload
%autoreload 2

import os, sys, warnings

import matplotlib.pyplot as plt
%matplotlib inline
plt.rc('font', size=24, family='serif')
plt.rcParams["figure.figsize"] =(15, 12)
plt.style.use('tableau-colorblind10')


sys.path.append('../')

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    
from healthy_gym.environments.adcb import *

## Initializing the environment. 

policy – Behavior policy ∈ {DX_Based, 
                                    Santiago_Based}
                                    
𝜖 – Overlap parameter ∈ [ 0, 1] 

𝛾 – Treatment Effect Heterogeneity ≥ 0

n_buffer – Number of Samples ≥ 0

horizon – Sample trajectory length (history length) ∈ {0, 1, …, 11}$

regenerate=True - Fit data again; a bit slow for high n_buffer

In [26]:
N = 1000
e = ADCBEnvironment(
                    gamma=2, 
                    epsilon=0.1, 
                    policy='DX_Based',
                    regenerate=False,
                    horizon=6,
                    n_buffer=N,
                    sequential=True,
                    z_dim=6)

cs = e.reset()
print('Time spent fitting env: %.2fs' %(e.fit_time))

Time spent fitting env: 0.03s


### Whole set of Generated Data

In [27]:
gen_data = e.model_
gen_data.describe()

Unnamed: 0,RID,AGE,PTETHCAT,PTRACCAT,PTGENDER,PTEDUCAT,PTMARRY,TAU,PTAU,FDG,...,Delta,Y_hat,Y_0,Y_1,Y_2,Y_3,Y_4,Y_5,Y_6,Y_7
count,6000.0,6000.0,6000.0,6000.0,6000.0,6000.0,6000.0,6000.0,6000.0,6000.0,...,6000.0,6000.0,6000.0,6000.0,6000.0,6000.0,6000.0,6000.0,6000.0,6000.0
mean,499.5,74.9974,0.966,0.158,0.511,5.0425,0.382333,332.245344,33.034588,1.175871,...,-1.672128,21.70245,23.374578,22.114878,18.771698,20.586978,20.104178,23.141418,20.650098,19.995938
std,288.69905,7.27608,0.207005,0.707898,0.499921,3.013338,0.859425,180.054492,19.451746,0.203103,...,3.202457,14.248495,14.468602,14.632707,15.203124,15.115042,15.392536,14.844198,14.562601,14.520686
min,0.0,48.6,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.401119,...,-9.6,-9.6,0.0,-5.85,-7.44,-9.09,-9.6,-6.03,-3.87,-8.07
25%,249.75,70.3,1.0,0.0,0.0,2.0,0.0,201.241349,18.627784,1.037728,...,-3.03,10.514733,11.740336,10.754987,6.693737,8.54389,8.259096,11.741699,8.946855,8.638976
50%,499.5,75.0,1.0,0.0,1.0,5.0,0.0,317.313721,31.454407,1.17922,...,0.0,20.960803,22.863668,21.736676,18.163191,20.164295,19.73495,22.664777,20.149922,19.291513
75%,749.25,79.9,1.0,0.0,1.0,7.0,0.0,451.644618,45.845225,1.31761,...,0.0,31.858211,34.135837,32.856076,29.880286,31.671842,31.065759,33.913513,31.54224,30.616945
max,999.0,99.3,2.0,6.0,1.0,12.0,4.0,994.94831,105.43998,1.92115,...,3.2,65.885532,66.387704,66.606354,65.335532,65.885532,69.587704,68.397704,63.043287,64.443287


### Buffer

Buffer contains one patient's data

In [28]:
e.buffer_.columns

Index(['RID', 'AGE', 'PTETHCAT', 'PTRACCAT', 'PTGENDER', 'PTEDUCAT', 'PTMARRY',
       'TAU', 'PTAU', 'FDG', 'AV45', 'Z', 'VISCODE', 'ADAS13', 'DX', 'A',
       'Y_hat', 'Y_0', 'Y_1', 'Y_2', 'Y_3', 'Y_4', 'Y_5', 'Y_6', 'Y_7',
       'prev_ADAS13', 'prev_DX', 'prev_AV45', 'prev_FDG', 'prev_TAU',
       'prev_PTAU', 'prev_Y_hat'],
      dtype='object')

In [29]:
cs = e.reset()
e.buffer_[['Y_0', 'Y_1', 'Y_2', 'Y_3', 'Y_4', 'Y_5', 'Y_6', 'Y_7']].head(10)

Unnamed: 0,Y_0,Y_1,Y_2,Y_3,Y_4,Y_5,Y_6,Y_7
559,12.699636,14.649636,10.219636,3.609636,9.499636,14.709636,11.409636,10.009636
1559,0.0,1.95,-2.48,-9.09,-3.2,2.01,-1.29,-2.69
2559,12.843565,14.793565,10.363565,3.753565,9.643565,14.853565,11.553565,10.153565
3559,21.770748,23.720748,19.290748,12.680748,18.570748,23.780748,20.480748,19.080748
4559,29.689607,31.639607,27.209607,20.599607,26.489607,31.699607,28.399607,26.999607
5559,48.098628,50.048628,45.618628,39.008628,44.898628,50.108628,46.808628,45.408628


### Stepping

### $a \in \{0, 1, ..., 7\}$
### $r = - (Y_a - Y_0) + N(\mu, \sigma)$

For non-sequential stepping, we can step multiple times and each time get a random reward ($r = -(Y_a - Y_0) + N(\mu=0, \sigma=1)$) for one patient :

In [30]:
cs = e.reset()

In [31]:
_, r, _, info  = e.step(3)
print(r)
info['context']

[3.04184965]


Unnamed: 0,RID,AGE,PTETHCAT,PTRACCAT,PTGENDER,PTEDUCAT,PTMARRY,TAU,PTAU,FDG,...,DX,A,Y_hat,prev_ADAS13,prev_DX,prev_AV45,prev_FDG,prev_TAU,prev_PTAU,prev_Y_hat
629,629,65.4,1,2,0,10,0,179.057767,10.003374,1.442447,...,0,0,6.800581,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0


### Reset

Resetting yields data from a new patient

In [32]:
cs = e.reset()

In [33]:
_, r, _, info = e.step(3)
info['outcomes']

array([[ 0.01139678],
       [ 5.86139678],
       [ 2.49139678],
       [ 9.10139678],
       [ 9.61139678],
       [ 6.04139678],
       [-1.27860322],
       [-2.67860322]])