Skip to content

Commit

Permalink
Improve type hint annotations for @particle_input (#2443)
Browse files Browse the repository at this point in the history
* Improve type hints from decorators.py

* Add changelog entry

* Update mypy.ini

* Update type hint annotations in decorators.py

* Update mypy.ini per-file ignores

* Minor edits

* Fix location of type: ignore comment

* Add type hints to particles/test_decorators.py

* Update 2443.trivial.rst

* Update comments

* Update return annotations to be more general

* typing.Callable → collections.abc.Callable
  • Loading branch information
namurphy committed Jan 10, 2024
1 parent 30e5c91 commit 6918db6
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 121 deletions.
2 changes: 2 additions & 0 deletions changelog/2443.trivial.rst
@@ -0,0 +1,2 @@
Improved type hint annotations for `plasmapy.particles.decorators`,
which includes |particle_input|, and the corresponding tests.
6 changes: 0 additions & 6 deletions mypy.ini
Expand Up @@ -326,9 +326,6 @@ disable_error_code = type-arg,var-annotated
[mypy-plasmapy.particles.atomic]
disable_error_code = arg-type,assignment,misc,no-any-return,no-untyped-call,no-untyped-def,union-attr

[mypy-plasmapy.particles.decorators]
disable_error_code = arg-type,assignment,call-arg,index,misc,no-any-return,no-untyped-call,no-untyped-def,operator,return-value,type-arg,union-attr,var-annotated

[mypy-plasmapy.particles.ionization_state]
disable_error_code = arg-type,assignment,call-overload,misc,no-any-return,no-untyped-call,no-untyped-def,operator,return-value,type-arg

Expand Down Expand Up @@ -356,9 +353,6 @@ disable_error_code = attr-defined,no-untyped-def
[mypy-plasmapy.particles.tests.test_atomic]
disable_error_code = arg-type,attr-defined,no-untyped-def

[mypy-plasmapy.particles.tests.test_decorators]
disable_error_code = attr-defined,misc,no-untyped-call,no-untyped-def,return-value

[mypy-plasmapy.particles.tests.test_exceptions]
disable_error_code = attr-defined,no-untyped-def,operator,var-annotated

Expand Down
143 changes: 96 additions & 47 deletions plasmapy/particles/decorators.py
Expand Up @@ -2,16 +2,17 @@

__all__ = ["particle_input"]


import functools
import inspect
import numpy as np
import warnings
import wrapt

from collections.abc import Callable, Iterable
from collections.abc import Callable, Iterable, MutableMapping
from inspect import BoundArguments
from numbers import Integral, Real
from typing import Any, Optional, Union
from typing import Any, Optional, TypedDict, Union

from plasmapy.particles._factory import _physical_particle_factory
from plasmapy.particles.exceptions import (
Expand All @@ -26,6 +27,19 @@
from plasmapy.particles.particle_collections import ParticleList, ParticleListLike
from plasmapy.utils.exceptions import PlasmaPyDeprecationWarning


class _CallableDataDict(TypedDict, total=False):
allow_custom_particles: bool
allow_particle_lists: bool
annotations: dict[str, Any]
any_of: Optional[Union[str, Iterable[str]]]
callable_: Callable[..., Any]
exclude: Optional[Union[str, Iterable[str]]]
parameters_to_process: list[str]
require: Optional[Union[str, Iterable[str]]]
signature: inspect.Signature


_basic_particle_input_annotations = (
Particle, # deprecated
ParticleLike,
Expand All @@ -43,7 +57,7 @@
)


def _get_annotations(callable_: Callable):
def _get_annotations(callable_: Callable[..., Any]) -> dict[str, Any]:
"""
Access the annotations of a callable.
Expand All @@ -53,10 +67,10 @@ def _get_annotations(callable_: Callable):
`inspect.get_annotations`.
"""
# Python 3.10: Replace this with inspect.get_annotations
return getattr(callable_, "__annotations__", None)
return getattr(callable_, "__annotations__", {})


def _make_into_set_or_none(obj) -> Optional[set]:
def _make_into_set_or_none(obj: Any) -> Optional[Iterable[str]]:
"""
Return `None` if ``obj`` is `None`, and otherwise convert ``obj``
into a `set`.
Expand All @@ -71,10 +85,10 @@ def _make_into_set_or_none(obj) -> Optional[set]:

def _bind_arguments(
wrapped_signature: inspect.Signature,
callable_: Callable,
args: Optional[tuple] = None,
kwargs: Optional[dict[str, Any]] = None,
instance=None,
callable_: Callable[..., Any],
args: Iterable[Any],
kwargs: MutableMapping[str, Any],
instance: Any = None,
) -> inspect.BoundArguments:
"""
Bind the arguments provided by ``args`` and ``kwargs`` to
Expand Down Expand Up @@ -177,36 +191,36 @@ class _ParticleInput:

def __init__(
self,
callable_: Callable,
callable_: Callable[..., Any],
*,
require: Optional[Union[str, set, list, tuple]] = None,
any_of: Optional[Union[str, set, list, tuple]] = None,
exclude: Optional[Union[str, set, list, tuple]] = None,
require: Optional[Union[str, Iterable[str]]] = None,
any_of: Optional[Union[str, Iterable[str]]] = None,
exclude: Optional[Union[str, Iterable[str]]] = None,
allow_custom_particles: bool = True,
allow_particle_lists: bool = True,
) -> None:
self._data = {}
self.callable_ = callable_
self._data: _CallableDataDict = {}
self.callable_: Callable[..., Any] = callable_
self.require = require
self.any_of = any_of
self.exclude = exclude
self.allow_custom_particles = allow_custom_particles
self.allow_particle_lists = allow_particle_lists

@property
def callable_(self) -> Callable:
def callable_(self) -> Callable[..., Any]:
"""
The callable that is being decorated.
Returns
-------
callable
"""
return self._data["callable"]
return self._data["callable_"]

@callable_.setter
def callable_(self, callable_: Callable) -> None:
self._data["callable"] = callable_
def callable_(self, callable_: Callable[..., Any]) -> None:
self._data["callable_"] = callable_
self._data["annotations"] = _get_annotations(callable_)
self._data["parameters_to_process"] = self.find_parameters_to_process()
self._data["signature"] = inspect.signature(callable_)
Expand Down Expand Up @@ -240,10 +254,10 @@ def annotations(self) -> dict[str, Any]:
-------
`dict` of `str` to `object`
"""
return self._data.get("annotations")
return self._data.get("annotations") # type: ignore[return-value]

@property
def require(self) -> Optional[set]:
def require(self) -> Optional[Iterable[str]]:
"""
Categories that the particle must belong to.
Expand All @@ -254,11 +268,11 @@ def require(self) -> Optional[set]:
return self._data["require"]

@require.setter
def require(self, require_: Optional[Union[str, set, list, tuple]]) -> None:
def require(self, require_: Optional[Union[str, Iterable[str]]]) -> None:
self._data["require"] = _make_into_set_or_none(require_)

@property
def any_of(self) -> Optional[set]:
def any_of(self) -> Optional[Iterable[str]]:
"""
Categories of which the particle must belong to at least one.
Expand All @@ -269,11 +283,11 @@ def any_of(self) -> Optional[set]:
return self._data["any_of"]

@any_of.setter
def any_of(self, any_of_: Optional[Union[str, set, list, tuple]]) -> None:
def any_of(self, any_of_: Optional[Union[str, Iterable[str]]]) -> None:
self._data["any_of"] = _make_into_set_or_none(any_of_)

@property
def exclude(self) -> Optional[set]:
def exclude(self) -> Optional[Iterable[str]]:
"""
Categories that the particle cannot belong to.
Expand All @@ -284,7 +298,7 @@ def exclude(self) -> Optional[set]:
return self._data["exclude"]

@exclude.setter
def exclude(self, exclude_) -> None:
def exclude(self, exclude_: Optional[Union[str, Iterable[str]]]) -> None:
self._data["exclude"] = _make_into_set_or_none(exclude_)

@property
Expand Down Expand Up @@ -331,7 +345,9 @@ def parameters_to_process(self) -> list[str]:
"""
return self._data["parameters_to_process"]

def verify_charge_categorization(self, particle) -> None:
def verify_charge_categorization(
self, particle: Union[Particle, CustomParticle, ParticleList]
) -> None:
"""
Raise an exception if the particle does not meet charge
categorization criteria.
Expand All @@ -351,7 +367,7 @@ def verify_charge_categorization(self, particle) -> None:

if isinstance(uncharged, Iterable):
uncharged = any(uncharged)
lacks_charge_info = any(lacks_charge_info)
lacks_charge_info = any(lacks_charge_info) # type: ignore[arg-type]

if must_be_charged and (uncharged or must_have_charge_info):
raise ChargeError(f"{self.callable_} can only accept charged particles.")
Expand All @@ -363,7 +379,13 @@ def verify_charge_categorization(self, particle) -> None:
)

@staticmethod
def category_errmsg(particle, require, exclude, any_of, callable_name) -> str:
def category_errmsg(
particle: Union[Particle, CustomParticle, ParticleList],
require: Optional[Union[str, Iterable[str]]],
exclude: Optional[Union[str, Iterable[str]]],
any_of: Optional[Union[str, Iterable[str]]],
callable_name: str,
) -> str:
"""
Return an error message for when a particle does not meet
categorization criteria.
Expand Down Expand Up @@ -391,15 +413,25 @@ def category_errmsg(particle, require, exclude, any_of, callable_name) -> str:

return category_errmsg

def verify_particle_categorization(self, particle) -> None:
def verify_particle_categorization(
self, particle: Union[Particle, CustomParticle, ParticleList]
) -> None:
"""
Verify that the particle meets the categorization criteria.
Parameters
----------
particle : Particle | CustomParticle
Raises
------
|ParticleError|
If the particle does not meet the categorization criteria.
Notes
-----
This method does not yet work with |ParticleList| objects.
See Also
--------
~plasmapy.particles.particle_class.Particle.is_category
Expand All @@ -418,7 +450,9 @@ def verify_particle_categorization(self, particle) -> None:
)
raise ParticleError(errmsg)

