In [3]:
import gym.spaces
import matplotlib.pyplot as plt
import numpy as np
from qlearning_template import QLearningAgent


def play_and_train(env, agent, t_max=10 ** 4):
    """ This function should
    - run a full game (for t_max steps), actions given by agent
    - train agent whenever possible
    - return total reward
    """
    total_reward = 0.0
    state = env.reset()

    for _ in range(t_max):
        a = agent.get_action(state)
        new_state, reward, done, _ = env.step(a)

        total_reward += reward
        agent.update(state, a, new_state, reward)
        state = new_state

        if done:
            break

    return total_reward


if __name__ == '__main__':
    max_iterations = 5000
    visualize = False
    # Create Taxi-v2 env
    env = gym.make('Taxi-v2').env
    env.reset()
    env.render()

    n_states = env.observation_space.n
    n_actions = env.action_space.n

    print('States number = %i, Actions number = %i' % (n_states, n_actions))

    # create q learning agent with
    # alpha=0.5
    # get_legal_actions = lambda s: range(n_actions)
    # epsilon=0.1
    # discount=0.99

    agent = QLearningAgent(alpha = 0.5, get_legal_actions=lambda s: range(n_actions), epsilon=0.1, discount=0.99)

    plt.figure(figsize=[10, 4])
    rewards = []

    # Training loop
    for i in range(max_iterations):
        # Play & train game
        # Update rewards
        rewards.append(play_and_train(env, agent))
        # Decay agent epsilon

        if i % 100 == 0:
            agent.epsilon *= 0.9
            print('Iteration {}, Average reward {:.2f}, Epsilon {:.3f}'.format(i, np.mean(rewards), agent.epsilon))

        if visualize:
            plt.subplot(1, 2, 1)
            plt.plot(rewards, color='r')
            plt.xlabel('Iterations')
            plt.ylabel('Total Reward')

            plt.subplot(1, 2, 2)
            plt.hist(rewards, bins=20, range=[-700, +20], color='blue', label='Rewards distribution')
            plt.xlabel('Reward')
            plt.ylabel('p(Reward)')
            plt.draw()
            plt.pause(0.05)
            plt.cla()

+---------+
|R: | : :[34;1mG[0m|
| : : : : |
| : : : : |
| | : | :[43m [0m|
|Y| : |[35mB[0m: |
+---------+

States number = 500, Actions number = 6
Iteration 0, Average reward -2500.00, Epsilon 0.090
Iteration 100, Average reward -288.00, Epsilon 0.081
Iteration 200, Average reward -161.79, Epsilon 0.073
Iteration 300, Average reward -108.69, Epsilon 0.066
Iteration 400, Average reward -81.12, Epsilon 0.059
Iteration 500, Average reward -64.08, Epsilon 0.053
Iteration 600, Average reward -52.52, Epsilon 0.048
Iteration 700, Average reward -44.21, Epsilon 0.043
Iteration 800, Average reward -37.96, Epsilon 0.039
Iteration 900, Average reward -33.02, Epsilon 0.035
Iteration 1000, Average reward -29.02, Epsilon 0.031
Iteration 1100, Average reward -25.78, Epsilon 0.028
Iteration 1200, Average reward -23.03, Epsilon 0.025
Iteration 1300, Average reward -20.73, Epsilon 0.023
Iteration 1400, Average reward -18.81, Epsilon 0.021
Iteration 1500, Average reward -17.07, Epsilon 0.019
Itera

<Figure size 720x288 with 0 Axes>