Skip to content

Commit

Permalink
Faster unique arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
Zac-HD committed Aug 30, 2021
1 parent cbdcc7c commit 2aa87ee
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 25 deletions.
5 changes: 5 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
RELEASE_TYPE: patch

This patch makes unique :func:`~hypothesis.extra.numpy.arrays` much more
efficient, especially when there are only a few valid elements - such as
for eight-bit integers (:issue:`3066`).
32 changes: 10 additions & 22 deletions hypothesis-python/src/hypothesis/extra/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,7 @@ def __init__(self, element_strategy, shape, dtype, fill, unique):
self.unique = unique
self._check_elements = dtype.kind not in ("O", "V")

def set_element(self, data, result, idx, strategy=None):
strategy = strategy or self.element_strategy
val = data.draw(strategy)
def set_element(self, val, result, idx, *, fill=False):
try:
result[idx] = val
except TypeError as err:
Expand All @@ -198,6 +196,7 @@ def set_element(self, data, result, idx, strategy=None):
f"{result.dtype!r} - possible mismatch of time units in dtypes?"
) from err
if self._check_elements and val != result[idx] and val == val:
strategy = self.fill if fill else self.element_strategy
raise InvalidArgument(
"Generated array element %r from %r cannot be represented as "
"dtype %r - instead it becomes %r (type %r). Consider using a more "
Expand Down Expand Up @@ -230,28 +229,17 @@ def do_draw(self, data):
# generate a fully dense array with a freshly drawn value for each
# entry.
if self.unique:
seen = set()
elements = cu.many(
data,
elems = st.lists(
self.element_strategy,
min_size=self.array_size,
max_size=self.array_size,
average_size=self.array_size,
unique=True,
)
i = 0
while elements.more():
# We assign first because this means we check for
# uniqueness after numpy has converted it to the relevant
# type for us. Because we don't increment the counter on
# a duplicate we will overwrite it on the next draw.
self.set_element(data, result, i)
if result[i] not in seen:
seen.add(result[i])
i += 1
else:
elements.reject()
for i, v in enumerate(data.draw(elems)):
self.set_element(v, result, i)
else:
for i in range(len(result)):
self.set_element(data, result, i)
self.set_element(data.draw(self.element_strategy), result, i)
else:
# We draw numpy arrays as "sparse with an offset". We draw a
# collection of index assignments within the array and assign
Expand All @@ -278,7 +266,7 @@ def do_draw(self, data):
if not needs_fill[i]:
elements.reject()
continue
self.set_element(data, result, i)
self.set_element(data.draw(self.element_strategy), result, i)
if self.unique:
if result[i] in seen:
elements.reject()
Expand All @@ -301,7 +289,7 @@ def do_draw(self, data):
one_element = np.zeros(
shape=1, dtype=object if unsized_string_dtype else self.dtype
)
self.set_element(data, one_element, 0, self.fill)
self.set_element(data.draw(self.fill), one_element, 0, fill=True)
if unsized_string_dtype:
one_element = one_element.astype(self.dtype)
fill_value = one_element[0]
Expand Down
21 changes: 20 additions & 1 deletion hypothesis-python/src/hypothesis/strategies/_internal/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,12 @@
from hypothesis.strategies._internal.functions import FunctionStrategy
from hypothesis.strategies._internal.lazy import LazyStrategy
from hypothesis.strategies._internal.misc import just, none, nothing
from hypothesis.strategies._internal.numbers import Real, floats, integers
from hypothesis.strategies._internal.numbers import (
IntegersStrategy,
Real,
floats,
integers,
)
from hypothesis.strategies._internal.recursive import RecursiveStrategy
from hypothesis.strategies._internal.shared import SharedStrategy
from hypothesis.strategies._internal.strategies import (
Expand Down Expand Up @@ -283,6 +288,20 @@ def lists(
tuple_suffixes = TupleStrategy(elements.element_strategies[1:])
elements = elements.element_strategies[0]

# UniqueSampledListStrategy offers a substantial performance improvement for
# unique arrays with few possible elements, e.g. of eight-bit integer types.
if (
isinstance(elements, IntegersStrategy)
and None not in (elements.start, elements.end)
and (elements.end - elements.start) <= 255
):
elements = SampledFromStrategy(
sorted(range(elements.start, elements.end + 1), key=abs)
if elements.end < 0 or elements.start > 0
else list(range(0, elements.end + 1))
+ list(range(-1, elements.start - 1, -1))
)

if isinstance(elements, SampledFromStrategy):
element_count = len(elements.elements)
if min_size > element_count:
Expand Down
21 changes: 19 additions & 2 deletions hypothesis-python/tests/numpy/test_gen_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pytest

from hypothesis import HealthCheck, assume, given, note, settings, strategies as st
from hypothesis.errors import InvalidArgument, Unsatisfiable
from hypothesis.errors import InvalidArgument
from hypothesis.extra import numpy as nps

from tests.common.debug import find_any, minimal
Expand Down Expand Up @@ -251,7 +251,7 @@ def test_array_values_are_unique(arr):

def test_cannot_generate_unique_array_of_too_many_elements():
strat = nps.arrays(dtype=int, elements=st.integers(0, 5), shape=10, unique=True)
with pytest.raises(Unsatisfiable):
with pytest.raises(InvalidArgument):
strat.example()


Expand All @@ -274,6 +274,23 @@ def test_generates_all_values_for_unique_array(arr):
assert len(set(arr)) == len(arr)


@given(nps.arrays(dtype="int8", shape=255, unique=True))
def test_efficiently_generates_all_unique_array(arr):
# Avoids the birthday paradox with UniqueSampledListStrategy
assert len(set(arr)) == len(arr)


@given(st.data(), st.integers(-100, 100), st.integers(1, 100))
def test_array_element_rewriting(data, start, size):
arr = nps.arrays(
dtype=np.dtype("int64"),
shape=size,
elements=st.integers(start, start + size - 1),
unique=True,
)
assert set(data.draw(arr)) == set(range(start, start + size))


def test_may_fill_with_nan_when_unique_is_set():
find_any(
nps.arrays(
Expand Down

0 comments on commit 2aa87ee

Please sign in to comment.