diff --git a/hypothesis-python/src/hypothesis/extra/numpy.py b/hypothesis-python/src/hypothesis/extra/numpy.py index f7eda5fc20..a9e5628a47 100644 --- a/hypothesis-python/src/hypothesis/extra/numpy.py +++ b/hypothesis-python/src/hypothesis/extra/numpy.py @@ -999,62 +999,74 @@ def array_for(index_shape, size): ) +def _unpack_generic(thing): + # get_origin and get_args fail on python<3.9 because (some of) the + # relevant types do not inherit from _GenericAlias. So just pick the + # value out directly. + real_thing = getattr(thing, "__origin__", None) + if real_thing is not None: + return (real_thing, getattr(thing, "__args__", ())) + else: + return (thing, ()) + + +def _unpack_dtype(dtype): + dtype_args = getattr(dtype, "__args__", ()) + if dtype_args: + assert len(dtype_args) == 1 + if isinstance(dtype_args[0], TypeVar): + # numpy.dtype[+ScalarType] + assert dtype_args[0].__bound__ == np.generic + dtype = Any + else: + # plain dtype + dtype = dtype_args[0] + return dtype + + +def _dtype_and_shape_from_args(args): + if len(args) <= 1: + # Zero args: ndarray, _SupportsArray + # One arg: ndarray[type], _SupportsArray[type] + shape = Any + dtype = _unpack_dtype(args[0]) if args else Any + else: + # Two args: ndarray[shape, type], NDArray[*] + assert len(args) == 2 + shape = args[0] + assert shape is Any + dtype = _unpack_dtype(args[1]) + return ( + scalar_dtypes() if dtype is Any else np.dtype(dtype), + array_shapes(max_dims=2) if shape is Any else shape, + ) + + def _from_type(thing: Type[Ex]) -> Optional[st.SearchStrategy[Ex]]: """Called by st.from_type to try to infer a strategy for thing using numpy. - If we can infer a dtype strategy for thing, we return that; otherwise, - returns None (or raises). + If we can infer a numpy-specific strategy for thing, we return that; otherwise, + we return None. """ - def unpack_generic(thing): - # get_origin and get_args fail on python<3.9 because (some of) the - # relevant types do not inherit from _GenericAlias. So just pick the - # value out directly. - real_thing = getattr(thing, "__origin__", None) - if real_thing is not None: - return (real_thing, getattr(thing, "__args__", ())) - else: - return (thing, ()) - - def unpack_dtype(dtype): - dtype_args = getattr(dtype, "__args__", ()) - if dtype_args: - assert len(dtype_args) == 1 - if isinstance(dtype_args[0], TypeVar): - # numpy.dtype[+ScalarType] - assert dtype_args[0].__bound__ == np.generic - dtype = Any - else: - # plain dtype - dtype = dtype_args[0] - return dtype - - def find_dtype_shape(args): - if len(args) <= 1: - # Zero args: ndarray, _SupportsArray - # One arg: ndarray[type], _SupportsArray[type] - shape = Any - dtype = unpack_dtype(args[0]) if args else Any - else: - # Two args: ndarray[shape, type], NDArray[*] - assert len(args) == 2 - shape = args[0] - assert shape is Any - dtype = unpack_dtype(args[1]) - return ( - scalar_dtypes() if dtype is Any else np.dtype(dtype), - array_shapes(max_dims=2) if shape is Any else shape, - ) - - def base_strats(): - return [ - st.booleans(), - st.integers(), - st.floats(), - st.complex_numbers(), - st.text(), - st.binary(), - ] + base_strats = st.one_of([ + st.booleans(), + st.integers(), + st.floats(), + st.complex_numbers(), + st.text(), + st.binary(), + ]) + # np.array(arr_like) (1.24.3) fails if mixing strings and non-ascii + # bytestrings (ex: ['', b'\x80']) + base_strats_ascii = st.one_of([ + st.booleans(), + st.integers(), + st.floats(), + st.complex_numbers(), + st.text(), + st.binary().filter(bytes.isascii), + ]) if thing == np.dtype: return array_dtypes() @@ -1062,48 +1074,59 @@ def base_strats(): if thing == ArrayLike: # We override the default type resolution to ensure the "coercible to # array" contract is honoured. See - # https://github.com/HypothesisWorks/hypothesis/pull/3670#issuecomment-1578140422 - base_strat = st.one_of(base_strats()) - base_strat_ex = st.one_of(base_strats()[:-1]) + # https://github.com/HypothesisWorks/hypothesis/pull/3670#issuecomment-1578140422. + # The actual type is (as of np 1.24), with + # scalars:=[bool, int, float, complex, str, bytes]: + # Union[ + # _SupportsArray, + # _NestedSequence[_SupportsArray], + # *scalars, + # _NestedSequence[Union[*scalars]] + # ] return st.one_of( - base_strat, - # Exclude binary() from mixed lists because it can fail when - # combined with text() - see linked comment above - st.lists(st.one_of(base_strat_ex)), - # Recurse on tuples to get (up to size-2) nested equal-length - # sequencess - st.recursive(st.tuples(), st.tuples), - st.recursive(st.tuples(base_strat), st.tuples), - st.recursive(st.tuples(base_strat_ex, base_strat_ex), st.tuples), + # *scalars + base_strats, + # _NestedSequence[Union[*scalars]], but excluding non-ascii binary + st.lists(base_strats_ascii), + # _SupportsArray, use plain ndarrays st.from_type(np.ndarray), + # _NestedSequence[_SupportsArray], but guaranteeing equal size + st.integers(min_value=0, max_value=4).flatmap( + lambda s: st.one_of( + st.recursive(st.lists( + base_strats_ascii, + min_size=s, max_size=s + ), extend=st.tuples), + st.recursive(arrays( + scalar_dtypes(), + array_shapes(min_dims=s, max_dims=s, min_side=s, max_side=s), + ), extend=st.tuples), + ), + ), ) if isinstance(thing, type) and issubclass(thing, np.generic): dtype = np.dtype(thing) return from_dtype(dtype) if dtype.kind not in "OV" else None - real_thing, args = unpack_generic(thing) + real_thing, args = _unpack_generic(thing) if real_thing == _NestedSequence: # We have to override the default resolution to ensure sequences are of - # equal length. Actually they are still not, due to base_strat possibly - # returning arbitrary-shaped arrays - hence the even more special + # equal length. Actually they are still not, if the arg specialization + # returns arbitrary-shaped sequences or arrays - hence the even more special # resolution of ArrayLike, above. assert len(args) <= 1 - base_strat = st.from_type(args[0]) if args else st.one_of(base_strats()) + base_strat = st.from_type(args[0]) if args else base_strats return st.one_of( + st.lists(base_strat), st.recursive(st.tuples(), st.tuples), st.recursive(st.tuples(base_strat), st.tuples), st.recursive(st.tuples(base_strat, base_strat), st.tuples), - st.lists(base_strat), ) - if real_thing == _SupportsArray: - dtype, shape = find_dtype_shape(args) - return arrays(dtype, shape) - - if isinstance(real_thing, type) and issubclass(real_thing, np.ndarray): - dtype, shape = find_dtype_shape(args) + if real_thing in [np.ndarray, _SupportsArray]: + dtype, shape = _dtype_and_shape_from_args(args) return arrays(dtype, shape) # We didn't find a type to resolve, continue