Skip to content

Commit

Permalink
Merge pull request #3899 from tybug/shrinker-ir
Browse files Browse the repository at this point in the history
migrate `Float` shrinker to the ir
  • Loading branch information
tybug committed Mar 14, 2024
2 parents 5eeb51a + 7595339 commit 749c8dd
Show file tree
Hide file tree
Showing 9 changed files with 569 additions and 111 deletions.
3 changes: 3 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
RELEASE_TYPE: patch

This patch starts work on refactoring our shrinker internals. There is no user-visible change.
71 changes: 68 additions & 3 deletions hypothesis-python/src/hypothesis/internal/conjecture/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,6 +920,50 @@ class IRNode:
kwargs: IRKWargsType = attr.ib()
was_forced: bool = attr.ib()

def copy(self, *, with_value: IRType) -> "IRNode":
# we may want to allow this combination in the future, but for now it's
# a footgun.
assert not self.was_forced, "modifying a forced node doesn't make sense"
return IRNode(
ir_type=self.ir_type,
value=with_value,
kwargs=self.kwargs,
was_forced=self.was_forced,
)


def ir_value_permitted(value, ir_type, kwargs):
if ir_type == "integer":
if kwargs["min_value"] is not None and value < kwargs["min_value"]:
return False
if kwargs["max_value"] is not None and value > kwargs["max_value"]:
return False

return True
elif ir_type == "float":
if math.isnan(value):
return kwargs["allow_nan"]
return (
sign_aware_lte(kwargs["min_value"], value)
and sign_aware_lte(value, kwargs["max_value"])
) and not (0 < abs(value) < kwargs["smallest_nonzero_magnitude"])
elif ir_type == "string":
if len(value) < kwargs["min_size"]:
return False
if kwargs["max_size"] is not None and len(value) > kwargs["max_size"]:
return False
return all(ord(c) in kwargs["intervals"] for c in value)
elif ir_type == "bytes":
return len(value) == kwargs["size"]
elif ir_type == "boolean":
if kwargs["p"] <= 2 ** (-64):
return value is False
if kwargs["p"] >= (1 - 2 ** (-64)):
return value is True
return True

raise NotImplementedError(f"unhandled type {type(value)} of ir value {value}")


@dataclass_transform()
@attr.s(slots=True)
Expand Down Expand Up @@ -1991,8 +2035,8 @@ def draw_boolean(
p: float = 0.5,
*,
forced: Optional[bool] = None,
observe: bool = True,
fake_forced: bool = False,
observe: bool = True,
) -> bool:
# Internally, we treat probabilities lower than 1 / 2**64 as
# unconditionally false.
Expand Down Expand Up @@ -2049,9 +2093,30 @@ def _pooled_kwargs(self, ir_type, kwargs):

def _pop_ir_tree_node(self, ir_type: IRTypeName, kwargs: IRKWargsType) -> IRNode:
assert self.ir_tree_nodes is not None

if self.ir_tree_nodes == []:
self.mark_overrun()

node = self.ir_tree_nodes.pop(0)
assert node.ir_type == ir_type
assert kwargs == node.kwargs
# If we're trying to draw a different ir type at the same location, then
# this ir tree has become badly misaligned. We don't have many good/simple
# options here for realigning beyond giving up.
#
# This is more of an issue for ir nodes while shrinking than it was for
# buffers: misaligned buffers are still usually valid, just interpreted
# differently. This would be somewhat like drawing a random value for
# the new ir type here. For what it's worth, misaligned buffers are
# rather unlikely to be *useful* buffers, so giving up isn't a big downgrade.
# (in fact, it is possible that giving up early here results in more time
# for useful shrinks to run).
if node.ir_type != ir_type:
self.mark_invalid()

# if a node has different kwargs (and so is misaligned), but has a value
# that is allowed by the expected kwargs, then we can coerce this node
# into an aligned one by using its value. It's unclear how useful this is.
if not ir_value_permitted(node.value, node.ir_type, kwargs):
self.mark_invalid()

return node

Expand Down
61 changes: 54 additions & 7 deletions hypothesis-python/src/hypothesis/internal/conjecture/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
Status,
StringKWargs,
)
from hypothesis.internal.floats import count_between_floats, float_to_int, int_to_float
from hypothesis.internal.floats import (
count_between_floats,
float_to_int,
int_to_float,
sign_aware_lte,
)


class PreviouslyUnseenBehaviour(HypothesisException):
Expand Down Expand Up @@ -184,7 +189,35 @@ def compute_max_children(ir_type, kwargs):
return sum(len(intervals) ** k for k in range(min_size, max_size + 1))