def verify_particle_name_criteria(self, parameter, particle):
def verify_particle_name_criteria(
self, parameter: str, particle: Union[Particle, CustomParticle, ParticleList]
) -> None:
"""
Check that parameters with special names meet the expected
categorization criteria.
Expand All @@ -432,7 +466,9 @@ def verify_particle_name_criteria(self, parameter, particle):
):
return

name_categorization_exception = [
name_categorization_exception: list[
tuple[str, dict[str, Optional[Union[str, Iterable[str]]]], type]
] = [
("element", {"require": "element"}, InvalidElementError),
("isotope", {"require": "isotope"}, InvalidIsotopeError),
(
Expand All @@ -449,7 +485,7 @@ def verify_particle_name_criteria(self, parameter, particle):
meets_name_criteria = particle.is_category(**categorization)

if isinstance(particle, Iterable) and not isinstance(particle, str):
meets_name_criteria = all(meets_name_criteria)
meets_name_criteria = all(meets_name_criteria) # type: ignore[arg-type]

if not meets_name_criteria:
raise exception(
Expand All @@ -458,7 +494,9 @@ def verify_particle_name_criteria(self, parameter, particle):
f"valid {parameter}."
)

def verify_allowed_types(self, particle):
def verify_allowed_types(
self, particle: Union[Particle, CustomParticle, ParticleList]
) -> None:
"""
Verify that the particle object contains only the allowed types
of particles.
Expand Down Expand Up @@ -489,8 +527,8 @@ def process_argument(
self,
parameter: str,
argument: Any,
Z: Optional[Integral],
mass_numb: Optional[Integral],
Z: Optional[float],
mass_numb: Optional[int],
) -> Any:
"""
Process an argument that has an appropriate annotation.
Expand Down Expand Up @@ -562,7 +600,9 @@ def process_argument(

parameters_to_skip = ("Z", "mass_numb")

def perform_pre_validations(self, Z, mass_numb):
def perform_pre_validations(
self, Z: Optional[float], mass_numb: Optional[int]
) -> None:
"""
Perform a variety of pre-checks on the arguments.
Expand Down Expand Up @@ -594,7 +634,10 @@ def perform_pre_validations(self, Z, mass_numb):
)

