In [1]:
from typing import Tuple, Dict, Optional, Iterable

import gym
from gym import spaces
from gym.error import DependencyNotInstalled

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.animation
from matplotlib import animation

from IPython.display import HTML

%matplotlib inline

In [2]:
def display_video(frames):
    orig_backend = matplotlib.get_backend()
    matplotlib.use('Agg')
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    matplotlib.use(orig_backend)
    ax.set_axis_off()
    ax.set_aspect('equal')
    ax.set_position([0, 0, 1, 1])
    im = ax.imshow(frames[0])
    def update(frame):
        im.set_data(frame)
        return [im]
    anim = animation.FuncAnimation(fig=fig, func=update, frames=frames,
                                    interval=500, blit=True, repeat=False)
    return HTML(anim.to_html5_video())


def test_agent(env, policy):
    frames = []
    state = env.reset()
    done = False
    frames.append(env.render(mode="rgb_array"))

    while not done:
        action_probs = policy[state]
        action = np.random.choice(range(len(action_probs)), p=action_probs)
        next_state, reward, terminated, truncated, info = env.step(action)
        frame = env.render(mode="rgb_array")
        frames.append(frame)
        state = next_state
        done = terminated or truncated

    return display_video(frames)

  and should_run_async(code)


In [3]:
def policy_evaluation(policy, env, discount_factor=0.99, theta=1e-5):

    V = np.zeros(env.observation_space.n)
    while True:
        delta = 0
        for state in range(env.observation_space.n):
            v = 0
            for action, action_prob in enumerate(policy[state]):
                for prob, next_state, reward, done in env.P[state][action]:
                    v += action_prob * prob * (reward + discount_factor * V[next_state])
            delta = max(delta, np.abs(v - V[state]))
            V[state] = v
        if delta < theta:
            break
    return np.array(V)

def policy_improvement(env, V, discount_factor=0.99):

    policy = np.zeros([env.observation_space.n, env.action_space.n])
    for state in range(env.observation_space.n):
        action_values = np.zeros(env.action_space.n)
        for action in range(env.action_space.n):
            for prob, next_state, reward, done in env.P[state][action]:
                action_values[action] += prob * (reward + discount_factor * V[next_state])
        best_action = np.argmax(action_values)
        policy[state, best_action] = 1.0
    return policy

def policy_iteration(env, discount_factor=0.99, theta=1e-5):

    policy = np.ones([env.observation_space.n, env.action_space.n]) / env.action_space.n
    while True:
        V = policy_evaluation(policy, env, discount_factor, theta)
        new_policy = policy_improvement(env, V, discount_factor)
        if np.all(policy == new_policy):
            break
        policy = new_policy
    return policy, V

env = gym.make('Taxi-v3', new_step_api=True)
optimal_policy, optimal_value = policy_iteration(env)

print("Optimal Policy:")
print(optimal_policy)
print("\nOptimal Value Function:")
print(optimal_value)


Optimal Policy:
[[0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1. 0.]
 ...
 [0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0.]]

Optimal Value Function:
[944.72316569 864.01270478 903.55686813 873.75019718 789.53759087
 864.01270478 789.53757231 816.76645272 864.01272333 826.02673924
 903.55686813 835.38053503 807.59881631 826.02673924 807.59879776
 873.75019718 955.27593403 873.75021631 913.69381566 883.58606752
 934.27593403 854.37257773 893.52129945 864.01269521 798.52282815
 873.75021631 798.52280978 826.02672968 854.3725961  816.76647185
 893.52129945 826.02672968 816.76649022 835.38055415 816.76647185
 883.58606752 944.72317469 883.58608645 903.5568775  893.52128998
 883.58610464 807.59880713 844.82885195 816.76646238 844.82887014
 923.9331565  844.82885195 873.75020684 844.82887014 807.59880713
 883.58608645 816.76646238 826.0267668  844.82885195 826.02674861
 893.52128998 893.52132673 934.27592494 893.52130873 903.55686813
 873.7502436  798.52281906 835.3805

In [4]:
test_agent(env,optimal_policy)

See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
  if not isinstance(terminated, (bool, np.bool8)):


In [5]:
def value_iteration(env, discount_factor=0.99, theta=1e-5):

    V = np.zeros(env.observation_space.n)
    while True:
        delta = 0
        for state in range(env.observation_space.n):
            action_values = np.zeros(env.action_space.n)
            for action in range(env.action_space.n):
                for prob, next_state, reward, done in env.P[state][action]:
                    action_values[action] += prob * (reward + discount_factor * V[next_state])
            best_action_value = np.max(action_values)
            delta = max(delta, np.abs(best_action_value - V[state]))
            V[state] = best_action_value
        if delta < theta:
            break

    policy = np.zeros([env.observation_space.n, env.action_space.n])
    for state in range(env.observation_space.n):
        action_values = np.zeros(env.action_space.n)
        for action in range(env.action_space.n):
            for prob, next_state, reward, done in env.P[state][action]:
                action_values[action] += prob * (reward + discount_factor * V[next_state])
        best_action = np.argmax(action_values)
        policy[state, best_action] = 1.0
    return policy, V

env = gym.make('Taxi-v3', new_step_api=True)
optimal_policy, optimal_value = value_iteration(env)

print("Optimal Policy:")
print(optimal_policy)
print("\nOptimal Value Function:")
print(optimal_value)


  and should_run_async(code)


Optimal Policy:
[[0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1. 0.]
 ...
 [0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0.]]

Optimal Value Function:
[944.72316569 864.01270478 903.55686813 873.75019718 789.53759087
 864.01270478 789.53757231 816.76645272 864.01272333 826.02673924
 903.55686813 835.38053503 807.59881631 826.02673924 807.59879776
 873.75019718 955.27593403 873.75021631 913.69381566 883.58606752
 934.27593403 854.37257773 893.52129945 864.01269521 798.52282815
 873.75021631 798.52280978 826.02672968 854.3725961  816.76647185
 893.52129945 826.02672968 816.76649022 835.38055415 816.76647185
 883.58606752 944.72317469 883.58608645 903.5568775  893.52128998
 883.58610464 807.59880713 844.82885195 816.76646238 844.82887014
 923.9331565  844.82885195 873.75020684 844.82887014 807.59880713
 883.58608645 816.76646238 826.0267668  844.82885195 826.02674861
 893.52128998 893.52132673 934.27592494 893.52130873 903.55686813
 873.7502436  798.52281906 835.3805

In [6]:
test_agent(env,optimal_policy)

See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
  if not isinstance(terminated, (bool, np.bool8)):
