Skip to content

Commit

Permalink
Simple predicate rewriting
Browse files Browse the repository at this point in the history
This is a minimal proof-of-concept for predicate rewriting, sufficient to demonstrate that it works without breaking anything else.  Equally importantly, it adds clear hooks for future work to allow open collaboration.
  • Loading branch information
Zac-HD committed Feb 8, 2021
1 parent e9ac4f5 commit 1767cbb
Show file tree
Hide file tree
Showing 7 changed files with 288 additions and 4 deletions.
11 changes: 11 additions & 0 deletions hypothesis-python/RELEASE.rst
@@ -0,0 +1,11 @@
RELEASE_TYPE: patch

This release lays the groundwork for automatic rewriting of simple filters,
for example converting ``integers().filter(lambda x: x > 9)`` to
``integers(min_value=10)``.

Note that this is **not supported yet**, and we will continue to recommend
writing the efficient form directly wherever possible - predicate rewriting
is provided mainly for the benefit of downstream libraries which would
otherwise have to implement it for themselves (e.g. :pypi:`pandera` and
:pypi:`icontract-hypothesis`). See :issue:`2701` for details.
108 changes: 108 additions & 0 deletions hypothesis-python/src/hypothesis/internal/filtering.py
@@ -0,0 +1,108 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Most of this work is copyright (C) 2013-2021 David R. MacIver
# (david@drmaciver.com), but it contains contributions by others. See
# CONTRIBUTING.rst for a full list of people who may hold copyright, and
# consult the git log if you need to determine who owns an individual
# contribution.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
#
# END HEADER

"""Tools for understanding predicates, to satisfy them by construction.
For example::
integers().filter(lamda x: x >= 0) -> integers(min_value=0)
This is intractable in general, but reasonably easy for simple cases involving
numeric bounds, strings with length or regex constraints, and collection lengths -
and those are precisely the most common cases. When they arise in e.g. Pandas
dataframes, it's also pretty painful to do the constructive version by hand in
a library; so we prefer to share all the implementation effort here.
See https://github.com/HypothesisWorks/hypothesis/issues/2701 for details.
"""

import operator
from decimal import Decimal
from fractions import Fraction
from functools import partial
from typing import Any, Callable, Mapping, Optional, Tuple, TypeVar

from hypothesis.internal.compat import ceil, floor

Ex = TypeVar("Ex")
Predicate = Callable[[Ex], bool]

ConstructivePredicate = Tuple[Mapping[str, Any], Optional[Predicate]]
"""Return kwargs to the appropriate strategy, and the predicate if needed.
For example::
integers().filter(lambda x: x >= 0)
-> {"min_value": 0"}, None
integers().filter(lambda x: x >= 0 and x % 7)
-> {"min_value": 0"}, lambda x: x % 7
At least in principle - for now we usually return the predicate unchanged
if needed.
We have a separate get-predicate frontend for each "group" of strategies; e.g.
for each numeric type, for strings, for bytes, for collection sizes, etc.
"""


def get_numeric_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
"""Shared logic for understanding numeric bounds.
We then specialise this in the other functions below, to ensure that e.g.
all the values are representable in the types that we're planning to generate
so that the strategy validation doesn't complain.
"""
if (
type(predicate) is partial
and len(predicate.args) == 1
and not predicate.keywords
):
arg = predicate.args[0]
if (isinstance(arg, Decimal) and Decimal.is_snan(arg)) or not isinstance(
arg, (int, float, Fraction, Decimal)
):
return {}, predicate
options = {
# We're talking about op(arg, x) - the reverse of our usual intuition!
operator.lt: {"min_value": arg, "exclude_min": True}, # lambda x: arg < x
operator.le: {"min_value": arg}, # lambda x: arg <= x
operator.eq: {"min_value": arg, "max_value": arg}, # lambda x: arg == x
operator.ge: {"max_value": arg}, # lambda x: arg >= x
operator.gt: {"max_value": arg, "exclude_max": True}, # lambda x: arg > x
}
if predicate.func in options:
return options[predicate.func], None

# TODO: handle lambdas by AST analysis

return {}, predicate


def get_integer_predicate_bounds(predicate: Predicate) -> ConstructivePredicate:
kwargs, predicate = get_numeric_predicate_bounds(predicate)

if "min_value" in kwargs:
if kwargs["min_value"] != int(kwargs["min_value"]):
kwargs["min_value"] = ceil(kwargs["min_value"])
elif kwargs.get("exclude_min", False):
kwargs["min_value"] = int(kwargs["min_value"]) + 1
if "max_value" in kwargs:
if kwargs["max_value"] != int(kwargs["max_value"]):
kwargs["max_value"] = floor(kwargs["max_value"])
elif kwargs.get("exclude_max", False):
kwargs["max_value"] = int(kwargs["max_value"]) - 1

