In [1]:
import numpy as np
import pickle
from tqdm import tqdm
from sklearn import utils
from joblib import Parallel, delayed

In [2]:
nS, nA = 750, 25
gamma = 0.99

In [3]:
# mask out actions that clinicians never taken
Q_mask = np.load('action_mask.npy')

In [4]:
clinician_policy = pickle.load(open('clinician_policy.p', 'rb'))
pi_0 = np.zeros((nS, nA))
for s, probs in clinician_policy.items():
    for a, p in probs.items():
        pi_0[s,a] = p

In [5]:
pi_0[0]

array([0.3056872 , 0.        , 0.        , 0.        , 0.        ,
       0.50236967, 0.02369668, 0.        , 0.03791469, 0.02132701,
       0.09241706, 0.        , 0.        , 0.        , 0.        ,
       0.01658768, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ])

In [6]:
tol = 1e-10
zeta = 0.05
Q_pi = np.load('output_ql/svp_near-greedy_Q_gamma=0.99_zeta={}.npy'.format(zeta))
Q_star = np.load('qlearn_Q.npy')
V_star = np.nanmax(Q_star, axis=1)

In [7]:
def construct_SVP(Q_pi, Q_star):
    # Construct the SVP
    # Dictionary of {s: [a1, a2, ...]}
    pi_svp = {}
    for s in range(nS):
        Q_cutoff = min(V_star[s], (1-zeta-tol) * V_star[s]) # lower bound for future return
        Pi_s = np.argwhere(
            np.where(Q_mask[s], Q_pi[s], -np.inf) > Q_cutoff
        )
        if len(Pi_s) > 0:
            assert not np.isnan(Q_pi[s][Pi_s]).all()
            pi_svp[s] = list(Pi_s.flatten())
        else:
            pi_svp[s] = [np.nanargmax(Q_star[s])] # fall back to the greedy action

    # Tabular form, SxA, (s,a)=1 if a is included in π(s)
    PI_svp = np.zeros((nS, nA), dtype=int)
    for s, pi_s in pi_svp.items():
        for a in pi_s:
            PI_svp[s,a] = 1
    
    return pi_svp, PI_svp

In [8]:
def soften_policy(svp):
    pi_e = np.zeros((nS, nA))
    for s, probs in clinician_policy.items():
        A_s = list(probs.keys())
        a_star = list(np.argwhere(svp[s] == 1).flatten())
        assert all(a_ in A_s for a_ in a_star)
        if len(A_s) == len(a_star):
            for a in A_s:
                pi_e[s,a] = 1.0 / len(a_star)
        else:
            for a in A_s:
                if a in a_star:
                    pi_e[s,a] = 0.99 / len(a_star)
                else:
                    pi_e[s,a] = 0.01 / (len(A_s)-len(a_star))
    return pi_e

In [9]:
pi_svp, PI_svp = construct_SVP(Q_star, Q_star)
pi_e = soften_policy(PI_svp)

In [10]:
pi_e[0]

array([0.14285714, 0.        , 0.        , 0.        , 0.        ,
       0.14285714, 0.14285714, 0.        , 0.14285714, 0.14285714,
       0.14285714, 0.        , 0.        , 0.        , 0.        ,
       0.14285714, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ])

In [11]:
np.isclose(pi_e.sum(axis=1), 1.0).all()

True

In [12]:
traj_te = pickle.load(open('trajDr_te.pkl', 'rb'))

In [13]:
# Filter out unusable trajectories
# must not contain (s,a) pairs not observed in the training set
trajectories = []
for traj in traj_te:
    usable = True
    for transition in traj:
        s = transition['s']
        a = transition['a']
        if np.isclose(pi_0[s,a], 0.0):
            usable = False
            break
    if usable:
        trajectories.append(traj)

In [14]:
N = len(trajectories)
print('Effective sample size of test set', N)

Effective sample size of test set 2801


## DR, WDR

