In [33]:
import gym
import numpy as np
import random 
import time
from IPython import display

In [34]:
# env = gym.make('FrozenLake-v0')
env = gym.make('FrozenLake8x8-v0')

In [35]:
# number of state
ns = env.observation_space.n
# number of action
na = env.action_space.n
print(f'number of state: {ns}\nnumber of action: {na}')

number of state: 64
number of action: 4


In [36]:
# evalua the policy
def policy_evaluation(policy, v_values, gamma):
    flag = False # not yet converge
    steps = 0
    while(1):
        steps += 1
        pre_v_values = np.copy(v_values)
        # compute the value for state
        for state in range(ns):
            # select action given by policy
            action = policy[state]
            q_value = 0
            # compute q_value for the action
            # Loop through each possible outcome
            for prob, next_state, reward, done in env.P[state][action]:
                q_value += prob*(reward + gamma*pre_v_values[next_state])
            v_values[state] = q_value
        #check convergence
        flag = np.all(np.isclose(v_values, pre_v_values))
        if flag:
            print("steps: ", steps)
            break
    return v_values

In [37]:
# improve policy
def policy_improvement(policy, v_values, gamma=0.9):
    prove_policy = policy.copy()
    # update each state
    for state in range(ns):
        q_values = []
        # compute q_value for each action
        for action in range(na):
            q_value = 0
            # each possible outcome
            for prob, next_state, reward, done in env.P[state][action]:
                q_value += prob*(reward + gamma * v_values[next_state])
            q_values.append(q_value)
        #select the best action
        best_action = np.argmax(q_values)
        prove_policy[state] =best_action
    return prove_policy  

In [38]:
def policy_iteration(gamma = 0.9):
    # initially policy by random
    policy = [random.randint(0, 3) for _ in range(ns)]
    # initial v_value for all state: v_value = 0
    v_values = np.zeros(ns)
    flag = False # not yet converge
    step = 0
    while(1):
        step += 1
        pre_policy = policy.copy()
        v_values = policy_evaluation(pre_policy, v_values, gamma)
        policy = policy_improvement(pre_policy, v_values, gamma)
        flag = np.all(np.isclose(policy, pre_policy))
        if flag: # if converge
            print(f'convergence after: {step} steptimes')
            break
    return policy

In [39]:
gamma = 0.9
policy = policy_iteration(gamma)
print(policy)

steps:  2
steps:  32
steps:  90
steps:  69
steps:  43
steps:  35
steps:  54
steps:  60
steps:  50
convergence after: 9 steptimes
[3, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 2, 2, 2, 1, 3, 3, 0, 0, 2, 3, 2, 1, 3, 3, 3, 1, 0, 0, 2, 1, 3, 3, 0, 0, 2, 1, 3, 2, 0, 0, 0, 1, 3, 0, 0, 2, 0, 0, 1, 0, 0, 0, 0, 2, 0, 1, 0, 0, 1, 1, 1, 0]


In [40]:
def play(policy):
    state = env.reset()
    total_reward = 0
    done = False
    steps = 0
    time.sleep(1)
    display.clear_output(wait=True)
    while not done:
        action = policy[state]
        next_state, reward, done, info = env.step(action)
        total_reward += reward
        steps += 1
        print(f'Step {steps}')
        env.render()
        time.sleep(0.2)
        if not done:
           display.clear_output(wait=True)
        state = next_state

    return total_reward

In [44]:
wi  = play(policy)
print(wi)

Step 77
  (Right)
SFFFFFFF
FFFFFFFF
FFFHFFFF
FFFFFHFF
FFFHFFFF
FHHFFFHF
FHFFHFHF
FFFHFFF[41mG[0m
1.0
