Skip to content

Commit

Permalink
first draft of a pre_run feature
Browse files Browse the repository at this point in the history
  • Loading branch information
Qwlouse committed May 4, 2015
1 parent aa270ee commit 612ec04
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 3 deletions.
27 changes: 27 additions & 0 deletions examples/pre_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/usr/bin/env python
# coding=utf-8
from __future__ import division, print_function, unicode_literals
from sacred import Experiment

ex = Experiment('hello_config_scope')


# A ConfigScope is a function like this decorated with @ex.config
# All local variables of this function will be put into the configuration
@ex.config
def cfg():
recipient = "world"
message = "Hello %s!" % recipient


@ex.pre_run
def foobar(command_name, config_updates, named_configs):
print('FOOOOBAR:', command_name)
config_updates['recipient'] = 'FOOO'
return command_name, config_updates, named_configs


# again we can access the message here by taking it as an argument
@ex.automain
def main(message):
print(message)
7 changes: 7 additions & 0 deletions sacred/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class Ingredient(object):
def __init__(self, path, ingredients=(), _generate_seed=False,
_caller_globals=None):
self.path = path
self._pre_run = None
self.cfgs = []
self.named_configs = dict()
self.ingredients = list(ingredients)
Expand All @@ -49,6 +50,12 @@ def __init__(self, path, ingredients=(), _generate_seed=False,
self.current_run = None

# =========================== Decorators ==================================
def pre_run(self, func):
if self._pre_run is None:
self._pre_run = func
else:
raise RuntimeError('Can only have one pre_run!')

def command(self, function=None, prefix=None):
"""
Decorator to define a new command for this Ingredient or Experiment.
Expand Down
19 changes: 16 additions & 3 deletions sacred/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,7 @@ def initialize_logging(experiment, scaffolding, loglevel=None):
return root_logger.getChild(experiment.name)


def create_scaffolding(experiment):
sorted_ingredients = gather_ingredients_topological(experiment)
def create_scaffolding(experiment, sorted_ingredients):
scaffolding = OrderedDict()
for ingredient in sorted_ingredients:
scaffolding[ingredient] = Scaffold(
Expand Down Expand Up @@ -259,9 +258,23 @@ def get_command(scaffolding, command_path):
raise KeyError('Command "%s" not found' % command_name)


def execute_pre_runs(ingredients, command_name, config_updates, named_configs):
args = (command_name, config_updates, named_configs)
for ingred in ingredients:
if ingred._pre_run:
args = ingred._pre_run(*args)
return args


def create_run(experiment, command_name, config_updates=None, log_level=None,
named_configs=()):
scaffolding = create_scaffolding(experiment)

sorted_ingredients = gather_ingredients_topological(experiment)
scaffolding = create_scaffolding(experiment, sorted_ingredients)

command_name, config_updates, named_configs = \
execute_pre_runs(sorted_ingredients, command_name, config_updates,
named_configs)

distribute_config_updates(scaffolding, config_updates)
distribute_named_configs(scaffolding, named_configs)
Expand Down

0 comments on commit 612ec04

Please sign in to comment.