Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type hint annotations for @particle_input #2443

Merged
merged 15 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog/2443.trivial.rst
@@ -0,0 +1,2 @@
Improved type hint annotations for `plasmapy.particles.decorators`,
which includes |particle_input|, and the corresponding tests.
6 changes: 0 additions & 6 deletions mypy.ini
Expand Up @@ -326,9 +326,6 @@ disable_error_code = type-arg,var-annotated
[mypy-plasmapy.particles.atomic]
disable_error_code = arg-type,assignment,misc,no-any-return,no-untyped-call,no-untyped-def,union-attr

[mypy-plasmapy.particles.decorators]
disable_error_code = arg-type,assignment,call-arg,index,misc,no-any-return,no-untyped-call,no-untyped-def,operator,return-value,type-arg,union-attr,var-annotated
Comment on lines -329 to -330
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the mypy errors that existed in plasmapy.particles.decorators before this pull request. My process was essentially to:

  1. Remove one error code and find the errors that mypy was reporting
  2. Either fix those errors, or add a #type: ignore comment if the reported error wasn't actually a problem with the code
  3. GOTO 1


[mypy-plasmapy.particles.ionization_state]
disable_error_code = arg-type,assignment,call-overload,misc,no-any-return,no-untyped-call,no-untyped-def,operator,return-value,type-arg

Expand Down Expand Up @@ -356,9 +353,6 @@ disable_error_code = attr-defined,no-untyped-def
[mypy-plasmapy.particles.tests.test_atomic]
disable_error_code = arg-type,attr-defined,no-untyped-def

[mypy-plasmapy.particles.tests.test_decorators]
disable_error_code = attr-defined,misc,no-untyped-call,no-untyped-def,return-value

[mypy-plasmapy.particles.tests.test_exceptions]
disable_error_code = attr-defined,no-untyped-def,operator,var-annotated

Expand Down
143 changes: 96 additions & 47 deletions plasmapy/particles/decorators.py
Expand Up @@ -2,16 +2,17 @@

__all__ = ["particle_input"]


import functools
import inspect
import numpy as np
import warnings
import wrapt

from collections.abc import Callable, Iterable
from collections.abc import Callable, Iterable, MutableMapping
ejohnson-96 marked this conversation as resolved.
Show resolved Hide resolved
from inspect import BoundArguments
from numbers import Integral, Real
from typing import Any, Optional, Union
from typing import Any, Optional, TypedDict, Union

