Skip to content

Commit

Permalink
Merge pull request #89 from Treora/support_method_capture
Browse files Browse the repository at this point in the history
Support using @ex.capture on methods.
  • Loading branch information
Qwlouse committed Jun 17, 2015
2 parents da0449e + 8f30dac commit f1979f7
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 21 deletions.
4 changes: 3 additions & 1 deletion sacred/config/captured_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def captured_function(wrapped, instance, args, kwargs):
options['_seed'] = get_seed(wrapped.rnd)
options['_rnd'] = create_rnd(options['_seed'])

args, kwargs = wrapped.signature.construct_arguments(args, kwargs, options)
bound = (instance is not None)
args, kwargs = wrapped.signature.construct_arguments(args, kwargs, options,
bound)
wrapped.logger.debug("Started")
start_time = time.time()
# =================== run actual function =================================
Expand Down
50 changes: 31 additions & 19 deletions sacred/config/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ def __init__(self, f):
self.positional_args = args[:len(args) - len(defaults)]
self.kwargs = OrderedDict(zip(args[-len(defaults):], defaults))

def get_free_parameters(self, args, kwargs):
return [a for a in self.arguments[len(args):] if a not in kwargs]
def get_free_parameters(self, args, kwargs, bound=False):
expected_args = self._get_expected_args(bound)
return [a for a in expected_args[len(args):] if a not in kwargs]

def construct_arguments(self, args, kwargs, options):
def construct_arguments(self, args, kwargs, options, bound=False):
"""
Construct args list and kwargs dictionary for this signature.
Expand All @@ -47,13 +48,15 @@ def construct_arguments(self, args, kwargs, options):
* conflicting values for a parameter in both args and kwargs
* there is an unfilled parameter at the end of this process
"""
self._assert_no_unexpected_args(args)
self._assert_no_unexpected_kwargs(kwargs)
self._assert_no_duplicate_args(args, kwargs)

args, kwargs = self._fill_in_options(args, kwargs, options)
expected_args = self._get_expected_args(bound)
self._assert_no_unexpected_args(expected_args, args)
self._assert_no_unexpected_kwargs(expected_args, kwargs)
self._assert_no_duplicate_args(expected_args, args, kwargs)

self._assert_no_missing_args(args, kwargs)
args, kwargs = self._fill_in_options(args, kwargs, options, bound)

self._assert_no_missing_args(args, kwargs, bound)
return args, kwargs

def __unicode__(self):
Expand All @@ -67,36 +70,45 @@ def __unicode__(self):
def __repr__(self):
return "<Signature at 0x{1:x} for '{0}'>".format(self.name, id(self))

def _assert_no_unexpected_args(self, args):
if not self.vararg_name and len(args) > len(self.arguments):
unexpected_args = args[len(self.arguments):]
def _get_expected_args(self, bound):
if bound:
# When called as instance method, the instance ('self') will be
# passed as first argument automatically, so the first argument
# should be excluded from the signature during this invocation.
return self.arguments[1:]
else:
return self.arguments

def _assert_no_unexpected_args(self, expected_args, args):
if not self.vararg_name and len(args) > len(expected_args):
unexpected_args = args[len(expected_args):]
raise TypeError("{} got unexpected argument(s): {}".format(
self.name, unexpected_args))

def _assert_no_unexpected_kwargs(self, kwargs):
def _assert_no_unexpected_kwargs(self, expected_args, kwargs):
if self.kw_wildcard_name:
return
unexpected_kwargs = set(kwargs) - set(self.arguments)
unexpected_kwargs = set(kwargs) - set(expected_args)
if unexpected_kwargs:
raise TypeError("{} got unexpected kwarg(s): {}".format(
self.name, sorted(unexpected_kwargs)))

def _assert_no_duplicate_args(self, args, kwargs):
positional_arguments = self.arguments[:len(args)]
def _assert_no_duplicate_args(self, expected_args, args, kwargs):
positional_arguments = expected_args[:len(args)]
duplicate_arguments = [v for v in positional_arguments if v in kwargs]
if duplicate_arguments:
raise TypeError("{} got multiple values for argument(s) {}".format(
self.name, duplicate_arguments))

def _fill_in_options(self, args, kwargs, options):
free_params = self.get_free_parameters(args, kwargs)
def _fill_in_options(self, args, kwargs, options, bound):
free_params = self.get_free_parameters(args, kwargs, bound)
for param in free_params:
if param in options:
kwargs[param] = options[param]
return args, kwargs

def _assert_no_missing_args(self, args, kwargs):
free_params = self.get_free_parameters(args, kwargs)
def _assert_no_missing_args(self, args, kwargs, bound):
free_params = self.get_free_parameters(args, kwargs, bound)
missing_args = [m for m in free_params if m not in self.kwargs]
if missing_args:
raise TypeError("{} is missing value(s) for {}".format(
Expand Down
13 changes: 12 additions & 1 deletion tests/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def onlykwrgs(**kwargs):
kwarg_list = [{}, {}, {'a': 1, 'b': 'fo', 'c': 9}, {'c': 3}, {}, {}, {}, {}]


class SomeClass:
def bla(self, a, b, c):
return a, b, c


# ####################### Tests #############################################

@pytest.mark.parametrize("function, name", zip(functions, names), ids=ids)
Expand Down Expand Up @@ -237,7 +242,6 @@ def test_construct_arguments_completes_kwargs_from_options():
args, kwargs = s.construct_arguments([2, 4], {}, {'c': 6})
assert args == [2, 4]
assert kwargs == {'c': 6}

s = Signature(complex_function_name)
args, kwargs = s.construct_arguments([], {'c': 6, 'b': 7}, {'a': 1})
assert args == []
Expand Down Expand Up @@ -302,6 +306,13 @@ def test_construct_arguments_does_not_raise_for_missing_defaults():
s.construct_arguments([], {}, {})


def test_construct_arguments_for_bound_method():
s = Signature(SomeClass.bla)
args, kwargs = s.construct_arguments([1], {'b': 2}, {'c': 3}, bound=True)
assert args == [1]
assert kwargs == {'b': 2, 'c': 3}


@pytest.mark.parametrize('func,expected', [
(foo, "foo()"),
(bariza, "bariza(a, b, c)"),
Expand Down

0 comments on commit f1979f7

Please sign in to comment.