Install
  * brew install node
  * jupyter labextension install @jupyter-widgets/jupyterlab-manager
  * jupyter labextension install jupyter-matplotlib

# Agent Training

In [1]:
%matplotlib widget

from rlai.agents.mdp import StochasticMdpAgent
from rlai.gpi.temporal_difference.evaluation import Mode
from rlai.value_estimation.tabular import TabularStateActionValueEstimator
from rlai.gpi.temporal_difference.iteration import iterate_value_q_pi
from numpy.random import RandomState
from rlai.environments.gridworld import Gridworld
import matplotlib.pyplot as plt
from rlai.gpi.utils import plot_policy_iteration, update_policy_iteration_plot
from matplotlib.animation import FuncAnimation
from threading import Thread
import traceback

fig = plot_policy_iteration(
    iteration_average_reward=[],
    iteration_total_states=[],
    iteration_num_states_improved=[],
    elapsed_seconds_average_rewards={},
    pdf=None
)

def animate(
    i
):
    try:
        update_policy_iteration_plot()
    except Exception as ex:
        with open(f'/Users/mvg0419/Desktop/log_{i}_exception.txt', 'w') as f:
            f.write(f'{ex}')
            traceback.print_exc(file=f)

ani = FuncAnimation(
    fig, 
    animate, 
    frames=1000, 
    interval=1000,
    repeat=False
)

def train():
    
    random_state = RandomState(12345)

    mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)

    epsilon = 0.05

    q_S_A = TabularStateActionValueEstimator(mdp_environment, epsilon, None)

    mdp_agent = StochasticMdpAgent(
        'test',
        random_state,
        q_S_A.get_initial_policy(),
        1
    )

    iterate_value_q_pi(
        agent=mdp_agent,
        environment=mdp_environment,
        num_improvements=1000000,
        num_episodes_per_improvement=50,
        num_updates_per_improvement=None,
        alpha=0.1,
        mode=Mode.SARSA,
        n_steps=1,
        epsilon=epsilon,
        planning_environment=None,
        make_final_policy_greedy=True,
        q_S_A=q_S_A,
        num_improvements_per_plot=10
    )
    
train_t = Thread(target=train)
train_t.start()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Value iteration 1:  Running temporal-difference evaluation of q_pi for 50 episode(s).
Finished 2 of 50 episode(s).
Finished 4 of 50 episode(s).
Finished 6 of 50 episode(s).
Finished 8 of 50 episode(s).
Finished 10 of 50 episode(s).
Finished 12 of 50 episode(s).
Finished 14 of 50 episode(s).
Finished 16 of 50 episode(s).
Finished 18 of 50 episode(s).
Finished 20 of 50 episode(s).
Finished 22 of 50 episode(s).
Finished 24 of 50 episode(s).
Finished 26 of 50 episode(s).
Finished 28 of 50 episode(s).
Finished 30 of 50 episode(s).
Finished 32 of 50 episode(s).
Finished 34 of 50 episode(s).
Finished 36 of 50 episode(s).
Finished 38 of 50 episode(s).
Finished 40 of 50 episode(s).
Finished 42 of 50 episode(s).
Finished 44 of 50 episode(s).
Finished 46 of 50 episode(s).
Finished 48 of 50 episode(s).
Finished 50 of 50 episode(s).
Value iteration 2:  Running temporal-difference evaluation of q_pi for 50 episode(s).
Finished 2 of 50 episode(s).
Finished 4 of 50 episode(s).