from plasmapy.particles._factory import _physical_particle_factory
from plasmapy.particles.exceptions import (
Expand All @@ -26,6 +27,19 @@
from plasmapy.particles.particle_collections import ParticleList, ParticleListLike
from plasmapy.utils.exceptions import PlasmaPyDeprecationWarning


class _CallableDataDict(TypedDict, total=False):
allow_custom_particles: bool
allow_particle_lists: bool
annotations: dict[str, Any]
any_of: Optional[Union[str, Iterable[str]]]
callable_: Callable[..., Any]
exclude: Optional[Union[str, Iterable[str]]]
parameters_to_process: list[str]
require: Optional[Union[str, Iterable[str]]]
signature: inspect.Signature
ejohnson-96 marked this conversation as resolved.
Show resolved Hide resolved


_basic_particle_input_annotations = (
Particle, # deprecated
ParticleLike,
Expand All @@ -43,7 +57,7 @@
)


def _get_annotations(callable_: Callable):
def _get_annotations(callable_: Callable[..., Any]) -> dict[str, Any]:
ejohnson-96 marked this conversation as resolved.
Show resolved Hide resolved
"""
Access the annotations of a callable.

Expand All @@ -53,10 +67,10 @@ def _get_annotations(callable_: Callable):
`inspect.get_annotations`.
"""
# Python 3.10: Replace this with inspect.get_annotations
return getattr(callable_, "__annotations__", None)
return getattr(callable_, "__annotations__", {})


def _make_into_set_or_none(obj) -> Optional[set]:
def _make_into_set_or_none(obj: Any) -> Optional[Iterable[str]]:
"""
Return `None` if ``obj`` is `None`, and otherwise convert ``obj``
into a `set`.
Expand All @@ -71,10 +85,10 @@ def _make_into_set_or_none(obj) -> Optional[set]:

def _bind_arguments(
wrapped_signature: inspect.Signature,
callable_: Callable,
args: Optional[tuple] = None,
kwargs: Optional[dict[str, Any]] = None,
instance=None,
callable_: Callable[..., Any],
args: Iterable[Any],
kwargs: MutableMapping[str, Any],
instance: Any = None,
) -> inspect.BoundArguments:
"""
Bind the arguments provided by ``args`` and ``kwargs`` to
Expand Down Expand Up @@ -177,36 +191,36 @@ class _ParticleInput:

def __init__(
self,
callable_: Callable,
callable_: Callable[..., Any],
*,
require: Optional[Union[str, set, list, tuple]] = None,
any_of: Optional[Union[str, set, list, tuple]] = None,
exclude: Optional[Union[str, set, list, tuple]] = None,
require: Optional[Union[str, Iterable[str]]] = None,
any_of: Optional[Union[str, Iterable[str]]] = None,
exclude: Optional[Union[str, Iterable[str]]] = None,
allow_custom_particles: bool = True,
allow_particle_lists: bool = True,
) -> None:
self._data = {}
self.callable_ = callable_
self._data: _CallableDataDict = {}
self.callable_: Callable[..., Any] = callable_
self.require = require
self.any_of = any_of
self.exclude = exclude
self.allow_custom_particles = allow_custom_particles
self.allow_particle_lists = allow_particle_lists

@property
def callable_(self) -> Callable:
def callable_(self) -> Callable[..., Any]:
"""
The callable that is being decorated.

Returns
-------
callable
"""
return self._data["callable"]
return self._data["callable_"]

@callable_.setter
def callable_(self, callable_: Callable) -> None:
self._data["callable"] = callable_
def callable_(self, callable_: Callable[..., Any]) -> None:
self._data["callable_"] = callable_
self._data["annotations"] = _get_annotations(callable_)
self._data["parameters_to_process"] = self.find_parameters_to_process()
self._data["signature"] = inspect.signature(callable_)
Expand Down Expand Up @@ -240,10 +254,10 @@ def annotations(self) -> dict[str, Any]:
-------
`dict` of `str` to `object`
"""
return self._data.get("annotations")
return self._data.get("annotations") # type: ignore[return-value]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't want to bother with fixing this since there's a new way to get annotations in Python 3.10, and we're dropping Python 3.9 support soon anyway.


@property
def require(self) -> Optional[set]:
def require(self) -> Optional[Iterable[str]]:
"""
Categories that the particle must belong to.

Expand All @@ -254,11 +268,11 @@ def require(self) -> Optional[set]:
return self._data["require"]

@require.setter
def require(self, require_: Optional[Union[str, set, list, tuple]]) -> None:
def require(self, require_: Optional[Union[str, Iterable[str]]]) -> None:
self._data["require"] = _make_into_set_or_none(require_)

@property
def any_of(self) -> Optional[set]:
def any_of(self) -> Optional[Iterable[str]]:
"""
Categories of which the particle must belong to at least one.

Expand All @@ -269,11 +283,11 @@ def any_of(self) -> Optional[set]:
return self._data["any_of"]

@any_of.setter
def any_of(self, any_of_: Optional[Union[str, set, list, tuple]]) -> None:
def any_of(self, any_of_: Optional[Union[str, Iterable[str]]]) -> None:
self._data["any_of"] = _make_into_set_or_none(any_of_)

@property
def exclude(self) -> Optional[set]:
def exclude(self) -> Optional[Iterable[str]]:
"""
Categories that the particle cannot belong to.

Expand All @@ -284,7 +298,7 @@ def exclude(self) -> Optional[set]:
return self._data["exclude"]

@exclude.setter
def exclude(self, exclude_) -> None:
def exclude(self, exclude_: Optional[Union[str, Iterable[str]]]) -> None:
self._data["exclude"] = _make_into_set_or_none(exclude_)

@property
Expand Down Expand Up @@ -331,7 +345,9 @@ def parameters_to_process(self) -> list[str]:
"""
return self._data["parameters_to_process"]

def verify_charge_categorization(self, particle) -> None:
def verify_charge_categorization(
self, particle: Union[Particle, CustomParticle, ParticleList]
) -> None:
"""
Raise an exception if the particle does not meet charge
categorization criteria.
Expand All @@ -351,7 +367,7 @@ def verify_charge_categorization(self, particle) -> None:

if isinstance(uncharged, Iterable):
uncharged = any(uncharged)
lacks_charge_info = any(lacks_charge_info)
lacks_charge_info = any(lacks_charge_info) # type: ignore[arg-type]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason, the built-in any and all annotations aren't matching the types that we're using here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔


if must_be_charged and (uncharged or must_have_charge_info):
raise ChargeError(f"{self.callable_} can only accept charged particles.")
Expand All @@ -363,7 +379,13 @@ def verify_charge_categorization(self, particle) -> None:
)

@staticmethod
def category_errmsg(particle, require, exclude, any_of, callable_name) -> str:
def category_errmsg(
particle: Union[Particle, CustomParticle, ParticleList],
require: Optional[Union[str, Iterable[str]]],
exclude: Optional[Union[str, Iterable[str]]],
any_of: Optional[Union[str, Iterable[str]]],
callable_name: str,
) -> str:
"""
Return an error message for when a particle does not meet
categorization criteria.
Expand Down Expand Up @@ -391,15 +413,25 @@ def category_errmsg(particle, require, exclude, any_of, callable_name) -> str:

return category_errmsg

def verify_particle_categorization(self, particle) -> None:
def verify_particle_categorization(
self, particle: Union[Particle, CustomParticle, ParticleList]
) -> None:
"""
Verify that the particle meets the categorization criteria.

