diff --git a/tests/smoke/integrate/test_package.py b/tests/smoke/integrate/test_package.py index 4227a3d7..216e5323 100644 --- a/tests/smoke/integrate/test_package.py +++ b/tests/smoke/integrate/test_package.py @@ -4,6 +4,6 @@ from galax.integrate import _base, _builtin -def test_all(): +def test_all() -> None: """Test the API.""" assert set(integrate.__all__) == set(_base.__all__ + _builtin.__all__) diff --git a/tests/smoke/potential/test_package.py b/tests/smoke/potential/test_package.py index 5ffb9f5b..12a670a5 100644 --- a/tests/smoke/potential/test_package.py +++ b/tests/smoke/potential/test_package.py @@ -2,7 +2,7 @@ from galax.potential import _potential -def test_all(): +def test_all() -> None: """Test the `galax.potential` package contents.""" # Test correct dumping of contents assert gp.__all__ == _potential.__all__ diff --git a/tests/smoke/utils/test_package.py b/tests/smoke/utils/test_package.py index e2b2c110..3c58d9b7 100644 --- a/tests/smoke/utils/test_package.py +++ b/tests/smoke/utils/test_package.py @@ -1,20 +1,10 @@ -import pytest - from galax import utils -def test__all__(): +def test__all__() -> None: """Test that `galax.utils` has the expected `__all__`.""" assert utils.__all__ == [ "dataclasses", *utils._jax.__all__, *utils._collections.__all__, ] - - -@pytest.mark.skip(reason="TODO") -def test_public_modules(): - """Test which modules are publicly importable.""" - # IDK how to discover all submodules of a package, even if they aren't - # imported without relying on the filesystem. The filesystem is generally - # safe, but I'd rather solve this generically. Low priority. diff --git a/tests/unit/potential/builtin/test_bar.py b/tests/unit/potential/builtin/test_bar.py index 30b1c1e4..8e30acbd 100644 --- a/tests/unit/potential/builtin/test_bar.py +++ b/tests/unit/potential/builtin/test_bar.py @@ -1,10 +1,13 @@ from typing import Any +import astropy.units as u import jax.experimental.array_api as xp import jax.numpy as jnp import pytest -import galax.potential as gp +from galax.potential import BarPotential +from galax.typing import Vec3 +from galax.units import UnitSystem from ..test_core import TestAbstractPotential as AbstractPotential_Test from .test_common import ( @@ -24,8 +27,8 @@ class TestBarPotential( ShapeCParameterMixin, ): @pytest.fixture(scope="class") - def pot_cls(self) -> type[gp.BarPotential]: - return gp.BarPotential + def pot_cls(self) -> type[BarPotential]: + return BarPotential @pytest.fixture(scope="class") def field_Omega(self) -> dict[str, Any]: @@ -33,7 +36,13 @@ def field_Omega(self) -> dict[str, Any]: @pytest.fixture(scope="class") def fields_( - self, field_m, field_a, field_b, field_c, field_Omega, field_units + self, + field_m: u.Quantity, + field_a: u.Quantity, + field_b: u.Quantity, + field_c: u.Quantity, + field_Omega: u.Quantity, + field_units: UnitSystem, ) -> dict[str, Any]: return { "m": field_m, @@ -46,18 +55,18 @@ def fields_( # ========================================================================== - def test_potential_energy(self, pot, x) -> None: + def test_potential_energy(self, pot: BarPotential, x: Vec3) -> None: assert jnp.isclose(pot.potential_energy(x, t=0), xp.asarray(-0.94601574)) - def test_gradient(self, pot, x): + def test_gradient(self, pot: BarPotential, x: Vec3) -> None: assert jnp.allclose( pot.gradient(x, t=0), xp.asarray([0.04011905, 0.08383918, 0.16552719]) ) - def test_density(self, pot, x): + def test_density(self, pot: BarPotential, x: Vec3) -> None: assert jnp.isclose(pot.density(x, t=0), 1.94669274e08) - def test_hessian(self, pot, x): + def test_hessian(self, pot: BarPotential, x: Vec3) -> None: assert jnp.allclose( pot.hessian(x, t=0), xp.asarray( diff --git a/tests/unit/potential/builtin/test_hernquist.py b/tests/unit/potential/builtin/test_hernquist.py index e2730283..afec23ca 100644 --- a/tests/unit/potential/builtin/test_hernquist.py +++ b/tests/unit/potential/builtin/test_hernquist.py @@ -5,6 +5,7 @@ import pytest from galax.potential import HernquistPotential +from galax.typing import Vec3 from ..test_core import TestAbstractPotential as AbstractPotential_Test from .test_common import MassParameterMixin, ShapeCParameterMixin @@ -26,18 +27,18 @@ def fields_(self, field_m, field_c, field_units) -> dict[str, Any]: # ========================================================================== - def test_potential_energy(self, pot, x) -> None: + def test_potential_energy(self, pot: HernquistPotential, x: Vec3) -> None: assert jnp.isclose(pot.potential_energy(x, t=0), xp.asarray(-0.94871936)) - def test_gradient(self, pot, x): + def test_gradient(self, pot: HernquistPotential, x: Vec3) -> None: assert jnp.allclose( pot.gradient(x, t=0), xp.asarray([0.05347411, 0.10694822, 0.16042233]) ) - def test_density(self, pot, x): + def test_density(self, pot: HernquistPotential, x: Vec3) -> None: assert jnp.isclose(pot.density(x, t=0), 3.989933e08) - def test_hessian(self, pot, x): + def test_hessian(self, pot: HernquistPotential, x: Vec3) -> None: assert jnp.allclose( pot.hessian(x, t=0), xp.asarray( diff --git a/tests/unit/potential/builtin/test_isochrone.py b/tests/unit/potential/builtin/test_isochrone.py index 29f3b534..0dc6c712 100644 --- a/tests/unit/potential/builtin/test_isochrone.py +++ b/tests/unit/potential/builtin/test_isochrone.py @@ -5,6 +5,8 @@ import pytest import galax.potential as gp +from galax.potential import IsochronePotential +from galax.typing import Vec3 from ..test_core import TestAbstractPotential as AbstractPotential_Test from .test_common import MassParameterMixin, ShapeBParameterMixin @@ -26,18 +28,18 @@ def fields_(self, field_m, field_b, field_units) -> dict[str, Any]: # ========================================================================== - def test_potential_energy(self, pot, x) -> None: + def test_potential_energy(self, pot: IsochronePotential, x: Vec3) -> None: assert jnp.isclose(pot.potential_energy(x, t=0), xp.asarray(-0.9231515)) - def test_gradient(self, pot, x): + def test_gradient(self, pot: IsochronePotential, x: Vec3) -> None: assert jnp.allclose( pot.gradient(x, t=0), xp.asarray([0.04891392, 0.09782784, 0.14674175]) ) - def test_density(self, pot, x): + def test_density(self, pot: IsochronePotential, x: Vec3) -> None: assert jnp.isclose(pot.density(x, t=0), 5.04511665e08) - def test_hessian(self, pot, x): + def test_hessian(self, pot: IsochronePotential, x: Vec3) -> None: assert jnp.allclose( pot.hessian(x, t=0), xp.asarray( diff --git a/tests/unit/potential/builtin/test_miyamotonagai.py b/tests/unit/potential/builtin/test_miyamotonagai.py index 8963b906..4f6cf2a5 100644 --- a/tests/unit/potential/builtin/test_miyamotonagai.py +++ b/tests/unit/potential/builtin/test_miyamotonagai.py @@ -1,10 +1,14 @@ from typing import Any +import astropy.units as u import jax.experimental.array_api as xp import jax.numpy as jnp import pytest import galax.potential as gp +from galax.potential import MiyamotoNagaiPotential +from galax.typing import Vec3 +from galax.units import UnitSystem from ..test_core import TestAbstractPotential as AbstractPotential_Test from .test_common import MassParameterMixin, ShapeAParameterMixin, ShapeBParameterMixin @@ -24,23 +28,29 @@ def pot_cls(self) -> type[gp.MiyamotoNagaiPotential]: return gp.MiyamotoNagaiPotential @pytest.fixture(scope="class") - def fields_(self, field_m, field_a, field_b, field_units) -> dict[str, Any]: + def fields_( + self, + field_m: u.Quantity, + field_a: u.Quantity, + field_b: u.Quantity, + field_units: UnitSystem, + ) -> dict[str, Any]: return {"m": field_m, "a": field_a, "b": field_b, "units": field_units} # ========================================================================== - def test_potential_energy(self, pot, x) -> None: + def test_potential_energy(self, pot: MiyamotoNagaiPotential, x: Vec3) -> None: assert jnp.isclose(pot.potential_energy(x, t=0), xp.asarray(-0.95208676)) - def test_gradient(self, pot, x): + def test_gradient(self, pot: MiyamotoNagaiPotential, x: Vec3) -> None: assert jnp.allclose( pot.gradient(x, t=0), xp.asarray([0.04264751, 0.08529503, 0.16840152]) ) - def test_density(self, pot, x): + def test_density(self, pot: MiyamotoNagaiPotential, x: Vec3) -> None: assert jnp.isclose(pot.density(x, t=0), 1.9949418e08) - def test_hessian(self, pot, x): + def test_hessian(self, pot: MiyamotoNagaiPotential, x: Vec3) -> None: assert jnp.allclose( pot.hessian(x, t=0), xp.asarray( diff --git a/tests/unit/potential/builtin/test_mwpotential.py b/tests/unit/potential/builtin/test_mwpotential.py index 300bc889..d8e0abc1 100644 --- a/tests/unit/potential/builtin/test_mwpotential.py +++ b/tests/unit/potential/builtin/test_mwpotential.py @@ -4,6 +4,7 @@ import pytest from galax.potential import MilkyWayPotential +from galax.typing import Vec3 from galax.units import galactic from ..test_core import TestAbstractPotential @@ -34,18 +35,18 @@ def test_init_units_from_args(self, pot_cls, fields_unitless): # ========================================================================== - def test_potential_energy(self, pot, x) -> None: + def test_potential_energy(self, pot: MilkyWayPotential, x: Vec3) -> None: assert xp.isclose(pot.potential_energy(x, t=0), xp.array(-0.19386052)) - def test_gradient(self, pot, x): + def test_gradient(self, pot: MilkyWayPotential, x: Vec3) -> None: assert xp.allclose( pot.gradient(x, t=0), xp.array([0.00256403, 0.00512806, 0.01115272]) ) - def test_density(self, pot, x): + def test_density(self, pot: MilkyWayPotential, x: Vec3) -> None: assert xp.isclose(pot.density(x, t=0), 33_365_858.46361218) - def test_hessian(self, pot, x): + def test_hessian(self, pot: MilkyWayPotential, x: Vec3) -> None: assert xp.allclose( pot.hessian(x, t=0), xp.array( diff --git a/tests/unit/potential/builtin/test_nfw.py b/tests/unit/potential/builtin/test_nfw.py index a5d88d77..38b33ef0 100644 --- a/tests/unit/potential/builtin/test_nfw.py +++ b/tests/unit/potential/builtin/test_nfw.py @@ -8,9 +8,9 @@ from typing_extensions import override import galax.potential as gp -from galax.potential import ConstantParameter +from galax.potential import AbstractPotential, ConstantParameter, NFWPotential from galax.typing import Vec3 -from galax.units import galactic +from galax.units import UnitSystem, galactic from galax.utils._optional_deps import HAS_GALA from ..param.test_field import ParameterFieldMixin @@ -21,7 +21,7 @@ class ScaleRadiusParameterMixin(ParameterFieldMixin): """Test the mass parameter.""" - pot_cls: type[gp.AbstractPotential] + pot_cls: type[AbstractPotential] @pytest.fixture(scope="class") def field_r_s(self) -> float: @@ -29,7 +29,9 @@ def field_r_s(self) -> float: # ===================================================== - def test_r_s_units(self, pot_cls, fields): + def test_r_s_units( + self, pot_cls: type[AbstractPotential], fields: dict[str, Any] + ) -> None: """Test the mass parameter.""" fields["r_s"] = 1.0 * u.Unit(10 * u.kpc) fields["units"] = galactic @@ -37,14 +39,18 @@ def test_r_s_units(self, pot_cls, fields): assert isinstance(pot.r_s, ConstantParameter) assert jnp.isclose(pot.r_s.value, 10) - def test_r_s_constant(self, pot_cls, fields): + def test_r_s_constant( + self, pot_cls: type[AbstractPotential], fields: dict[str, Any] + ): """Test the mass parameter.""" fields["r_s"] = 1.0 pot = pot_cls(**fields) assert pot.r_s(t=0) == 1.0 @pytest.mark.xfail(reason="TODO: user function doesn't have units") - def test_r_s_userfunc(self, pot_cls, fields): + def test_r_s_userfunc( + self, pot_cls: type[AbstractPotential], fields: dict[str, Any] + ): """Test the mass parameter.""" fields["r_s"] = lambda t: t + 2 pot = pot_cls(**fields) @@ -62,8 +68,8 @@ class TestNFWPotential( ): @pytest.fixture(scope="class") @override - def pot_cls(self) -> type[gp.NFWPotential]: - return gp.NFWPotential + def pot_cls(self) -> type[NFWPotential]: + return NFWPotential @pytest.fixture(scope="class") def field_softening_length(self) -> float: @@ -72,7 +78,11 @@ def field_softening_length(self) -> float: @pytest.fixture(scope="class") @override def fields_( - self, field_m, field_r_s, field_softening_length, field_units + self, + field_m: u.Quantity, + field_r_s: u.Quantity, + field_softening_length: float, + field_units: UnitSystem, ) -> dict[str, Any]: return { "m": field_m, @@ -83,18 +93,18 @@ def fields_( # ========================================================================== - def test_potential_energy(self, pot, x) -> None: + def test_potential_energy(self, pot: NFWPotential, x: Vec3) -> None: assert jnp.isclose(pot.potential_energy(x, t=0), xp.asarray(-1.87117234)) - def test_gradient(self, pot, x): + def test_gradient(self, pot: NFWPotential, x: Vec3) -> None: assert jnp.allclose( pot.gradient(x, t=0), xp.asarray([0.0658867, 0.1317734, 0.19766011]) ) - def test_density(self, pot, x): + def test_density(self, pot: NFWPotential, x: Vec3) -> None: assert jnp.isclose(pot.density(x, t=0), 9.46039849e08) - def test_hessian(self, pot, x): + def test_hessian(self, pot: NFWPotential, x: Vec3) -> None: assert jnp.allclose( pot.hessian(x, t=0), xp.asarray( @@ -110,9 +120,7 @@ def test_hessian(self, pot, x): # I/O @pytest.mark.skipif(not HAS_GALA, reason="requires gala") - def test_galax_to_gala_to_galax_roundtrip( - self, pot: gp.AbstractPotentialBase, x: Vec3 - ) -> None: + def test_galax_to_gala_to_galax_roundtrip(self, pot: NFWPotential, x: Vec3) -> None: """Test roundtripping ``gala_to_galax(galax_to_gala())``.""" from ..io.gala_helper import galax_to_gala diff --git a/tests/unit/potential/builtin/test_null.py b/tests/unit/potential/builtin/test_null.py index fc27ac4a..7061b562 100644 --- a/tests/unit/potential/builtin/test_null.py +++ b/tests/unit/potential/builtin/test_null.py @@ -4,32 +4,38 @@ import jax.numpy as jnp import pytest -import galax.potential as gp +from galax.potential import NullPotential +from galax.typing import Vec3 +from galax.units import UnitSystem from ..test_core import TestAbstractPotential as AbstractPotential_Test class TestNullPotential(AbstractPotential_Test): @pytest.fixture(scope="class") - def pot_cls(self) -> type[gp.NullPotential]: - return gp.NullPotential + def pot_cls(self) -> type[NullPotential]: + return NullPotential @pytest.fixture(scope="class") - def fields_(self, field_units) -> dict[str, Any]: + def fields_(self, field_units: UnitSystem) -> dict[str, Any]: return {"units": field_units} # ========================================================================== - def test_potential_energy(self, pot, x) -> None: + def test_potential_energy(self, pot: NullPotential, x: Vec3) -> None: + """Test :meth:`NullPotential.potential_energy`.""" assert jnp.isclose(pot.potential_energy(x, t=0), xp.asarray(0.0)) - def test_gradient(self, pot, x): + def test_gradient(self, pot: NullPotential, x: Vec3) -> None: + """Test :meth:`NullPotential.gradient`.""" assert jnp.allclose(pot.gradient(x, t=0), xp.asarray([0.0, 0.0, 0.0])) - def test_density(self, pot, x): + def test_density(self, pot: NullPotential, x: Vec3) -> None: + """Test :meth:`NullPotential.density`.""" assert jnp.isclose(pot.density(x, t=0), 0.0) - def test_hessian(self, pot, x): + def test_hessian(self, pot: NullPotential, x: Vec3) -> None: + """Test :meth:`NullPotential.hessian`.""" assert jnp.allclose( pot.hessian(x, t=0), xp.asarray([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), diff --git a/tests/unit/potential/test_composite.py b/tests/unit/potential/test_composite.py index 0c579912..d1677bf5 100644 --- a/tests/unit/potential/test_composite.py +++ b/tests/unit/potential/test_composite.py @@ -7,7 +7,14 @@ import pytest from typing_extensions import override -import galax.potential as gp +from galax.potential import ( + AbstractPotentialBase, + CompositePotential, + KeplerPotential, + MiyamotoNagaiPotential, + NFWPotential, +) +from galax.typing import Vec3 from galax.units import UnitSystem, dimensionless, galactic, solarsystem from galax.utils._misc import first @@ -24,18 +31,18 @@ class TestCompositePotential(AbstractCompositePotential_Test): """Test the `galax.potential.CompositePotential` class.""" @pytest.fixture(scope="class") - def pot_cls(self) -> type[gp.CompositePotential]: + def pot_cls(self) -> type[CompositePotential]: """Composite potential class.""" - return gp.CompositePotential + return CompositePotential @pytest.fixture(scope="class") - def pot_map(self) -> Mapping[str, gp.AbstractPotentialBase]: + def pot_map(self) -> Mapping[str, AbstractPotentialBase]: """Composite potential.""" return { - "disk": gp.MiyamotoNagaiPotential( + "disk": MiyamotoNagaiPotential( m=1e10 * u.solMass, a=6.5 * u.kpc, b=4.5 * u.kpc, units=galactic ), - "halo": gp.NFWPotential( + "halo": NFWPotential( m=1e12 * u.solMass, r_s=5 * u.kpc, softening_length=0, units=galactic ), } @@ -43,18 +50,18 @@ def pot_map(self) -> Mapping[str, gp.AbstractPotentialBase]: @pytest.fixture(scope="class") def pot( self, - pot_cls: type[gp.CompositePotential], - pot_map: Mapping[str, gp.AbstractPotentialBase], - ) -> gp.CompositePotential: + pot_cls: type[CompositePotential], + pot_map: Mapping[str, AbstractPotentialBase], + ) -> CompositePotential: """Composite potential.""" return pot_cls(**pot_map) @pytest.fixture(scope="class") - def pot_map_unitless(self) -> Mapping[str, gp.AbstractPotentialBase]: + def pot_map_unitless(self) -> Mapping[str, AbstractPotentialBase]: """Composite potential.""" return { - "disk": gp.MiyamotoNagaiPotential(m=1e10, a=6.5, b=4.5, units=None), - "halo": gp.NFWPotential(m=1e12, r_s=5, softening_length=0, units=None), + "disk": MiyamotoNagaiPotential(m=1e10, a=6.5, b=4.5, units=None), + "halo": NFWPotential(m=1e12, r_s=5, softening_length=0, units=None), } # ========================================================================== @@ -64,8 +71,8 @@ def pot_map_unitless(self) -> Mapping[str, gp.AbstractPotentialBase]: @override def test_init_units_invalid( self, - pot_cls: type[gp.CompositePotential], - pot_map: Mapping[str, gp.AbstractPotentialBase], + pot_cls: type[CompositePotential], + pot_map: Mapping[str, AbstractPotentialBase], ) -> None: """Test invalid unit system.""" # TODO: raise a specific error. The type depends on whether beartype is @@ -76,8 +83,8 @@ def test_init_units_invalid( @override def test_init_units_from_usys( self, - pot_cls: type[gp.CompositePotential], - pot_map: Mapping[str, gp.AbstractPotentialBase], + pot_cls: type[CompositePotential], + pot_map: Mapping[str, AbstractPotentialBase], ) -> None: """Test unit system from UnitSystem.""" usys = UnitSystem(u.km, u.s, u.Msun, u.radian) @@ -87,8 +94,8 @@ def test_init_units_from_usys( @override def test_init_units_from_args( self, - pot_cls: type[gp.CompositePotential], - pot_map_unitless: Mapping[str, gp.AbstractPotentialBase], + pot_cls: type[CompositePotential], + pot_map_unitless: Mapping[str, AbstractPotentialBase], ) -> None: """Test unit system from None.""" pot = pot_cls(**pot_map_unitless, units=None) @@ -97,8 +104,8 @@ def test_init_units_from_args( @override def test_init_units_from_tuple( self, - pot_cls: type[gp.CompositePotential], - pot_map: Mapping[str, gp.AbstractPotentialBase], + pot_cls: type[CompositePotential], + pot_map: Mapping[str, AbstractPotentialBase], ) -> None: """Test unit system from tuple.""" units = (u.km, u.s, u.Msun, u.radian) @@ -108,9 +115,9 @@ def test_init_units_from_tuple( @override def test_init_units_from_name( self, - pot_cls: type[gp.CompositePotential], - pot_map: Mapping[str, gp.AbstractPotentialBase], - pot_map_unitless: Mapping[str, gp.AbstractPotentialBase], + pot_cls: type[CompositePotential], + pot_map: Mapping[str, AbstractPotentialBase], + pot_map_unitless: Mapping[str, AbstractPotentialBase], ) -> None: """Test unit system from named string.""" units = "dimensionless" @@ -137,31 +144,31 @@ def test_init_units_from_name( # -------------------------- # `__or__` - def test_or_incorrect(self, pot): + def test_or_incorrect(self, pot: CompositePotential) -> None: """Test the `__or__` method with incorrect inputs.""" with pytest.raises(TypeError, match="unsupported operand type"): _ = pot | 1 - def test_or_pot(self, pot: gp.CompositePotential) -> None: + def test_or_pot(self, pot: CompositePotential) -> None: """Test the `__or__` method with a single potential.""" - single_pot = gp.KeplerPotential(m=1e12 * u.solMass, units=galactic) + single_pot = KeplerPotential(m=1e12 * u.solMass, units=galactic) newpot = pot | single_pot - assert isinstance(newpot, gp.CompositePotential) + assert isinstance(newpot, CompositePotential) newkey, newvalue = tuple(newpot.items())[-1] assert isinstance(newkey, str) assert newvalue is single_pot - def test_or_compot(self, pot: gp.CompositePotential) -> None: + def test_or_compot(self, pot: CompositePotential) -> None: """Test the `__or__` method with a composite potential.""" - comp_pot = gp.CompositePotential( - kep1=gp.KeplerPotential(m=1e12 * u.solMass, units=galactic), - kep2=gp.KeplerPotential(m=1e12 * u.solMass, units=galactic), + comp_pot = CompositePotential( + kep1=KeplerPotential(m=1e12 * u.solMass, units=galactic), + kep2=KeplerPotential(m=1e12 * u.solMass, units=galactic), ) newpot = pot | comp_pot - assert isinstance(newpot, gp.CompositePotential) + assert isinstance(newpot, CompositePotential) newkey, newvalue = tuple(newpot.items())[-2] assert newkey == "kep1" @@ -174,31 +181,31 @@ def test_or_compot(self, pot: gp.CompositePotential) -> None: # -------------------------- # `__ror__` - def test_ror_incorrect(self, pot): + def test_ror_incorrect(self, pot: CompositePotential) -> None: """Test the `__or__` method with incorrect inputs.""" with pytest.raises(TypeError, match="unsupported operand type"): _ = 1 | pot - def test_ror_pot(self, pot: gp.CompositePotential) -> None: + def test_ror_pot(self, pot: CompositePotential) -> None: """Test the `__ror__` method with a single potential.""" - single_pot = gp.KeplerPotential(m=1e12 * u.solMass, units=galactic) + single_pot = KeplerPotential(m=1e12 * u.solMass, units=galactic) newpot = single_pot | pot - assert isinstance(newpot, gp.CompositePotential) + assert isinstance(newpot, CompositePotential) newkey, newvalue = first(newpot.items()) assert isinstance(newkey, str) assert newvalue is single_pot - def test_ror_compot(self, pot: gp.CompositePotential) -> None: + def test_ror_compot(self, pot: CompositePotential) -> None: """Test the `__ror__` method with a composite potential.""" - comp_pot = gp.CompositePotential( - kep1=gp.KeplerPotential(m=1e12 * u.solMass, units=galactic), - kep2=gp.KeplerPotential(m=1e12 * u.solMass, units=galactic), + comp_pot = CompositePotential( + kep1=KeplerPotential(m=1e12 * u.solMass, units=galactic), + kep2=KeplerPotential(m=1e12 * u.solMass, units=galactic), ) newpot = comp_pot | pot - assert isinstance(newpot, gp.CompositePotential) + assert isinstance(newpot, CompositePotential) newkey, newvalue = first(newpot.items()) assert newkey == "kep1" @@ -211,32 +218,32 @@ def test_ror_compot(self, pot: gp.CompositePotential) -> None: # -------------------------- # `__add__` - def test_add_incorrect(self, pot): + def test_add_incorrect(self, pot: CompositePotential) -> None: """Test the `__add__` method with incorrect inputs.""" # TODO: specific error with pytest.raises(Exception): # noqa: B017, PT011 _ = pot + 1 - def test_add_pot(self, pot: gp.CompositePotential) -> None: + def test_add_pot(self, pot: CompositePotential) -> None: """Test the `__add__` method with a single potential.""" - single_pot = gp.KeplerPotential(m=1e12 * u.solMass, units=galactic) + single_pot = KeplerPotential(m=1e12 * u.solMass, units=galactic) newpot = pot + single_pot - assert isinstance(newpot, gp.CompositePotential) + assert isinstance(newpot, CompositePotential) newkey, newvalue = tuple(newpot.items())[-1] assert isinstance(newkey, str) assert newvalue is single_pot - def test_add_compot(self, pot: gp.CompositePotential) -> None: + def test_add_compot(self, pot: CompositePotential) -> None: """Test the `__add__` method with a composite potential.""" - comp_pot = gp.CompositePotential( - kep1=gp.KeplerPotential(m=1e12 * u.solMass, units=galactic), - kep2=gp.KeplerPotential(m=1e12 * u.solMass, units=galactic), + comp_pot = CompositePotential( + kep1=KeplerPotential(m=1e12 * u.solMass, units=galactic), + kep2=KeplerPotential(m=1e12 * u.solMass, units=galactic), ) newpot = pot + comp_pot - assert isinstance(newpot, gp.CompositePotential) + assert isinstance(newpot, CompositePotential) newkey, newvalue = tuple(newpot.items())[-2] assert newkey == "kep1" @@ -248,18 +255,18 @@ def test_add_compot(self, pot: gp.CompositePotential) -> None: # ========================================================================== - def test_potential_energy(self, pot, x) -> None: + def test_potential_energy(self, pot: CompositePotential, x: Vec3) -> None: assert jnp.isclose(pot.potential_energy(x, t=0), xp.asarray(-0.6753781)) - def test_gradient(self, pot, x): + def test_gradient(self, pot: CompositePotential, x: Vec3) -> None: assert jnp.allclose( pot.gradient(x, t=0), xp.asarray([0.01124388, 0.02248775, 0.03382281]) ) - def test_density(self, pot, x): + def test_density(self, pot: CompositePotential, x: Vec3) -> None: assert jnp.isclose(pot.density(x, t=0), 2.7958598e08) - def test_hessian(self, pot, x): + def test_hessian(self, pot: CompositePotential, x: Vec3) -> None: assert jnp.allclose( pot.hessian(x, t=0), xp.asarray( diff --git a/tests/unit/potential/test_core.py b/tests/unit/potential/test_core.py index c42860a1..5354699c 100644 --- a/tests/unit/potential/test_core.py +++ b/tests/unit/potential/test_core.py @@ -50,7 +50,7 @@ def fields_(self, field_units) -> dict[str, Any]: ########################################################################### - def test_init(self): + def test_init(self) -> None: """Test the initialization of `AbstractPotentialBase`.""" # Test that the abstract class cannot be instantiated with pytest.raises(TypeError): diff --git a/tests/unit/potential/test_special.py b/tests/unit/potential/test_special.py index f96b9b12..80d0d970 100644 --- a/tests/unit/potential/test_special.py +++ b/tests/unit/potential/test_special.py @@ -6,7 +6,13 @@ import pytest from typing_extensions import override -import galax.potential as gp +from galax.potential import ( + AbstractPotentialBase, + CompositePotential, + KeplerPotential, + MilkyWayPotential, +) +from galax.typing import Vec3 from galax.units import UnitSystem, dimensionless, galactic, solarsystem from galax.utils._misc import first @@ -17,12 +23,12 @@ class TestMilkyWayPotential(AbstractCompositePotential_Test): """Test the `galax.potential.CompositePotential` class.""" @pytest.fixture(scope="class") - def pot_cls(self) -> type[gp.MilkyWayPotential]: + def pot_cls(self) -> type[MilkyWayPotential]: """Composite potential class.""" - return gp.MilkyWayPotential + return MilkyWayPotential @pytest.fixture(scope="class") - def pot_map(self) -> Mapping[str, gp.AbstractPotentialBase]: + def pot_map(self) -> Mapping[str, AbstractPotentialBase]: """Composite potential.""" return { "disk": {"m": 6.8e10 * u.Msun, "a": 3.0 * u.kpc, "b": 0.28 * u.kpc}, @@ -34,14 +40,14 @@ def pot_map(self) -> Mapping[str, gp.AbstractPotentialBase]: @pytest.fixture(scope="class") def pot( self, - pot_cls: type[gp.CompositePotential], - pot_map: Mapping[str, gp.AbstractPotentialBase], - ) -> gp.CompositePotential: + pot_cls: type[MilkyWayPotential], + pot_map: Mapping[str, AbstractPotentialBase], + ) -> MilkyWayPotential: """Composite potential.""" return pot_cls(**pot_map) @pytest.fixture(scope="class") - def pot_map_unitless(self, pot_map) -> Mapping[str, gp.AbstractPotentialBase]: + def pot_map_unitless(self, pot_map) -> Mapping[str, AbstractPotentialBase]: """Composite potential.""" return {k: {kk: vv.value for kk, vv in v.items()} for k, v in pot_map.items()} @@ -52,8 +58,8 @@ def pot_map_unitless(self, pot_map) -> Mapping[str, gp.AbstractPotentialBase]: @override def test_init_units_invalid( self, - pot_cls: type[gp.CompositePotential], - pot_map: Mapping[str, gp.AbstractPotentialBase], + pot_cls: type[MilkyWayPotential], + pot_map: Mapping[str, AbstractPotentialBase], ) -> None: """Test invalid unit system.""" # TODO: raise a specific error. The type depends on whether beartype is @@ -64,8 +70,8 @@ def test_init_units_invalid( @override def test_init_units_from_usys( self, - pot_cls: type[gp.CompositePotential], - pot_map: gp.MilkyWayPotential, + pot_cls: type[MilkyWayPotential], + pot_map: MilkyWayPotential, ) -> None: """Test unit system from UnitSystem.""" usys = UnitSystem(u.km, u.s, u.Msun, u.radian) @@ -74,8 +80,8 @@ def test_init_units_from_usys( @override def test_init_units_from_args( self, - pot_cls: type[gp.CompositePotential], - pot_map_unitless: Mapping[str, gp.AbstractPotentialBase], + pot_cls: type[MilkyWayPotential], + pot_map_unitless: Mapping[str, AbstractPotentialBase], ) -> None: """Test unit system from None.""" pot = pot_cls(**pot_map_unitless, units=None) @@ -84,8 +90,8 @@ def test_init_units_from_args( @override def test_init_units_from_tuple( self, - pot_cls: type[gp.CompositePotential], - pot_map: Mapping[str, gp.AbstractPotentialBase], + pot_cls: type[MilkyWayPotential], + pot_map: Mapping[str, AbstractPotentialBase], ) -> None: """Test unit system from tuple.""" units = (u.km, u.s, u.Msun, u.radian) @@ -94,9 +100,9 @@ def test_init_units_from_tuple( @override def test_init_units_from_name( self, - pot_cls: type[gp.CompositePotential], - pot_map: Mapping[str, gp.AbstractPotentialBase], - pot_map_unitless: Mapping[str, gp.AbstractPotentialBase], + pot_cls: type[MilkyWayPotential], + pot_map: Mapping[str, AbstractPotentialBase], + pot_map_unitless: Mapping[str, AbstractPotentialBase], ) -> None: """Test unit system from named string.""" units = "dimensionless" @@ -120,31 +126,31 @@ def test_init_units_from_name( # -------------------------- # `__or__` - def test_or_incorrect(self, pot): + def test_or_incorrect(self, pot: MilkyWayPotential) -> None: """Test the `__or__` method with incorrect inputs.""" with pytest.raises(TypeError, match="unsupported operand type"): _ = pot | 1 - def test_or_pot(self, pot: gp.CompositePotential) -> None: + def test_or_pot(self, pot: MilkyWayPotential) -> None: """Test the `__or__` method with a single potential.""" - single_pot = gp.KeplerPotential(m=1e12 * u.solMass, units=galactic) + single_pot = KeplerPotential(m=1e12 * u.solMass, units=galactic) newpot = pot | single_pot - assert isinstance(newpot, gp.CompositePotential) + assert isinstance(newpot, CompositePotential) newkey, newvalue = tuple(newpot.items())[-1] assert isinstance(newkey, str) assert newvalue is single_pot - def test_or_compot(self, pot: gp.CompositePotential) -> None: + def test_or_compot(self, pot: MilkyWayPotential) -> None: """Test the `__or__` method with a composite potential.""" - comp_pot = gp.CompositePotential( - kep1=gp.KeplerPotential(m=1e12 * u.solMass, units=galactic), - kep2=gp.KeplerPotential(m=1e12 * u.solMass, units=galactic), + comp_pot = CompositePotential( + kep1=KeplerPotential(m=1e12 * u.solMass, units=galactic), + kep2=KeplerPotential(m=1e12 * u.solMass, units=galactic), ) newpot = pot | comp_pot - assert isinstance(newpot, gp.CompositePotential) + assert isinstance(newpot, CompositePotential) newkey, newvalue = tuple(newpot.items())[-2] assert newkey == "kep1" @@ -157,31 +163,31 @@ def test_or_compot(self, pot: gp.CompositePotential) -> None: # -------------------------- # `__ror__` - def test_ror_incorrect(self, pot): + def test_ror_incorrect(self, pot: CompositePotential) -> None: """Test the `__or__` method with incorrect inputs.""" with pytest.raises(TypeError, match="unsupported operand type"): _ = 1 | pot - def test_ror_pot(self, pot: gp.CompositePotential) -> None: + def test_ror_pot(self, pot: CompositePotential) -> None: """Test the `__ror__` method with a single potential.""" - single_pot = gp.KeplerPotential(m=1e12 * u.solMass, units=galactic) + single_pot = KeplerPotential(m=1e12 * u.solMass, units=galactic) newpot = single_pot | pot - assert isinstance(newpot, gp.CompositePotential) + assert isinstance(newpot, CompositePotential) newkey, newvalue = first(newpot.items()) assert isinstance(newkey, str) assert newvalue is single_pot - def test_ror_compot(self, pot: gp.CompositePotential) -> None: + def test_ror_compot(self, pot: CompositePotential) -> None: """Test the `__ror__` method with a composite potential.""" - comp_pot = gp.CompositePotential( - kep1=gp.KeplerPotential(m=1e12 * u.solMass, units=galactic), - kep2=gp.KeplerPotential(m=1e12 * u.solMass, units=galactic), + comp_pot = CompositePotential( + kep1=KeplerPotential(m=1e12 * u.solMass, units=galactic), + kep2=KeplerPotential(m=1e12 * u.solMass, units=galactic), ) newpot = comp_pot | pot - assert isinstance(newpot, gp.CompositePotential) + assert isinstance(newpot, CompositePotential) newkey, newvalue = first(newpot.items()) assert newkey == "kep1" @@ -194,32 +200,32 @@ def test_ror_compot(self, pot: gp.CompositePotential) -> None: # -------------------------- # `__add__` - def test_add_incorrect(self, pot): + def test_add_incorrect(self, pot: CompositePotential) -> None: """Test the `__add__` method with incorrect inputs.""" # TODO: specific error with pytest.raises(Exception): # noqa: B017, PT011 _ = pot + 1 - def test_add_pot(self, pot: gp.CompositePotential) -> None: + def test_add_pot(self, pot: CompositePotential) -> None: """Test the `__add__` method with a single potential.""" - single_pot = gp.KeplerPotential(m=1e12 * u.solMass, units=galactic) + single_pot = KeplerPotential(m=1e12 * u.solMass, units=galactic) newpot = pot + single_pot - assert isinstance(newpot, gp.CompositePotential) + assert isinstance(newpot, CompositePotential) newkey, newvalue = tuple(newpot.items())[-1] assert isinstance(newkey, str) assert newvalue is single_pot - def test_add_compot(self, pot: gp.CompositePotential) -> None: + def test_add_compot(self, pot: CompositePotential) -> None: """Test the `__add__` method with a composite potential.""" - comp_pot = gp.CompositePotential( - kep1=gp.KeplerPotential(m=1e12 * u.solMass, units=galactic), - kep2=gp.KeplerPotential(m=1e12 * u.solMass, units=galactic), + comp_pot = CompositePotential( + kep1=KeplerPotential(m=1e12 * u.solMass, units=galactic), + kep2=KeplerPotential(m=1e12 * u.solMass, units=galactic), ) newpot = pot + comp_pot - assert isinstance(newpot, gp.CompositePotential) + assert isinstance(newpot, CompositePotential) newkey, newvalue = tuple(newpot.items())[-2] assert newkey == "kep1" @@ -231,18 +237,22 @@ def test_add_compot(self, pot: gp.CompositePotential) -> None: # ========================================================================== - def test_potential_energy(self, pot, x) -> None: + def test_potential_energy(self, pot: MilkyWayPotential, x: Vec3) -> None: + """Test the :meth:`MilkyWayPotential.potential_energy` method.""" assert jnp.isclose(pot.potential_energy(x, t=0), xp.asarray(-0.19386052)) - def test_gradient(self, pot, x): + def test_gradient(self, pot: MilkyWayPotential, x: Vec3) -> None: + """Test the :meth:`MilkyWayPotential.gradient` method.""" assert jnp.allclose( pot.gradient(x, t=0), xp.asarray([0.00256403, 0.00512806, 0.01115272]) ) - def test_density(self, pot, x): + def test_density(self, pot: MilkyWayPotential, x: Vec3) -> None: + """Test the :meth:`MilkyWayPotential.density` method.""" assert jnp.isclose(pot.density(x, t=0), 33365858.46361218) - def test_hessian(self, pot, x): + def test_hessian(self, pot: MilkyWayPotential, x: Vec3) -> None: + """Test the :meth:`MilkyWayPotential.hessian` method.""" assert jnp.allclose( pot.hessian(x, t=0), xp.asarray( diff --git a/tests/unit/potential/test_utils.py b/tests/unit/potential/test_utils.py index 09390703..1b0588be 100644 --- a/tests/unit/potential/test_utils.py +++ b/tests/unit/potential/test_utils.py @@ -1,10 +1,13 @@ """Tests for `galax.potential._potential.utils` package.""" from dataclasses import replace +from typing import Any import astropy.units as u import pytest +from jax import Array +from galax.potential import AbstractPotentialBase from galax.potential._potential.utils import ( UnitSystem, converter_to_usys, @@ -18,26 +21,26 @@ class TestConverterToUtils: """Tests for `galax.potential._potential.utils.converter_to_usys`.""" - def test_invalid(self): + def test_invalid(self) -> None: """Test conversion from unsupported value.""" with pytest.raises(NotImplementedError): converter_to_usys(1234567890) - def test_from_usys(self): + def test_from_usys(self) -> None: """Test conversion from UnitSystem.""" usys = UnitSystem(u.km, u.s, u.Msun, u.radian) assert converter_to_usys(usys) == usys - def test_from_none(self): + def test_from_none(self) -> None: """Test conversion from None.""" assert converter_to_usys(None) == dimensionless - def test_from_args(self): + def test_from_args(self) -> None: """Test conversion from tuple.""" value = UnitSystem(u.km, u.s, u.Msun, u.radian) assert converter_to_usys(value) == value - def test_from_name(self): + def test_from_name(self) -> None: """Test conversion from named string.""" assert converter_to_usys("dimensionless") == dimensionless assert converter_to_usys("solarsystem") == solarsystem @@ -47,7 +50,7 @@ def test_from_name(self): converter_to_usys("invalid_value") @pytest.mark.skipif(not HAS_GALA, reason="requires gala") - def test_from_gala(self): + def test_from_gala(self) -> None: """Test conversion from gala.""" # ------------------------------- # UnitSystem @@ -71,7 +74,7 @@ class FieldUnitSystemMixin: """Mixin for testing the ``units`` field on a ``Potential``.""" @pytest.fixture() - def fields_unitless(self, fields): + def fields_unitless(self, fields: dict[str, Any]) -> dict[str, Array]: """Fields with no units.""" return { k: (v.value if isinstance(v, u.Quantity) else v) for k, v in fields.items() @@ -79,18 +82,20 @@ def fields_unitless(self, fields): # =========================================== - def test_init_units_invalid(self, pot): + def test_init_units_invalid(self, pot: AbstractPotentialBase) -> None: """Test invalid unit system.""" msg = "cannot convert 1234567890 to a UnitSystem" with pytest.raises(NotImplementedError, match=msg): replace(pot, units=1234567890) - def test_init_units_from_usys(self, pot): + def test_init_units_from_usys(self, pot: AbstractPotentialBase) -> None: """Test unit system from UnitSystem.""" usys = UnitSystem(u.km, u.s, u.Msun, u.radian) assert replace(pot, units=usys).units == usys - def test_init_units_from_args(self, pot_cls, fields_unitless): + def test_init_units_from_args( + self, pot_cls: type[AbstractPotentialBase], fields_unitless: dict[str, Array] + ) -> None: """Test unit system from None.""" # strip the units from the fields otherwise the test will fail # because the units are not equal and we just want to check that @@ -100,12 +105,14 @@ def test_init_units_from_args(self, pot_cls, fields_unitless): pot = pot_cls(**fields_unitless, units=None) assert pot.units == dimensionless - def test_init_units_from_tuple(self, pot): + def test_init_units_from_tuple(self, pot: AbstractPotentialBase) -> None: """Test unit system from tuple.""" units = (u.km, u.s, u.Msun, u.radian) assert replace(pot, units=units).units == UnitSystem(*units) - def test_init_units_from_name(self, pot_cls, fields_unitless): + def test_init_units_from_name( + self, pot_cls: type[AbstractPotentialBase], fields_unitless: dict[str, Array] + ) -> None: """Test unit system from named string.""" fields_unitless.pop("units") diff --git a/tests/unit/test_units.py b/tests/unit/test_units.py index 3b940e87..90bb2908 100644 --- a/tests/unit/test_units.py +++ b/tests/unit/test_units.py @@ -1,91 +1,99 @@ -# Standard library import pickle +from pathlib import Path -# Third party import astropy.units as u import numpy as np import pytest -# This package from galax.units import UnitSystem, dimensionless -def test_init(): - usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun) +class TestUnitSystem: + """Test :class:`~galax.units.UnitSystem`.""" - with pytest.raises( - ValueError, match="must specify a unit for the physical type .*mass" - ): - UnitSystem(u.kpc, u.Myr, u.radian) # no mass + def test_constructor(self) -> None: + """Test the :class:`~galax.units.UnitSystem` constructor.""" + usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun) - with pytest.raises( - ValueError, match="must specify a unit for the physical type .*angle" - ): - UnitSystem(u.kpc, u.Myr, u.Msun) + match = "must specify a unit for the physical type .*mass" + with pytest.raises(ValueError, match=match): + UnitSystem(u.kpc, u.Myr, u.radian) # no mass - with pytest.raises( - ValueError, match="must specify a unit for the physical type .*time" - ): - UnitSystem(u.kpc, u.radian, u.Msun) + match = "must specify a unit for the physical type .*angle" + with pytest.raises(ValueError, match=match): + UnitSystem(u.kpc, u.Myr, u.Msun) - with pytest.raises( - ValueError, match="must specify a unit for the physical type .*length" - ): - UnitSystem(u.Myr, u.radian, u.Msun) + match = "must specify a unit for the physical type .*time" + with pytest.raises(ValueError, match=match): + UnitSystem(u.kpc, u.radian, u.Msun) - usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun) - usys = UnitSystem(usys) + match = "must specify a unit for the physical type .*length" + with pytest.raises(ValueError, match=match): + UnitSystem(u.Myr, u.radian, u.Msun) + usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun) + usys = UnitSystem(usys) -def test_quantity_init(): - usys = UnitSystem(5 * u.kpc, 50 * u.Myr, 1e5 * u.Msun, u.rad) - assert np.isclose((8 * u.Myr).decompose(usys).value, 8 / 50) + def test_constructor_quantity(self) -> None: + """Test the :class:`~galax.units.UnitSystem` constructor with quantities.""" + usys = UnitSystem(5 * u.kpc, 50 * u.Myr, 1e5 * u.Msun, u.rad) + assert np.isclose((8 * u.Myr).decompose(usys).value, 8 / 50) + def test_preferred(self) -> None: + """Test the :meth:`~galax.units.UnitSystem.preferred` method.""" + usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.km / u.s) + q = 15.0 * u.km / u.s + assert usys.preferred("velocity") == u.km / u.s + assert q.decompose(usys).unit == u.kpc / u.Myr + assert usys.as_preferred(q).unit == u.km / u.s -def test_preferred(): - usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.km / u.s) - q = 15.0 * u.km / u.s - assert usys.preferred("velocity") == u.km / u.s - assert q.decompose(usys).unit == u.kpc / u.Myr - assert usys.as_preferred(q).unit == u.km / u.s + # =============================================================== + def test_compare(self) -> None: + """Test the :meth:`~galax.units.UnitSystem.compare` method.""" + usys1 = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.mas / u.yr) + usys1_clone = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.mas / u.yr) -def test_dimensionless(): - assert dimensionless["dimensionless"] == u.one - assert dimensionless["length"] == u.one + usys2 = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.kiloarcsecond / u.yr) + usys3 = UnitSystem(u.kpc, u.Myr, u.radian, u.kg, u.mas / u.yr) - with pytest.raises(ValueError, match="can not be decomposed into"): - (15 * u.kpc).decompose(dimensionless) + assert usys1 == usys1_clone + assert usys1_clone == usys1 - with pytest.raises(ValueError, match="are not convertible"): - dimensionless.as_preferred(15 * u.kpc) + assert usys1 != usys2 + assert usys2 != usys1 + assert usys1 != usys3 + assert usys3 != usys1 -def test_compare(): - usys1 = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.mas / u.yr) - usys1_clone = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.mas / u.yr) + def test_pickle(self, tmpdir: Path) -> None: + """Test pickling and unpickling a :class:`~galax.units.UnitSystem`.""" + usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun) - usys2 = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.kiloarcsecond / u.yr) - usys3 = UnitSystem(u.kpc, u.Myr, u.radian, u.kg, u.mas / u.yr) + path = tmpdir / "test.pkl" + with path.open(mode="wb") as f: + pickle.dump(usys, f) - assert usys1 == usys1_clone - assert usys1_clone == usys1 + with path.open(mode="rb") as f: + usys2 = pickle.load(f) - assert usys1 != usys2 - assert usys2 != usys1 + assert usys == usys2 - assert usys1 != usys3 - assert usys3 != usys1 +class TestDimensionlessUnitSystem: + """Test :class:`~galax.units.DimensionlessUnitSystem`.""" -def test_pickle(tmpdir): - usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun) + def test_getitem(self) -> None: + """Test :meth:`~galax.units.DimensionlessUnitSystem.__getitem__`.""" + assert dimensionless["dimensionless"] == u.one + assert dimensionless["length"] == u.one - path = tmpdir / "test.pkl" - with path.open(mode="wb") as f: - pickle.dump(usys, f) + def test_decompose(self) -> None: + """Test that dimensionless unitsystem can be decomposed.""" + with pytest.raises(ValueError, match="can not be decomposed into"): + (15 * u.kpc).decompose(dimensionless) - with path.open(mode="rb") as f: - usys2 = pickle.load(f) - - assert usys == usys2 + def test_preferred(self) -> None: + """Test the :meth:`~galax.units.DimensionlessUnitSystem.preferred` method.""" + with pytest.raises(ValueError, match="are not convertible"): + dimensionless.as_preferred(15 * u.kpc) diff --git a/tests/unit/utils/test_shape.py b/tests/unit/utils/test_shape.py index c85f7add..5afc1ee6 100644 --- a/tests/unit/utils/test_shape.py +++ b/tests/unit/utils/test_shape.py @@ -1,7 +1,9 @@ """Test the `galax.utils._shape` module.""" import re +from typing import Any +import jax import jax.experimental.array_api as xp import jax.numpy as jnp import pytest @@ -13,7 +15,7 @@ class TestAtleastBatched: """Test the `atleast_batched` function.""" - def test_atleast_batched_no_args(self): + def test_atleast_batched_no_args(self) -> None: """Test the `atleast_batched` function with no arguments.""" with pytest.raises( ValueError, @@ -21,7 +23,7 @@ def test_atleast_batched_no_args(self): ): _ = atleast_batched() - def test_atleast_batched_example(self): + def test_atleast_batched_example(self) -> None: """Test the `atleast_batched` function with an example.""" x = xp.asarray([1, 2, 3]) # `atleast_batched` versus `atleast_2d` @@ -38,13 +40,13 @@ def test_atleast_batched_example(self): ([1, 2, 3], [[1], [2], [3]]), ], ) - def test_atleast_batched_one_arg(self, x, expect): + def test_atleast_batched_one_arg(self, x: Any, expect: Any) -> None: """Test the `atleast_batched` function with one argument.""" got = atleast_batched(xp.asarray(x)) assert array_equal(got, xp.asarray(expect)) assert got.ndim >= 2 - def test_atleast_batched_multiple_args(self): + def test_atleast_batched_multiple_args(self) -> None: """Test the `atleast_batched` function with multiple arguments.""" x = xp.asarray([1, 2, 3]) y = xp.asarray([4, 5, 6]) @@ -68,7 +70,12 @@ class TestBatchedShape: (xp.asarray([[1, 2], [3, 4]]), 2, ((), (2, 2))), ], ) - def test_batched_shape(self, arr, expect_ndim, expect): + def test_batched_shape( + self, + arr: jax.Array, + expect_ndim: int, + expect: tuple[tuple[int, ...], tuple[int, ...]], + ) -> None: """Test the `galax.utils._shape.batched_shape` function.""" batch, shape = batched_shape(arr, expect_ndim=expect_ndim) assert batch == expect[0]