In [2]:
import numpy as np

In [3]:
nS = 16
nA = 4
dx = [-1, 0, 1, 0]
dy = [0, 1, 0, -1]

def is_inside(i, j):
    return i >= 0 and i < 4 and j >= 0 and j < 4

In [4]:
def policy_eval(policy, discount_factor=1.0, theta=1e-4, num_iter=100):
    v = np.zeros(nS)
    max_delta = 1e3
    while(max_delta > theta and num_iter > 0):
        num_iter -= 1
        for s in range(1, nS - 1):
            new_val = 0
            sr, sc = s // 4, s % 4
            for k in range(4):
                nsr, nsc = sr + dx[k], sc + dy[k]
                if(not is_inside(nsr, nsc)):
                    ns = s
                else:
                    ns = nsr * 4 + nsc
                new_val += policy[s][k] * (-1 + discount_factor * v[ns])
            max_delta = max(max_delta, new_val - v[s])
            v[s] = new_val
    return v

In [5]:
policy1 = np.ones([nS, nA]) / nA
v = policy_eval(policy1)
print(v)

[  0.         -13.99765839 -19.99663362 -21.99629468 -13.99765839
 -17.99712654 -19.99688008 -19.99691576 -19.99663362 -19.99688008
 -17.99736736 -13.99803444 -21.99629468 -19.99691576 -13.99803444
   0.        ]


In [6]:
# Test: Make sure the evaluated policy is what we expected
expected_v = np.array([0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22, -20, -14, 0])
np.testing.assert_array_almost_equal(v, expected_v, decimal=2)

In [11]:
def policy_improve(discount_factor=1.0):
    # Start with random policy
    policy = np.ones([nS, nA]) / nA
    while(1):
        v = policy_eval(policy, discount_factor=discount_factor, theta=0.1)
        new_policy = policy.copy()
        for s in range(1, nS-1):
            best_action = -1
            sr, sc = s // 4, s % 4
            max_value = -1e9
            for k in range(4):
                nsr, nsc = sr + dx[k], sc + dy[k]
                if(not is_inside(nsr, nsc)):
                    ns = s
                else:
                    ns = nsr * 4 + nsc
                cur_val = (-1 + discount_factor * v[ns])
                if(cur_val > max_value):
                    best_action = k
                    max_value = cur_val
            for a in range(nA):
                if(a == best_action):
                    new_policy[s][a] = 1
                else:
                    new_policy[s][a] = 0
        if(np.equal(new_policy, policy).all()):
            return policy, v
        else:
            policy = new_policy

In [13]:
policy, v = policy_improve()
print(policy, v)

[[0.25 0.25 0.25 0.25]
 [0.   0.   0.   1.  ]
 [0.   0.   0.   1.  ]
 [0.   0.   1.   0.  ]
 [1.   0.   0.   0.  ]
 [1.   0.   0.   0.  ]
 [1.   0.   0.   0.  ]
 [0.   0.   1.   0.  ]
 [1.   0.   0.   0.  ]
 [1.   0.   0.   0.  ]
 [0.   1.   0.   0.  ]
 [0.   0.   1.   0.  ]
 [1.   0.   0.   0.  ]
 [0.   1.   0.   0.  ]
 [0.   1.   0.   0.  ]
 [0.25 0.25 0.25 0.25]] [ 0. -1. -2. -3. -1. -2. -3. -2. -2. -3. -2. -1. -3. -2. -1.  0.]


In [14]:
# Test the value function
expected_v = np.array([ 0, -1, -2, -3, -1, -2, -3, -2, -2, -3, -2, -1, -3, -2, -1,  0])
np.testing.assert_array_almost_equal(v, expected_v, decimal=2)