Install
  * brew install node
  * jupyter labextension install @jupyter-widgets/jupyterlab-manager
  * jupyter labextension install jupyter-matplotlib

In [1]:
%matplotlib widget

from rlai.runners.trainer import run
from rlai.agents.mdp import StochasticMdpAgent
from rlai.gpi.temporal_difference.evaluation import Mode
from rlai.value_estimation.tabular import TabularStateActionValueEstimator
from rlai.gpi.temporal_difference.iteration import iterate_value_q_pi
from numpy.random import RandomState
from rlai.environments.gridworld import Gridworld
import matplotlib.pyplot as plt
from rlai.gpi.utils import plot_policy_iteration, update_policy_iteration_plot
from matplotlib.animation import FuncAnimation
from threading import Thread
import traceback
import ipywidgets as widgets
from IPython.display import display
from rlai.utils import get_argument_parser
from rlai.runners.trainer import get_argument_parser_for_train_function
from argparse import RawTextHelpFormatter
import re
from os import listdir, path
import pickle
import shlex
from rlai.runners.trainer import get_argument_parser_for_run

COMMANDS_DIR = './commands/'

# Training

In [2]:
def get_and_trim_help(
    selected_identifier
):
    if selected_identifier is None:
        return ''
    
    excluded_args = {
        'help',
        'q-S-A',
        'function-approximation-model',
        'feature-extractor',
        'agent',
        'environment',
        'planning-environment',
        'train-function',
        'resume'
    }

    if selected_identifier.endswith('.run'):
        parser = get_argument_parser_for_run()
    elif selected_identifier.endswith('.iterate_value_q_pi'):
        parser = get_argument_parser_for_train_function(selected_identifier)
    else:
        parser = get_argument_parser(selected_identifier)
        
    parser.formatter_class = RawTextHelpFormatter
    help_text = parser.format_help()
        
    lines = list(filter(None, help_text.split('\n')))
    start_i = next(i for i, line in enumerate(lines) if line.startswith('optional arguments:')) + 1
    help_text = '\n'.join(lines[start_i:])
    arg_entries = help_text.split('  --')
    arg_entries = [
        re.sub(' [A-Z_0-9]+ ', ' VALUE ', ' '.join(s.strip() for s in arg_entry.split('\n')), count=1)
        for i, arg_entry in enumerate(arg_entries)
        if not any(arg_entry.startswith(exclusion) for exclusion in excluded_args)
    ]
    
    help_text = '\n\n--'.join(arg_entries).strip('\n')
    
    help_text = re.sub(' +', ' ', help_text)
    
    if help_text == '':
        help_text = '(no arguments)'

    return help_text
    

def get_tab_children(
    names_classes
):
    names_classes = list(sorted(names_classes))
    
    class_dropdown = widgets.Dropdown(
        options=[('', None)] + names_classes,
        description='Type',
        value=None
    )

    args = widgets.Textarea(
        description='Arguments',
        layout=widgets.Layout(width='90%', height='100px')
    )

    help_text = widgets.Textarea(
        description='Help',
        layout=widgets.Layout(width='90%', height='400px'),
        disabled=True
    )

    def on_change(change):
        if change['type'] == 'change' and change['name'] == 'value':
            selected_identifier = change['new']
            help_text.value = get_and_trim_help(selected_identifier)

    class_dropdown.observe(on_change)
    
    layout = widgets.VBox([
        class_dropdown,
        args,
        help_text
    ])
    
    return layout

tab_layouts_names = [
    (
        get_tab_children(
            names_classes=[
                ('Train', 'rlai.runners.trainer.run')
            ]
        ),
        'Runner'
    ),
    (
        get_tab_children(
            names_classes=[
                ('Stochastic MDP Agent', 'rlai.agents.mdp.StochasticMdpAgent')
            ]
        ),
        'Agent'
    ),
    (
        get_tab_children(
            names_classes=[
                ("Gambler's Problem", 'rlai.environments.gamblers_problem.GamblersProblem'),
                ('Mancala', 'rlai.environments.mancala.Mancala'),
                ('OpenAI Gym', 'rlai.environments.openai_gym.Gym'),
                ('Gridworld', 'rlai.environments.gridworld.Gridworld')
            ]
        ),
        'Environment'
    ),
    (
        get_tab_children(
            names_classes=[
                ('Monte Carlo Q-Value', 'rlai.gpi.monte_carlo.iteration.iterate_value_q_pi'),
                ('Temporal-Difference Q-Value', 'rlai.gpi.temporal_difference.iteration.iterate_value_q_pi')
            ]
        ),
        'Learning Algorithm'
    ),
    (
        get_tab_children(
            names_classes=[
                ('Tabular State-Action Value Estimator', 'rlai.value_estimation.tabular.TabularStateActionValueEstimator'),
                ('Approximate State-Action Value Estimator', 'rlai.value_estimation.function_approximation.estimators.ApproximateStateActionValueEstimator')
            ]
        ),
        'Value Estimator'        
    ),
    (
        get_tab_children(
            names_classes=[
                ('Stochastic Gradient Descent', 'rlai.value_estimation.function_approximation.models.sklearn.SKLearnSGD')
            ]
        ),
        'Approximation Model'  
    ),
    (
        get_tab_children(
            names_classes=[
                ('Cartpole Feature Extractor', 'rlai.environments.openai_gym.CartpoleFeatureExtractor'),
                ('Gridworld Feature Extractor', 'rlai.environments.gridworld.GridworldFeatureExtractor')
            ]
        ),
        'Feature Extractor'
    )
]

