In [1]:
# Use value iteration to find the optimal policy according to the MDP transition distribution

In [2]:
import numpy as np
import pickle
from collections import Counter
import itertools
from tqdm import tqdm

In [3]:
# transition distribution is a dictionary of {(s, a) : [(prob, s', r, done)]}
P = pickle.load(open( "MDP_P.p", "rb" ))

In [4]:
P[1]

defaultdict(list,
            {5: [(0.367816091954023, 1, 0, False),
              (0.005747126436781609, 572, 0, False),
              (0.06896551724137931, 751, 100, True),
              (0.005747126436781609, 692, 0, False),
              (0.005747126436781609, 403, 0, False),
              (0.011494252873563218, 137, 0, False),
              (0.028735632183908046, 16, 0, False),
              (0.017241379310344827, 165, 0, False),
              (0.005747126436781609, 551, 0, False),
              (0.005747126436781609, 747, 0, False),
              (0.011494252873563218, 697, 0, False),
              (0.011494252873563218, 548, 0, False),
              (0.005747126436781609, 50, 0, False),
              (0.011494252873563218, 362, 0, False),
              (0.011494252873563218, 256, 0, False),
              (0.017241379310344827, 169, 0, False),
              (0.017241379310344827, 706, 0, False),
              (0.028735632183908046, 396, 0, False),
              (0.011494252873563

In [5]:
Q_mask = np.load('action_mask.npy') # mask out actions that clinicians never taken

In [6]:
nS, nA = 750, 25
gamma = 0.99
theta = 1e-10

In [7]:
# Value iteration
V = np.zeros(nS)
for i in tqdm(itertools.count()):
    delta = 0.0
    for s in range(nS):
        old_v = V[s]

        ## V[s] = max {a} sum {s', r} P[s', r | s, a] * (r + gamma * V[s'])
        Q_s = np.zeros(nA)
        for a in P[s]:
            Q_s[a] = sum(p * (r + 
                            (0 if done else gamma * V[s_])
                           ) for p, s_, r, done in P[s][a])
        Q_s[~Q_mask[s]] = np.nan
        new_v = np.nanmax(Q_s)
        V[s] = new_v
        delta = max(delta, np.abs(new_v - old_v))
    if delta < theta:
        break

240it [00:31,  8.12it/s]

In [8]:
V.shape

(750,)

In [9]:
V

array([ 93.47105827,  93.9628946 ,  93.60121677,  91.90413337,
        92.15992317,  92.8442038 ,  93.29051883,  92.54315378,
        93.20723155,  92.27882419,  93.17597533,  90.5357923 ,
        -3.32246808,  93.98107632,  92.96016289,  94.50009817,
        92.37299511,  94.77152656,  93.01826251,  94.39481475,
        90.87369253,  96.98410639,  90.10882742,  93.348013  ,
        93.57197821,  93.39805803,  92.89403367,  93.32187196,
        92.76844392,  93.71896026,  89.39530143,  94.41619592,
        94.69452911,  93.13375464,  94.63837649,  93.88486307,
        90.87613928,  92.57969024,  94.41559014,  91.84033959,
        93.23972941,  95.53765966,  91.08537046,  93.93862221,
        96.72197585,  91.2671434 ,  92.63383689,  88.09563982,
        93.69096136,  92.87653944,  93.79755878,  93.7213981 ,
        92.48713405,  92.43616901,  92.72103122,  91.84262324,
        92.74289282,  93.65983191,  92.6805886 ,  93.91339095,
        92.09154498,  92.61511826,  96.85417957,  93.09

In [10]:
np.save('value_iter_V.npy', V)

In [11]:
policy = np.zeros(nS, dtype=np.int)
Q = np.zeros((nS, nA))
for s in range(nS):
    for a in P[s]:
        Q[s,a] = sum(p * (r + 
                        (0 if done else gamma * V[s_])
                       ) for p, s_, r, done in P[s][a])
    Q[s,~Q_mask[s]] = np.nan
    best_action = np.nanargmax(Q[s])
    if best_action is None:
        policy[s] = 0
    else:
        policy[s] = best_action

In [12]:
Q

array([[93.19875872,         nan,         nan, ...,         nan,
                nan,         nan],
       [93.29808224,         nan,         nan, ...,         nan,
                nan,         nan],
       [93.60121677,         nan,         nan, ...,         nan,
                nan,         nan],
       ...,
       [92.23102785,         nan,         nan, ...,         nan,
                nan,         nan],
       [80.9842792 ,         nan,         nan, ...,         nan,
                nan,         nan],
       [93.41996942,         nan,         nan, ...,         nan,
                nan,         nan]])

In [13]:
np.save('value_iter_Q.npy', Q)

In [14]:
Counter(policy)

Counter({9: 22,
         13: 9,
         0: 279,
         5: 83,
         20: 12,
         10: 98,
         15: 71,
         23: 2,
         19: 8,
         7: 34,
         18: 8,
         8: 50,
         6: 40,
         14: 8,
         17: 4,
         1: 3,
         3: 3,
         24: 2,
         11: 4,
         12: 5,
         16: 4,
         4: 1})

In [15]:
policy

array([ 9, 13,  0,  5, 20, 13,  0, 10, 10,  0,  0, 15, 23,  0, 19,  7,  7,
        0,  0, 18,  5,  0, 10, 10,  8,  8,  0,  0, 15,  9,  8,  6,  6, 14,
        0,  0,  7, 10,  7, 15,  9,  0, 17,  0,  0,  0,  0,  0, 18, 19,  8,
        0,  6,  0,  8,  0, 10,  6,  0, 10, 15,  7,  0,  0, 10, 15,  0,  0,
        0, 23, 10,  0,  5, 15,  0, 18,  0,  0,  5,  0,  5, 10, 15,  0,  0,
        0,  0, 15,  6,  5,  0,  8,  0,  6, 10,  0,  5,  0,  0,  0,  0,  0,
        0,  6,  8,  8, 10,  6,  0,  0, 10, 10,  7,  5, 10,  9,  0, 10,  7,
       10, 20, 15,  0,  0,  6,  5,  8, 15,  0,  5,  0,  0, 10, 10, 15,  0,
        0, 15,  0,  0,  8,  0,  1,  6,  8, 10,  0,  5,  6,  3, 24, 15,  0,
        0,  0,  0,  0,  8, 10, 10,  5,  0,  5,  0,  0, 10, 15,  7,  0,  9,
       15, 10,  0,  9,  0,  0,  0,  5,  0,  0,  9, 15,  0, 10,  6,  5, 10,
        8,  5, 11, 10,  0,  5,  0, 10, 10, 19,  0,  7, 15,  8,  0, 20, 10,
       15,  0,  0,  0,  0, 10,  5, 15,  0, 15, 10,  0,  0, 15, 13,  0,  0,
        5, 10,  6,  8,  6

In [16]:
with open('value_iter_policy_gamma={}.p'.format(gamma), 'wb') as f:
    pickle.dump(policy, f)