From 0d25e4a89f95cf4b9481933be76f7738f931ee45 Mon Sep 17 00:00:00 2001 From: Julien Danjou Date: Fri, 22 Dec 2017 10:39:14 +0100 Subject: [PATCH] Allow any validator to be compiled This allows any validator to be compiled by implementing the __voluptuous_compile__ method. This avoids having voluptuous.Any and voluptuous.All defining new Schema for sub-validators: they can be compiled recursively using the same parent schema. This solves the recursive Self case. Fixes #18 --- README.md | 15 ------- voluptuous/schema_builder.py | 2 + voluptuous/tests/tests.py | 49 +++++++++++++++++++++ voluptuous/validators.py | 85 +++++++++++++++++++++--------------- 4 files changed, 101 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index d4c38c8..8b6acac 100644 --- a/README.md +++ b/README.md @@ -455,21 +455,6 @@ True ``` -This only works if `Self` is used in the `Schema` directly. If you use `Any`, -`All` or `SomeOf`, this won't work as they compile their arguments down to a -new `Schema`. In that case, you can use an external reference: - -```pycon ->>> from voluptuous import Schema, Any ->>> def s2(v): -... return s1(v) -... ->>> s1 = Schema({"key": Any(s2, "value")}) ->>> s1({"key": {"key": "value"}}) -{'key': {'key': 'value'}} - -``` - ### Extending an existing Schema Often it comes handy to have a base `Schema` that is extended with more diff --git a/voluptuous/schema_builder.py b/voluptuous/schema_builder.py index 7a4e35f..dff3e55 100644 --- a/voluptuous/schema_builder.py +++ b/voluptuous/schema_builder.py @@ -276,6 +276,8 @@ def _compile(self, schema): return lambda _, v: v if schema is Self: return lambda p, v: self._compiled(p, v) + elif hasattr(schema, "__voluptuous_compile__"): + return schema.__voluptuous_compile__(self) if isinstance(schema, Object): return self._compile_object(schema) if isinstance(schema, collections.Mapping): diff --git a/voluptuous/tests/tests.py b/voluptuous/tests/tests.py index bd1e6c9..98a82ca 100644 --- a/voluptuous/tests/tests.py +++ b/voluptuous/tests/tests.py @@ -1085,6 +1085,55 @@ def test_self_validation(): schema({"follow": {"follow": {"number": 123456}}}) +def test_self_any(): + schema = Schema({"number": int, + "follow": Any(Self, "stop")}) + try: + schema({"number": "abc"}) + except MultipleInvalid: + pass + else: + assert False, "Did not raise Invalid" + try: + schema({"follow": {"number": '123456.712'}}) + except MultipleInvalid: + pass + else: + assert False, "Did not raise Invalid" + schema({"follow": {"number": 123456}}) + schema({"follow": {"follow": {"number": 123456}}}) + schema({"follow": {"follow": {"number": 123456, "follow": "stop"}}}) + + +def test_self_all(): + schema = Schema({"number": int, + "follow": All(Self, + Schema({"extra_number": int}, + extra=ALLOW_EXTRA))}, + extra=ALLOW_EXTRA) + try: + schema({"number": "abc"}) + except MultipleInvalid: + pass + else: + assert False, "Did not raise Invalid" + try: + schema({"follow": {"number": '123456.712'}}) + except MultipleInvalid: + pass + else: + assert False, "Did not raise Invalid" + schema({"follow": {"number": 123456}}) + schema({"follow": {"follow": {"number": 123456}}}) + schema({"follow": {"number": 123456, "extra_number": 123}}) + try: + schema({"follow": {"number": 123456, "extra_number": "123"}}) + except MultipleInvalid: + pass + else: + assert False, "Did not raise Invalid" + + def test_SomeOf_on_bounds_assertion(): with raises(AssertionError, 'when using "SomeOf" you should specify at least one of min_valid and max_valid'): SomeOf(validators=[]) diff --git a/voluptuous/validators.py b/voluptuous/validators.py index 138941a..c1448e2 100644 --- a/voluptuous/validators.py +++ b/voluptuous/validators.py @@ -181,7 +181,33 @@ def Boolean(v): return bool(v) -class Any(object): +class _WithSubValidators(object): + def __init__(self, *validators, **kwargs): + self.validators = validators + self.msg = kwargs.pop('msg', None) + + def __voluptuous_compile__(self, schema): + self._compiled = [ + schema._compile(v) + for v in self.validators + ] + return self._run + + def _run(self, path, value): + return self._exec(self._compiled, value, path) + + def __call__(self, v): + return self._exec((Schema(val) for val in self.validators), v) + + def __repr__(self): + return '%s(%s, msg=%r)' % ( + self.__class__.__name__, + ", ".join(repr(v) for v in self.validators), + self.msg + ) + + +class Any(_WithSubValidators): """Use the first validated value. :param msg: Message to deliver to user if validation fails. @@ -206,16 +232,14 @@ class Any(object): ... validate(4) """ - def __init__(self, *validators, **kwargs): - self.validators = validators - self.msg = kwargs.pop('msg', None) - self._schemas = [Schema(val, **kwargs) for val in validators] - - def __call__(self, v): + def _exec(self, funcs, v, path=None): error = None - for schema in self._schemas: + for func in funcs: try: - return schema(v) + if path is None: + return func(v) + else: + return func(path, v) except Invalid as e: if error is None or len(e.path) > len(error.path): error = e @@ -224,15 +248,12 @@ def __call__(self, v): raise error if self.msg is None else AnyInvalid(self.msg) raise AnyInvalid(self.msg or 'no valid value found') - def __repr__(self): - return 'Any([%s])' % (", ".join(repr(v) for v in self.validators)) - # Convenience alias Or = Any -class All(object): +class All(_WithSubValidators): """Value must pass all validators. The output of each validator is passed as input to the next. @@ -245,25 +266,17 @@ class All(object): 10 """ - def __init__(self, *validators, **kwargs): - self.validators = validators - self.msg = kwargs.pop('msg', None) - self._schemas = [Schema(val, **kwargs) for val in validators] - - def __call__(self, v): + def _exec(self, funcs, v, path=None): try: - for schema in self._schemas: - v = schema(v) + for func in funcs: + if path is None: + v = func(v) + else: + v = func(path, v) except Invalid as e: raise e if self.msg is None else AllInvalid(self.msg) return v - def __repr__(self): - return 'All(%s, msg=%r)' % ( - ", ".join(repr(v) for v in self.validators), - self.msg - ) - # Convenience alias And = All @@ -936,7 +949,7 @@ def _get_precision_scale(self, number): return (len(decimal_num.as_tuple().digits), -(decimal_num.as_tuple().exponent), decimal_num) -class SomeOf(object): +class SomeOf(_WithSubValidators): """Value must pass at least some validations, determined by the given parameter. Optionally, number of passed validations can be capped. @@ -965,19 +978,21 @@ def __init__(self, validators, min_valid=None, max_valid=None, **kwargs): 'when using "%s" you should specify at least one of min_valid and max_valid' % (type(self).__name__,) self.min_valid = min_valid or 0 self.max_valid = max_valid or len(validators) - self.validators = validators - self.msg = kwargs.pop('msg', None) - self._schemas = [Schema(val, **kwargs) for val in validators] + super(SomeOf, self).__init__(*validators, **kwargs) - def __call__(self, v): + def _exec(self, funcs, v, path=None): errors = [] - for schema in self._schemas: + funcs = list(funcs) + for func in funcs: try: - v = schema(v) + if path is None: + v = func(v) + else: + v = func(path, v) except Invalid as e: errors.append(e) - passed_count = len(self._schemas) - len(errors) + passed_count = len(funcs) - len(errors) if self.min_valid <= passed_count <= self.max_valid: return v