def process_arguments(
self, args: tuple, kwargs: dict[str, Any], instance=None
self,
args: Iterable[Any],
kwargs: MutableMapping[str, Any],
instance: Any = None,
) -> BoundArguments:
"""
Process the arguments passed to the callable_ callable.
Expand Down Expand Up @@ -639,14 +682,14 @@ def process_arguments(


def particle_input(
callable_: Optional[Callable] = None,
callable_: Optional[Callable[..., Any]] = None,
*,
require: Optional[Union[str, set, list, tuple]] = None,
any_of: Optional[Union[str, set, list, tuple]] = None,
exclude: Optional[Union[str, set, list, tuple]] = None,
require: Optional[Union[str, Iterable[str]]] = None,
any_of: Optional[Union[str, Iterable[str]]] = None,
exclude: Optional[Union[str, Iterable[str]]] = None,
allow_custom_particles: bool = True,
allow_particle_lists: bool = True,
) -> Callable:
) -> Callable[..., Any]:
r"""
Convert |particle-like| |arguments| into particle objects.
Expand Down Expand Up @@ -918,9 +961,15 @@ def instance_method(self, particle: ParticleLike, B: u.Quantity[u.T]):

@wrapt.decorator
def wrapper(
callable__: Callable, instance: Any, args: tuple, kwargs: dict[str, Any]
):
callable__: Callable[..., Any],
instance: Any,
args: Iterable[Any],
kwargs: MutableMapping[str, Any],
) -> Callable[..., Any]:
bound_arguments = particle_validator.process_arguments(args, kwargs, instance)
return callable__(*bound_arguments.args, **bound_arguments.kwargs)
return callable__( # type: ignore[no-any-return]
*bound_arguments.args,
**bound_arguments.kwargs,
)

return wrapper(callable_)
return wrapper(callable_, instance=None, args=(), kwargs={})

0 comments on commit 6918db6

Please sign in to comment.