kwargs = {k: v for k, v in kwargs.items() if k in {"min_value", "max_value"}}
return kwargs, predicate
22 changes: 20 additions & 2 deletions hypothesis-python/src/hypothesis/strategies/_internal/lazy.py
Expand Up @@ -20,6 +20,7 @@
arg_string,
convert_keyword_arguments,
convert_positional_arguments,
get_pretty_function_description,
)
from hypothesis.strategies._internal.strategies import SearchStrategy

Expand Down Expand Up @@ -63,20 +64,25 @@ def unwrap_strategies(s):
assert unwrap_depth >= 0


def _repr_filter(condition):
return f".filter({get_pretty_function_description(condition)})"


class LazyStrategy(SearchStrategy):
"""A strategy which is defined purely by conversion to and from another
strategy.
Its parameter and distribution come from that other strategy.
"""

def __init__(self, function, args, kwargs, *, force_repr=None):
def __init__(self, function, args, kwargs, filters=(), *, force_repr=None):
SearchStrategy.__init__(self)
self.__wrapped_strategy = None
self.__representation = force_repr
self.function = function
self.__args = args
self.__kwargs = kwargs
self.__filters = filters

@property
def supports_find(self):
Expand Down Expand Up @@ -110,8 +116,19 @@ def wrapped_strategy(self):
self.__wrapped_strategy = self.function(
*unwrapped_args, **unwrapped_kwargs
)
for f in self.__filters:
self.__wrapped_strategy = self.__wrapped_strategy.filter(f)
return self.__wrapped_strategy

def filter(self, condition):
return LazyStrategy(
self.function,
self.__args,
self.__kwargs,
self.__filters + (condition,),
force_repr=f"{self!r}{_repr_filter(condition)}",
)

