Skip to content

Commit

Permalink
Refactor numpy resolution a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
jobh committed Jun 8, 2023
1 parent 47613b7 commit d256282
Showing 1 changed file with 97 additions and 74 deletions.
171 changes: 97 additions & 74 deletions hypothesis-python/src/hypothesis/extra/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,111 +999,134 @@ def array_for(index_shape, size):
)


def _unpack_generic(thing):
# get_origin and get_args fail on python<3.9 because (some of) the
# relevant types do not inherit from _GenericAlias. So just pick the
# value out directly.
real_thing = getattr(thing, "__origin__", None)
if real_thing is not None:
return (real_thing, getattr(thing, "__args__", ()))
else:
return (thing, ())


def _unpack_dtype(dtype):
dtype_args = getattr(dtype, "__args__", ())
if dtype_args:
assert len(dtype_args) == 1
if isinstance(dtype_args[0], TypeVar):
# numpy.dtype[+ScalarType]
assert dtype_args[0].__bound__ == np.generic
dtype = Any
else:
# plain dtype
dtype = dtype_args[0]
return dtype


def _dtype_and_shape_from_args(args):
if len(args) <= 1:
# Zero args: ndarray, _SupportsArray
# One arg: ndarray[type], _SupportsArray[type]
shape = Any
dtype = _unpack_dtype(args[0]) if args else Any
else:
# Two args: ndarray[shape, type], NDArray[*]
assert len(args) == 2
shape = args[0]
assert shape is Any
dtype = _unpack_dtype(args[1])
return (
scalar_dtypes() if dtype is Any else np.dtype(dtype),
array_shapes(max_dims=2) if shape is Any else shape,
)


def _from_type(thing: Type[Ex]) -> Optional[st.SearchStrategy[Ex]]:
"""Called by st.from_type to try to infer a strategy for thing using numpy.
If we can infer a dtype strategy for thing, we return that; otherwise,
returns None (or raises).
If we can infer a numpy-specific strategy for thing, we return that; otherwise,
we return None.
"""

def unpack_generic(thing):
# get_origin and get_args fail on python<3.9 because (some of) the
# relevant types do not inherit from _GenericAlias. So just pick the
# value out directly.
real_thing = getattr(thing, "__origin__", None)
if real_thing is not None:
return (real_thing, getattr(thing, "__args__", ()))
else:
return (thing, ())

def unpack_dtype(dtype):
dtype_args = getattr(dtype, "__args__", ())
if dtype_args:
assert len(dtype_args) == 1
if isinstance(dtype_args[0], TypeVar):
# numpy.dtype[+ScalarType]
assert dtype_args[0].__bound__ == np.generic
dtype = Any
else:
# plain dtype
dtype = dtype_args[0]
return dtype

def find_dtype_shape(args):
if len(args) <= 1:
# Zero args: ndarray, _SupportsArray
# One arg: ndarray[type], _SupportsArray[type]
shape = Any
dtype = unpack_dtype(args[0]) if args else Any
else:
# Two args: ndarray[shape, type], NDArray[*]
assert len(args) == 2
shape = args[0]
assert shape is Any
dtype = unpack_dtype(args[1])
return (
scalar_dtypes() if dtype is Any else np.dtype(dtype),
array_shapes(max_dims=2) if shape is Any else shape,
)

def base_strats():
return [
st.booleans(),
st.integers(),
st.floats(),
st.complex_numbers(),
st.text(),
st.binary(),
]
base_strats = st.one_of([
st.booleans(),
st.integers(),
st.floats(),
st.complex_numbers(),
st.text(),
st.binary(),
])
# np.array(arr_like) (1.24.3) fails if mixing strings and non-ascii
# bytestrings (ex: ['', b'\x80'])
base_strats_ascii = st.one_of([
st.booleans(),
st.integers(),
st.floats(),
st.complex_numbers(),
st.text(),
st.binary().filter(bytes.isascii),
])

if thing == np.dtype:
return array_dtypes()

if thing == ArrayLike:
# We override the default type resolution to ensure the "coercible to
# array" contract is honoured. See
# https://github.com/HypothesisWorks/hypothesis/pull/3670#issuecomment-1578140422
base_strat = st.one_of(base_strats())
base_strat_ex = st.one_of(base_strats()[:-1])
# https://github.com/HypothesisWorks/hypothesis/pull/3670#issuecomment-1578140422.
# The actual type is (as of np 1.24), with
# scalars:=[bool, int, float, complex, str, bytes]:
# Union[
# _SupportsArray,
# _NestedSequence[_SupportsArray],
# *scalars,
# _NestedSequence[Union[*scalars]]
# ]
return st.one_of(
base_strat,
# Exclude binary() from mixed lists because it can fail when
# combined with text() - see linked comment above
st.lists(st.one_of(base_strat_ex)),
# Recurse on tuples to get (up to size-2) nested equal-length
# sequencess
st.recursive(st.tuples(), st.tuples),
st.recursive(st.tuples(base_strat), st.tuples),
st.recursive(st.tuples(base_strat_ex, base_strat_ex), st.tuples),
# *scalars
base_strats,
# _NestedSequence[Union[*scalars]], but excluding non-ascii binary
st.lists(base_strats_ascii),
# _SupportsArray, use plain ndarrays
st.from_type(np.ndarray),
# _NestedSequence[_SupportsArray], but guaranteeing equal size
st.integers(min_value=0, max_value=4).flatmap(
lambda s: st.one_of(
st.recursive(st.lists(
base_strats_ascii,
min_size=s, max_size=s
), extend=st.tuples),
st.recursive(arrays(
scalar_dtypes(),
array_shapes(min_dims=s, max_dims=s, min_side=s, max_side=s),
), extend=st.tuples),
),
),
)

if isinstance(thing, type) and issubclass(thing, np.generic):
dtype = np.dtype(thing)
return from_dtype(dtype) if dtype.kind not in "OV" else None

real_thing, args = unpack_generic(thing)
real_thing, args = _unpack_generic(thing)

if real_thing == _NestedSequence:
# We have to override the default resolution to ensure sequences are of
# equal length. Actually they are still not, due to base_strat possibly
# returning arbitrary-shaped arrays - hence the even more special
# equal length. Actually they are still not, if the arg specialization
# returns arbitrary-shaped sequences or arrays - hence the even more special
# resolution of ArrayLike, above.
assert len(args) <= 1
base_strat = st.from_type(args[0]) if args else st.one_of(base_strats())
base_strat = st.from_type(args[0]) if args else base_strats
return st.one_of(
st.lists(base_strat),
st.recursive(st.tuples(), st.tuples),
st.recursive(st.tuples(base_strat), st.tuples),
st.recursive(st.tuples(base_strat, base_strat), st.tuples),
st.lists(base_strat),
)

if real_thing == _SupportsArray:
dtype, shape = find_dtype_shape(args)
return arrays(dtype, shape)

if isinstance(real_thing, type) and issubclass(real_thing, np.ndarray):
dtype, shape = find_dtype_shape(args)
if real_thing in [np.ndarray, _SupportsArray]:
dtype, shape = _dtype_and_shape_from_args(args)
return arrays(dtype, shape)

# We didn't find a type to resolve, continue
Expand Down

0 comments on commit d256282

Please sign in to comment.