elif ir_type == "float":
return count_between_floats(kwargs["min_value"], kwargs["max_value"])
min_value = kwargs["min_value"]
max_value = kwargs["max_value"]
smallest_nonzero_magnitude = kwargs["smallest_nonzero_magnitude"]

count = count_between_floats(min_value, max_value)

# we have two intervals:
# a. [min_value, max_value]
# b. [-smallest_nonzero_magnitude, smallest_nonzero_magnitude]
#
# which could be subsets (in either order), overlapping, or disjoint. We
# want the interval difference a - b.

# next_down because endpoints are ok with smallest_nonzero_magnitude
min_point = max(min_value, -flt.next_down(smallest_nonzero_magnitude))
max_point = min(max_value, flt.next_down(smallest_nonzero_magnitude))

if min_point > max_point:
# case: disjoint intervals.
return count

count -= count_between_floats(min_point, max_point)
if sign_aware_lte(min_value, -0.0) and sign_aware_lte(-0.0, max_value):
# account for -0.0
count += 1
if sign_aware_lte(min_value, 0.0) and sign_aware_lte(0.0, max_value):
# account for 0.0
count += 1
return count

raise NotImplementedError(f"unhandled ir_type {ir_type}")

Expand Down Expand Up @@ -247,16 +280,30 @@ def floats_between(a, b):

min_value = kwargs["min_value"]
max_value = kwargs["max_value"]
smallest_nonzero_magnitude = kwargs["smallest_nonzero_magnitude"]

# handle zeroes separately so smallest_nonzero_magnitude can think of
# itself as a complete interval (instead of a hole at ±0).
if sign_aware_lte(min_value, -0.0) and sign_aware_lte(-0.0, max_value):
yield -0.0
if sign_aware_lte(min_value, 0.0) and sign_aware_lte(0.0, max_value):
yield 0.0

if flt.is_negative(min_value):
if flt.is_negative(max_value):
# if both are negative, have to invert order
yield from floats_between(max_value, min_value)
# case: both negative.
max_point = min(max_value, -smallest_nonzero_magnitude)
# float_to_int increases as negative magnitude increases, so
# invert order.
yield from floats_between(max_point, min_value)
else:
yield from floats_between(-0.0, min_value)
yield from floats_between(0.0, max_value)
# case: straddles midpoint (which is between -0.0 and 0.0).
yield from floats_between(-smallest_nonzero_magnitude, min_value)
yield from floats_between(smallest_nonzero_magnitude, max_value)
else:
yield from floats_between(min_value, max_value)
# case: both positive.
min_point = max(min_value, smallest_nonzero_magnitude)
yield from floats_between(min_point, max_value)


@attr.s(slots=True)
Expand Down
67 changes: 37 additions & 30 deletions hypothesis-python/src/hypothesis/internal/conjecture/shrinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

import math
from collections import defaultdict
from typing import TYPE_CHECKING, Callable, Dict, Optional

