Skip to content

Commit

Permalink
added option_hook
Browse files Browse the repository at this point in the history
  • Loading branch information
Qwlouse committed Jan 5, 2017
1 parent e43aaaf commit b54edb3
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 26 deletions.
93 changes: 67 additions & 26 deletions sacred/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sacred.ingredient import Ingredient
from sacred.initialize import create_run
from sacred.utils import print_filtered_stacktrace
from sacred.config.signature import Signature
__sacred__ = True # marks files that should be filtered from stack traces

__all__ = ('Experiment',)
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(self, name=None, ingredients=(), interactive=False):
self.command(print_dependencies, unobserved=True)
self.observers = []
self.current_run = None
self.option_hooks = []

# =========================== Decorators ==================================

Expand Down Expand Up @@ -125,12 +127,33 @@ def my_main():
self.run_commandline()
return captured

def option_hook(self, function):
"""
Decorator for adding an option hook function.
An option hook is a function that is called right before a run
is created. It receives (and potentially modifies) the options
dictionary. That is, the dictionary of commandline options used for
this run.
NOTE: The decorated function MUST have an argument called options.
NOTE: While the options still contain the COMMAND and UPDATE entries,
changing them has no effect. Only flags (starting with '--') can
be modified.
"""
sig = Signature(function)
if "options" not in sig.arguments:
raise KeyError("option_hook functions must have an argument called"
" 'options', but got {}".format(sig.arguments))
self.option_hooks.append(function)
return function

# =========================== Public Interface ============================

def run(self, command_name=None, config_updates=None, named_configs=(),
meta_info=None, options=None):
"""
Run the main function of the experiment.
Run the main function of the experiment or a given command.
Parameters
----------
Expand All @@ -154,32 +177,9 @@ def run(self, command_name=None, config_updates=None, named_configs=(),
sacred.run.Run
the Run object corresponding to the finished run
"""
command_name = command_name or self.default_command
if command_name is None:
raise RuntimeError('No command found to be run. Specify a command '
'or define a main function.')

default_options = self.get_default_options()
if options:
default_options.update(options)
options = default_options

run = create_run(self, command_name, config_updates,
named_configs=named_configs,
force=options.get(ForceOption.get_flag(), False))

if meta_info:
run.meta_info.update(meta_info)

self.current_run = run

for option in gather_command_line_options():
option_value = options.get(option.get_flag(), False)
if option_value:
option.apply(option_value, run)
run = self._create_run(command_name, config_updates, named_configs,
meta_info, options)
run()
self.current_run = None

return run

def run_command(self, command_name, config_updates=None,
Expand Down Expand Up @@ -216,6 +216,14 @@ def run_commandline(self, argv=None):
argv = sys.argv
elif isinstance(argv, basestring):
argv = shlex.split(argv)
else:
if not isinstance(argv, (list, tuple)):
raise ValueError("argv must be str or list, but was {}"
.format(type(argv)))
if not all([isinstance(a, basestring) for a in argv]):
problems = [a for a in argv if not isinstance(a, basestring)]
raise ValueError("argv must be list of str but contained the "
"following elements: {}".format(problems))

all_commands = self.gather_commands()

Expand Down Expand Up @@ -334,3 +342,36 @@ def get_default_options(self):
description=self.doc,
commands=OrderedDict(all_commands))
return {k: v for k, v in args.items() if k.startswith('--')}

# =========================== Internal Interface ==========================

def _create_run(self, command_name=None, config_updates=None,
named_configs=(), meta_info=None, options=None):
command_name = command_name or self.default_command
if command_name is None:
raise RuntimeError('No command found to be run. Specify a command '
'or define a main function.')

default_options = self.get_default_options()
if options:
default_options.update(options)
options = default_options

# call option hooks
for oh in self.option_hooks:
oh(options=options)

run = create_run(self, command_name, config_updates,
named_configs=named_configs,
force=options.get(ForceOption.get_flag(), False))

if meta_info:
run.meta_info.update(meta_info)

for option in gather_command_line_options():
option_value = options.get(option.get_flag(), False)
if option_value:
option.apply(option_value, run)

self.current_run = run
return run
20 changes: 20 additions & 0 deletions tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,23 @@ def run(a):

assert ex.run().result == 1
assert ex.run(named_configs=['ncfg']).result == 10


def test_adding_option_hooks(ex):
@ex.option_hook
def hook(options):
pass

@ex.option_hook
def hook2(options):
pass

assert hook in ex.option_hooks
assert hook2 in ex.option_hooks


def test_option_hooks_without_options_arg_raises(ex):
with pytest.raises(KeyError):
@ex.option_hook
def invalid_hook(wrong_arg_name):
pass

0 comments on commit b54edb3

Please sign in to comment.