In [1]:
import numpy as np
import sys
if "../" not in sys.path:
  sys.path.append("../") 
from lib.envs.gridworld import GridworldEnv
import matplotlib.pylab as plt

In [2]:
env = GridworldEnv()

In [36]:
def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):
    """
    Evaluate a policy given an environment and a full description of the environment's dynamics.
    
    Args:
        policy: [S, A] shaped matrix representing the policy.
        env: OpenAI env. env.P represents the transition probabilities of the environment.
            env.P[s][a] is a (prob, next_state, reward, done) tuple.
        theta: We stop evaluation once our value function change is less than theta for all states.
        discount_factor: lambda discount factor.
    
    Returns:
        Vector of length env.nS representing the value function.
    """
    # Start with a random (all 0) value function
    V = np.zeros(env.nS)
    while True:
        delta = 0
        # For each state, perform a "full backup"
        for s in range(env.nS):
            v = 0
            # Look at the possible next actions
            for a, action_prob in enumerate(policy[s]):
                # For each action, look at the possible next states...
                for  prob, next_state, reward, done in env.P[s][a]:
                    # Calculate the expected value
                    v += action_prob *(reward + prob*discount_factor * V[next_state])
            # How much our value function changed (across any states)
            delta = max(delta, np.abs(v - V[s]))
            V[s] = v
        # Stop evaluating once our value function change is below a threshold
        if delta < theta:
            break
    return np.array(V)

In [39]:
env.P

{0: {0: [(1.0, 0, 0.0, True)],
  1: [(1.0, 0, 0.0, True)],
  2: [(1.0, 0, 0.0, True)],
  3: [(1.0, 0, 0.0, True)]},
 1: {0: [(1.0, 1, -1.0, False)],
  1: [(1.0, 2, -1.0, False)],
  2: [(1.0, 5, -1.0, False)],
  3: [(1.0, 0, -1.0, True)]},
 2: {0: [(1.0, 2, -1.0, False)],
  1: [(1.0, 3, -1.0, False)],
  2: [(1.0, 6, -1.0, False)],
  3: [(1.0, 1, -1.0, False)]},
 3: {0: [(1.0, 3, -1.0, False)],
  1: [(1.0, 3, -1.0, False)],
  2: [(1.0, 7, -1.0, False)],
  3: [(1.0, 2, -1.0, False)]},
 4: {0: [(1.0, 0, -1.0, True)],
  1: [(1.0, 5, -1.0, False)],
  2: [(1.0, 8, -1.0, False)],
  3: [(1.0, 4, -1.0, False)]},
 5: {0: [(1.0, 1, -1.0, False)],
  1: [(1.0, 6, -1.0, False)],
  2: [(1.0, 9, -1.0, False)],
  3: [(1.0, 4, -1.0, False)]},
 6: {0: [(1.0, 2, -1.0, False)],
  1: [(1.0, 7, -1.0, False)],
  2: [(1.0, 10, -1.0, False)],
  3: [(1.0, 5, -1.0, False)]},
 7: {0: [(1.0, 3, -1.0, False)],
  1: [(1.0, 7, -1.0, False)],
  2: [(1.0, 11, -1.0, False)],
  3: [(1.0, 6, -1.0, False)]},
 8: {0: [(1.0, 4

In [37]:
random_policy = np.ones([env.nS, env.nA]) / env.nA
v = policy_eval(random_policy, env)

0 0 0.0
0 1 0.0
0 2 0.0
0 3 0.0
1 0 -0.25
1 1 -0.5
1 2 -0.75
1 3 -1.0
2 0 -0.25
2 1 -0.75
2 2 -1.0
2 3 -1.25
3 0 -0.25
3 1 -0.75
3 2 -1.3125
3 3 -1.5625
4 0 -0.25
4 1 -0.75
4 2 -1.3125
4 3 -1.953125
5 0 -0.25
5 1 -0.75
5 2 -1.3125
5 3 -1.953125
6 0 -0.25
6 1 -0.75
6 2 -1.3125
6 3 -1.953125
7 0 -0.25
7 1 -0.75
7 2 -1.3125
7 3 -1.953125
8 0 -0.25
8 1 -0.75
8 2 -1.3125
8 3 -1.953125
9 0 -0.25
9 1 -0.75
9 2 -1.3125
9 3 -1.953125
10 0 -0.25
10 1 -0.75
10 2 -1.3125
10 3 -1.953125
11 0 -0.25
11 1 -0.75
11 2 -1.3125
11 3 -1.953125
12 0 -0.25
12 1 -0.75
12 2 -1.3125
12 3 -1.953125
13 0 -0.25
13 1 -0.75
13 2 -1.3125
13 3 -1.953125
14 0 -0.25
14 1 -0.75
14 2 -1.3125
14 3 -1.953125
15 0 0.0
15 1 -0.25
15 2 -0.5625
15 3 -0.953125
0 0 0.0
0 1 -0.25
0 2 -0.5625
0 3 -0.953125
1 0 -0.48828125
1 1 -0.98828125
1 2 -1.55078125
1 3 -2.19140625
2 0 -0.48828125
2 1 -1.2861328125
2 2 -1.8486328125
2 3 -2.4892578125
3 0 -0.48828125
3 1 -1.2861328125
3 2 -2.15844726562
3 3 -2.79907226562
4 0 -0.48828125
4 1 -1.

In [35]:
# Test: Make sure the evaluated policy is what we expected
expected_v = np.array([0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22, -20, -14, 0])
np.testing.assert_array_almost_equal(v, expected_v, decimal=2)

AssertionError: 
Arrays are not almost equal to 2 decimals

(mismatch 100.0%)
 x: array([-4.5, -5.8, -6.1, -6.4, -6.7, -6.7, -6.7, -6.7, -6.7, -6.7, -6.7,
       -6.7, -6.7, -6.7, -6.7, -5.7])
 y: array([  0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22,
       -20, -14,   0])