# Graphical Interface to RLAI
This JupyterLab notebook provides a graphical interface to most of the functionality implemented in the `rlai` package. In essence, it is a tool for composing command-line interface (CLI) commands, which are then passed to `rlai` for processing (see [here](https://matthewgerber.github.io/rlai/cli_guide.html) for details of the CLI). The notebook adds a few features like pausing/resuming `rlai` execution, displaying interactive figures of training progress, and saving/loading commands to file. Beyond these features, the notebook does not add much to the core functionality of `rlai`. See the `rlai` [website](https://matthewgerber.github.io/rlai) for more information, including other ways to use `rlai` (e.g., incorporating it as a package into another project).

To get started, select Run -> Run All Cells from the menu. The interface supports the following command operations, organized from top to bottom.

1. Load:  A list of previously saved commands that can be loaded. This includes example commands provided by `rlai` as well as commands that you create and save.
1. Compose:  A sequence of tabs used to build up `rlai` commands or edit one that you load.
1. View:  A readout of the command. The command can either be executed directly in this notebook or copy/pasted into a terminal window.
1. Save:  Save the currently composed command to file.
1. Run:  Run the current command, with options for pausing, resumption, termination, and graphical display.

In [1]:
%matplotlib widget

from rlai.runners.trainer import run
from rlai.agents.mdp import StochasticMdpAgent
from rlai.gpi.temporal_difference.evaluation import Mode
from rlai.value_estimation.tabular import TabularStateActionValueEstimator
from rlai.gpi.temporal_difference.iteration import iterate_value_q_pi
from numpy.random import RandomState
from rlai.environments.gridworld import Gridworld
import matplotlib.pyplot as plt
from rlai.gpi.utils import plot_policy_iteration, update_policy_iteration_plot
from matplotlib.animation import FuncAnimation
from threading import Thread, Event
import traceback
import ipywidgets as widgets
from IPython.display import display
from rlai.utils import get_argument_parser
from rlai.runners.trainer import get_argument_parser_for_train_function
from argparse import RawTextHelpFormatter
import re
from os import listdir, path
import pickle
import shlex
import time
from rlai.runners.trainer import get_argument_parser_for_run
from rlai.utils import RunThreadManager

COMMANDS_DIR = './commands/'

def get_and_trim_help(
    selected_identifier
):
    if selected_identifier is None:
        return ''
    
    excluded_args = {
        'help',
        'q-S-A',
        'function-approximation-model',
        'feature-extractor',
        'agent',
        'environment',
        'planning-environment',
        'train-function',
        'resume'
    }

    if selected_identifier.endswith('.run'):
        parser = get_argument_parser_for_run()
    elif selected_identifier.endswith('.iterate_value_q_pi'):
        parser = get_argument_parser_for_train_function(selected_identifier)
    else:
        parser = get_argument_parser(selected_identifier)
        
    parser.formatter_class = RawTextHelpFormatter
    help_text = parser.format_help()
        
    lines = list(filter(None, help_text.split('\n')))
    start_i = next(i for i, line in enumerate(lines) if line.startswith('optional arguments:')) + 1
    help_text = '\n'.join(lines[start_i:])
    arg_entries = help_text.split('  --')
    arg_entries = [
        re.sub(' [A-Z_0-9]+ ', ' VALUE ', ' '.join(s.strip() for s in arg_entry.split('\n')), count=1)
        for i, arg_entry in enumerate(arg_entries)
        if not any(arg_entry.startswith(exclusion) for exclusion in excluded_args)
    ]
    
    help_text = '\n\n--'.join(arg_entries).strip('\n')
    help_text = re.sub(' +', ' ', help_text)

    if help_text == '':
        help_text = '(no arguments)'

    return help_text
    

def get_tab_children(
    names_classes
):
    names_classes = list(sorted(names_classes))
    
    class_dropdown = widgets.Dropdown(
        options=[('', None)] + names_classes,
        description='Type',
        value=None,
        layout=widgets.Layout(width='90%')
    )

    args = widgets.Textarea(
        description='Arguments',
        layout=widgets.Layout(width='90%', height='250px')
    )

    help_text = widgets.Textarea(
        description='Help',
        layout=widgets.Layout(width='90%', height='250px'),
        disabled=True
    )

    def on_change(change):
        if change['type'] == 'change' and change['name'] == 'value':
            selected_identifier = change['new']
            help_text.value = get_and_trim_help(selected_identifier)

    class_dropdown.observe(on_change)
    
    layout = widgets.VBox([
        class_dropdown,
        args,
        help_text
    ])
    
    return layout

tab_layouts_names = [
    (
        get_tab_children(
            names_classes=[
                ('Train', 'rlai.runners.trainer.run')
            ]
        ),
        'Runner'
    ),
    (
        get_tab_children(
            names_classes=[
                ('Stochastic MDP Agent', 'rlai.agents.mdp.StochasticMdpAgent')
            ]
        ),
        'Agent'
    ),
    (
        get_tab_children(
            names_classes=[
                ("Gambler's Problem", 'rlai.environments.gamblers_problem.GamblersProblem'),
                ('Gridworld', 'rlai.environments.gridworld.Gridworld'),
                ('Mancala', 'rlai.environments.mancala.Mancala'),
                ('OpenAI Gym', 'rlai.environments.openai_gym.Gym'),
                ('Robocode', 'rlai.environments.robocode.RobocodeEnvironment')
            ]
        ),
        'Environment'
    ),
    (
        get_tab_children(
            names_classes=[
                ('Monte Carlo Q-Value', 'rlai.gpi.monte_carlo.iteration.iterate_value_q_pi'),
                ('Temporal-Difference Q-Value', 'rlai.gpi.temporal_difference.iteration.iterate_value_q_pi')
            ]
        ),
        'Learning Algorithm'
    ),
    (
        get_tab_children(
            names_classes=[
                ('Approximate State-Action Value Estimator', 'rlai.value_estimation.function_approximation.estimators.ApproximateStateActionValueEstimator'),
                ('Tabular State-Action Value Estimator', 'rlai.value_estimation.tabular.TabularStateActionValueEstimator')
            ]
        ),
        'Value Estimator'        
    ),
    (
        get_tab_children(
            names_classes=[
                ('Stochastic Gradient Descent', 'rlai.value_estimation.function_approximation.models.sklearn.SKLearnSGD')
            ]
        ),
        'Approximation Model'  
    ),
    (
        get_tab_children(
            names_classes=[
                ('Cartpole Feature Extractor', 'rlai.environments.openai_gym.CartpoleFeatureExtractor'),
                ('Gridworld Feature Extractor', 'rlai.environments.gridworld.GridworldFeatureExtractor'),
                ('Robocode Feature Extractor', 'rlai.environments.robocode.RobocodeFeatureExtractor')
            ]
        ),
        'Feature Extractor'
    )
]

##############################
# initialize all ui controls #
##############################

tab = widgets.Tab()
tab.children = [tab_layout for tab_layout, _ in tab_layouts_names]
for i, (_, name) in enumerate(tab_layouts_names):
    tab.set_title(i, name)
    
load_command_dropdown = widgets.Dropdown(
    description='Available commands',
    layout=widgets.Layout(width='auto'),
    style={'description_width': 'auto'}
)

load_command_button = widgets.Button(
    description='Load'
)

command_text = widgets.Textarea(
    description='Command',
    layout=widgets.Layout(width='90%', height='100px'),
    disabled=True
)

save_command_text = widgets.Text(
    description='Name'
)

save_command_button = widgets.Button(
    description='Save'
)

start_pause_resume_button = widgets.Button(
    description='Start',
    disabled=True
)

terminate_button = widgets.Button(
    description='Terminate',
    disabled=True
)

#############################
# ui control event handlers #
#############################

# refresh commands available in folder
def refresh_available_commands():
    load_command_dropdown.options = [
        command_path
        for command_path in listdir(COMMANDS_DIR)
        if command_path.endswith('.rlai')
    ]
    
refresh_available_commands()

# load selected command
def load_command(b):
    if load_command_dropdown.value is not None:
        try:
            with open(f'{COMMANDS_DIR}{load_command_dropdown.value}', 'rb') as f:
                command = pickle.load(f)
                
            for tab_values, layout in zip(command, tab.children):
                layout.children[0].value = tab_values[0]  # dropdown selection
                layout.children[1].value = tab_values[1]  # argument values
                
        except Exception as ex:
            print(f'Error loading command:  {ex}')                

load_command_button.on_click(load_command)

# update the command text when any tab's dropdown or arguments change
def format_args(layout):
    return ' '.join(l for l in layout.children[1].value.split('\n') if l != '')

def update_command_text(change):
    
    # compile arguments for each tab, starting with the agent. the first tab is a bit of a dummy (see below).
    args = ' '.join([
        
        (f'{arg} {layout.children[0].value} ' + format_args(layout)).strip()
        
        for arg, layout in zip([
            '--agent',
            '--environment',
            '--train-function',
            '--q-S-A',
            '--function-approximation-model',
            '--feature-extractor'
        ], tab.children[1:])  # first layout is the trainer, which is handled below
        
        # only include if dropdown value is selected
        if layout.children[0].value is not None
    ])
    
    # prefix runner args from the first tab
    runner_args = format_args(tab.children[0])
    if runner_args != '':
        args = runner_args + ' ' + args
    
    command_text.value = '' if args == '' else f'rlai train {args}'

for layout in tab.children:
    layout.children[0].observe(update_command_text)
    layout.children[1].observe(update_command_text)

# enable the start button when there is command text
def command_text_changed(change):
    start_pause_resume_button.disabled = change['new'] == ''
    
command_text.observe(command_text_changed)
    
# save current command
def save_command(b):
    if save_command_text.value is not None and save_command_text.value != '':
        save_command_path = f'{COMMANDS_DIR}{save_command_text.value}.rlai'
        with open(save_command_path, 'wb') as f:
            pickle.dump(
                [
                    (
                        layout.children[0].value,
                        layout.children[1].value
                    )
                    for layout in tab.children
                ],
                f
            )
            
        refresh_available_commands()

save_command_button.on_click(save_command)

###################
# display layouts #
###################

# construct and display top-level layout
top_level_layout = widgets.VBox([
    widgets.HBox([load_command_dropdown, load_command_button]),
    tab,
    command_text,
    widgets.HBox([save_command_text, save_command_button]),
    widgets.HBox([start_pause_resume_button, terminate_button])
])

widget_output = widgets.Output()

with widget_output:
    display(top_level_layout)

##########################
# primary thread objects #
##########################

train_thread = None
thread_manager = None

################################
# set up policy iteration plot #
################################

# construct initial plot to animate and display it in its own output widget
policy_iteration_plot_output = widgets.Output()

with policy_iteration_plot_output:
    policy_iteration_fig = plot_policy_iteration(
        iteration_average_reward=[],
        iteration_total_states=[],
        iteration_num_states_improved=[],
        elapsed_seconds_average_rewards={},
        pdf=None
    )

policy_iteration_plot_animator = None

# animation for policy iteration plot
def animate_policy_iteration_fig(
    i
):
    update_policy_iteration_plot()
    
    # the animator might not be assigned right away, depending on whether FuncAnimation waits an 
    # interval before calling the current function.
    if policy_iteration_plot_animator is not None:
        
        # stop animator if the training thread is not yet initialized (start hasn't been clicked)
        # or if the thread is no longer running (it finished).
        if train_thread is None or not train_thread.is_alive():
            policy_iteration_plot_animator.event_source.stop()

policy_iteration_plot_animator = FuncAnimation(
    policy_iteration_fig, 
    animate_policy_iteration_fig, 
    frames=100000000000, 
    interval=5000,
    repeat=False
)

#########################
# set up estimator plot #
#########################

# the pattern is slightly different here, because the estimator and its associated plotting
# figure don't become available until after the training thread has started and kicked them
# back to us. initialize all to None.

estimator = None
estimator_fig = None
estimator_plot_animator = None

# create slider to select the estimator iteration to display
estimator_plot_slider = widgets.IntSlider(value=0, min=0, max=0, step=1, description='Iteration')

# display estimator plot in its own output
estimator_plot_output = widgets.Output()

# animation for estimator plot
def animate_estimator_fig(
    i
):
    # by the time the current function has been called by the animator function, the 
    # estimator and its figure will have been initialized. so there's no need to check 
    # them for None. the sliders minimum value will always be zero, so subtracting 1
    # will pass -1 when the slider is at its minimum. per the documentation, passing -1
    # will plot update the plot to show the most recent iteration.
    estimator.update_plot(estimator_plot_slider.value - 1)

    # update maximum value of slider. to make this consistent with the previous call, we
    # set the max to be the previous iteration value. if the environment is very slow to 
    # complete evaluations (e.g., robocode), it's possible that no evaluations will have 
    # been completed by the time the current function is called. we can't set max < min, 
    # so default max to zero if we haven't yet finished an evaluation.
    estimator_plot_slider.max = max(estimator.evaluation_policy_improvement_count - 1, 0)

    # the animator might not be assigned right away, depending on how FuncAnimation works.
    if estimator_plot_animator is not None:

        # stop animator if the training thread is not yet initialized (start hasn't been clicked)
        # or if the thread is no longer running (it finished).
        if train_thread is None or not train_thread.is_alive():
            estimator_plot_animator.event_source.stop()

# update the estimator plot when the slider is changed
def estimator_plot_slider_changed(change):
    if estimator is not None:
        animate_estimator_fig(None)

estimator_plot_slider.observe(estimator_plot_slider_changed)

##########################
# set up training thread #
##########################

train_args = None
train_args_wait = Event()

def train_args_callback(args):
    
    global train_args
    train_args = args
    
    # let the main thread know that the training arguments have arrived from the 
    # training thread.
    train_args_wait.set()

def train(thread_manager):
    
    try:
        # the first two arguments will be 'rlai train', which aren't passed to run.
        args = shlex.split(command_text.value)[2:]
        run(args, thread_manager, train_args_callback)
    except:
        pass
    
    # training finished. reset run control buttons. the animators will stop themselves
    # the next time they check the state of the thread.
    start_pause_resume_button.description = 'Start'
    start_pause_resume_button.disabled = False
    terminate_button.disabled = True

# hook up start and terminate buttons
def start_pause_resume_button_clicked(b):
    
    global train_thread
    global thread_manager
    global estimator
    global estimator_fig
    global estimator_plot_animator
    
    if start_pause_resume_button.description == 'Start':
        
        # reset all to None, as we're about to begin a new training run.
        estimator = None
        estimator_fig = None
        estimator_plot_animator = None
        estimator_plot_slider.value = estimator_plot_slider.max = 0
        
        # start thread and wait for training arguments to come back
        train_args_wait.clear()
        thread_manager = RunThreadManager(True)
        train_thread = Thread(target=train, args=(thread_manager,))
        train_thread.start()
        
        # start the policy iteration animator. it's already been configured above.
        policy_iteration_plot_animator.event_source.start()
        
        # wait for training arguments to come back
        train_args_wait.wait()
    
        estimator = train_args['q_S_A']
        
        # clear the plot's display widget and create the initial estimator plot
        estimator_plot_output.clear_output()
        with estimator_plot_output:

            # hack the `final` parameter to force the plot to go through. we can't rely on the
            # rendering schedule to be in any particular state, since it's off in another thread.
            estimator_fig = estimator.plot(True, None)  

        # not all estimators return a plot (e.g., tabular currently does not)
        if estimator_fig is not None:
            
            # display slider
            with estimator_plot_output:
                display(estimator_plot_slider)
            
            estimator_plot_animator = FuncAnimation(
                estimator_fig, 
                animate_estimator_fig, 
                frames=100000000000, 
                interval=5000,
                repeat=False
            )
        
        start_pause_resume_button.description = 'Pause'
        terminate_button.disabled = False

    elif start_pause_resume_button.description == 'Pause':
        
        thread_manager.clear()
        policy_iteration_plot_animator.event_source.stop()
        
        # we won't have an animator if the estimator didn't return a plot
        if estimator_plot_animator is not None:
            estimator_plot_animator.event_source.stop()
            
        start_pause_resume_button.description = 'Resume'
        
    elif start_pause_resume_button.description == 'Resume':
        
        thread_manager.set()
        policy_iteration_plot_animator.event_source.start()
        
        # we won't have an animator if the estimator didn't return a plot
        if estimator_plot_animator is not None:
            estimator_plot_animator.event_source.start()
            
        start_pause_resume_button.description = 'Pause'
    
start_pause_resume_button.on_click(start_pause_resume_button_clicked)

def terminate_button_clicked(b):
    thread_manager.abort = True
    thread_manager.set()
    terminate_button.disabled = True
    start_pause_resume_button.disabled = True
    
terminate_button.on_click(terminate_button_clicked)
    
display(widget_output)
display(policy_iteration_plot_output)
display(estimator_plot_output)

Output()

Output()

Output()