Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean-up and slightly optimise extra.array_api.ArrayStrategy #3105

Closed
wants to merge 9 commits into from
5 changes: 5 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
RELEASE_TYPE: patch

This patch makes :func:`~xps.arrays()` from the
:ref:`Array API extra <array-api>` slightly faster by not repeating internal
checks done on generated elements.
96 changes: 58 additions & 38 deletions hypothesis-python/src/hypothesis/extra/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

import math
import sys
from collections import defaultdict
from numbers import Real
from types import SimpleNamespace
from typing import (
Any,
DefaultDict,
Iterable,
Iterator,
List,
Expand Down Expand Up @@ -133,7 +135,7 @@ def find_castable_builtin_for_dtype(
# None equals NumPy's xp.float64 object, so we specifically skip it here to
# ensure that InvalidArgument is still raised. xp.float64 is in fact an
# alias of np.dtype('float64'), and its equality with None is meant to be
# deprecated at some point - see https://github.com/numpy/numpy/issues/18434.
# deprecated at some point. See https://github.com/numpy/numpy/issues/18434
if dtype is not None and dtype in float_dtypes:
return float

Expand Down Expand Up @@ -256,6 +258,13 @@ def check_valid_minmax(prefix, val, info_obj):


class ArrayStrategy(st.SearchStrategy):
# Checking value assignment to arrays is slightly expensive due to us
# casting 0d arrays to builtin objects, so we cache these values in
# check_hist to skip redundant checks. Any new value will be checked
# *before* being added to the cache, meaning we do not store disallowed
# elements. See https://github.com/HypothesisWorks/hypothesis/pull/3105
check_hist: DefaultDict[DataType, set] = defaultdict(set)

def __init__(self, xp, elements_strategy, dtype, shape, fill, unique):
self.xp = xp
self.elements_strategy = elements_strategy
Expand All @@ -266,62 +275,65 @@ def __init__(self, xp, elements_strategy, dtype, shape, fill, unique):
self.array_size = math.prod(shape)
self.builtin = find_castable_builtin_for_dtype(xp, dtype)

def set_value(self, result, i, val, strategy=None):
strategy = strategy or self.elements_strategy
try:
result[i] = val
except TypeError as e:
raise InvalidArgument(
f"Could not add generated array element {val!r} "
f"of dtype {type(val)} to array of dtype {result.dtype}."
) from e
self.check_set_value(val, result[i], strategy)

def check_set_value(self, val, val_0d, strategy):
if self.builtin is bool:
finite = True
else:
finite = self.xp.isfinite(val_0d)
if val in ArrayStrategy.check_hist[self.dtype]:
return
finite = self.builtin is bool or self.xp.isfinite(val_0d)
if finite and self.builtin(val_0d) != val:
raise InvalidArgument(
f"Generated array element {val!r} from strategy {strategy} "
f"cannot be represented as dtype {self.dtype}. "
f"cannot be represented with dtype {self.dtype}. "
f"Array module {self.xp.__name__} instead "
f"represents the element as {val_0d!r}. "
f"represents the element as {val_0d}. "
"Consider using a more precise elements strategy, "
"for example passing the width argument to floats()."
)
ArrayStrategy.check_hist[self.dtype].add(val)

def do_draw(self, data):
if 0 in self.shape:
return self.xp.zeros(self.shape, dtype=self.dtype)

# We reset a dtype's cache when it reaches a certain size to prevent
# unbounded memory usage. The limit 75_000 is under a set's reallocation
# size of 78_642, but is other chosen as an arbitrarily large number.
if len(ArrayStrategy.check_hist[self.dtype]) >= 75_000:
ArrayStrategy.check_hist[self.dtype] = set()

if self.fill.is_empty:
# We have no fill value (either because the user explicitly
# disabled it or because the default behaviour was used and our
# elements strategy does not produce reusable values), so we must
# generate a fully dense array with a freshly drawn value for each
# entry.

# This could legitimately be a xp.empty, but the performance gains
# for that are likely marginal, so there's really not much point
# risking undefined behaviour shenanigans.
result = self.xp.zeros(self.array_size, dtype=self.dtype)

if self.unique:
seen = set()
elems = st.lists(
elems = data.draw(
st.lists(
self.elements_strategy,
min_size=self.array_size,
max_size=self.array_size,
unique=True,
unique=self.unique,
)
for i, v in enumerate(data.draw(elems)):
self.set_value(result, i, v)
else:
for i in range(self.array_size):
val = data.draw(self.elements_strategy)
self.set_value(result, i, val)
)
try:
result = self.xp.asarray(elems, dtype=self.dtype)
except Exception as e:
if len(elems) <= 6:
f_elems = str(elems)
else:
f_elems = f"[{elems[0]}, {elems[1]}, ..., {elems[-2]}, {elems[-1]}]"
types = tuple(
sorted({type(e) for e in elems}, key=lambda t: t.__name__)
)
f_types = f"type {types[0]}" if len(types) == 1 else f"types {types}"
raise InvalidArgument(
f"Generated elements {f_elems} from strategy "
f"{self.elements_strategy} could not be converted "
f"to array of dtype {self.dtype}. "
f"Consider if elements of {f_types} "
f"are compatible with {self.dtype}."
) from e
for i in range(self.array_size):
self.check_set_value(elems[i], result[i], self.elements_strategy)
else:
# We draw arrays as "sparse with an offset". We assume not every
# element will be assigned and so first draw a single value from our
Expand All @@ -338,7 +350,7 @@ def do_draw(self, data):
f"with fill value {fill_val!r}"
) from e
sample = result[0]
self.check_set_value(fill_val, sample, strategy=self.fill)
self.check_set_value(fill_val, sample, self.fill)
if self.unique and not self.xp.all(self.xp.isnan(result)):
raise InvalidArgument(
f"Array module {self.xp.__name__} did not recognise fill "
Expand Down Expand Up @@ -371,7 +383,14 @@ def do_draw(self, data):
continue
else:
seen.add(val)
self.set_value(result, i, val)
try:
result[i] = val
except Exception as e:
raise InvalidArgument(
f"Could not add generated array element {val!r} "
f"of type {type(val)} to array of dtype {result.dtype}."
) from e
self.check_set_value(val, result[i], self.elements_strategy)
assigned.add(i)

result = self.xp.reshape(result, self.shape)
Expand Down Expand Up @@ -459,8 +478,9 @@ def _arrays(
hundreds or more elements, having a fill value is essential if you want
your tests to run in reasonable time.
"""

check_xp_attributes(xp, ["zeros", "full", "all", "isnan", "isfinite", "reshape"])
check_xp_attributes(
xp, ["asarray", "zeros", "full", "all", "isnan", "isfinite", "reshape"]
)

if isinstance(dtype, st.SearchStrategy):
return dtype.flatmap(
Expand Down
95 changes: 93 additions & 2 deletions hypothesis-python/tests/array_api/test_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
#
# END HEADER

from contextlib import contextmanager
from copy import copy

import pytest

from hypothesis import assume, given, strategies as st
from hypothesis import assume, given, settings, strategies as st
from hypothesis.errors import InvalidArgument
from hypothesis.extra.array_api import DTYPE_NAMES, NUMERIC_NAMES
from hypothesis.extra.array_api import DTYPE_NAMES, NUMERIC_NAMES, ArrayStrategy

from tests.array_api.common import COMPLIANT_XP, xp, xps
from tests.common.debug import find_any, minimal
Expand Down Expand Up @@ -201,6 +204,7 @@ def count_unique(x):
# TODO: The Array API makes boolean indexing optional, so in the future this
# will need to be reworked if we want to test libraries other than NumPy.
# If not possible, errors should be caught and the test skipped.
# See https://github.com/data-apis/array-api/issues/249
filtered_x = x[~nan_index]
unique_x = xp.unique(filtered_x)
n_unique += unique_x.size
Expand Down Expand Up @@ -482,3 +486,90 @@ def test_may_reuse_distinct_integers_if_asked():
),
lambda x: count_unique(x) < x.size,
)


@contextmanager
def suspend_cache():
"""Empties cache, and reassigns it on teardown."""
tmp_check_hist = copy(ArrayStrategy.check_hist)
ArrayStrategy.check_hist.clear()
yield
ArrayStrategy.check_hist = tmp_check_hist


def test_check_hist_persists_between_instances():
"""Multiple instances of the strategy, with the same arguments, update the
same cache of checked values.

Hypothesis caches its strategies, so calling xps.arrays() with identifical
arguments will return the same underlying ArrayStrategy. Therefore, if
check_hist was bounded per-instance to ArrayStrategy, the cache of one
instance would not carry over to other instances (and this test would
expectedly fail).
"""
with suspend_cache():

@given(xps.arrays(dtype=xp.uint8, shape=5))
def first_test_case(_):
pass

first_test_case()
assert len(ArrayStrategy.check_hist[xp.uint8]) > 0
old_hist = copy(ArrayStrategy.check_hist[xp.uint8])

@given(xps.arrays(dtype=xp.uint8, shape=5))
def second_test_case(_):
pass

second_test_case()
assert len(ArrayStrategy.check_hist[xp.uint8]) > 0
assert old_hist.issubset(ArrayStrategy.check_hist[xp.uint8])


def test_check_hist_not_shared_between_different_dtypes():
"""Strategy does not share its cache of checked values between test cases
using different dtypes."""
with suspend_cache():
# The element 300 is valid for uint16 arrays, so it will pass its check
# to subsequently be cached in check_hist.
@given(xps.arrays(dtype=xp.uint16, shape=5, elements=st.just(300)))
def valid_test_case(_):
pass

valid_test_case()

# This should raise InvalidArgument, as the element 300 is too large for
# a uint8. If the cache from running valid_test_case above was used in
# this test case, either no error would raise, or an array library would
# raise their own when assigning 300 to an array - overflow behaviour is
# outside the Array API spec but something we want to always prevent.
@given(xps.arrays(dtype=xp.uint8, shape=5, elements=st.just(300)))
@settings(max_examples=1)
def overflow_test_case(_):
pass

with pytest.raises(InvalidArgument):
overflow_test_case()


@given(st.data())
@settings(max_examples=1)
def test_check_hist_resets_when_too_large(data):
"""Strategy resets its cache of checked values once it gets too large.

At the start of a draw, xps.arrays() should check the size of the cache.
If it contains 75_000 or more values, it should be completely reset.
"""
with suspend_cache():
# We inject the mocked cache containing all positive integers below 75_000.
ArrayStrategy.check_hist[xp.uint64] = set(range(74_999))
# Our elements/fill strategy generates values >=75_000 so that it won't
# collide with our mocked cached values later.
strat = xps.arrays(dtype=xp.uint64, shape=5, elements={"min_value": 75_000})
# We then call the strategy's do_draw() method.
data.draw(strat)
# The cache should *not* reset here, as the check is done at the start of draws.
assert len(ArrayStrategy.check_hist[xp.uint64]) >= 75_000
# But another call of do_draw() should reset the cache.
data.draw(strat)
assert 1 <= len(ArrayStrategy.check_hist[xp.uint64]) <= 5