diff --git a/hypothesis-python/RELEASE.rst b/hypothesis-python/RELEASE.rst new file mode 100644 index 0000000000..a35bb9dc92 --- /dev/null +++ b/hypothesis-python/RELEASE.rst @@ -0,0 +1,13 @@ +RELEASE_TYPE: minor + +In preparation for `future versions of the Array API standard +`__, +:func:`~hypothesis.extra.array_api.make_strategies_namespace` now accepts an +optional ``api_version`` argument, which determines the version conformed to by +the returned strategies namespace. If ``None``, the version of the passed array +module ``xp`` is inferred. + +This release also introduces :func:`xps.real_dtypes`. This is currently +equivalent to the existing :func:`xps.numeric_dtypes` strategy, but exists +because the latter is expected to include complex numbers in the next version of +the standard. diff --git a/hypothesis-python/docs/conf.py b/hypothesis-python/docs/conf.py index 22117c4336..2498328bc9 100644 --- a/hypothesis-python/docs/conf.py +++ b/hypothesis-python/docs/conf.py @@ -57,10 +57,16 @@ def setup(app): app.tags.add("has_release_file") # patch in mock array_api namespace so we can autodoc it - from hypothesis.extra.array_api import make_strategies_namespace, mock_xp + from hypothesis.extra.array_api import ( + RELEASED_VERSIONS, + make_strategies_namespace, + mock_xp, + ) mod = types.ModuleType("xps") - mod.__dict__.update(make_strategies_namespace(mock_xp).__dict__) + mod.__dict__.update( + make_strategies_namespace(mock_xp, api_version=RELEASED_VERSIONS[-1]).__dict__ + ) assert "xps" not in sys.modules sys.modules["xps"] = mod diff --git a/hypothesis-python/docs/numpy.rst b/hypothesis-python/docs/numpy.rst index 8f592e1107..c7e2be099b 100644 --- a/hypothesis-python/docs/numpy.rst +++ b/hypothesis-python/docs/numpy.rst @@ -75,6 +75,9 @@ The resulting namespace contains all our familiar strategies like :func:`~xps.arrays` and :func:`~xps.from_dtype`, but based on the Array API standard semantics and returning objects from the ``xp`` module: +.. + TODO: for next released xp version, include complex_dtypes here + .. automodule:: xps :members: from_dtype, @@ -83,6 +86,7 @@ standard semantics and returning objects from the ``xp`` module: scalar_dtypes, boolean_dtypes, numeric_dtypes, + real_dtypes, integer_dtypes, unsigned_integer_dtypes, floating_dtypes, diff --git a/hypothesis-python/src/hypothesis/extra/array_api.py b/hypothesis-python/src/hypothesis/extra/array_api.py index 0b849d89b2..6cd05d47c4 100644 --- a/hypothesis-python/src/hypothesis/extra/array_api.py +++ b/hypothesis-python/src/hypothesis/extra/array_api.py @@ -17,6 +17,7 @@ Iterable, Iterator, List, + Literal, Mapping, NamedTuple, Optional, @@ -25,8 +26,10 @@ Type, TypeVar, Union, + get_args, ) from warnings import warn +from weakref import WeakValueDictionary from hypothesis import strategies as st from hypothesis.errors import HypothesisWarning, InvalidArgument @@ -62,11 +65,20 @@ ] +RELEASED_VERSIONS = ("2021.12",) +NOMINAL_VERSIONS = RELEASED_VERSIONS + ("draft",) +assert sorted(NOMINAL_VERSIONS) == list(NOMINAL_VERSIONS) # sanity check +NominalVersion = Literal["2021.12", "draft"] +assert get_args(NominalVersion) == NOMINAL_VERSIONS # sanity check + + INT_NAMES = ("int8", "int16", "int32", "int64") UINT_NAMES = ("uint8", "uint16", "uint32", "uint64") ALL_INT_NAMES = INT_NAMES + UINT_NAMES FLOAT_NAMES = ("float32", "float64") -NUMERIC_NAMES = ALL_INT_NAMES + FLOAT_NAMES +REAL_NAMES = ALL_INT_NAMES + FLOAT_NAMES +COMPLEX_NAMES = ("complex64", "complex128") +NUMERIC_NAMES = REAL_NAMES + COMPLEX_NAMES DTYPE_NAMES = ("bool",) + NUMERIC_NAMES DataType = TypeVar("DataType") @@ -106,8 +118,8 @@ def warn_on_missing_dtypes(xp: Any, stubs: List[str]) -> None: def find_castable_builtin_for_dtype( - xp: Any, dtype: DataType -) -> Type[Union[bool, int, float]]: + xp: Any, api_version: NominalVersion, dtype: DataType +) -> Type[Union[bool, int, float, complex]]: """Returns builtin type which can have values that are castable to the given dtype, according to :xp-ref:`type promotion rules `. @@ -136,6 +148,15 @@ def find_castable_builtin_for_dtype( stubs.extend(int_stubs) stubs.extend(float_stubs) + + if api_version > "2021.12": + complex_dtypes, complex_stubs = partition_attributes_and_stubs( + xp, COMPLEX_NAMES + ) + if dtype in complex_dtypes: + return complex + stubs.extend(complex_stubs) + if len(stubs) > 0: warn_on_missing_dtypes(xp, stubs) raise InvalidArgument(f"dtype={dtype} not recognised in {xp.__name__}") @@ -159,6 +180,7 @@ def dtype_from_name(xp: Any, name: str) -> DataType: def _from_dtype( xp: Any, + api_version: NominalVersion, dtype: Union[DataType, str], *, min_value: Optional[Union[int, float]] = None, @@ -184,11 +206,12 @@ def _from_dtype( passed through from :func:`arrays()`, as it seamlessly handles the ``width`` or other representable bounds for you. """ + # TODO: for next released xp version, add note for complex dtype support check_xp_attributes(xp, ["iinfo", "finfo"]) if isinstance(dtype, str): dtype = dtype_from_name(xp, dtype) - builtin = find_castable_builtin_for_dtype(xp, dtype) + builtin = find_castable_builtin_for_dtype(xp, api_version, dtype) def check_valid_minmax(prefix, val, info_obj): name = f"{prefix}_value" @@ -218,7 +241,7 @@ def check_valid_minmax(prefix, val, info_obj): check_valid_minmax("max", max_value, iinfo) check_valid_interval(min_value, max_value, "min_value", "max_value") return st.integers(min_value=min_value, max_value=max_value) - else: + elif builtin is float: finfo = xp.finfo(dtype) kw = {} @@ -269,10 +292,37 @@ def check_valid_minmax(prefix, val, info_obj): kw["exclude_max"] = exclude_max return st.floats(width=finfo.bits, **kw) + else: + # A less-inelegant solution to support complex dtypes exists, but as + # this is currently a draft feature, we might as well wait for + # discussion of complex inspection to resolve first - a better method + # might become available soon enough. + # See https://github.com/data-apis/array-api/issues/433 + for attr in ["float32", "float64", "complex64"]: + if not hasattr(xp, attr): + raise NotImplementedError( + f"Array module {xp.__name__} has no dtype {attr}, which is " + "currently required for xps.from_dtype() to work with " + "any complex dtype." + ) + component_dtype = xp.float32 if dtype == xp.complex64 else xp.float64 + + floats = _from_dtype( + xp, + api_version, + component_dtype, + allow_nan=allow_nan, + allow_infinity=allow_infinity, + allow_subnormal=allow_subnormal, + ) + + return st.builds(complex, floats, floats) # type: ignore[arg-type] class ArrayStrategy(st.SearchStrategy): - def __init__(self, xp, elements_strategy, dtype, shape, fill, unique): + def __init__( + self, *, xp, api_version, elements_strategy, dtype, shape, fill, unique + ): self.xp = xp self.elements_strategy = elements_strategy self.dtype = dtype @@ -280,12 +330,11 @@ def __init__(self, xp, elements_strategy, dtype, shape, fill, unique): self.fill = fill self.unique = unique self.array_size = math.prod(shape) - self.builtin = find_castable_builtin_for_dtype(xp, dtype) + self.builtin = find_castable_builtin_for_dtype(xp, api_version, dtype) self.finfo = None if self.builtin is not float else xp.finfo(self.dtype) def check_set_value(self, val, val_0d, strategy): - finite = self.builtin is bool or self.xp.isfinite(val_0d) - if finite and self.builtin(val_0d) != val: + if val == val and self.builtin(val_0d) != val: if self.builtin is float: assert self.finfo is not None # for mypy try: @@ -418,6 +467,7 @@ def do_draw(self, data): def _arrays( xp: Any, + api_version: NominalVersion, dtype: Union[DataType, str, st.SearchStrategy[DataType], st.SearchStrategy[str]], shape: Union[int, Shape, st.SearchStrategy[Shape]], *, @@ -500,14 +550,18 @@ def _arrays( if isinstance(dtype, st.SearchStrategy): return dtype.flatmap( - lambda d: _arrays(xp, d, shape, elements=elements, fill=fill, unique=unique) + lambda d: _arrays( + xp, api_version, d, shape, elements=elements, fill=fill, unique=unique + ) ) elif isinstance(dtype, str): dtype = dtype_from_name(xp, dtype) if isinstance(shape, st.SearchStrategy): return shape.flatmap( - lambda s: _arrays(xp, dtype, s, elements=elements, fill=fill, unique=unique) + lambda s: _arrays( + xp, api_version, dtype, s, elements=elements, fill=fill, unique=unique + ) ) elif isinstance(shape, int): shape = (shape,) @@ -519,9 +573,9 @@ def _arrays( ) if elements is None: - elements = _from_dtype(xp, dtype) + elements = _from_dtype(xp, api_version, dtype) elif isinstance(elements, Mapping): - elements = _from_dtype(xp, dtype, **elements) + elements = _from_dtype(xp, api_version, dtype, **elements) check_strategy(elements, "elements") if fill is None: @@ -532,7 +586,15 @@ def _arrays( fill = elements check_strategy(fill, "fill") - return ArrayStrategy(xp, elements, dtype, shape, fill, unique) + return ArrayStrategy( + xp=xp, + api_version=api_version, + elements_strategy=elements, + dtype=dtype, + shape=shape, + fill=fill, + unique=unique, + ) @check_function @@ -548,9 +610,9 @@ def check_dtypes(xp: Any, dtypes: List[DataType], stubs: List[str]) -> None: warn_on_missing_dtypes(xp, stubs) -def _scalar_dtypes(xp: Any) -> st.SearchStrategy[DataType]: +def _scalar_dtypes(xp: Any, api_version: NominalVersion) -> st.SearchStrategy[DataType]: """Return a strategy for all :xp-ref:`valid dtype ` objects.""" - return st.one_of(_boolean_dtypes(xp), _numeric_dtypes(xp)) + return st.one_of(_boolean_dtypes(xp), _numeric_dtypes(xp, api_version)) def _boolean_dtypes(xp: Any) -> st.SearchStrategy[DataType]: @@ -563,8 +625,8 @@ def _boolean_dtypes(xp: Any) -> st.SearchStrategy[DataType]: ) from None -def _numeric_dtypes(xp: Any) -> st.SearchStrategy[DataType]: - """Return a strategy for all numeric dtype objects.""" +def _real_dtypes(xp: Any) -> st.SearchStrategy[DataType]: + """Return a strategy for all real-valued dtype objects.""" return st.one_of( _integer_dtypes(xp), _unsigned_integer_dtypes(xp), @@ -572,6 +634,16 @@ def _numeric_dtypes(xp: Any) -> st.SearchStrategy[DataType]: ) +def _numeric_dtypes( + xp: Any, api_version: NominalVersion +) -> st.SearchStrategy[DataType]: + """Return a strategy for all numeric dtype objects.""" + strat: st.SearchStrategy[DataType] = _real_dtypes(xp) + if api_version > "2021.12": + strat |= _complex_dtypes(xp) + return strat + + @check_function def check_valid_sizes( category: str, sizes: Sequence[int], valid_sizes: Sequence[int] @@ -634,7 +706,7 @@ def _unsigned_integer_dtypes( def _floating_dtypes( xp: Any, *, sizes: Union[int, Sequence[int]] = (32, 64) ) -> st.SearchStrategy[DataType]: - """Return a strategy for floating-point dtype objects. + """Return a strategy for real-valued floating-point dtype objects. ``sizes`` contains the floating-point sizes in bits, defaulting to ``(32, 64)`` which covers all valid sizes. @@ -649,6 +721,24 @@ def _floating_dtypes( return st.sampled_from(dtypes) +def _complex_dtypes( + xp: Any, *, sizes: Union[int, Sequence[int]] = (64, 128) +) -> st.SearchStrategy[DataType]: + """Return a strategy for complex dtype objects. + + ``sizes`` contains the complex sizes in bits, defaulting to ``(64, 128)`` + which covers all valid sizes. + """ + if isinstance(sizes, int): + sizes = (sizes,) + check_valid_sizes("complex", sizes, (64, 128)) + dtypes, stubs = partition_attributes_and_stubs( + xp, numeric_dtype_names("complex", sizes) + ) + check_dtypes(xp, dtypes, stubs) + return st.sampled_from(dtypes) + + @proxies(_valid_tuple_axes) def valid_tuple_axes(*args, **kwargs): return _valid_tuple_axes(*args, **kwargs) @@ -760,10 +850,21 @@ def indices( ) -def make_strategies_namespace(xp: Any) -> SimpleNamespace: +# Cache for make_strategies_namespace() +_args_to_xps: WeakValueDictionary = WeakValueDictionary() + + +def make_strategies_namespace( + xp: Any, *, api_version: Optional[NominalVersion] = None +) -> SimpleNamespace: """Creates a strategies namespace for the given array module. * ``xp`` is the Array API library to automatically pass to the namespaced methods. + * ``api_version`` is the version of the Array API which the returned + strategies namespace should conform to. If ``None``, the latest API + version which ``xp`` supports will be inferred from ``xp.__array_api_version__``. + If a version string in the ``YYYY.MM`` format, the strategies namespace + will conform to that version if supported. A :obj:`python:types.SimpleNamespace` is returned which contains all the strategy methods in this module but without requiring the ``xp`` argument. @@ -772,8 +873,11 @@ def make_strategies_namespace(xp: Any) -> SimpleNamespace: .. code-block:: pycon - >>> from numpy import array_api as xp + >>> xp.__array_api_version__ # xp is your desired array library + '2021.12' >>> xps = make_strategies_namespace(xp) + >>> xps.api_version + '2021.12' >>> x = xps.arrays(xp.int8, (2, 3)).example() >>> x Array([[-8, 6, 3], @@ -782,6 +886,34 @@ def make_strategies_namespace(xp: Any) -> SimpleNamespace: True """ + not_available_msg = ( + "If the standard version you want is not available, please ensure " + "you're using the latest version of Hypothesis, then open an issue if " + "one doesn't already exist." + ) + if api_version is None: + check_argument( + hasattr(xp, "__array_api_version__"), + f"Array module {xp.__name__} has no attribute __array_api_version__, " + "which is required when inferring api_version. If you believe " + f"{xp.__name__} is indeed an Array API module, try explicitly " + "passing an api_version.", + ) + check_argument( + isinstance(xp.__array_api_version__, str) + and xp.__array_api_version__ in RELEASED_VERSIONS, + f"{xp.__array_api_version__=}, but xp.__array_api_version__ must " + f"be a valid version string {RELEASED_VERSIONS}. {not_available_msg}", + ) + api_version = xp.__array_api_version__ + inferred_version = True + else: + check_argument( + isinstance(api_version, str) and api_version in NOMINAL_VERSIONS, + f"{api_version=}, but api_version must be None, or a valid version " + f"string {RELEASED_VERSIONS}. {not_available_msg}", + ) + inferred_version = False try: array = xp.zeros(1) array.__array_namespace__() @@ -791,6 +923,13 @@ def make_strategies_namespace(xp: Any) -> SimpleNamespace: HypothesisWarning, ) + try: + namespace = _args_to_xps[(xp, api_version)] + except (KeyError, TypeError): + pass + else: + return namespace + @defines_strategy(force_reusable_values=True) def from_dtype( dtype: Union[DataType, str], @@ -805,6 +944,7 @@ def from_dtype( ) -> st.SearchStrategy[Union[bool, int, float]]: return _from_dtype( xp, + api_version, # type: ignore[arg-type] dtype, min_value=min_value, max_value=max_value, @@ -828,6 +968,7 @@ def arrays( ) -> st.SearchStrategy: return _arrays( xp, + api_version, # type: ignore[arg-type] dtype, shape, elements=elements, @@ -837,15 +978,19 @@ def arrays( @defines_strategy() def scalar_dtypes() -> st.SearchStrategy[DataType]: - return _scalar_dtypes(xp) + return _scalar_dtypes(xp, api_version) # type: ignore[arg-type] @defines_strategy() def boolean_dtypes() -> st.SearchStrategy[DataType]: return _boolean_dtypes(xp) + @defines_strategy() + def real_dtypes() -> st.SearchStrategy[DataType]: + return _real_dtypes(xp) + @defines_strategy() def numeric_dtypes() -> st.SearchStrategy[DataType]: - return _numeric_dtypes(xp) + return _numeric_dtypes(xp, api_version) # type: ignore[arg-type] @defines_strategy() def integer_dtypes( @@ -869,21 +1014,45 @@ def floating_dtypes( arrays.__doc__ = _arrays.__doc__ scalar_dtypes.__doc__ = _scalar_dtypes.__doc__ boolean_dtypes.__doc__ = _boolean_dtypes.__doc__ + real_dtypes.__doc__ = _real_dtypes.__doc__ numeric_dtypes.__doc__ = _numeric_dtypes.__doc__ integer_dtypes.__doc__ = _integer_dtypes.__doc__ unsigned_integer_dtypes.__doc__ = _unsigned_integer_dtypes.__doc__ floating_dtypes.__doc__ = _floating_dtypes.__doc__ - class PrettySimpleNamespace(SimpleNamespace): - def __repr__(self): - return f"make_strategies_namespace({xp.__name__})" + class StrategiesNamespace(SimpleNamespace): + def __init__(self, **kwargs): + for attr in ["name", "api_version"]: + if attr not in kwargs.keys(): + raise ValueError(f"'{attr}' kwarg required") + super().__init__(**kwargs) + + @property + def complex_dtypes(self): + try: + return self.__dict__["complex_dtypes"] + except KeyError as e: + raise AttributeError( + "You attempted to access 'complex_dtypes', but it is not " + f"available for api_version='{self.api_version}' of " + f"xp={self.name}." + ) from e - return PrettySimpleNamespace( + def __repr__(self): + f_args = self.name + if not inferred_version: + f_args += f", api_version='{self.api_version}'" + return f"make_strategies_namespace({f_args})" + + kwargs = dict( + name=xp.__name__, + api_version=api_version, from_dtype=from_dtype, arrays=arrays, array_shapes=array_shapes, scalar_dtypes=scalar_dtypes, boolean_dtypes=boolean_dtypes, + real_dtypes=real_dtypes, numeric_dtypes=numeric_dtypes, integer_dtypes=integer_dtypes, unsigned_integer_dtypes=unsigned_integer_dtypes, @@ -894,6 +1063,25 @@ def __repr__(self): indices=indices, ) + if api_version > "2021.12": + + @defines_strategy() + def complex_dtypes( + *, sizes: Union[int, Sequence[int]] = (64, 128) + ) -> st.SearchStrategy[DataType]: + return _complex_dtypes(xp, sizes=sizes) + + complex_dtypes.__doc__ = _complex_dtypes.__doc__ + kwargs["complex_dtypes"] = complex_dtypes + + namespace = StrategiesNamespace(**kwargs) + try: + _args_to_xps[(xp, api_version)] = namespace + except TypeError: + pass + + return namespace + try: import numpy as np @@ -935,7 +1123,8 @@ def mock_finfo(dtype: DataType) -> FloatInfo: ) mock_xp = SimpleNamespace( - __name__="mockpy", + __name__="mock", + __array_api_version__="2021.12", # Data types int8=np.int8, int16=np.int16, @@ -947,6 +1136,8 @@ def mock_finfo(dtype: DataType) -> FloatInfo: uint64=np.uint64, float32=np.float32, float64=np.float64, + complex64=np.complex64, + complex128=np.complex128, bool=np.bool_, # Constants nan=np.nan, diff --git a/hypothesis-python/tests/array_api/README.md b/hypothesis-python/tests/array_api/README.md index e3dcefa03f..bc2f4d1aa2 100644 --- a/hypothesis-python/tests/array_api/README.md +++ b/hypothesis-python/tests/array_api/README.md @@ -1,16 +1,19 @@ This folder contains tests for `hypothesis.extra.array_api`. -## Running against different array modules +## Mocked array module + +A mock of the Array API namespace exists as `mock_xp` in `extra.array_api`. This +wraps NumPy-proper to conform it to the *draft* spec, where `numpy.array_api` +might not. This is not a fully compliant wrapper, but conforms enough for the +purposes of testing. -By default it will run against `numpy.array_api`. If that's not available -(likely because an older NumPy version is installed), these tests will fallback -to using the mock defined at the bottom of `src/hypothesis/extra/array_api.py`. +## Running against different array modules You can test other array modules which adopt the Array API via the `HYPOTHESIS_TEST_ARRAY_API` environment variable. There are two recognized options: -* `"default"`: only uses `numpy.array_api`, or if not available, fallbacks to the mock. +* `"default"`: uses the mock. * `"all"`: uses all array modules found via entry points, _and_ the mock. If neither of these, the test suite will then try resolve the variable like so: @@ -30,3 +33,18 @@ or use the import path (**2.**), The former method is more ergonomic, but as entry points are optional for adopting the Array API, you will need to use the latter method for libraries that opt-out. + +## Running against different API versions + +You can specify the `api_version` to use when testing array modules via the +`HYPOTHESIS_TEST_ARRAY_API_VERSION` environment variable. There is one +recognized option: + +* `"default"`: infers the latest API version for each array module. + +Otherwise the test suite will use the variable as the `api_version` argument for +`make_strategies_namespace()`. + +In the future we intend to support running tests against multiple API versioned +namespaces, likely with an additional recognized option that infers all +supported versions. diff --git a/hypothesis-python/tests/array_api/common.py b/hypothesis-python/tests/array_api/common.py index 72c674b435..cf706d970a 100644 --- a/hypothesis-python/tests/array_api/common.py +++ b/hypothesis-python/tests/array_api/common.py @@ -11,14 +11,30 @@ from importlib.metadata import EntryPoint, entry_points # type: ignore from typing import Dict +import pytest + +from hypothesis.extra.array_api import ( + COMPLEX_NAMES, + REAL_NAMES, + RELEASED_VERSIONS, + NominalVersion, +) from hypothesis.internal.floats import next_up __all__ = [ + "MIN_VER_FOR_COMPLEX:", "installed_array_modules", "flushes_to_zero", + "dtype_name_params", ] +# This should be updated to the next spec release, which should include complex numbers +MIN_VER_FOR_COMPLEX: NominalVersion = "draft" +if len(RELEASED_VERSIONS) > 1: + assert MIN_VER_FOR_COMPLEX == RELEASED_VERSIONS[1] + + def installed_array_modules() -> Dict[str, EntryPoint]: """Returns a dictionary of array module names paired to their entry points @@ -48,3 +64,9 @@ def flushes_to_zero(xp, width: int) -> bool: raise ValueError(f"{width=}, but should be either 32 or 64") dtype = getattr(xp, f"float{width}") return bool(xp.asarray(next_up(0.0, width=width), dtype=dtype) == 0) + + +dtype_name_params = ["bool"] + list(REAL_NAMES) +for name in COMPLEX_NAMES: + param = pytest.param(name, marks=pytest.mark.xp_min_version(MIN_VER_FOR_COMPLEX)) + dtype_name_params.append(param) diff --git a/hypothesis-python/tests/array_api/conftest.py b/hypothesis-python/tests/array_api/conftest.py index afa59d40a8..7ee6a1f4eb 100644 --- a/hypothesis-python/tests/array_api/conftest.py +++ b/hypothesis-python/tests/array_api/conftest.py @@ -11,62 +11,109 @@ import warnings from importlib import import_module from os import getenv +from types import ModuleType, SimpleNamespace +from typing import Tuple import pytest -from hypothesis.errors import HypothesisWarning -from hypothesis.extra.array_api import make_strategies_namespace, mock_xp +from hypothesis.errors import HypothesisWarning, InvalidArgument +from hypothesis.extra.array_api import ( + NOMINAL_VERSIONS, + NominalVersion, + make_strategies_namespace, + mock_xp, +) from tests.array_api.common import installed_array_modules +# See README.md in regards to the env variables +test_xp_option = getenv("HYPOTHESIS_TEST_ARRAY_API", "default") + +test_version_option = getenv("HYPOTHESIS_TEST_ARRAY_API_VERSION", "default") +if test_version_option != "default" and test_version_option not in NOMINAL_VERSIONS: + raise ValueError( + f"HYPOTHESIS_TEST_ARRAY_API_VERSION='{test_version_option}' is not " + f"'default' or a valid api_version {NOMINAL_VERSIONS}." + ) with pytest.warns(HypothesisWarning): - mock_xps = make_strategies_namespace(mock_xp) + mock_version = "draft" if test_version_option == "default" else test_version_option + mock_xps = make_strategies_namespace(mock_xp, api_version=mock_version) +api_version = None if test_version_option == "default" else test_version_option + + +class InvalidArgumentWarning(UserWarning): + """Custom warning so we can bypass our global capturing""" + -# See README.md in regards to the HYPOTHESIS_TEST_ARRAY_API env variable -test_xp_option = getenv("HYPOTHESIS_TEST_ARRAY_API", "default") name_to_entry_point = installed_array_modules() +xp_and_xps_pairs: Tuple[ModuleType, SimpleNamespace] = [] with warnings.catch_warnings(): - # We ignore all warnings here as many array modules warn on import + # We ignore all warnings here as many array modules warn on import. Ideally + # we would just ignore ImportWarning, but no one seems to use it! warnings.simplefilter("ignore") - # We go through the steps described in README.md to define `params`, which - # contains the array module(s) to be ran against the test suite. - # Specifically `params` is a list of pytest parameters, with each parameter - # containing the array module and its respective strategies namespace. + warnings.simplefilter("default", category=InvalidArgumentWarning) + # We go through the steps described in README.md to define `xp_xps_pairs`, + # which contains the array module(s) to be run against the test suite, along + # with their respective strategy namespaces. if test_xp_option == "default": - try: - xp = name_to_entry_point["numpy"].load() - xps = make_strategies_namespace(xp) - params = [pytest.param(xp, xps, id="numpy")] - except KeyError: - params = [pytest.param(mock_xp, mock_xps, id="mock")] + xp_and_xps_pairs = [(mock_xp, mock_xps)] elif test_xp_option == "all": if len(name_to_entry_point) == 0: raise ValueError( "HYPOTHESIS_TEST_ARRAY_API='all', but no entry points where found" ) - params = [pytest.param(mock_xp, mock_xps, id="mock")] + xp_and_xps_pairs = [(mock_xp, mock_xps)] for name, ep in name_to_entry_point.items(): xp = ep.load() - xps = make_strategies_namespace(xp) - params.append(pytest.param(xp, xps, id=name)) + try: + xps = make_strategies_namespace(xp, api_version=api_version) + except InvalidArgument as e: + warnings.warn(str(e), InvalidArgumentWarning) + else: + xp_and_xps_pairs.append((xp, xps)) elif test_xp_option in name_to_entry_point.keys(): ep = name_to_entry_point[test_xp_option] xp = ep.load() - xps = make_strategies_namespace(xp) - params = [pytest.param(xp, xps, id=test_xp_option)] + xps = make_strategies_namespace(xp, api_version=api_version) + xp_and_xps_pairs = [(xp, xps)] else: try: xp = import_module(test_xp_option) - xps = make_strategies_namespace(xp) - params = [pytest.param(xp, xps, id=test_xp_option)] except ImportError as e: raise ValueError( f"HYPOTHESIS_TEST_ARRAY_API='{test_xp_option}' is not a valid " "option ('default' or 'all'), name of an available entry point, " "or a valid import path." ) from e + else: + xps = make_strategies_namespace(xp, api_version=api_version) + xp_and_xps_pairs = [(xp, xps)] def pytest_generate_tests(metafunc): - if "xp" in metafunc.fixturenames and "xps" in metafunc.fixturenames: - metafunc.parametrize("xp, xps", params) + xp_params = [] + xp_and_xps_params = [] + for xp, xps in xp_and_xps_pairs: + xp_params.append(pytest.param(xp, id=xp.__name__)) + xp_and_xps_params.append( + pytest.param(xp, xps, id=f"{xp.__name__}-{xps.api_version}") + ) + if "xp" in metafunc.fixturenames: + if "xps" in metafunc.fixturenames: + metafunc.parametrize("xp, xps", xp_and_xps_params) + else: + metafunc.parametrize("xp", xp_params) + + +def pytest_collection_modifyitems(config, items): + for item in items: + if "xps" in item.fixturenames: + markers = [m for m in item.own_markers if m.name == "xp_min_version"] + if markers: + assert len(markers) == 1 # sanity check + min_version: NominalVersion = markers[0].args[0] + xps_version: NominalVersion = item.callspec.params["xps"].api_version + if xps_version < min_version: + item.add_marker( + pytest.mark.skip(reason=f"requires api_version=>{min_version}") + ) diff --git a/hypothesis-python/tests/array_api/test_argument_validation.py b/hypothesis-python/tests/array_api/test_argument_validation.py index 0ea07d20ce..3a14ee9987 100644 --- a/hypothesis-python/tests/array_api/test_argument_validation.py +++ b/hypothesis-python/tests/array_api/test_argument_validation.py @@ -8,14 +8,24 @@ # v. 2.0. If a copy of the MPL was not distributed with this file, You can # obtain one at https://mozilla.org/MPL/2.0/. +from typing import Optional + import pytest from hypothesis.errors import InvalidArgument +from hypothesis.extra.array_api import NominalVersion, make_strategies_namespace + +from tests.array_api.common import MIN_VER_FOR_COMPLEX -def e(name, **kwargs): +def e(name, *, _min_version: Optional[NominalVersion] = None, **kwargs): kw = ", ".join(f"{k}={v!r}" for k, v in kwargs.items()) - return pytest.param(name, kwargs, id=f"{name}({kw})") + id_ = f"{name}({kw})" + if _min_version is None: + marks = () + else: + marks = pytest.mark.xp_min_version(_min_version) + return pytest.param(name, kwargs, id=id_, marks=marks) @pytest.mark.parametrize( @@ -61,6 +71,8 @@ def e(name, **kwargs): e("unsigned_integer_dtypes", sizes=(3,)), e("floating_dtypes", sizes=()), e("floating_dtypes", sizes=(3,)), + e("complex_dtypes", _min_version=MIN_VER_FOR_COMPLEX, sizes=()), + e("complex_dtypes", _min_version=MIN_VER_FOR_COMPLEX, sizes=(3,)), e("valid_tuple_axes", ndim=-1), e("valid_tuple_axes", ndim=2, min_size=-1), e("valid_tuple_axes", ndim=2, min_size=3, max_size=10), @@ -214,3 +226,10 @@ def test_raise_invalid_argument(xp, xps, strat_name, kwargs): strat = strat_func(**kwargs) with pytest.raises(InvalidArgument): strat.example() + + +@pytest.mark.parametrize("api_version", [..., "latest", "1970.01", 42]) +def test_make_strategies_namespace_raise_invalid_argument(xp, api_version): + """Function raises helpful error with invalid arguments.""" + with pytest.raises(InvalidArgument): + make_strategies_namespace(xp, api_version=api_version) diff --git a/hypothesis-python/tests/array_api/test_arrays.py b/hypothesis-python/tests/array_api/test_arrays.py index 6d9feff3d6..b36520688d 100644 --- a/hypothesis-python/tests/array_api/test_arrays.py +++ b/hypothesis-python/tests/array_api/test_arrays.py @@ -12,10 +12,14 @@ from hypothesis import given, strategies as st from hypothesis.errors import InvalidArgument -from hypothesis.extra.array_api import DTYPE_NAMES, NUMERIC_NAMES +from hypothesis.extra.array_api import COMPLEX_NAMES, REAL_NAMES from hypothesis.internal.floats import width_smallest_normals -from tests.array_api.common import flushes_to_zero +from tests.array_api.common import ( + MIN_VER_FOR_COMPLEX, + dtype_name_params, + flushes_to_zero, +) from tests.common.debug import assert_all_examples, find_any, minimal from tests.common.utils import flaky @@ -38,14 +42,14 @@ def xfail_on_indistinct_nans(xp): pytest.xfail("NaNs not distinct") -@pytest.mark.parametrize("dtype_name", DTYPE_NAMES) +@pytest.mark.parametrize("dtype_name", dtype_name_params) def test_draw_arrays_from_dtype(xp, xps, dtype_name): """Draw arrays from dtypes.""" dtype = getattr(xp, dtype_name) assert_all_examples(xps.arrays(dtype, ()), lambda x: x.dtype == dtype) -@pytest.mark.parametrize("dtype_name", DTYPE_NAMES) +@pytest.mark.parametrize("dtype_name", dtype_name_params) def test_draw_arrays_from_scalar_names(xp, xps, dtype_name): """Draw arrays from dtype names.""" dtype = getattr(xp, dtype_name) @@ -77,6 +81,10 @@ def test_draw_arrays_from_int_shapes(xp, xps, data): "integer_dtypes", "unsigned_integer_dtypes", "floating_dtypes", + "real_dtypes", + pytest.param( + "complex_dtypes", marks=pytest.mark.xp_min_version(MIN_VER_FOR_COMPLEX) + ), ], ) def test_draw_arrays_from_dtype_strategies(xp, xps, strat_name): @@ -86,14 +94,16 @@ def test_draw_arrays_from_dtype_strategies(xp, xps, strat_name): find_any(xps.arrays(strat, ())) -@given( - strat=st.lists(st.sampled_from(DTYPE_NAMES), min_size=1, unique=True).flatmap( - st.sampled_from - ) -) -def test_draw_arrays_from_dtype_name_strategies(xp, xps, strat): +@given(data=st.data()) +def test_draw_arrays_from_dtype_name_strategies(xp, xps, data): """Draw arrays from dtype name strategies.""" - find_any(xps.arrays(strat, ())) + all_names = ("bool",) + REAL_NAMES + if xps.api_version > "2021.12": + all_names += COMPLEX_NAMES + sample_names = data.draw( + st.lists(st.sampled_from(all_names), min_size=1, unique=True) + ) + find_any(xps.arrays(st.sampled_from(sample_names), ())) def test_generate_arrays_from_shapes_strategy(xp, xps): @@ -156,7 +166,7 @@ def test_minimize_arrays_with_0d_shape_strategy(xp, xps): assert smallest.shape == () -@pytest.mark.parametrize("dtype", NUMERIC_NAMES) +@pytest.mark.parametrize("dtype", dtype_name_params[1:]) def test_minimizes_numeric_arrays(xp, xps, dtype): """Strategies with numeric dtypes minimize to zero-filled arrays.""" smallest = minimal(xps.arrays(dtype, (2, 2))) @@ -296,6 +306,11 @@ def test_may_not_use_overflowing_integers(xp, xps, kwargs): [ ("float32", st.floats(min_value=10**40, allow_infinity=False)), ("float64", st.floats(min_value=10**40, allow_infinity=False)), + pytest.param( + "complex64", + st.complex_numbers(min_magnitude=10**300, allow_infinity=False), + marks=pytest.mark.xp_min_version(MIN_VER_FOR_COMPLEX), + ), ], ) def test_may_not_use_unrepresentable_elements(xp, xps, fill, dtype, strat): diff --git a/hypothesis-python/tests/array_api/test_from_dtype.py b/hypothesis-python/tests/array_api/test_from_dtype.py index 1ea61a5e2d..18f6d081f1 100644 --- a/hypothesis-python/tests/array_api/test_from_dtype.py +++ b/hypothesis-python/tests/array_api/test_from_dtype.py @@ -12,10 +12,10 @@ import pytest -from hypothesis.extra.array_api import DTYPE_NAMES, find_castable_builtin_for_dtype +from hypothesis.extra.array_api import find_castable_builtin_for_dtype from hypothesis.internal.floats import width_smallest_normals -from tests.array_api.common import flushes_to_zero +from tests.array_api.common import dtype_name_params, flushes_to_zero from tests.common.debug import ( assert_all_examples, assert_no_examples, @@ -24,32 +24,32 @@ ) -@pytest.mark.parametrize("dtype_name", DTYPE_NAMES) +@pytest.mark.parametrize("dtype_name", dtype_name_params) def test_strategies_have_reusable_values(xp, xps, dtype_name): """Inferred strategies have reusable values.""" strat = xps.from_dtype(dtype_name) assert strat.has_reusable_values -@pytest.mark.parametrize("dtype_name", DTYPE_NAMES) +@pytest.mark.parametrize("dtype_name", dtype_name_params) def test_produces_castable_instances_from_dtype(xp, xps, dtype_name): """Strategies inferred by dtype generate values of a builtin type castable to the dtype.""" dtype = getattr(xp, dtype_name) - builtin = find_castable_builtin_for_dtype(xp, dtype) + builtin = find_castable_builtin_for_dtype(xp, xps.api_version, dtype) assert_all_examples(xps.from_dtype(dtype), lambda v: isinstance(v, builtin)) -@pytest.mark.parametrize("dtype_name", DTYPE_NAMES) +@pytest.mark.parametrize("dtype_name", dtype_name_params) def test_produces_castable_instances_from_name(xp, xps, dtype_name): """Strategies inferred by dtype name generate values of a builtin type castable to the dtype.""" dtype = getattr(xp, dtype_name) - builtin = find_castable_builtin_for_dtype(xp, dtype) + builtin = find_castable_builtin_for_dtype(xp, xps.api_version, dtype) assert_all_examples(xps.from_dtype(dtype_name), lambda v: isinstance(v, builtin)) -@pytest.mark.parametrize("dtype_name", DTYPE_NAMES) +@pytest.mark.parametrize("dtype_name", dtype_name_params) def test_passing_inferred_strategies_in_arrays(xp, xps, dtype_name): """Inferred strategies usable in arrays strategy.""" elements = xps.from_dtype(dtype_name) diff --git a/hypothesis-python/tests/array_api/test_partial_adoptors.py b/hypothesis-python/tests/array_api/test_partial_adoptors.py index 6ae47a220d..2f0f410626 100644 --- a/hypothesis-python/tests/array_api/test_partial_adoptors.py +++ b/hypothesis-python/tests/array_api/test_partial_adoptors.py @@ -18,6 +18,7 @@ from hypothesis import given, strategies as st from hypothesis.errors import HypothesisWarning, InvalidArgument from hypothesis.extra.array_api import ( + COMPLEX_NAMES, DTYPE_NAMES, FLOAT_NAMES, INT_NAMES, @@ -30,8 +31,9 @@ @lru_cache() -def make_mock_xp(exclude: Tuple[str, ...] = ()) -> SimpleNamespace: +def make_mock_xp(*, exclude: Tuple[str, ...] = ()) -> SimpleNamespace: xp = copy(mock_xp) + assert isinstance(exclude, tuple) # sanity check for attr in exclude: delattr(xp, attr) return xp @@ -41,7 +43,7 @@ def test_warning_on_noncompliant_xp(): """Using non-compliant array modules raises helpful warning""" xp = make_mock_xp() with pytest.warns(HypothesisWarning, match=MOCK_WARN_MSG): - make_strategies_namespace(xp) + make_strategies_namespace(xp, api_version="draft") @pytest.mark.filterwarnings(f"ignore:.*{MOCK_WARN_MSG}.*") @@ -53,7 +55,7 @@ def test_error_on_missing_attr(stratname, args, attr): """Strategies raise helpful error when using array modules that lack required attributes.""" xp = make_mock_xp(exclude=(attr,)) - xps = make_strategies_namespace(xp) + xps = make_strategies_namespace(xp, api_version="draft") func = getattr(xps, stratname) with pytest.raises(InvalidArgument, match=f"{mock_xp.__name__}.*required.*{attr}"): func(*args).example() @@ -61,7 +63,7 @@ def test_error_on_missing_attr(stratname, args, attr): dtypeless_xp = make_mock_xp(exclude=tuple(DTYPE_NAMES)) with pytest.warns(HypothesisWarning): - dtypeless_xps = make_strategies_namespace(dtypeless_xp) + dtypeless_xps = make_strategies_namespace(dtypeless_xp, api_version="draft") @pytest.mark.parametrize( @@ -73,6 +75,8 @@ def test_error_on_missing_attr(stratname, args, attr): "integer_dtypes", "unsigned_integer_dtypes", "floating_dtypes", + "real_dtypes", + "complex_dtypes", ], ) def test_error_on_missing_dtypes(stratname): @@ -88,10 +92,12 @@ def test_error_on_missing_dtypes(stratname): "stratname, keep_anys", [ ("scalar_dtypes", [INT_NAMES, UINT_NAMES, FLOAT_NAMES]), - ("numeric_dtypes", [INT_NAMES, UINT_NAMES, FLOAT_NAMES]), + ("numeric_dtypes", [INT_NAMES, UINT_NAMES, FLOAT_NAMES, COMPLEX_NAMES]), ("integer_dtypes", [INT_NAMES]), ("unsigned_integer_dtypes", [UINT_NAMES]), ("floating_dtypes", [FLOAT_NAMES]), + ("real_dtypes", [INT_NAMES, UINT_NAMES, FLOAT_NAMES]), + ("complex_dtypes", [COMPLEX_NAMES]), ], ) @given(st.data()) @@ -111,7 +117,24 @@ def test_warning_on_partial_dtypes(stratname, keep_anys, data): ) ) xp = make_mock_xp(exclude=tuple(exclude)) - xps = make_strategies_namespace(xp) + xps = make_strategies_namespace(xp, api_version="draft") func = getattr(xps, stratname) with pytest.warns(HypothesisWarning, match=f"{mock_xp.__name__}.*dtype.*namespace"): data.draw(func()) + + +def test_raises_on_inferring_with_no_dunder_version(): + """When xp has no __array_api_version__, inferring api_version raises + helpful error.""" + xp = make_mock_xp(exclude=("__array_api_version__",)) + with pytest.raises(InvalidArgument, match="has no attribute"): + make_strategies_namespace(xp) + + +def test_raises_on_invalid_dunder_version(): + """When xp has invalid __array_api_version__, inferring api_version raises + helpful error.""" + xp = make_mock_xp() + xp.__array_api_version__ = None + with pytest.raises(InvalidArgument): + make_strategies_namespace(xp) diff --git a/hypothesis-python/tests/array_api/test_pretty.py b/hypothesis-python/tests/array_api/test_pretty.py index 79a7c6bb60..bb38f6427f 100644 --- a/hypothesis-python/tests/array_api/test_pretty.py +++ b/hypothesis-python/tests/array_api/test_pretty.py @@ -12,6 +12,11 @@ import pytest +from hypothesis.errors import InvalidArgument +from hypothesis.extra.array_api import make_strategies_namespace + +from tests.array_api.common import MIN_VER_FOR_COMPLEX + @pytest.mark.parametrize( "name", @@ -25,6 +30,10 @@ "integer_dtypes", "unsigned_integer_dtypes", "floating_dtypes", + "real_dtypes", + pytest.param( + "complex_dtypes", marks=pytest.mark.xp_min_version(MIN_VER_FOR_COMPLEX) + ), "valid_tuple_axes", "broadcastable_shapes", "mutually_broadcastable_shapes", @@ -55,6 +64,10 @@ def test_namespaced_methods_meta(xp, xps, name): ("integer_dtypes", []), ("unsigned_integer_dtypes", []), ("floating_dtypes", []), + ("real_dtypes", []), + pytest.param( + "complex_dtypes", [], marks=pytest.mark.xp_min_version(MIN_VER_FOR_COMPLEX) + ), ("valid_tuple_axes", [0]), ("broadcastable_shapes", [()]), ("mutually_broadcastable_shapes", [3]), @@ -70,8 +83,22 @@ def test_namespaced_strategies_repr(xp, xps, name, valid_args): assert xp.__name__ not in repr(strat), f"{xp.__name__} in strat repr" -def test_strategies_namespace_repr(xp, xps): - """Strategies namespace has good repr.""" +@pytest.mark.filterwarnings("ignore::hypothesis.errors.HypothesisWarning") +def test_inferred_version_strategies_namespace_repr(xp): + """Strategies namespace has good repr when api_version=None.""" + try: + xps = make_strategies_namespace(xp) + except InvalidArgument as e: + pytest.skip(str(e)) expected = f"make_strategies_namespace({xp.__name__})" assert repr(xps) == expected assert str(xps) == expected + + +@pytest.mark.filterwarnings("ignore::hypothesis.errors.HypothesisWarning") +def test_specified_version_strategies_namespace_repr(xp): + """Strategies namespace has good repr when api_version is specified.""" + xps = make_strategies_namespace(xp, api_version="2021.12") + expected = f"make_strategies_namespace({xp.__name__}, api_version='2021.12')" + assert repr(xps) == expected + assert str(xps) == expected diff --git a/hypothesis-python/tests/array_api/test_scalar_dtypes.py b/hypothesis-python/tests/array_api/test_scalar_dtypes.py index 4f733a1796..09e294bd1d 100644 --- a/hypothesis-python/tests/array_api/test_scalar_dtypes.py +++ b/hypothesis-python/tests/array_api/test_scalar_dtypes.py @@ -11,48 +11,86 @@ import pytest from hypothesis.extra.array_api import ( + COMPLEX_NAMES, DTYPE_NAMES, FLOAT_NAMES, INT_NAMES, NUMERIC_NAMES, + REAL_NAMES, UINT_NAMES, ) +from tests.array_api.common import MIN_VER_FOR_COMPLEX from tests.common.debug import assert_all_examples, find_any, minimal -def test_can_generate_scalar_dtypes(xp, xps): - dtypes = [getattr(xp, name) for name in DTYPE_NAMES] - assert_all_examples(xps.scalar_dtypes(), lambda dtype: dtype in dtypes) - - -def test_can_generate_boolean_dtypes(xp, xps): - assert_all_examples(xps.boolean_dtypes(), lambda dtype: dtype == xp.bool) +@pytest.mark.parametrize( + ("strat_name", "dtype_names"), + [ + ("integer_dtypes", INT_NAMES), + ("unsigned_integer_dtypes", UINT_NAMES), + ("floating_dtypes", FLOAT_NAMES), + ("real_dtypes", REAL_NAMES), + pytest.param( + "complex_dtypes", + COMPLEX_NAMES, + marks=pytest.mark.xp_min_version(MIN_VER_FOR_COMPLEX), + ), + ], +) +def test_all_generated_dtypes_are_of_group(xp, xps, strat_name, dtype_names): + """Strategy only generates expected dtypes.""" + strat_func = getattr(xps, strat_name) + dtypes = [getattr(xp, n) for n in dtype_names] + assert_all_examples(strat_func(), lambda dtype: dtype in dtypes) -def test_can_generate_numeric_dtypes(xp, xps): - numeric_dtypes = [getattr(xp, name) for name in NUMERIC_NAMES] - assert_all_examples(xps.numeric_dtypes(), lambda dtype: dtype in numeric_dtypes) +def test_all_generated_scalar_dtypes_are_scalar(xp, xps): + """Strategy only generates scalar dtypes.""" + if xps.api_version > "2021.12": + dtypes = [getattr(xp, n) for n in DTYPE_NAMES] + else: + dtypes = [getattr(xp, n) for n in ("bool",) + REAL_NAMES] + assert_all_examples(xps.scalar_dtypes(), lambda dtype: dtype in dtypes) -def test_can_generate_integer_dtypes(xp, xps): - int_dtypes = [getattr(xp, name) for name in INT_NAMES] - assert_all_examples(xps.integer_dtypes(), lambda dtype: dtype in int_dtypes) +def test_all_generated_numeric_dtypes_are_numeric(xp, xps): + """Strategy only generates numeric dtypes.""" + if xps.api_version > "2021.12": + dtypes = [getattr(xp, n) for n in NUMERIC_NAMES] + else: + dtypes = [getattr(xp, n) for n in REAL_NAMES] + assert_all_examples(xps.numeric_dtypes(), lambda dtype: dtype in dtypes) -def test_can_generate_unsigned_integer_dtypes(xp, xps): - uint_dtypes = [getattr(xp, name) for name in UINT_NAMES] - assert_all_examples( - xps.unsigned_integer_dtypes(), lambda dtype: dtype in uint_dtypes - ) +def skipif_unsupported_complex(strat_name, dtype_name): + if not dtype_name.startswith("complex"): + return strat_name, dtype_name + mark = pytest.mark.xp_min_version(MIN_VER_FOR_COMPLEX) + return pytest.param(strat_name, dtype_name, marks=mark) -def test_can_generate_floating_dtypes(xp, xps): - float_dtypes = [getattr(xp, name) for name in FLOAT_NAMES] - assert_all_examples(xps.floating_dtypes(), lambda dtype: dtype in float_dtypes) +@pytest.mark.parametrize( + ("strat_name", "dtype_name"), + [ + *[skipif_unsupported_complex("scalar_dtypes", n) for n in DTYPE_NAMES], + *[skipif_unsupported_complex("numeric_dtypes", n) for n in NUMERIC_NAMES], + *[("integer_dtypes", n) for n in INT_NAMES], + *[("unsigned_integer_dtypes", n) for n in UINT_NAMES], + *[("floating_dtypes", n) for n in FLOAT_NAMES], + *[("real_dtypes", n) for n in REAL_NAMES], + *[skipif_unsupported_complex("complex_dtypes", n) for n in COMPLEX_NAMES], + ], +) +def test_strategy_can_generate_every_dtype(xp, xps, strat_name, dtype_name): + """Strategy generates every expected dtype.""" + strat_func = getattr(xps, strat_name) + dtype = getattr(xp, dtype_name) + find_any(strat_func(), lambda d: d == dtype) def test_minimise_scalar_dtypes(xp, xps): + """Strategy minimizes to bool dtype.""" assert minimal(xps.scalar_dtypes()) == xp.bool @@ -62,9 +100,13 @@ def test_minimise_scalar_dtypes(xp, xps): ("integer_dtypes", 8), ("unsigned_integer_dtypes", 8), ("floating_dtypes", 32), + pytest.param( + "complex_dtypes", 64, marks=pytest.mark.xp_min_version(MIN_VER_FOR_COMPLEX) + ), ], ) def test_can_specify_sizes_as_an_int(xp, xps, strat_name, sizes): + """Strategy treats ints as a single size.""" strat_func = getattr(xps, strat_name) strat = strat_func(sizes=sizes) find_any(strat) diff --git a/hypothesis-python/tests/array_api/test_strategies_namespace.py b/hypothesis-python/tests/array_api/test_strategies_namespace.py new file mode 100644 index 0000000000..d1f18a6c25 --- /dev/null +++ b/hypothesis-python/tests/array_api/test_strategies_namespace.py @@ -0,0 +1,86 @@ +# This file is part of Hypothesis, which may be found at +# https://github.com/HypothesisWorks/hypothesis/ +# +# Copyright the Hypothesis Authors. +# Individual contributors are listed in AUTHORS.rst and the git log. +# +# This Source Code Form is subject to the terms of the Mozilla Public License, +# v. 2.0. If a copy of the MPL was not distributed with this file, You can +# obtain one at https://mozilla.org/MPL/2.0/. + +from types import SimpleNamespace +from weakref import WeakValueDictionary + +import pytest + +from hypothesis.extra import array_api +from hypothesis.extra.array_api import ( + NOMINAL_VERSIONS, + make_strategies_namespace, + mock_xp, +) +from hypothesis.strategies import SearchStrategy + +pytestmark = pytest.mark.filterwarnings("ignore::hypothesis.errors.HypothesisWarning") + + +class HashableArrayModuleFactory: + """ + mock_xp cannot be hashed and thus cannot be used in our cache. So just for + the purposes of testing the cache, we wrap it with an unsafe hash method. + """ + + def __getattr__(self, name): + return getattr(mock_xp, name) + + def __hash__(self): + return hash(tuple(sorted(mock_xp.__dict__))) + + +@pytest.mark.parametrize("api_version", ["2021.12", None]) +def test_caching(api_version, monkeypatch): + """Caches namespaces respective to arguments.""" + xp = HashableArrayModuleFactory() + assert isinstance(array_api._args_to_xps, WeakValueDictionary) # sanity check + monkeypatch.setattr(array_api, "_args_to_xps", WeakValueDictionary()) + assert len(array_api._args_to_xps) == 0 # sanity check + xps1 = array_api.make_strategies_namespace(xp, api_version=api_version) + assert len(array_api._args_to_xps) == 1 + xps2 = array_api.make_strategies_namespace(xp, api_version=api_version) + assert len(array_api._args_to_xps) == 1 + assert isinstance(xps2, SimpleNamespace) + assert xps2 is xps1 + del xps1 + del xps2 + assert len(array_api._args_to_xps) == 0 + + +@pytest.mark.parametrize( + "api_version1, api_version2", [(None, "2021.12"), ("2021.12", None)] +) +def test_inferred_namespace_shares_cache(api_version1, api_version2, monkeypatch): + """Results from inferred versions share the same cache key as results + from specified versions.""" + xp = HashableArrayModuleFactory() + xp.__array_api_version__ = "2021.12" + assert isinstance(array_api._args_to_xps, WeakValueDictionary) # sanity check + monkeypatch.setattr(array_api, "_args_to_xps", WeakValueDictionary()) + assert len(array_api._args_to_xps) == 0 # sanity check + xps1 = array_api.make_strategies_namespace(xp, api_version=api_version1) + assert xps1.api_version == "2021.12" # sanity check + assert len(array_api._args_to_xps) == 1 + xps2 = array_api.make_strategies_namespace(xp, api_version=api_version2) + assert xps2.api_version == "2021.12" # sanity check + assert len(array_api._args_to_xps) == 1 + assert xps2 is xps1 + + +def test_complex_dtypes_raises_on_2021_12(): + """Accessing complex_dtypes() for 2021.12 strategy namespace raises helpful + error, but accessing on future versions returns expected strategy.""" + first_xps = make_strategies_namespace(mock_xp, api_version="2021.12") + with pytest.raises(AttributeError, match="attempted to access"): + first_xps.complex_dtypes() + for api_version in NOMINAL_VERSIONS[1:]: + xps = make_strategies_namespace(mock_xp, api_version=api_version) + assert isinstance(xps.complex_dtypes(), SearchStrategy) diff --git a/hypothesis-python/tests/conftest.py b/hypothesis-python/tests/conftest.py index bcbd024903..cb015988fb 100644 --- a/hypothesis-python/tests/conftest.py +++ b/hypothesis-python/tests/conftest.py @@ -40,6 +40,10 @@ def pytest_configure(config): config.addinivalue_line("markers", "slow: pandas expects this marker to exist.") + config.addinivalue_line( + "markers", + "xp_min_version(api_version): run when greater or equal to api_version", + ) def pytest_addoption(parser):