In [1]:
import numpy as np
import sys
if "../" not in sys.path:
  sys.path.append("../") 
from lib.envs.gridworld import GridworldEnv

In [2]:
env = GridworldEnv()
"""
env: OpenAI env. env.P represents the transition probabilities of the environment.
            env.P[s][a] is a list of transition tuples (prob, next_state, reward, done).
            env.nS is a number of states in the environment. 
            env.nA is a number of actions in the environment.
"""

In [3]:
def policy_evaluation(policy, env, discount_rate=1.0, theta=0.00001):
    """
    policy: [S, A] shaped matrix representing the policy.
    theta: stop evaluation if changes to the values of states lesser than theta
    """
    
    V = np.zeros(env.nS)
    while True:
        delta = 0
        for s in range(env.nS):
            v = V[s]
            new_val = 0
            for a, a_prob in enumerate(policy[s]):
                for prob, next_state, reward, done in env.P[s][a]:
                    new_val += a_prob * prob * (reward + discount_rate * V[next_state])
            V[s] = new_val
            delta = max(delta, np.abs(v - V[s]))
        if delta < theta:
            break
    return V
        

In [27]:
def policy_iteration(env, policy_eval_fn=policy_evaluation, discount_rate=1.0):
    policy = np.ones([env.nS, env.nA]) / env.nA
    while True:        
        #policy evaluation
        V = policy_eval_fn(policy, env, discount_rate)
        
        #policy improvement
        policy_stable = True
        for s in range(env.nS):
            policy_a = np.argmax(policy[s])
            A = np.zeros(env.nA)
            for a, _ in enumerate(policy[s]):
                for prob, next_state, reward, done in env.P[s][a]:
                    A[a] += prob * (reward + discount_rate * V[next_state])
            
            max_a = np.argmax(A)
            if policy_a != max_a:
                policy_stable = False
            policy[s] = np.eye(env.nA)[max_a]
        if policy_stable:
            return policy, V

In [28]:
policy, v = policy_iteration(env)


In [37]:
print("Optimal Value Function:")
print("")
print(v.reshape(env.shape))

Optimal Value Function:

[[ 0. -1. -2. -3.]
 [-1. -2. -3. -2.]
 [-2. -3. -2. -1.]
 [-3. -2. -1.  0.]]


In [36]:
print("Optimal Policy:")
print("")
print("0=up, 1=right, 2=down, 3=left")
print("")
print(np.reshape(np.argmax(policy, axis=1), env.shape))

Optimal Policy:

0=up, 1=right, 2=down, 3=left

[[0 3 3 2]
 [0 0 0 2]
 [0 0 1 2]
 [0 1 1 0]]
