Skip to content

Commit

Permalink
Merge pull request #3852 from jobh/generate-empty-flags
Browse files Browse the repository at this point in the history
Generate empty flag enums
  • Loading branch information
Zac-HD committed Jan 25, 2024
2 parents 349c726 + 28afe87 commit fb58bcc
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 35 deletions.
6 changes: 6 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
RELEASE_TYPE: minor

Changes the distribution of :func:`~hypothesis.strategies.sampled_from` when
sampling from a :class:`~python:enum.Flag`. Previously, no-flags-set values would
never be generated, and all-flags-set values would be unlikely for large enums.
With this change, the distribution is more uniform in the number of flags set.
21 changes: 20 additions & 1 deletion hypothesis-python/src/hypothesis/internal/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import sys
import typing
from functools import partial
from typing import Any, ForwardRef, get_args
from typing import Any, ForwardRef, List, Optional, get_args

try:
BaseExceptionGroup = BaseExceptionGroup
Expand Down Expand Up @@ -180,6 +180,25 @@ def ceil(x):
return y


def extract_bits(x: int, /, width: Optional[int] = None) -> List[int]:
assert x >= 0
result = []
while x:
result.append(x & 1)
x >>= 1
if width is not None:
result = (result + [0] * width)[:width]
result.reverse()
return result


# int.bit_count was added sometime around python 3.9
try:
bit_count = int.bit_count
except AttributeError: # pragma: no cover
bit_count = lambda self: sum(extract_bits(abs(self)))


