In [3]:
import numpy as np
import sys
sys.path.append("..")
from reference.lib.envs.gridworld import GridworldEnv

In [4]:
env = GridworldEnv()

In [7]:
def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):
    """
    Evaluate a policy given an environment and a full description of the environment's dynamics.
    
    Args:
        policy: [S, A] shaped matrix representing the policy.
        env: OpenAI env. env.P represents the transition probabilities of the environment.
            env.P[s][a] is a list of transition tuples (prob, next_state, reward, done).
            env.nS is a number of states in the environment. 
            env.nA is a number of actions in the environment.
        theta: We stop evaluation once our value function change is less than theta for all states.
        discount_factor: Gamma discount factor.
    
    Returns:
        Vector of length env.nS representing the value function.
    """
    # Start with a random (all 0) value function
    V = np.zeros(env.nS)
    
    # loop until convergence - system is stable
    while True:
        # this is the change value
        delta_value = 0
        
        # for every state value
        for i in range(len(V)):
            current_state = 0
            
            for policy_index in range(4):
                action = policy_index
                action_prob = policy[i,policy_index]
                
                for prob, next_state, reward, done in env.P[i][action]:
                    current_state = current_state + action_prob * prob * (reward + discount_factor * V[next_state]) 
                    
            delta_value = max(delta_value,np.abs(current_state-V[i]))
            V[i] = current_state
            
        if delta_value < theta: 
            break
            
        print(V)
        input()
        
    return np.array(V)

In [8]:
random_policy = np.ones([env.nS, env.nA]) / env.nA
v = policy_eval(random_policy, env)

print(v.reshape((4,4)))


[ 0.        -1.        -1.25      -1.3125    -1.        -1.5
 -1.6875    -1.75      -1.25      -1.6875    -1.84375   -1.8984375
 -1.3125    -1.75      -1.8984375  0.       ]

[ 0.         -1.9375     -2.546875   -2.73046875 -1.9375     -2.8125
 -3.23828125 -3.40429688 -2.546875   -3.23828125 -3.56835938 -3.21777344
 -2.73046875 -3.40429688 -3.21777344  0.        ]

[ 0.         -2.82421875 -3.83496094 -4.17504883 -2.82421875 -4.03125
 -4.7097168  -4.87670898 -3.83496094 -4.7097168  -4.96374512 -4.26455688
 -4.17504883 -4.87670898 -4.26455688  0.        ]

[ 0.         -3.67260742 -5.0980835  -5.58122253 -3.67260742 -5.19116211
 -6.03242493 -6.18872833 -5.0980835  -6.03242493 -6.14849091 -5.15044403
 -5.58122253 -6.18872833 -5.15044403  0.        ]

[ 0.         -4.49046326 -6.30054855 -6.91293049 -4.49046326 -6.26144409
 -7.22480297 -7.36922646 -6.30054855 -7.22480297 -7.1876235  -5.9268235
 -6.91293049 -7.36922646 -5.9268235   0.        ]

[ 0.         -5.26311398 -7.425349   -8.15510


[  0.         -13.59042627 -19.41118296 -21.35189769 -13.59042627
 -17.49739929 -19.45429113 -19.46053098 -19.41118296 -19.45429113
 -17.53952163 -13.65620118 -21.35189769 -19.46053098 -13.65620118
   0.        ]

[  0.         -13.62475213 -19.46053098 -21.40621434 -13.62475213
 -17.53952163 -19.5000263  -19.5057432  -19.46053098 -19.5000263
 -17.57811374 -13.68501453 -21.40621434 -19.5057432  -13.68501453
   0.        ]

[  0.         -13.65620118 -19.5057432  -21.45597877 -13.65620118
 -17.57811374 -19.54192847 -19.54716624 -19.5057432  -19.54192847
 -17.6134715  -13.71141307 -21.45597877 -19.54716624 -13.71141307
   0.        ]

[  0.         -13.68501453 -19.54716624 -21.50157251 -13.68501453
 -17.6134715  -19.58031887 -19.58511767 -19.54716624 -19.58031887
 -17.64586597 -13.73559918 -21.50157251 -19.58511767 -13.73559918
   0.        ]

[  0.         -13.71141307 -19.58511767 -21.54334509 -13.71141307
 -17.64586597 -19.61549182 -19.61988844 -19.58511767 -19.61549182
 -17.6755455


[  0.         -13.98651684 -19.98061616 -21.97866449 -13.98651684
 -17.9834544  -19.98203528 -19.98224069 -19.98061616 -19.98203528
 -17.98484106 -13.98868215 -21.97866449 -19.98224069 -13.98868215
   0.        ]

[  0.         -13.98764685 -19.98224069 -21.98045259 -13.98764685
 -17.98484106 -19.98354088 -19.98372908 -19.98224069 -19.98354088
 -17.98611151 -13.98963069 -21.98045259 -19.98372908 -13.98963069
   0.        ]

[  0.         -13.98868215 -19.98372908 -21.98209083 -13.98868215
 -17.98611151 -19.9849203  -19.98509272 -19.98372908 -19.9849203
 -17.98727549 -13.99049973 -21.98209083 -19.98509272 -13.99049973
   0.        ]

[  0.         -13.98963069 -19.98509272 -21.98359178 -13.98963069
 -17.98727549 -19.98618411 -19.98634208 -19.98509272 -19.98618411
 -17.98834192 -13.99129593 -21.98359178 -19.98634208 -13.99129593
   0.        ]

[  0.         -13.99049973 -19.98634208 -21.98496693 -13.99049973
 -17.98834192 -19.987342   -19.98748674 -19.98634208 -19.987342
 -17.98931897 


[  0.         -13.99955613 -19.99936188 -21.99929764 -13.99955613
 -17.99945532 -19.9994086  -19.99941536 -19.99936188 -19.9994086
 -17.99950097 -13.99962742 -21.99929764 -19.99941536 -13.99962742
   0.        ]

[  0.         -13.99959333 -19.99941536 -21.9993565  -13.99959333
 -17.99950097 -19.99945817 -19.99946436 -19.99941536 -19.99945817
 -17.99954279 -13.99965864 -21.9993565  -19.99946436 -13.99965864
   0.        ]

[  0.         -13.99962742 -19.99946436 -21.99941043 -13.99962742
 -17.99954279 -19.99950358 -19.99950925 -19.99946436 -19.99950358
 -17.99958111 -13.99968725 -21.99941043 -19.99950925 -13.99968725
   0.        ]

[  0.         -13.99965864 -19.99950925 -21.99945984 -13.99965864
 -17.99958111 -19.99954518 -19.99955038 -19.99950925 -19.99954518
 -17.99961622 -13.99971346 -21.99945984 -19.99955038 -13.99971346
   0.        ]

[  0.         -13.99968725 -19.99955038 -21.99950511 -13.99968725
 -17.99961622 -19.9995833  -19.99958806 -19.99955038 -19.9995833
 -17.99964838

In [None]:
# Test: Make sure the evaluated policy is what we expected
expected_v = np.array([0, -14, -20, -22, -14, -18, -20, -20, -20, -20, -18, -14, -22, -20, -14, 0])
np.testing.assert_array_almost_eq










ual(v, expected_v, decimal=2)

In [36]:
! start .