Expand All @@ -19,14 +20,9 @@
prefix_selection_order,
random_selection_order,
)
from hypothesis.internal.conjecture.data import (
DRAW_FLOAT_LABEL,
ConjectureData,
ConjectureResult,
Status,
)
from hypothesis.internal.conjecture.data import ConjectureData, ConjectureResult, Status
from hypothesis.internal.conjecture.dfa import ConcreteDFA
from hypothesis.internal.conjecture.floats import float_to_lex, lex_to_float
from hypothesis.internal.conjecture.floats import is_simple
from hypothesis.internal.conjecture.junkdrawer import (
binary_search,
find_integer,
Expand Down Expand Up @@ -379,6 +375,12 @@ def calls(self):
test function."""
return self.engine.call_count

def consider_new_tree(self, tree):
data = ConjectureData.for_ir_tree(tree)
self.engine.test_function(data)

return self.consider_new_buffer(data.buffer)

def consider_new_buffer(self, buffer):
"""Returns True if after running this buffer the result would be
the current shrink_target."""
Expand Down Expand Up @@ -774,6 +776,10 @@ def buffer(self):
def blocks(self):
return self.shrink_target.blocks

@property
def nodes(self):
return self.shrink_target.examples.ir_tree_nodes

@property
def examples(self):
return self.shrink_target.examples
Expand Down Expand Up @@ -1207,31 +1213,32 @@ def minimize_floats(self, chooser):
anything particularly meaningful for non-float values.
"""

ex = chooser.choose(
self.examples,
lambda ex: (
ex.label == DRAW_FLOAT_LABEL
and len(ex.children) == 2
and ex.children[1].length == 8
),
node = chooser.choose(
self.nodes,
lambda node: node.ir_type == "float" and not node.was_forced
# avoid shrinking integer-valued floats. In our current ordering, these
# are already simpler than all other floats, so it's better to shrink
# them in other passes.
and not is_simple(node.value),
)

u = ex.children[1].start
v = ex.children[1].end
buf = self.shrink_target.buffer
b = buf[u:v]
f = lex_to_float(int_from_bytes(b))
b2 = int_to_bytes(float_to_lex(f), 8)
if b == b2 or self.consider_new_buffer(buf[:u] + b2 + buf[v:]):
Float.shrink(
f,
lambda x: self.consider_new_buffer(
self.shrink_target.buffer[:u]
+ int_to_bytes(float_to_lex(x), 8)
+ self.shrink_target.buffer[v:]
),
random=self.random,
)
i = self.nodes.index(node)
# the Float shrinker was only built to handle positive floats. We'll
# shrink the positive portion and reapply the sign after, which is
# equivalent to this shrinker's previous behavior. We'll want to refactor
# Float to handle negative floats natively in the future. (likely a pure
# code quality change, with no shrinking impact.)
sign = math.copysign(1.0, node.value)
Float.shrink(
abs(node.value),
lambda val: self.consider_new_tree(
self.nodes[:i]
+ [node.copy(with_value=sign * val)]
+ self.nodes[i + 1 :]
),
random=self.random,
node=node,
)

@defines_shrink_pass()
def redistribute_block_pairs(self, chooser):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def incorporate(self, value):
def consider(self, value):
"""Returns True if make_immutable(value) == self.current after calling
self.incorporate(value)."""
self.debug(f"considering {value}")
value = self.make_immutable(value)
if value == self.current:
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import math
import sys

from hypothesis.internal.conjecture.data import ir_value_permitted
from hypothesis.internal.conjecture.floats import float_to_lex
from hypothesis.internal.conjecture.shrinking.common import Shrinker
from hypothesis.internal.conjecture.shrinking.integer import Integer
Expand All @@ -19,9 +20,16 @@


class Float(Shrinker):
def setup(self):
def setup(self, node):
self.NAN = math.nan
self.debugging_enabled = True
self.node = node

def consider(self, value):
if not ir_value_permitted(value, "float", self.node.kwargs):
self.debug(f"rejecting {value} as disallowed for {self.node.kwargs}")
return False
return super().consider(value)

def make_immutable(self, f):
f = float(f)
Expand Down
30 changes: 28 additions & 2 deletions hypothesis-python/tests/conjecture/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from hypothesis.internal.conjecture.engine import BUFFER_SIZE, ConjectureRunner
from hypothesis.internal.conjecture.utils import calc_label_from_name
from hypothesis.internal.entropy import deterministic_PRNG
from hypothesis.internal.floats import sign_aware_lte
from hypothesis.internal.floats import SMALLEST_SUBNORMAL, sign_aware_lte
from hypothesis.strategies._internal.strings import OneCharStringStrategy, TextStrategy

from tests.common.strategies import intervals
Expand Down Expand Up @@ -220,6 +220,8 @@ def draw_float_kwargs(
pivot = forced if (use_forced and not math.isnan(forced)) else None
min_value = -math.inf
max_value = math.inf
smallest_nonzero_magnitude = SMALLEST_SUBNORMAL
allow_nan = True if (use_forced and math.isnan(forced)) else draw(st.booleans())

if use_min_value:
min_value = draw(st.floats(max_value=pivot, allow_nan=False))
Expand All @@ -231,7 +233,31 @@ def draw_float_kwargs(
min_val = pivot if sign_aware_lte(min_value, pivot) else min_value
max_value = draw(st.floats(min_value=min_val, allow_nan=False))

return {"min_value": min_value, "max_value": max_value, "forced": forced}
largest_magnitude = max(abs(min_value), abs(max_value))
# can't force something smaller than our smallest magnitude.
if pivot is not None and pivot != 0.0:
largest_magnitude = min(largest_magnitude, pivot)

# avoid drawing from an empty range
if largest_magnitude > 0:
smallest_nonzero_magnitude = draw(
st.floats(
min_value=0,
# smallest_nonzero_magnitude breaks internal clamper invariants if
# it is allowed to be larger than the magnitude of {min, max}_value.
max_value=largest_magnitude,
allow_nan=False,
exclude_min=True,
allow_infinity=False,
)
)
return {
"min_value": min_value,
"max_value": max_value,
"forced": forced,
"allow_nan": allow_nan,
"smallest_nonzero_magnitude": smallest_nonzero_magnitude,
}


@st.composite
Expand Down

0 comments on commit 749c8dd

Please sign in to comment.