def bad_django_TestCase(runner):
if runner is None or "django.test" not in sys.modules:
return False
Expand Down
58 changes: 43 additions & 15 deletions hypothesis-python/src/hypothesis/strategies/_internal/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from hypothesis.internal.compat import (
Concatenate,
ParamSpec,
bit_count,
ceil,
floor,
get_type_hints,
Expand Down Expand Up @@ -202,6 +203,48 @@ def sampled_from(
that behaviour, use ``sampled_from(seq) if seq else nothing()``.
"""
values = check_sample(elements, "sampled_from")
try:
if isinstance(elements, type) and issubclass(elements, enum.Enum):
repr_ = f"sampled_from({elements.__module__}.{elements.__name__})"
else:
repr_ = f"sampled_from({elements!r})"
except Exception: # pragma: no cover
repr_ = None
if isclass(elements) and issubclass(elements, enum.Flag):
# Combinations of enum.Flag members (including empty) are also members. We generate these
# dynamically, because static allocation takes O(2^n) memory. LazyStrategy is used for the
# ease of force_repr.
# Add all named values, both flag bits (== list(elements)) and aliases. The aliases are
# necessary for full coverage for flags that would fail enum.NAMED_FLAGS check, and they
# are also nice values to shrink to.
flags = sorted(
set(elements.__members__.values()),
key=lambda v: (bit_count(v.value), v.value),
)
# Finally, try to construct the empty state if it is not named. It's placed at the
# end so that we shrink to named values.
flags_with_empty = flags
if not flags or flags[0].value != 0:
try:
flags_with_empty = [*flags, elements(0)]
except TypeError: # pragma: no cover
# Happens on some python versions (at least 3.12) when there are no named values
pass
inner = [
# Consider one or no named flags set, with shrink-to-named-flag behaviour.
# Special cases (length zero or one) are handled by the inner sampled_from.
sampled_from(flags_with_empty),
]
if len(flags) > 1:
inner += [
# Uniform distribution over number of named flags or combinations set. The overlap
# at r=1 is intentional, it may lead to oversampling but gives consistent shrinking
# behaviour.
integers(min_value=1, max_value=len(flags))
.flatmap(lambda r: sets(sampled_from(flags), min_size=r, max_size=r))
.map(lambda s: elements(reduce(operator.or_, s))),
]
return LazyStrategy(one_of, args=inner, kwargs={}, force_repr=repr_)
if not values:
if (
isinstance(elements, type)
Expand All @@ -217,21 +260,6 @@ def sampled_from(
raise InvalidArgument("Cannot sample from a length-zero sequence.")
if len(values) == 1:
return just(values[0])
try:
if isinstance(elements, type) and issubclass(elements, enum.Enum):
repr_ = f"sampled_from({elements.__module__}.{elements.__name__})"
else:
repr_ = f"sampled_from({elements!r})"
except Exception: # pragma: no cover
repr_ = None
if isclass(elements) and issubclass(elements, enum.Flag):
# Combinations of enum.Flag members are also members. We generate
# these dynamically, because static allocation takes O(2^n) memory.
# LazyStrategy is used for the ease of force_repr.
inner = sets(sampled_from(list(values)), min_size=1).map(
lambda s: reduce(operator.or_, s)
)
return LazyStrategy(lambda: inner, args=[], kwargs={}, force_repr=repr_)
return SampledFromStrategy(values, repr_)


Expand Down
12 changes: 2 additions & 10 deletions hypothesis-python/tests/conjecture/test_float_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pytest

from hypothesis import HealthCheck, assume, example, given, settings, strategies as st
from hypothesis.internal.compat import ceil, floor, int_to_bytes
from hypothesis.internal.compat import ceil, extract_bits, floor, int_to_bytes
from hypothesis.internal.conjecture import floats as flt
from hypothesis.internal.conjecture.data import ConjectureData
from hypothesis.internal.conjecture.engine import ConjectureRunner
Expand Down Expand Up @@ -115,16 +115,8 @@ def test_fractional_floats_are_worse_than_one(f):


def test_reverse_bits_table_reverses_bits():
def bits(x):
result = []
for _ in range(8):
result.append(x & 1)
x >>= 1
result.reverse()
return result

for i, b in enumerate(flt.REVERSE_BITS_TABLE):
assert bits(i) == list(reversed(bits(b)))
assert extract_bits(i, width=8) == list(reversed(extract_bits(b, width=8)))


def test_reverse_bits_table_has_right_elements():
Expand Down
17 changes: 16 additions & 1 deletion hypothesis-python/tests/cover/test_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@

import pytest

from hypothesis.internal.compat import ceil, dataclass_asdict, floor, get_type_hints
from hypothesis.internal.compat import (
ceil,
dataclass_asdict,
extract_bits,
floor,
get_type_hints,
)

floor_ceil_values = [
-10.7,
Expand Down Expand Up @@ -128,3 +134,12 @@ def test_dataclass_asdict():
"d": {4: 5},
"e": {},
}


@pytest.mark.parametrize("width", [None, 8])
@pytest.mark.parametrize("x", [0, 2, 123])
def test_extract_bits_roundtrip(width, x):
bits = extract_bits(x, width=width)
if width is not None:
assert len(bits) == width
assert x == sum(v << p for p, v in enumerate(reversed(bits)))
3 changes: 2 additions & 1 deletion hypothesis-python/tests/cover/test_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import pytest

from hypothesis import assume, given, strategies as st
from hypothesis import HealthCheck, assume, given, settings, strategies as st
from hypothesis.errors import (
HypothesisDeprecationWarning,
HypothesisWarning,
Expand Down Expand Up @@ -47,6 +47,7 @@ def draw_ordered_with_assume(draw):


@given(draw_ordered_with_assume())
@settings(suppress_health_check=[HealthCheck.filter_too_much])
def test_can_assume_in_draw(xy):
assert xy[0] < xy[1]

Expand Down
4 changes: 2 additions & 2 deletions hypothesis-python/tests/cover/test_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,9 +522,9 @@ def test_resolves_flag_enum(resolver):
# Storing all combinations takes O(2^n) memory. Using an enum of 52
# members in this test ensures that we won't try!
F = enum.Flag("F", " ".join(string.ascii_letters))
# Filter to check that we can generate compound members of enum.Flags

@given(resolver(F).filter(lambda ex: ex not in tuple(F)))
# Checks for combination coverage are found in nocover/test_sampled_from
@given(resolver(F))
def inner(ex):
assert isinstance(ex, F)

Expand Down
10 changes: 7 additions & 3 deletions hypothesis-python/tests/cover/test_sampled_from.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@
filter_not_satisfied,
)

from tests.common.debug import assert_all_examples
from tests.common.utils import fails_with

an_enum = enum.Enum("A", "a b c")
a_flag = enum.Flag("A", "a b c")
# named zero state is required for empty flags from around py3.11/3.12
an_empty_flag = enum.Flag("EmptyFlag", {"a": 0})

an_ordereddict = collections.OrderedDict([("a", 1), ("b", 2), ("c", 3)])

Expand All @@ -48,9 +52,9 @@ def test_can_sample_ordereddict_without_warning():
sampled_from(an_ordereddict).example()


@given(sampled_from(an_enum))
def test_can_sample_enums(member):
assert isinstance(member, an_enum)
@pytest.mark.parametrize("enum_class", [an_enum, a_flag, an_empty_flag])
def test_can_sample_enums(enum_class):
assert_all_examples(sampled_from(enum_class), lambda x: isinstance(x, enum_class))


@fails_with(FailedHealthCheck)
Expand Down
17 changes: 15 additions & 2 deletions hypothesis-python/tests/nocover/test_regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
import pytest

from hypothesis import assume, given, reject, strategies as st
from hypothesis.strategies._internal.regex import base_regex_strategy
from hypothesis.strategies._internal.regex import (
IncompatibleWithAlphabet,
base_regex_strategy,
)


@st.composite
Expand Down Expand Up @@ -85,7 +88,17 @@ def test_fuzz_stuff(data):
# Possible nested sets, e.g. "[[", trigger a FutureWarning
reject()

ex = data.draw(st.from_regex(regex))
try:
ex = data.draw(st.from_regex(regex))
except IncompatibleWithAlphabet:
if isinstance(pattern, str) and flags & re.ASCII:
with pytest.raises(UnicodeEncodeError):
pattern.encode("ascii")
regex = re.compile(pattern, flags=flags ^ re.ASCII)
ex = data.draw(st.from_regex(regex))
else:
raise

assert regex.search(ex)


Expand Down
56 changes: 56 additions & 0 deletions hypothesis-python/tests/nocover/test_sampled_from.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@
# obtain one at https://mozilla.org/MPL/2.0/.

import enum
import functools
import itertools
import operator

import pytest

from hypothesis import given, strategies as st
from hypothesis.errors import InvalidArgument
from hypothesis.internal.compat import bit_count
from hypothesis.strategies._internal.strategies import SampledFromStrategy

from tests.common.debug import find_any, minimal
from tests.common.utils import fails_with


Expand Down Expand Up @@ -79,8 +84,59 @@ def test_enum_repr_uses_class_not_a_list():
class AFlag(enum.Flag):
a = enum.auto()
b = enum.auto()
c = enum.auto()


LargeFlag = enum.Flag("LargeFlag", {f"bit{i}": enum.auto() for i in range(64)})


class UnnamedFlag(enum.Flag):
# Would fail under EnumCheck.NAMED_FLAGS
a = 0
b = 7


def test_flag_enum_repr_uses_class_not_a_list():
lazy_repr = repr(st.sampled_from(AFlag))
assert lazy_repr == "sampled_from(tests.nocover.test_sampled_from.AFlag)"


def test_exhaustive_flags():
# Generate powerset of flag combinations. There are only 2^3 of them, so
# we can reasonably expect that they are all are found.
unseen_flags = {
functools.reduce(operator.or_, flaglist, AFlag(0))
for r in range(len(AFlag) + 1)
for flaglist in itertools.combinations(AFlag, r)
}

@given(st.sampled_from(AFlag))
def accept(flag):
unseen_flags.discard(flag)

accept()

assert not unseen_flags


def test_flags_minimize_to_first_named_flag():
assert minimal(st.sampled_from(LargeFlag)) == LargeFlag.bit0


def test_flags_minimizes_bit_count():
shrunk = minimal(st.sampled_from(LargeFlag), lambda f: bit_count(f.value) > 1)
# Ideal would be (bit0 | bit1), but:
# minimal(st.sets(st.sampled_from(range(10)), min_size=3)) == {0, 8, 9} # not {0, 1, 2}
assert shrunk == LargeFlag.bit0 | LargeFlag.bit63 # documents actual behaviour


def test_flags_finds_all_bits_set():
assert find_any(st.sampled_from(LargeFlag), lambda f: f == ~LargeFlag(0))


def test_sample_unnamed_alias():
assert find_any(st.sampled_from(UnnamedFlag), lambda f: f == UnnamedFlag.b)


def test_shrink_to_named_empty():
assert minimal(st.sampled_from(UnnamedFlag)) == UnnamedFlag(0)

0 comments on commit fb58bcc

Please sign in to comment.