Skip to content
This repository has been archived by the owner on Dec 11, 2022. It is now read-only.

Commit

Permalink
extract functions display_all_presets_and_exit, expand_preset
Browse files Browse the repository at this point in the history
  • Loading branch information
zach-nervana committed Oct 23, 2018
1 parent 21f8ca3 commit 3714d8e
Showing 1 changed file with 39 additions and 26 deletions.
65 changes: 39 additions & 26 deletions rl_coach/coach.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,39 @@ def get_graph_manager_from_args(args: argparse.Namespace) -> 'GraphManager':
return graph_manager


def display_all_presets_and_exit():
# list available presets
if args.list:
screen.log_title("Available Presets:")
for preset in sorted(list_all_presets()):
print(preset)
sys.exit(0)

def expand_preset(preset):
if preset.lower() in [p.lower() for p in list_all_presets()]:
preset = "{}.py:graph_manager".format(os.path.join(get_base_dir(), 'presets', preset))
else:
preset = "{}".format(preset)
# if a graph manager variable was not specified, try the default of :graph_manager
if len(preset.split(":")) == 1:
preset += ":graph_manager"

# verify that the preset exists
preset_path = preset.split(":")[0]
if not os.path.exists(preset_path):
screen.error("The given preset ({}) cannot be found.".format(preset))

# verify that the preset can be instantiated
try:
short_dynamic_import(preset, ignore_module_case=True)
except TypeError as e:
traceback.print_exc()
screen.error('Internal Error: ' + str(e) + "\n\nThe given preset ({}) cannot be instantiated."
.format(preset))

return preset


def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
"""
Parse the arguments that the user entered
Expand All @@ -103,38 +136,15 @@ def parse_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
# if no arg is given
if len(sys.argv) == 1:
parser.print_help()
exit(0)
sys.exit(0)

# list available presets
preset_names = list_all_presets()
if args.list:
screen.log_title("Available Presets:")
for preset in sorted(preset_names):
print(preset)
sys.exit(0)
display_all_presets_and_exit()

# replace a short preset name with the full path
if args.preset is not None:
if args.preset.lower() in [p.lower() for p in preset_names]:
args.preset = "{}.py:graph_manager".format(os.path.join(get_base_dir(), 'presets', args.preset))
else:
args.preset = "{}".format(args.preset)
# if a graph manager variable was not specified, try the default of :graph_manager
if len(args.preset.split(":")) == 1:
args.preset += ":graph_manager"

# verify that the preset exists
preset_path = args.preset.split(":")[0]
if not os.path.exists(preset_path):
screen.error("The given preset ({}) cannot be found.".format(args.preset))

# verify that the preset can be instantiated
try:
short_dynamic_import(args.preset, ignore_module_case=True)
except TypeError as e:
traceback.print_exc()
screen.error('Internal Error: ' + str(e) + "\n\nThe given preset ({}) cannot be instantiated."
.format(args.preset))
args.preset = expand_preset(args.preset)

# validate the checkpoints args
if args.checkpoint_restore_dir is not None and not os.path.exists(args.checkpoint_restore_dir):
Expand Down Expand Up @@ -179,6 +189,9 @@ def add_items_to_dict(target_dict, source_dict):


def open_dashboard(experiment_path):
"""
open X11 based dashboard in a new process (nonblocking)
"""
dashboard_path = 'python {}/dashboard.py'.format(get_base_dir())
cmd = "{} --experiment_dir {}".format(dashboard_path, experiment_path)
screen.log_title("Opening dashboard - experiment path: {}".format(experiment_path))
Expand Down

0 comments on commit 3714d8e

Please sign in to comment.