Parameters
----------
particle : Particle | CustomParticle

Raises
------
|ParticleError|
If the particle does not meet the categorization criteria.

Notes
-----
This method does not yet work with |ParticleList| objects.

See Also
--------
~plasmapy.particles.particle_class.Particle.is_category
Expand All @@ -418,7 +450,9 @@ def verify_particle_categorization(self, particle) -> None:
)
raise ParticleError(errmsg)

def verify_particle_name_criteria(self, parameter, particle):
def verify_particle_name_criteria(
self, parameter: str, particle: Union[Particle, CustomParticle, ParticleList]
) -> None:
"""
Check that parameters with special names meet the expected
categorization criteria.
Expand All @@ -432,7 +466,9 @@ def verify_particle_name_criteria(self, parameter, particle):
):
return

name_categorization_exception = [
name_categorization_exception: list[
tuple[str, dict[str, Optional[Union[str, Iterable[str]]]], type]
] = [
("element", {"require": "element"}, InvalidElementError),
("isotope", {"require": "isotope"}, InvalidIsotopeError),
(
Expand All @@ -449,7 +485,7 @@ def verify_particle_name_criteria(self, parameter, particle):
meets_name_criteria = particle.is_category(**categorization)

if isinstance(particle, Iterable) and not isinstance(particle, str):
meets_name_criteria = all(meets_name_criteria)
meets_name_criteria = all(meets_name_criteria) # type: ignore[arg-type]

if not meets_name_criteria:
raise exception(
Expand All @@ -458,7 +494,9 @@ def verify_particle_name_criteria(self, parameter, particle):
f"valid {parameter}."
)

def verify_allowed_types(self, particle):
def verify_allowed_types(
self, particle: Union[Particle, CustomParticle, ParticleList]
) -> None:
"""
Verify that the particle object contains only the allowed types
of particles.
Expand Down Expand Up @@ -489,8 +527,8 @@ def process_argument(
self,
parameter: str,
argument: Any,
Z: Optional[Integral],
mass_numb: Optional[Integral],
Z: Optional[float],
mass_numb: Optional[int],
) -> Any:
"""
Process an argument that has an appropriate annotation.
Expand Down Expand Up @@ -562,7 +600,9 @@ def process_argument(

parameters_to_skip = ("Z", "mass_numb")

def perform_pre_validations(self, Z, mass_numb):
def perform_pre_validations(
self, Z: Optional[float], mass_numb: Optional[int]
) -> None:
"""
Perform a variety of pre-checks on the arguments.

Expand Down Expand Up @@ -594,7 +634,10 @@ def perform_pre_validations(self, Z, mass_numb):
)

def process_arguments(
self, args: tuple, kwargs: dict[str, Any], instance=None
self,
args: Iterable[Any],
kwargs: MutableMapping[str, Any],
instance: Any = None,
) -> BoundArguments:
"""
Process the arguments passed to the callable_ callable.
Expand Down Expand Up @@ -639,14 +682,14 @@ def process_arguments(


def particle_input(
callable_: Optional[Callable] = None,
callable_: Optional[Callable[..., Any]] = None,
*,
require: Optional[Union[str, set, list, tuple]] = None,
any_of: Optional[Union[str, set, list, tuple]] = None,
exclude: Optional[Union[str, set, list, tuple]] = None,
require: Optional[Union[str, Iterable[str]]] = None,
any_of: Optional[Union[str, Iterable[str]]] = None,
exclude: Optional[Union[str, Iterable[str]]] = None,
allow_custom_particles: bool = True,
allow_particle_lists: bool = True,
) -> Callable:
) -> Callable[..., Any]:
r"""
Convert |particle-like| |arguments| into particle objects.

Expand Down Expand Up @@ -918,9 +961,15 @@ def instance_method(self, particle: ParticleLike, B: u.Quantity[u.T]):

@wrapt.decorator
def wrapper(
callable__: Callable, instance: Any, args: tuple, kwargs: dict[str, Any]
):
callable__: Callable[..., Any],
instance: Any,
args: Iterable[Any],
kwargs: MutableMapping[str, Any],
) -> Callable[..., Any]:
bound_arguments = particle_validator.process_arguments(args, kwargs, instance)
return callable__(*bound_arguments.args, **bound_arguments.kwargs)
return callable__( # type: ignore[no-any-return]
*bound_arguments.args,
**bound_arguments.kwargs,
)

return wrapper(callable_)
return wrapper(callable_, instance=None, args=(), kwargs={})