Skip to content

Commit

Permalink
Support using @ex.capture on methods.
Browse files Browse the repository at this point in the history
When constructing arguments for a bound method invocation, the first
argument of the signature ('self') should be ignored, as it is passed
automatically.
  • Loading branch information
Treora committed Jun 17, 2015
1 parent da0449e commit 9037b38
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 20 deletions.
4 changes: 3 additions & 1 deletion sacred/config/captured_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def captured_function(wrapped, instance, args, kwargs):
options['_seed'] = get_seed(wrapped.rnd)
options['_rnd'] = create_rnd(options['_seed'])

args, kwargs = wrapped.signature.construct_arguments(args, kwargs, options)
bound = (instance is not None)
args, kwargs = wrapped.signature.construct_arguments(args, kwargs, options,
bound)
wrapped.logger.debug("Started")
start_time = time.time()
# =================== run actual function =================================
Expand Down
50 changes: 31 additions & 19 deletions sacred/config/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ def __init__(self, f):
self.positional_args = args[:len(args) - len(defaults)]
self.kwargs = OrderedDict(zip(args[-len(defaults):], defaults))

def get_free_parameters(self, args, kwargs):
return [a for a in self.arguments[len(args):] if a not in kwargs]
def get_free_parameters(self, args, kwargs, bound=False):
expected_args = self._get_expected_args(bound)
return [a for a in expected_args[len(args):] if a not in kwargs]

def construct_arguments(self, args, kwargs, options):
def construct_arguments(self, args, kwargs, options, bound=False):
"""
Construct args list and kwargs dictionary for this signature.
Expand All @@ -47,13 +48,15 @@ def construct_arguments(self, args, kwargs, options):
* conflicting values for a parameter in both args and kwargs
* there is an unfilled parameter at the end of this process
"""
self._assert_no_unexpected_args(args)
self._assert_no_unexpected_kwargs(kwargs)
self._assert_no_duplicate_args(args, kwargs)

args, kwargs = self._fill_in_options(args, kwargs, options)
expected_args = self._get_expected_args(bound)
self._assert_no_unexpected_args(expected_args, args)
self._assert_no_unexpected_kwargs(expected_args, kwargs)
self._assert_no_duplicate_args(expected_args, args, kwargs)

self._assert_no_missing_args(args, kwargs)
args, kwargs = self._fill_in_options(args, kwargs, options, bound)

self._assert_no_missing_args(args, kwargs, bound)
return args, kwargs

def __unicode__(self):
Expand All @@ -67,36 +70,45 @@ def __unicode__(self):
def __repr__(self):
return "<Signature at 0x{1:x} for '{0}'>".format(self.name, id(self))

def _assert_no_unexpected_args(self, args):
if not self.vararg_name and len(args) > len(self.arguments):
unexpected_args = args[len(self.arguments):]
def _get_expected_args(self, bound):
if bound:
# When called as instance method, the instance ('self') will be
# passed as first argument automatically, so the first argument
# should be excluded from the signature during this invocation.
return self.arguments[1:]
else:
return self.arguments

def _assert_no_unexpected_args(self, expected_args, args):
if not self.vararg_name and len(args) > len(expected_args):
unexpected_args = args[len(expected_args):]
raise TypeError("{} got unexpected argument(s): {}".format(
self.name, unexpected_args))

def _assert_no_unexpected_kwargs(self, kwargs):
def _assert_no_unexpected_kwargs(self, expected_args, kwargs):
if self.kw_wildcard_name:
return
unexpected_kwargs = set(kwargs) - set(self.arguments)
unexpected_kwargs = set(kwargs) - set(expected_args)
if unexpected_kwargs:
raise TypeError("{} got unexpected kwarg(s): {}".format(
self.name, sorted(unexpected_kwargs)))

def _assert_no_duplicate_args(self, args, kwargs):
positional_arguments = self.arguments[:len(args)]
def _assert_no_duplicate_args(self, expected_args, args, kwargs):
positional_arguments = expected_args[:len(args)]
duplicate_arguments = [v for v in positional_arguments if v in kwargs]
if duplicate_arguments:
raise TypeError("{} got multiple values for argument(s) {}".format(
self.name, duplicate_arguments))

def _fill_in_options(self, args, kwargs, options):
free_params = self.get_free_parameters(args, kwargs)
def _fill_in_options(self, args, kwargs, options, bound):
free_params = self.get_free_parameters(args, kwargs, bound)
for param in free_params:
if param in options:
kwargs[param] = options[param]
return args, kwargs

def _assert_no_missing_args(self, args, kwargs):
free_params = self.get_free_parameters(args, kwargs)
def _assert_no_missing_args(self, args, kwargs, bound):
free_params = self.get_free_parameters(args, kwargs, bound)
missing_args = [m for m in free_params if m not in self.kwargs]
if missing_args:
raise TypeError("{} is missing value(s) for {}".format(
Expand Down

0 comments on commit 9037b38

Please sign in to comment.