Skip to content

Commit

Permalink
Shared cache between ArrayStrategy instances
Browse files Browse the repository at this point in the history
  • Loading branch information
honno committed Sep 26, 2021
1 parent da5a173 commit 06e2593
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 35 deletions.
23 changes: 13 additions & 10 deletions hypothesis-python/src/hypothesis/extra/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

import math
import sys
from collections import defaultdict
from numbers import Real
from types import SimpleNamespace
from typing import (
Any,
DefaultDict,
Iterable,
Iterator,
List,
Expand Down Expand Up @@ -256,6 +258,13 @@ def check_valid_minmax(prefix, val, info_obj):


class ArrayStrategy(st.SearchStrategy):
# Checking value assignment to arrays is slightly expensive due to us
# casting 0d arrays to builtin objects, so we cache these values in
# check_hist to skip redundant checks. Any new value will be checked
# *before* being added to the cache, meaning we do not store disallowed
# elements. See https://github.com/HypothesisWorks/hypothesis/pull/3105
check_hist: DefaultDict[DataType, set] = defaultdict(set)

def __init__(self, xp, elements_strategy, dtype, shape, fill, unique):
self.xp = xp
self.elements_strategy = elements_strategy
Expand All @@ -265,15 +274,9 @@ def __init__(self, xp, elements_strategy, dtype, shape, fill, unique):
self.unique = unique
self.array_size = math.prod(shape)
self.builtin = find_castable_builtin_for_dtype(xp, dtype)
# Checking value assignment to arrays is slightly expensive due to us
# casting 0d arrays to builtin objects, so we cache these values in
# check_hist to skip redundant checks. Any new value will be checked
# *before* being added to the cache, meaning we do not store disallowed
# elements. See https://github.com/HypothesisWorks/hypothesis/pull/3105
self.check_hist = set()

def check_set_value(self, val, val_0d, strategy):
if val in self.check_hist:
if val in self.check_hist[self.dtype]:
return
finite = self.builtin is bool or self.xp.isfinite(val_0d)
if finite and self.builtin(val_0d) != val:
Expand All @@ -285,16 +288,16 @@ def check_set_value(self, val, val_0d, strategy):
"Consider using a more precise elements strategy, "
"for example passing the width argument to floats()."
)
self.check_hist.add(val)
self.check_hist[self.dtype].add(val)

def do_draw(self, data):
if 0 in self.shape:
return self.xp.zeros(self.shape, dtype=self.dtype)

# We reset check_hist when it reaches an arbitrarily large size to
# prevent unbounded memory usage.
if len(self.check_hist) >= 100_000:
self.check_hist = set()
if len(self.check_hist[self.dtype]) >= 100_000:
self.check_hist[self.dtype] = set()

if self.fill.is_empty:
# We have no fill value (either because the user explicitly
Expand Down
95 changes: 70 additions & 25 deletions hypothesis-python/tests/array_api/test_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
#
# END HEADER

from copy import copy

import pytest

from hypothesis import assume, given, settings, strategies as st
from hypothesis import HealthCheck, assume, given, settings, strategies as st
from hypothesis.errors import InvalidArgument
from hypothesis.extra.array_api import DTYPE_NAMES, NUMERIC_NAMES, ArrayStrategy

Expand Down Expand Up @@ -485,22 +487,75 @@ def test_may_reuse_distinct_integers_if_asked():
)


def test_check_hist_not_shared_between_test_cases():
"""Strategy does not share its cache of checked values between test cases."""
def arrays_lite(dtype, shape, elements=None):
"""Bare minimum imitation of xps.arrays, used in fresh_arrays fixture."""
if isinstance(shape, int):
shape = (shape,)
elements = elements or xps.from_dtype(dtype)
if isinstance(elements, dict):
elements = xps.from_dtype(dtype, **elements)
return ArrayStrategy(xp, elements, dtype, shape, elements, False)