tab = widgets.Tab()
tab.children = [tab_layout for tab_layout, _ in tab_layouts_names]
for i, (_, name) in enumerate(tab_layouts_names):
    tab.set_title(i, name)
    
command_text = widgets.Textarea(
    description='Command',
    layout=widgets.Layout(width='90%', height='100px'),
    disabled=True
)

def format_args(layout):
    return ' '.join(l for l in layout.children[1].value.split('\n') if l != '')

def update_command_text(change):
    
    args = ' '.join([
        (f'{arg} {layout.children[0].value} ' + format_args(layout)).strip()
        for arg, layout in zip([
            '--agent',
            '--environment',
            '--train-function',
            '--q-S-A',
            '--function-approximation-model',
            '--feature-extractor'
        ], tab.children[1:])  # first layout is the trainer
        if layout.children[0].value is not None
    ])
    
    runner_args = format_args(tab.children[0])
    if runner_args != '':
        args = runner_args + ' ' + args
    
    command_text.value = '' if args == '' else f'rlai train {args}'

# update the command text when any tab's type or arguments change
for layout in tab.children:
    layout.children[0].observe(update_command_text)
    layout.children[1].observe(update_command_text)

start_stop_button = widgets.Button(
    description='Start',
    disabled=True
)

def command_text_changed(change):
    start_stop_button.disabled = change['new'] == ''
    
command_text.observe(command_text_changed)

# load command
load_command_dropdown = widgets.Dropdown()

def refresh_available_commands():
    load_command_dropdown.options = [
        command_path
        for command_path in listdir(COMMANDS_DIR)
        if command_path.endswith('.rlai')
    ]
    
refresh_available_commands()

load_command_button = widgets.Button(
    description='Load Command'
)

def load_command(b):
    if load_command_dropdown.value is not None:
        
        try:
            with open(f'{COMMANDS_DIR}{load_command_dropdown.value}', 'rb') as f:
                command = pickle.load(f)
                
            for tab_values, layout in zip(command, tab.children):
                layout.children[0].value = tab_values[0]
                layout.children[1].value = tab_values[1]
                
        except Exception as ex:
            print(f'Error loading command:  {ex}')                

load_command_button.on_click(load_command)

# save command
save_command_text = widgets.Text(
    description='Name'
)

save_command_button = widgets.Button(
    description='Save Command'
)

def save_command(b):
    
    if save_command_text.value is not None and save_command_text.value != '':
        save_command_path = f'{COMMANDS_DIR}{save_command_text.value}.rlai'
        with open(save_command_path, 'wb') as f:
            pickle.dump(
                [
                    (
                        layout.children[0].value,
                        layout.children[1].value
                    )
                    for layout in tab.children
                ],
                f
            )
            
        refresh_available_commands()

save_command_button.on_click(save_command)

# top-level layout
top_level_layout = widgets.VBox([
    tab,
    command_text,
    start_stop_button,
    widgets.HBox([save_command_text, save_command_button]),
    widgets.HBox([load_command_dropdown, load_command_button])
])

# primary output
output = widgets.Output()
with output:
    display(top_level_layout)

def train():
    
    # first two arguments will be 'rlai train', which aren't passed to run.
    args = shlex.split(command_text.value)[2:]
    run(args)
    
train_t = Thread(target=train)
   
def animate(
        i
    ):
        update_policy_iteration_plot()

with output:
    fig = plot_policy_iteration(
        iteration_average_reward=[],
        iteration_total_states=[],
        iteration_num_states_improved=[],
        elapsed_seconds_average_rewards={},
        pdf=None
    )

ani = FuncAnimation(
    fig, 
    animate, 
    frames=1000, 
    interval=1000,
    repeat=False
)

def start_stop_button_clicked(b):
    
    if start_stop_button.description == 'Start':
        
        ani.event_source.start()
        
        if train_t.is_alive():
            # todo:  pause thread with Event passed into the iteration method
            pass
        else:
            train_t.start()
            
        start_stop_button.description = 'Stop'
    else:
        ani.event_source.stop()
        start_stop_button.description = 'Start'
    
start_stop_button.on_click(start_stop_button_clicked)
    
display(output)

Output()