Skip to content

Commit

Permalink
Merge c229931 into a181df5
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrieldemarmiesse committed Aug 2, 2019
2 parents a181df5 + c229931 commit fe8cd74
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 45 deletions.
20 changes: 8 additions & 12 deletions sacred/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sacred.ingredient import Ingredient
from sacred.initialize import create_run
from sacred.utils import print_filtered_stacktrace, ensure_wellformed_argv, \
SacredError, format_sacred_error, PathType
SacredError, format_sacred_error, PathType, join_paths

__all__ = ('Experiment',)

Expand Down Expand Up @@ -403,17 +403,13 @@ def log_scalar(self, name: str,
# The same as Run.log_scalar
self.current_run.log_scalar(name, value, step)

def _gather(self, func):
"""
Removes the experiment's path (prefix) from the names of the gathered
items. This means that, for example, 'experiment.print_config' becomes
'print_config'.
"""
for ingredient, _ in self.traverse_ingredients():
for name, item in func(ingredient):
if ingredient == self:
name = name[len(self.path) + 1:]
yield name, item
def post_process_name(self, name, ingredient):
if ingredient == self:
# Removes the experiment's path (prefix) from the names
# of the gathered items. This means that, for example,
# 'experiment.print_config' becomes 'print_config'.
return name[len(self.path) + 1:]
return name

def get_default_options(self):
"""Get a dictionary of default options as used with run.
Expand Down
48 changes: 15 additions & 33 deletions sacred/ingredient.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,6 @@ def collect_repositories(sources):
for s in sources if s.repo]


@wrapt.decorator
def gather_from_ingredients(wrapped, instance=None, args=None, kwargs=None):
"""
Decorator that calls `_gather` on the instance the wrapped function is
bound to (should be an `Ingredient`) and yields from the returned
generator.
This function is necessary, because `Ingredient._gather` cannot directly be
used as a decorator inside of `Ingredient`.
"""
yield from instance._gather(wrapped)


class Ingredient:
"""
Ingredients are reusable parts of experiments.
Expand Down Expand Up @@ -285,21 +272,11 @@ def add_package_dependency(self, package_name, version):
raise ValueError('Invalid Version: "{}"'.format(version))
self.dependencies.add(PackageDependency(package_name, version))

def _gather(self, func):
"""
Function needed and used by gathering functions through the decorator
`gather_from_ingredients` in `Ingredient`. Don't use this function by
itself outside of the decorator!
def post_process_name(self, name, ingredient):
""" Can be overridden to change the command name."""
return name

By overwriting this function you can filter what is visible when
gathering something (e.g. commands). See `Experiment._gather` for an
example.
"""
for ingredient, _ in self.traverse_ingredients():
yield from func(ingredient)

@gather_from_ingredients
def gather_commands(self, ingredient):
def gather_commands(self):
"""Collect all commands from this ingredient and its sub-ingredients.
Yields
Expand All @@ -309,11 +286,13 @@ def gather_commands(self, ingredient):
cmd: function
The corresponding captured function.
"""
for command_name, command in ingredient.commands.items():
yield join_paths(ingredient.path, command_name), command
for ingredient, _ in self.traverse_ingredients():
for command_name, command in ingredient.commands.items():
cmd_name = join_paths(ingredient.path, command_name)
cmd_name = self.post_process_name(cmd_name, ingredient)
yield cmd_name, command

@gather_from_ingredients
def gather_named_configs(self, ingredient):
def gather_named_configs(self):
"""Collect all named configs from this ingredient and its
sub-ingredients.
Expand All @@ -324,8 +303,11 @@ def gather_named_configs(self, ingredient):
config: ConfigScope or ConfigDict or basestring
The corresponding named config.
"""
for config_name, config in ingredient.named_configs.items():
yield join_paths(ingredient.path, config_name), config
for ingredient, _ in self.traverse_ingredients():
for config_name, config in ingredient.named_configs.items():
config_name = join_paths(ingredient.path, config_name)
config_name = self.post_process_name(config_name, ingredient)
yield config_name, config

def get_experiment_info(self):
"""Get a dictionary with information about this experiment.
Expand Down

0 comments on commit fe8cd74

Please sign in to comment.