# Policy Iteration (using iterative policy evaluation) for estimating $\pi\approx\pi_*$

In [1]:
import gymnasium as gym
import numpy as np
import time

In [2]:
def policy_evaluation(env, policy, gamma=0.99, theta=1e-6):
    
    n_states = env.observation_space.n
    V = np.zeros(n_states)

    while True:
        delta = 0
        for state in range(n_states):
            v_old = V[state]
            v_new = 0
            action = policy[state]
            for prob, new_state, reward, done in env.P[state][action]:
                v_new += prob * (reward + gamma * V[new_state])
            V[state] = v_new
            delta = max(delta, abs(v_old - v_new))

        if delta < theta:
            break
    
    return V

In [3]:
def policy_improvement(env, policy, V, gamma=0.99):
    
    n_states = env.observation_space.n
    n_actions = env.action_space.n
    stable = True

    for state in range(n_states):
        old_action = policy[state]
        action_values = np.zeros(n_actions)
        for action in range(n_actions):
            for prob, new_state, reward, done in env.P[state][action]:
                action_values[action] += prob * (reward + gamma * V[new_state])
        new_action = np.argmax(action_values)
        if old_action != new_action:
            stable = False
        policy[state] = new_action
    
    return policy, stable

In [4]:
def policy_iteration(env, gamma=0.99, theta=1e-3):
    
    n_states = env.observation_space.n
    n_actions = env.action_space.n
    V = np.zeros(n_states)
    policy = np.random.randint(0, n_actions, size=n_states)

    while True:
        V = policy_evaluation(env, policy, gamma, theta)
        policy, stable = policy_improvement(env, policy, V, gamma)
        if stable:
            break
    
    return policy, V

In [5]:
# Training the Policy
env = gym.make("Taxi-v3")
policy, V = policy_iteration(env.unwrapped)

In [None]:
# Save Policy and Value Function
np.savez("results/Taxi.npz", policy=policy, V=V)

In [7]:
# Visualize Learned Policy
env = gym.make("Taxi-v3", render_mode="human")
state = env.reset()[0] 
done = False

while not done:
    env.render()
    action = policy[state]
    state, reward, done, truncate, info = env.step(action)
    time.sleep(0.25)
env.close()

  from pkg_resources import resource_stream, resource_exists
