In [1]:
import gym
import numpy

In [2]:
# details plz read demo_policyEvaluation.ipynb
def policy_eval(policy, env, discount = 1.0, tolerance = 0.00001):
    value = np.zeros(env.nS)
    
    while True:
        delta = 0
        for state in range(env.nS):
            v = 0
            for action, prob_action in enumerate(policy[state]):
                for prob_state, next_state, reward, done in env.P[state][action]:
                    v += prob_action * prob_state * (reward + discount * value[next_state])
            delta = max(delta, np.abs(v - value[state]))
            value[state] = v
        
        if delta < tolerance:
            break
    return value
        

In [3]:
def policy_improvement(env, value_function, state, discount):
    '''
    calculate the optimal(greedy) action according to current value function
    '''
    actions = np.zeros(env.nA)
    # According to the equation from intro. to RL
    for action in range(env.nA):
        for prob_state, next_state, reward, done in env.P[state][action]:
            actions[action] += prob_state * (reward + discount * value_function[next_state])
    return np.argmax(actions)

In [4]:
def policy_iteration(env, discount = 1):
    '''
    Do policy evaluation, then choose greedy action accordingly
    '''
    policy = np.random.rand(env.nS, env.nA)
    # policy = np.zeros((env.nS, env.nA))
    while True:
        
        # evaluate the current policy
        value_function = policy_eval(policy, env)
        stable = True
        
        for state in range(env.nS):
            curr_action = np.argmax(policy[state])
            optimal_action = policy_improvement(env, value_function, state, discount)

            if optimal_action != curr_action:
                stable = False
            policy[state] = np.eye(env.nA)[optimal_action]
        
        if stable:
            return policy, value_function
            

In [5]:
env = gym.make('FrozenLake-v1')
env.reset()
policy, value = policy_iteration(env)

  # This is added back by InteractiveShellApp.init_path()
  if sys.path[0] == '':


In [6]:
print("The policy distribution: ")
print(policy)
print("The action in each state: ")
print(np.reshape(np.argmax(policy, axis=1), (4,4)))
env.render()
env.close()

The policy distribution: 
[[1. 0. 0. 0.]
 [0. 0. 0. 1.]
 [0. 0. 0. 1.]
 [0. 0. 0. 1.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 0. 1.]
 [0. 1. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [1. 0. 0. 0.]
 [0. 0. 1. 0.]
 [0. 1. 0. 0.]
 [1. 0. 0. 0.]]
The action in each state: 
[[0 3 3 3]
 [0 0 0 0]
 [3 1 0 0]
 [0 2 1 0]]

[41mS[0mFFF
FHFH
FFFH
HFFG


In [7]:
print(np.reshape(value, (4,4)))

[[0.8233628  0.82330813 0.82327014 0.82325081]
 [0.82337956 0.         0.52929815 0.        ]
 [0.8234058  0.82343946 0.76462706 0.        ]
 [0.         0.88229042 0.94114466 0.        ]]