def do_validate(self):
w = self.wrapped_strategy
assert isinstance(w, SearchStrategy), f"{self!r} returned non-strategy {w!r}"
Expand Down Expand Up @@ -140,9 +157,10 @@ def __repr__(self):
for k, v in defaults.items():
if k in kwargs_for_repr and kwargs_for_repr[k] is v:
del kwargs_for_repr[k]
self.__representation = "{}({})".format(
self.__representation = "{}({}){}".format(
self.function.__name__,
arg_string(self.function, _args, kwargs_for_repr, reorder=False),
"".join(map(_repr_filter, self.__filters)),
)
return self.__representation

Expand Down
17 changes: 16 additions & 1 deletion hypothesis-python/src/hypothesis/strategies/_internal/numbers.py
Expand Up @@ -18,6 +18,7 @@
from hypothesis.control import assume, reject
from hypothesis.internal.conjecture import floats as flt, utils as d
from hypothesis.internal.conjecture.utils import calc_label_from_name
from hypothesis.internal.filtering import get_integer_predicate_bounds
from hypothesis.internal.floats import float_of
from hypothesis.strategies._internal.strategies import SearchStrategy

Expand Down Expand Up @@ -51,11 +52,25 @@ def __init__(self, start, end):
self.end = end

def __repr__(self):
return f"BoundedIntStrategy({self.start}, {self.end})"
return f"integers({self.start}, {self.end})"

def do_draw(self, data):
return d.integer_range(data, self.start, self.end)

def filter(self, condition):
kwargs, pred = get_integer_predicate_bounds(condition)
start = max(self.start, kwargs.get("min_value", self.start))
end = min(self.end, kwargs.get("max_value", self.end))
if start > self.start or end < self.end:
if start > end:
from hypothesis.strategies._internal.core import nothing

return nothing()
self = type(self)(start, end)
if pred is None:
return self
return super().filter(pred)


NASTY_FLOATS = sorted(
[
Expand Down
Expand Up @@ -717,6 +717,7 @@ def __init__(self, strategy, conditions):
assert not isinstance(self.filtered_strategy, FilteredStrategy)

self.__condition = None
self.__validated = False

def calc_is_empty(self, recur):
return recur(self.filtered_strategy)
Expand All @@ -737,6 +738,26 @@ def __repr__(self):

def do_validate(self):
self.filtered_strategy.validate()
if not self.__validated:
fresh = self.filtered_strategy
for cond in self.flat_conditions:
fresh = fresh.filter(cond)
if isinstance(fresh, FilteredStrategy):
FilteredStrategy.__init__(
self, fresh.filtered_strategy, fresh.flat_conditions
)
else:
FilteredStrategy.__init__(self, fresh, (lambda _: True,))
self.__validated = True

def filter(self, condition):
# Allow strategy rewriting to 'see through' an unhandled predicate.
out = self.filtered_strategy.filter(condition)
if isinstance(out, FilteredStrategy):
return FilteredStrategy(
out.filtered_strategy, self.flat_conditions + out.flat_conditions
)
return FilteredStrategy(out, self.flat_conditions)

@property
def condition(self):
Expand Down
2 changes: 1 addition & 1 deletion hypothesis-python/tests/cover/test_direct_strategies.py
Expand Up @@ -484,7 +484,7 @@ def test_chained_filter(x):

def test_chained_filter_tracks_all_conditions():
s = ds.integers().filter(bool).filter(lambda x: x % 3)
assert len(s.flat_conditions) == 2
assert len(s.wrapped_strategy.flat_conditions) == 2


@pytest.mark.parametrize("version", [4, 6])
Expand Down
111 changes: 111 additions & 0 deletions hypothesis-python/tests/cover/test_filter_rewriting.py
@@ -0,0 +1,111 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Most of this work is copyright (C) 2013-2021 David R. MacIver
# (david@drmaciver.com), but it contains contributions by others. See
# CONTRIBUTING.rst for a full list of people who may hold copyright, and
# consult the git log if you need to determine who owns an individual
# contribution.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
#
# END HEADER

import operator
from functools import partial
import decimal

import pytest

from hypothesis import given, strategies as st
from hypothesis.errors import Unsatisfiable
from hypothesis.strategies._internal.lazy import LazyStrategy
from hypothesis.strategies._internal.numbers import BoundedIntStrategy
from hypothesis.strategies._internal.strategies import FilteredStrategy

from tests.common.utils import fails_with


@pytest.mark.parametrize(
"strategy, predicate, start, end",
[
# Integers with integer bounds
(st.integers(1, 5), partial(operator.lt, 3), 4, 5), # lambda x: 3 < x
(st.integers(1, 5), partial(operator.le, 3), 3, 5), # lambda x: 3 <= x
(st.integers(1, 5), partial(operator.eq, 3), 3, 3), # lambda x: 3 == x
(st.integers(1, 5), partial(operator.ge, 3), 1, 3), # lambda x: 3 >= x
(st.integers(1, 5), partial(operator.gt, 3), 1, 2), # lambda x: 3 > x
# Integers with non-integer bounds
(st.integers(1, 5), partial(operator.lt, 3.5), 4, 5),
(st.integers(1, 5), partial(operator.le, 3.5), 4, 5),
(st.integers(1, 5), partial(operator.ge, 3.5), 1, 3),
(st.integers(1, 5), partial(operator.gt, 3.5), 1, 3),
],
)
@given(data=st.data())
def test_filter_rewriting(data, strategy, predicate, start, end):
s = strategy.filter(predicate)
assert isinstance(s, LazyStrategy)
assert isinstance(s.wrapped_strategy, BoundedIntStrategy)
assert s.wrapped_strategy.start == start
assert s.wrapped_strategy.end == end
value = data.draw(s)
assert predicate(value)


@pytest.mark.parametrize(
"s",
[
st.integers(1, 5).filter(partial(operator.lt, 6)),
st.integers(1, 5).filter(partial(operator.eq, 3.5)),
st.integers(1, 5).filter(partial(operator.eq, "can't compare to strings")),
st.integers(1, 5).filter(partial(operator.ge, 0)),
],
)
@fails_with(Unsatisfiable)
@given(data=st.data())
def test_rewrite_unsatisfiable_filter(data, s):
data.draw(s)


def test_rewriting_does_not_compare_decimal_snan():
s = st.integers(1, 5).filter(partial(operator.eq, decimal.Decimal("snan")))
s.wrapped_strategy
with pytest.raises(decimal.InvalidOperation):
s.example()


def mod2(x):
return x % 2


@given(
data=st.data(),
predicates=st.permutations(
[
partial(operator.lt, 1),
partial(operator.le, 2),
partial(operator.ge, 4),
partial(operator.gt, 5),
mod2,
]
),
)
def test_rewrite_filter_chains_with_some_unhandled(data, predicates):
# Set up our strategy
s = st.integers(1, 5)
for p in predicates:
s = s.filter(p)

# Whatever value we draw is in fact valid for these strategies
value = data.draw(s)
for p in predicates:
assert p(value), f"p={p!r}, value={value}"

# No matter the order of the filters, we get the same resulting structure
unwrapped = s.wrapped_strategy
assert isinstance(unwrapped, FilteredStrategy)
assert isinstance(unwrapped.filtered_strategy, BoundedIntStrategy)
assert unwrapped.flat_conditions == (mod2,)

0 comments on commit 1767cbb

Please sign in to comment.