Skip to content

Commit

Permalink
Add particlewise keyword to particle_collections.is_category (#2648)
Browse files Browse the repository at this point in the history
* add particlewise keyword to particle_collections.is_category

* add changelog

* fix precommit

* remove brackets in test input

* add noqa

* attempt to fix decorators with particlewise=True

* attempt to fix decorators with particlewise=True

* fix iterable bool issue

* fix is_category error in formulary/collisions/helio/collsional_analysis

* np.allclose -> all(), and remove # noqa

* overload is_category to improve type hints -- does this fix mypy static check?

* refactor overload to use Literal[]

* revert the changes to particles.decorator.verify_allowed_types

* modify mypy checks

* modify docstring

* add info to docs about particlewise

* fix changelog + typo

* fix docs

* fix changelog link

* update changelog

---------

Co-authored-by: Jeffrey Reep <reep@kahiku.local>
Co-authored-by: Jeffrey Reep <reep@atrcw30.IfA.Hawaii.Edu>
  • Loading branch information
3 people committed May 4, 2024
1 parent 820c300 commit ca92246
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 24 deletions.
3 changes: 3 additions & 0 deletions changelog/2648.breaking.rst
@@ -0,0 +1,3 @@
Added a new keyword ``particlewise`` to the method `~plasmapy.particles.particle_collections.ParticleList.is_category` of |ParticleList|,
which now causes the function to return a `bool` for the whole list by default. The old functionality is still available
by setting ``particlewise`` to `True`.
12 changes: 12 additions & 0 deletions docs/particles/particle_class.rst
Expand Up @@ -259,6 +259,18 @@ and/or |ParticleList| objects together.
>>> helium_ions + cp + proton
ParticleList(['He-4 0+', 'He-4 1+', 'He-4 2+', 'Fe 9.5+', 'p+'])

As with an individual |Particle| and |CustomParticle|, we can check whether
all the particles in a list fall within a category using |is_category|.

>>> helium_ions.is_category("ion")
False

We may also check each particle in the list individually by setting
the keyword ``particlewise`` to `True`.

>>> helium_ions.is_category("ion", particlewise=True)
[False, True, True]

The machinery contained with |ParticleList| lets us calculate plasma
parameters from `plasmapy.formulary` for multiple particles at once.

Expand Down
Expand Up @@ -211,7 +211,7 @@ def temp_ratio( # noqa: C901
f"Instead received {len(ions)} input values."
)

if not all(ions.is_category("ion")):
if not ions.is_category("ion"):
raise ValueError(
f"Particle(s) in 'ions' must be ions, received {ions=} "
"instead. Please renter the 'ions' input parameter."
Expand Down
40 changes: 30 additions & 10 deletions src/plasmapy/particles/decorators.py
Expand Up @@ -351,12 +351,18 @@ def verify_charge_categorization(
must_be_charged = self.require is not None and "charged" in self.require
must_have_charge_info = self.any_of == {"charged", "uncharged"}

uncharged = particle.is_category("uncharged")
lacks_charge_info = particle.is_category(exclude={"charged", "uncharged"})
if isinstance(particle, ParticleList):
uncharged = particle.is_category("uncharged", particlewise=True)
lacks_charge_info = particle.is_category(
exclude={"charged", "uncharged"}, particlewise=True
)
else:
uncharged = particle.is_category("uncharged")
lacks_charge_info = particle.is_category(exclude={"charged", "uncharged"})

if isinstance(uncharged, Iterable):
uncharged = any(uncharged)
lacks_charge_info = any(lacks_charge_info) # type: ignore[arg-type]
lacks_charge_info = any(lacks_charge_info)

if must_be_charged and (uncharged or must_have_charge_info):
raise ChargeError(f"{self.callable_} can only accept charged particles.")
Expand Down Expand Up @@ -425,11 +431,20 @@ def verify_particle_categorization(
--------
~plasmapy.particles.particle_class.Particle.is_category
"""
if not particle.is_category(
require=self.require,
any_of=self.any_of,
exclude=self.exclude,
):
if isinstance(particle, ParticleList):
particle_in_category = particle.is_category(
require=self.require,
any_of=self.any_of,
exclude=self.exclude,
particlewise=True,
)
else:
particle_in_category = particle.is_category(
require=self.require,
any_of=self.any_of,
exclude=self.exclude,
)
if not particle_in_category:
errmsg = self.category_errmsg(
particle,
self.require,
Expand Down Expand Up @@ -471,7 +486,12 @@ def verify_particle_name_criteria(
if parameter != name or particle is None:
continue

meets_name_criteria = particle.is_category(**categorization)
if isinstance(particle, ParticleList):
meets_name_criteria = particle.is_category(
**categorization, particlewise=True
)
else:
meets_name_criteria = particle.is_category(**categorization)

if isinstance(particle, Iterable) and not isinstance(particle, str):
meets_name_criteria = all(meets_name_criteria) # type: ignore[arg-type]
Expand Down Expand Up @@ -505,7 +525,7 @@ def verify_allowed_types(
if (
not self.allow_custom_particles
and isinstance(particle, ParticleList)
and any(particle.is_category("custom"))
and any(particle.is_category("custom", particlewise=True)) # type: ignore[arg-type]
):
raise InvalidParticleError(
f"{self.callable_.__name__} does not accept CustomParticle "
Expand Down
57 changes: 48 additions & 9 deletions src/plasmapy/particles/particle_collections.py
Expand Up @@ -5,7 +5,7 @@
import collections
import contextlib
from collections.abc import Callable, Iterable, Sequence
from typing import TypeAlias, Union
from typing import Literal, TypeAlias, Union, overload

import astropy.units as u
import numpy as np
Expand Down Expand Up @@ -314,41 +314,79 @@ def insert(self, index, particle: ParticleLike) -> None:
particle = Particle(particle)
self.data.insert(index, particle)

@overload
def is_category(
self,
*category_tuple,
require: str | Iterable[str] | None = None,
any_of: str | Iterable[str] | None = None,
exclude: str | Iterable[str] | None = None,
) -> list[bool]:
particlewise: Literal[True] = ...,
) -> bool: ...

@overload
def is_category(
self,
*category_tuple,
require: str | Iterable[str] | None = None,
any_of: str | Iterable[str] | None = None,
exclude: str | Iterable[str] | None = None,
particlewise: Literal[False],
) -> list[bool]: ...

def is_category(
self,
*category_tuple,
require: str | Iterable[str] | None = None,
any_of: str | Iterable[str] | None = None,
exclude: str | Iterable[str] | None = None,
particlewise: bool = False,
) -> bool | list[bool]:
"""
Determine element-wise if the particles in the |ParticleList|
Determine if the particles in the |ParticleList|
meet categorization criteria.
Return a `list` in which each element will be `True` if the
corresponding particle is consistent with the categorization
criteria, and `False` otherwise.
Please refer to the documentation of
`~plasmapy.particles.particle_class.Particle.is_category`
for information on the parameters and categories, as well as
more extensive examples.
Parameters
----------
particlewise : `bool`, default: `False`
If `True`, return a `list` of `bool` in which an element will be `True`
if the corresponding particle is consistent with the categorization
criteria, and `False` otherwise. If `False`, return a `bool` which
will be `True` if all particles are consistent with the categorization
criteria and `False` otherwise.
Returns
-------
`list` of `bool`
`bool` or `list` of `bool`
See Also
--------
`~plasmapy.particles.particle_class.Particle.is_category`
Examples
--------
>>> particles = ParticleList(["proton", "electron", "tau neutrino"])
>>> particles.is_category("lepton")
False
>>> particles.is_category("lepton", particlewise=True)
[False, True, True]
>>> particles.is_category(require="lepton", exclude="neutrino")
False
>>> particles.is_category(
... require="lepton", exclude="neutrino", particlewise=True
... )
[False, True, False]
>>> particles.is_category(any_of=["lepton", "charged"])
True
>>> particles.is_category(any_of=["lepton", "charged"], particlewise=True)
[True, True, True]
"""
return [
category_list = [
particle.is_category(
*category_tuple,
require=require,
Expand All @@ -357,6 +395,7 @@ def is_category(
)
for particle in self
]
return category_list if particlewise else all(category_list)

@property
def charge_number(self) -> np.array:
Expand Down
38 changes: 34 additions & 4 deletions tests/particles/test_particle_collections.py
Expand Up @@ -338,27 +338,57 @@ def test_particle_multiplication(method, particle) -> None:
[
["electron", "proton", "neutron"],
["lepton"],
{},
{"particlewise": True},
[True, False, False],
],
[
["electron", "proton", "neutron"],
[],
{"require": "lepton"},
{"require": "lepton", "particlewise": True},
[True, False, False],
],
[
["electron", "proton", "neutron"],
[],
{"exclude": "lepton"},
{"exclude": "lepton", "particlewise": True},
[False, True, True],
],
[
["electron", "proton", "neutron"],
[],
{"any_of": {"lepton", "charged"}},
{"any_of": {"lepton", "charged"}, "particlewise": True},
[True, True, False],
],
[
["electron", "proton", "neutron"],
["lepton"],
{},
False,
],
[
["electron", "proton", "neutron"],
[],
{"require": "lepton"},
False,
],
[
["electron", "proton", "neutron"],
[],
{"exclude": "lepton"},
False,
],
[
["electron", "proton", "neutron"],
[],
{"any_of": {"lepton", "charged"}},
False,
],
[
["electron", "proton", "tau neutrino"],
[],
{"any_of": {"lepton", "charged"}},
True,
],
],
)
def test_particle_list_is_category(particles, args, kwargs, expected) -> None:
Expand Down

0 comments on commit ca92246

Please sign in to comment.