Install
  * brew install node
  * jupyter labextension install @jupyter-widgets/jupyterlab-manager
  * jupyter labextension install jupyter-matplotlib

In [1]:
%matplotlib widget

from rlai.agents.mdp import StochasticMdpAgent
from rlai.gpi.temporal_difference.evaluation import Mode
from rlai.value_estimation.tabular import TabularStateActionValueEstimator
from rlai.gpi.temporal_difference.iteration import iterate_value_q_pi
from numpy.random import RandomState
from rlai.environments.gridworld import Gridworld
import matplotlib.pyplot as plt
from rlai.gpi.utils import plot_policy_iteration, update_policy_iteration_plot
from matplotlib.animation import FuncAnimation
from threading import Thread
import traceback
import ipywidgets as widgets
from IPython.display import display
from rlai.utils import get_help
from rlai.runners.trainer import get_argument_parser_for_train_function

# Training

In [5]:
def get_tab_children(
    names_classes
):
    names_classes = list(sorted(names_classes))
    
    class_dropdown = widgets.Dropdown(
        options=names_classes,
        description='Type',
        value=None
    )

    args = widgets.Textarea(
        description='Arguments',
        layout=widgets.Layout(width='90%', height='100px')
    )

    help_text = widgets.Textarea(
        description='Help',
        layout=widgets.Layout(width='90%', height='500px', overflow_y='auto')
    )

    def on_change(change):
        
        if change['type'] == 'change' and change['name'] == 'value':
            selected = change['new']
            if selected.endswith('iterate_value_q_pi'):
                help_text.value = get_argument_parser_for_train_function(selected).format_help()
            else:
                help_text.value = get_help(selected)

    class_dropdown.observe(on_change)
    
    layout = widgets.VBox([
        class_dropdown,
        args,
        help_text
    ])
    
    return layout

layout_names = [
    (
        get_tab_children(
            names_classes=[
                ('Stochastic MDP Agent', 'rlai.agents.mdp.StochasticMdpAgent')
            ]
        ),
        'Agent'
    ),
    (
        get_tab_children(
            names_classes=[
                ("Gambler's Problem", 'rlai.environments.gamblers_problem.GamblersProblem'),
                ('Mancala', 'rlai.environments.mancala.Mancala'),
                ('OpenAI Gym', 'rlai.environments.openai_gym.Gym')
            ]
        ),
        'Environment'
    ),
    (
        get_tab_children(
            names_classes=[
                ('Monte Carlo Q-Value', 'rlai.gpi.monte_carlo.iteration.iterate_value_q_pi'),
                ('Temporal-Difference Q-Value', 'rlai.gpi.temporal_difference.iteration.iterate_value_q_pi')
            ]
        ),
        'Learning Algorithm'
    ),
    (
        get_tab_children(
            names_classes=[
                ('Tabular State-Action Value Estimator', 'rlai.value_estimation.tabular.TabularStateActionValueEstimator'),
                ('Approximate State-Action Value Estimator', 'rlai.value_estimation.function_approximation.estimators.ApproximateStateActionValueEstimator')
            ]
        ),
        'Value Estimator'        
    ),
    (
        get_tab_children(
            names_classes=[
                ('Stochastic Gradient Descent', 'rlai.value_estimation.function_approximation.models.sklearn.SKLearnSGD')
            ]
        ),
        'Approximation Model'  
    ),
    (
        get_tab_children(
            names_classes=[
                ('Cartpole Feature Extractor', 'rlai.environments.openai_gym.CartpoleFeatureExtractor')
            ]
        ),
        'Feature Extractor'
    )
]

# top-level tab layout
tab = widgets.Tab()
tab.children = [layout for layout, _ in layout_names]
for i, (_, name) in enumerate(layout_names):
    tab.set_title(i, name)

# primary output
output = widgets.Output()
with output:
    display(tab)
    
display(output)

Output()

# Agent Training

In [1]:


def train():
    
    random_state = RandomState(12345)

    mdp_environment: Gridworld = Gridworld.example_4_1(random_state, None)

    epsilon = 0.05

    q_S_A = TabularStateActionValueEstimator(mdp_environment, epsilon, None)

    mdp_agent = StochasticMdpAgent(
        'test',
        random_state,
        q_S_A.get_initial_policy(),
        1
    )

    iterate_value_q_pi(
        agent=mdp_agent,
        environment=mdp_environment,
        num_improvements=1000000,
        num_episodes_per_improvement=50,
        num_updates_per_improvement=None,
        alpha=0.1,
        mode=Mode.SARSA,
        n_steps=1,
        epsilon=epsilon,
        planning_environment=None,
        make_final_policy_greedy=True,
        q_S_A=q_S_A,
        num_improvements_per_plot=10
    )
    
train_t = Thread(target=train)
    
start_stop_button = widgets.Button(description='Start')
output = widgets.Output()

with output:
    display(start_stop_button)

display(output)

def animate(
        i
    ):
        try:
            update_policy_iteration_plot()
        except Exception as ex:
            with open(f'/Users/mvg0419/Desktop/log_{i}_exception.txt', 'w') as f:
                f.write(f'{ex}')
                traceback.print_exc(file=f)

with output:
    fig = plot_policy_iteration(
        iteration_average_reward=[],
        iteration_total_states=[],
        iteration_num_states_improved=[],
        elapsed_seconds_average_rewards={},
        pdf=None
    )

ani = FuncAnimation(
    fig, 
    animate, 
    frames=1000, 
    interval=1000,
    repeat=False
)

def start_stop_button_clicked(b):
    
    if start_stop_button.description == 'Start':
        
        ani.event_source.start()
        
        if train_t.is_alive():
            # todo:  pause thread with Event passed into the iteration method
            pass
        else:
            train_t.start()
            
        start_stop_button.description = 'Stop'
    else:
        ani.event_source.stop()
        start_stop_button.description = 'Start'
    
start_stop_button.on_click(start_stop_button_clicked)


Output()