In [15]:
# Calculate all per-step importance sampling ratio
rho_all = []
for traj in trajectories:
    rho = []
    for transition in traj:
        s = transition['s']
        a = transition['a']
        rho_t = pi_e[s,a] / pi_0[s, a]
        rho.append(rho_t)
    rho_all.append(np.array(rho))

# Find out the maximum trajectory length
max_H = max(len(traj) for traj in trajectories)

# Calculate cumulative importance ratio, rho_{1:t} for each trajectory at each timestep
rho_cum = np.zeros((N, max_H))
for i, rho in enumerate(rho_all):
    rho_tmp = np.ones(max_H)
    rho_tmp[:len(rho)] = rho
    rho_cum[i] = np.cumprod(rho_tmp)

# Calculate the average cumulative importance ratio at every horizon t
weights = rho_cum.mean(axis=0)

In [16]:
def doubly_robust_estimator(trajectory, Q, pi_0, pi_e, rho_cumulative, gamma):
    V_DR = 0
    T = len(trajectory)
    for t in range(T):
        transition = trajectory[t]
        s = transition['s']
        a = transition['a']
        r = transition['r']
        
        Q_hat = Q[s,a]
        V_hat = np.nansum(Q[s] * pi_e[s])
        assert not np.isclose(pi_0[s,a], 0.0)
        rho_1t = rho_cumulative[t]
        if t == 0:
            rho_1t_1 = 1.0
        else:
            rho_1t_1 = rho_cumulative[t-1]
        
        V_DR = V_DR + np.power(gamma, t) * (rho_1t * r - (rho_1t * Q_hat - rho_1t_1 * V_hat))
    
    return V_DR

In [17]:
def weighted_doubly_robust_estimator(trajectory, Q, pi_0, pi_e, rho_cumulative, weight_t, gamma):
    V_WDR = 0
    T = len(trajectory)
    for t in range(T):
        transition = trajectory[t]
        s = transition['s']
        a = transition['a']
        r = transition['r']
        
        Q_hat = Q[s,a]
        V_hat = np.nansum(Q[s] * pi_e[s])
        assert not np.isclose(pi_0[s,a], 0.0)
        rho_1t = rho_cumulative[t] / weight_t[t]
        if t == 0:
            rho_1t_1 = 1.0
        else:
            rho_1t_1 = rho_cumulative[t-1] / weight_t[t-1]
        
        V_WDR = V_WDR + np.power(gamma, t) * (rho_1t * r - (rho_1t * Q_hat - rho_1t_1 * V_hat))
    
    return V_WDR

## DR

In [18]:
V_DR = [
    doubly_robust_estimator(traj, Q_star, pi_0, pi_e, rho_cumulative, gamma) 
    for traj, rho_cumulative in zip(trajectories, rho_cum)
]

In [19]:
np.mean(np.clip(V_DR, -100, 100))

84.32283157244169

In [20]:
V_DR_b = []
for i in tqdm(range(1000)):
    V_DR_boot = utils.resample(V_DR, replace=True, random_state=i)
    V_DR_b.append(np.mean(np.clip(V_DR_boot, -100, 100)))

100%|██████████| 1000/1000 [00:01<00:00, 713.98it/s]


In [21]:
np.mean(V_DR_b), np.std(V_DR_b)

(84.32747698783487, 0.6343827995823593)

## WDR

In [22]:
V_WDR = [
    weighted_doubly_robust_estimator(traj, Q_star, pi_0, pi_e, rho_cumulative, weights, gamma) 
    for traj, rho_cumulative in zip(trajectories, rho_cum)
]

In [23]:
np.mean(np.clip(V_WDR, -100, 100))

89.7260389134681

In [24]:
V_WDR_b = []
for i in tqdm(range(1000)):
    V_WDR_boot = utils.resample(V_WDR, replace=True, random_state=i)
    V_WDR_b.append(np.mean(np.clip(V_WDR_boot, -100, 100)))

100%|██████████| 1000/1000 [00:01<00:00, 790.44it/s]


In [25]:
np.mean(V_WDR_b), np.std(V_WDR_b)

(89.73435102771438, 0.3184995008518237)