Skip to content

Commit

Permalink
Merge pull request #3889 from JonathanPlasse/fix-extra-numpy-arrays-t…
Browse files Browse the repository at this point in the history
…ype-signature

Fix and improve hypothesis.extra.numpy typing
  • Loading branch information
Zac-HD committed Mar 4, 2024
2 parents 30c2b96 + 162d013 commit b9f3d75
Show file tree
Hide file tree
Showing 13 changed files with 206 additions and 37 deletions.
1 change: 1 addition & 0 deletions hypothesis-python/.coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ exclude_lines =
if sys\.version_info
if "[\w\.]+" in sys\.modules:
if .+ := sys\.modules\.get\("[\w\.]+"\)
@overload
7 changes: 7 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
RELEASE_TYPE: patch

This patch improves the type annotations in :mod:`hypothesis.extra.numpy`,
which makes inferred types more precise for both :pypi:`mypy` and
:pypi:`pyright`, and fixes some strict-mode errors on the latter.

Thanks to Jonathan Plasse for reporting and fixing this in :pull:`3889`!
4 changes: 2 additions & 2 deletions hypothesis-python/src/hypothesis/extra/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,7 @@ def complex_dtypes(

np = Mock()
else:
np = None
np = None # type: ignore[assignment]
if np is not None:

class FloatInfo(NamedTuple):
Expand All @@ -1112,7 +1112,7 @@ def mock_finfo(dtype: DataType) -> FloatInfo:
introduced it in v1.21.1, so we just use the equivalent tiny attribute
to keep mocking with older versions working.
"""
_finfo = np.finfo(dtype)
_finfo = np.finfo(dtype) # type: ignore[call-overload]
return FloatInfo(
int(_finfo.bits),
float(_finfo.eps),
Expand Down
89 changes: 70 additions & 19 deletions hypothesis-python/src/hypothesis/extra/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
Type,
TypeVar,
Union,
cast,
overload,
)

import numpy as np
Expand Down Expand Up @@ -136,7 +138,7 @@ def from_dtype(
kwargs = {k: v for k, v in locals().items() if k != "dtype" and v is not None}

# Compound datatypes, eg 'f4,f4,f4'
if dtype.names is not None:
if dtype.names is not None and dtype.fields is not None:
# mapping np.void.type over a strategy is nonsense, so return now.
subs = [from_dtype(dtype.fields[name][0], **kwargs) for name in dtype.names]
return st.tuples(*subs)
Expand Down Expand Up @@ -164,7 +166,7 @@ def compat_kw(*args, **kw):
result: st.SearchStrategy[Any] = st.booleans()
elif dtype.kind == "f":
result = st.floats(
width=min(8 * dtype.itemsize, 64),
width=cast(Literal[16, 32, 64], min(8 * dtype.itemsize, 64)),
**compat_kw(
"min_value",
"max_value",
Expand All @@ -177,7 +179,9 @@ def compat_kw(*args, **kw):
)
elif dtype.kind == "c":
result = st.complex_numbers(
width=min(8 * dtype.itemsize, 128), # convert from bytes to bits
width=cast(
Literal[32, 64, 128], min(8 * dtype.itemsize, 128)
), # convert from bytes to bits
**compat_kw(
"min_magnitude",
"max_magnitude",
Expand Down Expand Up @@ -411,6 +415,31 @@ def fill_for(elements, unique, fill, name=""):


D = TypeVar("D", bound="DTypeLike")
G = TypeVar("G", bound="np.generic")


@overload
@defines_strategy(force_reusable_values=True)
def arrays(
dtype: Union["np.dtype[G]", st.SearchStrategy["np.dtype[G]"]],
shape: Union[int, st.SearchStrategy[int], Shape, st.SearchStrategy[Shape]],
*,
elements: Optional[Union[st.SearchStrategy[Any], Mapping[str, Any]]] = None,
fill: Optional[st.SearchStrategy[Any]] = None,
unique: bool = False,
) -> "st.SearchStrategy[NDArray[G]]": ...


@overload
@defines_strategy(force_reusable_values=True)
def arrays(
dtype: Union[D, st.SearchStrategy[D]],
shape: Union[int, st.SearchStrategy[int], Shape, st.SearchStrategy[Shape]],
*,
elements: Optional[Union[st.SearchStrategy[Any], Mapping[str, Any]]] = None,
fill: Optional[st.SearchStrategy[Any]] = None,
unique: bool = False,
) -> "st.SearchStrategy[NDArray[Any]]": ...


@defines_strategy(force_reusable_values=True)
Expand All @@ -421,7 +450,7 @@ def arrays(
elements: Optional[Union[st.SearchStrategy[Any], Mapping[str, Any]]] = None,
fill: Optional[st.SearchStrategy[Any]] = None,
unique: bool = False,
) -> "st.SearchStrategy[NDArray[D]]":
) -> "st.SearchStrategy[NDArray[Any]]":
r"""Returns a strategy for generating :class:`numpy:numpy.ndarray`\ s.
* ``dtype`` may be any valid input to :class:`~numpy:numpy.dtype`
Expand Down Expand Up @@ -498,7 +527,7 @@ def arrays(
lambda s: arrays(dtype, s, elements=elements, fill=fill, unique=unique)
)
# From here on, we're only dealing with values and it's relatively simple.
dtype = np.dtype(dtype)
dtype = np.dtype(dtype) # type: ignore[arg-type,assignment]
assert isinstance(dtype, np.dtype) # help mypy out a bit...
if elements is None or isinstance(elements, Mapping):
if dtype.kind in ("m", "M") and "[" not in dtype.str:
Expand Down Expand Up @@ -554,8 +583,8 @@ def inner(*args, **kwargs):


@defines_dtype_strategy
def boolean_dtypes() -> st.SearchStrategy[np.dtype]:
return st.just("?")
def boolean_dtypes() -> st.SearchStrategy["np.dtype[np.bool_]"]:
return st.just("?") # type: ignore[arg-type]


def dtype_factory(kind, sizes, valid_sizes, endianness):
Expand Down Expand Up @@ -592,7 +621,7 @@ def unsigned_integer_dtypes(
*,
endianness: str = "?",
sizes: Sequence[Literal[8, 16, 32, 64]] = (8, 16, 32, 64),
) -> st.SearchStrategy[np.dtype]:
) -> st.SearchStrategy["np.dtype[np.unsignedinteger[Any]]"]:
"""Return a strategy for unsigned integer dtypes.
endianness may be ``<`` for little-endian, ``>`` for big-endian,
Expand All @@ -610,7 +639,7 @@ def integer_dtypes(
*,
endianness: str = "?",
sizes: Sequence[Literal[8, 16, 32, 64]] = (8, 16, 32, 64),
) -> st.SearchStrategy[np.dtype]:
) -> st.SearchStrategy["np.dtype[np.signedinteger[Any]]"]:
"""Return a strategy for signed integer dtypes.
endianness and sizes are treated as for
Expand All @@ -624,7 +653,7 @@ def floating_dtypes(
*,
endianness: str = "?",
sizes: Sequence[Literal[16, 32, 64, 96, 128]] = (16, 32, 64),
) -> st.SearchStrategy[np.dtype]:
) -> st.SearchStrategy["np.dtype[np.floating[Any]]"]:
"""Return a strategy for floating-point dtypes.
sizes is the size in bits of floating-point number. Some machines support
Expand All @@ -642,7 +671,7 @@ def complex_number_dtypes(
*,
endianness: str = "?",
sizes: Sequence[Literal[64, 128, 192, 256]] = (64, 128),
) -> st.SearchStrategy[np.dtype]:
) -> st.SearchStrategy["np.dtype[np.complexfloating[Any, Any]]"]:
"""Return a strategy for complex-number dtypes.
sizes is the total size in bits of a complex number, which consists
Expand Down Expand Up @@ -681,7 +710,7 @@ def validate_time_slice(max_period, min_period):
@defines_dtype_strategy
def datetime64_dtypes(
*, max_period: str = "Y", min_period: str = "ns", endianness: str = "?"
) -> st.SearchStrategy[np.dtype]:
) -> st.SearchStrategy["np.dtype[np.datetime64]"]:
"""Return a strategy for datetime64 dtypes, with various precisions from
year to attosecond."""
return dtype_factory(
Expand All @@ -695,7 +724,7 @@ def datetime64_dtypes(
@defines_dtype_strategy
def timedelta64_dtypes(
*, max_period: str = "Y", min_period: str = "ns", endianness: str = "?"
) -> st.SearchStrategy[np.dtype]:
) -> st.SearchStrategy["np.dtype[np.timedelta64]"]:
"""Return a strategy for timedelta64 dtypes, with various precisions from
year to attosecond."""
return dtype_factory(
Expand All @@ -709,7 +738,7 @@ def timedelta64_dtypes(
@defines_dtype_strategy
def byte_string_dtypes(
*, endianness: str = "?", min_len: int = 1, max_len: int = 16
) -> st.SearchStrategy[np.dtype]:
) -> st.SearchStrategy["np.dtype[np.bytes_]"]:
"""Return a strategy for generating bytestring dtypes, of various lengths
and byteorder.
Expand All @@ -724,7 +753,7 @@ def byte_string_dtypes(
@defines_dtype_strategy
def unicode_string_dtypes(
*, endianness: str = "?", min_len: int = 1, max_len: int = 16
) -> st.SearchStrategy[np.dtype]:
) -> st.SearchStrategy["np.dtype[np.str_]"]:
"""Return a strategy for generating unicode string dtypes, of various
lengths and byteorder.
Expand Down Expand Up @@ -771,7 +800,7 @@ def array_dtypes(
elements |= st.tuples(
name_titles, subtype_strategy, array_shapes(max_dims=2, max_side=2)
)
return st.lists(
return st.lists( # type: ignore[return-value]
elements=elements,
min_size=min_size,
max_size=max_size,
Expand Down Expand Up @@ -948,13 +977,35 @@ def basic_indices(
)


I = TypeVar("I", bound=np.integer)


@overload
@defines_strategy()
def integer_array_indices(
shape: Shape,
*,
result_shape: st.SearchStrategy[Shape] = array_shapes(),
) -> "st.SearchStrategy[Tuple[NDArray[np.signedinteger[Any]], ...]]": ...


@overload
@defines_strategy()
def integer_array_indices(
shape: Shape,
*,
result_shape: st.SearchStrategy[Shape] = array_shapes(),
dtype: "np.dtype[I]",
) -> "st.SearchStrategy[Tuple[NDArray[I], ...]]": ...


@defines_strategy()
def integer_array_indices(
shape: Shape,
*,
result_shape: st.SearchStrategy[Shape] = array_shapes(),
dtype: D = np.dtype(int),
) -> "st.SearchStrategy[Tuple[NDArray[D], ...]]":
dtype: "np.dtype[I] | np.dtype[np.signedinteger[Any]]" = np.dtype(int),
) -> "st.SearchStrategy[Tuple[NDArray[I], ...]]":
"""Return a search strategy for tuples of integer-arrays that, when used
to index into an array of shape ``shape``, given an array whose shape
was drawn from ``result_shape``.
Expand Down Expand Up @@ -1146,7 +1197,7 @@ def _from_type(thing: Type[Ex]) -> Optional[st.SearchStrategy[Ex]]:

if real_thing in [np.ndarray, _SupportsArray]:
dtype, shape = _dtype_and_shape_from_args(args)
return arrays(dtype, shape)
return arrays(dtype, shape) # type: ignore[return-value]

# We didn't find a type to resolve, continue
return None
5 changes: 4 additions & 1 deletion hypothesis-python/src/hypothesis/internal/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def add_note(exc, note):
exc.add_note(note)
except AttributeError:
if not hasattr(exc, "__notes__"):
exc.__notes__ = []
try:
exc.__notes__ = []
except AttributeError:
return # give up, might be e.g. a frozen dataclass
exc.__notes__.append(note)


Expand Down
10 changes: 10 additions & 0 deletions hypothesis-python/tests/cover/test_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest

from hypothesis.internal.compat import (
add_note,
ceil,
dataclass_asdict,
extract_bits,
Expand Down Expand Up @@ -143,3 +144,12 @@ def test_extract_bits_roundtrip(width, x):
if width is not None:
assert len(bits) == width
assert x == sum(v << p for p, v in enumerate(reversed(bits)))


@dataclass(frozen=True)
class ImmutableError:
msg: str


def test_add_note_fails_gracefully_on_frozen_instance():
add_note(ImmutableError("msg"), "some note")
2 changes: 1 addition & 1 deletion requirements/coverage.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ ptyprocess==0.7.0
# via pexpect
pyarrow==15.0.0
# via -r requirements/coverage.in
pytest==8.0.2
pytest==8.1.0
# via
# -r requirements/test.in
# pytest-xdist
Expand Down
2 changes: 1 addition & 1 deletion requirements/fuzzing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ pyarrow==15.0.0
# via -r requirements/coverage.in
pygments==2.17.2
# via rich
pytest==8.0.2
pytest==8.1.0
# via
# -r requirements/test.in
# hypofuzz
Expand Down
2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pluggy==1.4.0
# via pytest
ptyprocess==0.7.0
# via pexpect
pytest==8.0.2
pytest==8.1.0
# via
# -r requirements/test.in
# pytest-xdist
Expand Down
1 change: 1 addition & 0 deletions requirements/tools.in
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ ipython
lark
libcst
mypy
numpy
pelican[markdown]
pip-tools
pyright
Expand Down
6 changes: 4 additions & 2 deletions requirements/tools.txt
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ nh3==0.2.15
# via readme-renderer
nodeenv==1.8.0
# via pyright
numpy==1.26.4
# via -r requirements/tools.in
ordered-set==4.1.0
# via pelican
packaging==23.2
Expand All @@ -164,7 +166,7 @@ pexpect==4.9.0
# via ipython
pip-tools==7.4.0
# via -r requirements/tools.in
pkginfo==1.9.6
pkginfo==1.10.0
# via twine
platformdirs==4.2.0
# via
Expand Down Expand Up @@ -200,7 +202,7 @@ pyproject-hooks==1.0.0
# pip-tools
pyright==1.1.352
# via -r requirements/tools.in
pytest==8.0.2
pytest==8.1.0
# via -r requirements/tools.in
python-dateutil==2.9.0.post0
# via
Expand Down

0 comments on commit b9f3d75

Please sign in to comment.