@pytest.fixture
def fresh_arrays():
"""Empties cache and then yields an imitated strategy.
We use this because:
* xps.arrays returns a wrapped strategy, which makes accessing our cache
tricky.
* We still want to write tests like we were using xps.arrays, as opposed to
just using ArrayStrategy.
* We want to clear the cache at the start of tests, but also keep the cache
around so our test suite can still reap the performance gains.
"""
check_hist = copy(ArrayStrategy.check_hist)
ArrayStrategy.check_hist.clear()
yield arrays_lite
ArrayStrategy.check_hist = check_hist


def test_check_hist_shared_between_instances(fresh_arrays):
"""Different instances of the strategy share the same cache of checked
values."""
first_strat = fresh_arrays(dtype=xp.uint8, shape=5)

@given(first_strat)
def first_test_case(_):
pass

first_test_case()
assert len(first_strat.check_hist[xp.uint8]) > 0
old_hist = copy(first_strat.check_hist[xp.uint8])

second_strat = fresh_arrays(dtype=xp.uint8, shape=5)

@given(second_strat)
def second_test_case(_):
pass

second_test_case()
assert len(second_strat.check_hist[xp.uint8]) > 0
assert old_hist.issubset(second_strat.check_hist[xp.uint8])


def test_check_hist_not_shared_between_different_dtypes(fresh_arrays):
"""Strategy does not share its cache of checked values between test cases
using different dtypes."""
# The element 300 is valid for uint16 arrays, so it will pass its check to
# subsequently be cached in check_hist.
@given(xps.arrays(dtype=xp.uint16, shape=5, elements=st.just(300)))
@given(fresh_arrays(dtype=xp.uint16, shape=5, elements=st.just(300)))
def valid_test_case(_):
pass

valid_test_case()

# This should raise InvalidArgument, as the element 300 is too large for a
# uint8. If the cache from running valid_test_case above was shared to
# this test case, either no error would raise, or an array library would
# raise their own when assigning 300 to an array - overflow behaviour is
# outside the Array API spec but something we want to always prevent.
@given(xps.arrays(dtype=xp.uint8, shape=5, elements=st.just(300)))
# uint8. If the cache from running valid_test_case above was used in this
# test case, either no error would raise, or an array library would raise
# their own when assigning 300 to an array - overflow behaviour is outside
# the Array API spec but something we want to always prevent.
@given(fresh_arrays(dtype=xp.uint8, shape=5, elements=st.just(300)))
@settings(max_examples=1)
def overflow_test_case(_):
pass
Expand All @@ -510,32 +565,22 @@ def overflow_test_case(_):


@given(st.data())
@settings(max_examples=1)
def test_check_hist_resets_when_too_large(data):
@settings(max_examples=1, suppress_health_check=(HealthCheck.function_scoped_fixture,))
def test_check_hist_resets_when_too_large(fresh_arrays, data):
"""Strategy resets its cache of checked values once it gets too large.
At the start of a draw, xps.arrays() should check the size of the cache.
If it contains 100_000 or more values, it should be completely reset.
"""
# Our elements/fill strategy generates values >=100_000 so that it won't
# collide with our mocked cached values later.
elements = xps.from_dtype(xp.uint64, min_value=100_000)
# We test with the private ArrayStrategy, as xps.arrays() returns a wrapped
# strategy which would make injection of our mocked cache tricky.
strat = ArrayStrategy(
xp=xp,
elements_strategy=elements,
dtype=xp.uint64,
shape=(5,),
fill=elements,
unique=False,
)
strat = fresh_arrays(dtype=xp.uint64, shape=5, elements={"min_value": 100_000})
# We inject the mocked cache containing all positive integers below 100_000.
strat.check_hist = set(range(99_999))
strat.check_hist[xp.uint64] = set(range(99_999))
# We then call the strategy's do_draw() method.
data.draw(strat)
# The cache should *not* reset here, as the check is done at the start of a draw.
assert len(strat.check_hist) >= 100_000
assert len(strat.check_hist[xp.uint64]) >= 100_000
# But another call of do_draw() should reset the cache.
data.draw(strat)
assert 1 <= len(strat.check_hist) <= 5
assert 1 <= len(strat.check_hist[xp.uint64]) <= 5

0 comments on commit 06e2593

Please sign in to comment.