## CSCI E-89C Deep Reinforcement Learning, Summer 2020
### Section 6

## Off-policy Monte Carlo (MC) control for estimating $\pi_*$

We consider patients with end-stage liver disease (ESLD). We assume that patient's health condition is fully characterized by the Model for End-stage Liver Disease (MELD) score (Jae-Hyeon Ahn and John Hornberger, Involving patients in the cadaveric kidney transplant allocation process: a decision-theoretic perspective. Manage Sci. 1996;42(5):629–41).

The MELD score ranges from 6 to 40 and is derived based on the probability of survival at 3 months for patients with ESLD. Data in ESLD is usually sparse and often aggregated into Stages. We assume that there are 18 stages based on the ESLD: Stage 1, Stage2, ..., Stage 18. The time step is 1 year and the actions in Stages 1 through 18 are "wait" (denoted by 0) and "transplant" (denoted by 1). 

We assume that the Markov property holds. There are two additional states of the Markov Decision Process: "Posttransplant Life" (denoted by 19) and "Death" (which is denoted by 20 and combines so caled "Pretransplant Death" and "Posttransplant Death"). The only action availible in state "Posttransplant Life" is "wait" and "Death" is the terminal state with no actions. Assume that the length of an episode is T=50, unless it terminates earlier due to the transition to the absorbing state "Death."

We do not know the transition probabilities, but if a patient selects "wait," the possible transitions are   
1) Stage 1->Stage 1, Stage 1->Stage 2, Stage 1->Death  
2) For k in {2,3,4,...17}, Stage k->Stage (k-1), Stage k->Stage k, Stage k->Stage (k+1), Stage k->Death    
3) Stage 18->Stage 17, Stage 18->Stage 18, Stage 18->Death    

If a patient selects "transplant" at Stage k, k=1,2,...,18, the only possible transition is  
4) Stage k->"Posttransplant Life"

Finally, there are two more possible transitions"  
5) "Posttransplant Life"->"Posttransplant Life" and "Posttransplant Life"->"Death"  


The patient gets reward 1 in all states "Stage k" (k=1,2,...,18) and reward 0.2 in the "Posttransplant Life" state - assume that the patient gets these rewards on "exit" from the states, i.e. after we observe the corresponding stage. We assume the discounting parameter $\gamma=0.97$, one of the most common discounting rate used in medical decision making (Gold MR, Siegel JE, Russell LB, Weinstein MC. Cost-Effectiveness in Health and Medicine. Oxford University Press; New York: 1996).


Please consider statistics on 8,000 patients with ESLD saved in the 'ESLD_statistics.csv' file. Eeach row represents an episode (i.e. one patient) and the columns are the sequences of the patients' states and actions. This data were generated under the behavor policy:

$b(1|k)=0.02$ for $k\in\{1,2,3,4,5,6,7,8,9,10,11,12,13\}$;   
$b(1|14)=0.05$;   
$b(1|15)=0.10$;   
$b(1|16)=0.20$;   
$b(1|17)=0.40$;  
$b(1|18)=0.60$;  

which means that, for example, 5% of paients at stage 14 received a transplant.


Off-policy MC control for estimating $\pi_*$ (the weighted importance sampling case):

In [28]:
import random
from matplotlib import pyplot as plt 
import numpy as np
import pandas as pd

In [26]:
b = np.empty([20, 2])
b[:,:]=0
b[:13,] = [0.98, 0.02]
b[13,] = [0.95, 0.05]
b[14,] = [0.9, 0.1]
b[15,] = [0.8, 0.2]
b[16,] = [0.6, 0.4]
b[17,] = [0.4, 0.6]
b[18,] = [1.0, 0.0]
b[19,] = [1.0, 0.0]
print(b)

[[0.98 0.02]
 [0.98 0.02]
 [0.98 0.02]
 [0.98 0.02]
 [0.98 0.02]
 [0.98 0.02]
 [0.98 0.02]
 [0.98 0.02]
 [0.98 0.02]
 [0.98 0.02]
 [0.98 0.02]
 [0.98 0.02]
 [0.98 0.02]
 [0.95 0.05]
 [0.9  0.1 ]
 [0.8  0.2 ]
 [0.6  0.4 ]
 [0.4  0.6 ]
 [1.   0.  ]
 [1.   0.  ]]


In [32]:
df = pd.read_csv('data/ESLD_statistics.csv')
df.head(5)

Unnamed: 0,S0,A0,S1,A1,S2,A2,S3,A3,S4,A4,...,A45,S46,A46,S47,A47,S48,A48,S49,A49,S50
0,12,0,12,0,13,0,13,0,20,0,...,0,20,0,20,0,20,0,20,0,20
1,3,0,3,0,3,0,3,0,3,0,...,0,20,0,20,0,20,0,20,0,20
2,16,0,16,0,16,1,19,0,19,0,...,0,20,0,20,0,20,0,20,0,20
3,13,0,13,0,13,0,13,0,14,0,...,0,20,0,20,0,20,0,20,0,20
4,4,0,4,0,4,0,20,0,20,0,...,0,20,0,20,0,20,0,20,0,20


In [89]:
k = 2
states = df.values[k,0:-1:2]
states

array([16, 16, 16, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,
       19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,
       20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20],
      dtype=int64)

In [92]:
rewards = (states <= 18)*1 + (states == 19)*0.2
rewards

array([1. , 1. , 1. , 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2,
       0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ])

In [147]:
k = 2
actions = df.values[k,1::2]
actions

array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0], dtype=int64)

In [98]:
np.zeros(20)

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0.])

In [94]:
def get_episode(df, k):
    states = df.values[k,0:-1:2]
    actions = df.values[k,1::2]
    rewards = (states <= 18)*1 + (states == 19)*0.2
    return [states, actions, rewards]

In [176]:
gamma = 0.97
T = 50
num_episodes = df.shape[0]
Q = np.zeros((20, 2))
C = np.zeros((20, 2))
policy = np.zeros(20)

for k in range(num_episodes):
        episode = get_episode(df, k)
        G = 0
        W = 1.0
        for t in reversed(range(0,T)):
            S = episode[0][t]
            A = episode[1][t]
            R = episode[2][t]
            G = gamma*G + R
            C[S-1,A] = C[S-1,A] + W
            Q[S-1,A] = Q[S-1,A] + W/C[S-1,A]*(G - Q[S-1,A])
            policy[S-1] = np.argmax(Q[S-1,:])
            if A != policy[S-1]:
                break
            W = W/b[S-1,A]

In [177]:
policy

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
       1., 0., 0.])

In [178]:
Q

array([[14.04386711,  2.50709837],
       [12.30060841,  2.5653977 ],
       [10.11898042,  2.76031117],
       [ 8.5337964 ,  2.80124673],
       [ 7.50008863,  2.77006538],
       [ 6.44066428,  2.60816736],
       [ 5.66026235,  2.83090016],
       [ 5.34210319,  2.78590844],
       [ 4.71337471,  2.65443031],
       [ 4.08613482,  2.54168138],
       [ 3.74892683,  3.0153909 ],
       [ 3.88574364,  2.83798778],
       [ 3.24772698,  2.61107184],
       [ 3.24685664,  2.69061215],
       [ 3.1308385 ,  2.74816766],
       [ 3.20223224,  2.74999031],
       [ 2.67243911,  2.75961379],
       [ 2.42983035,  2.51717334],
       [ 1.70551512,  0.        ],
       [ 0.        ,  0.        ]])