Skip to content

Commit

Permalink
test: add annotations (#117)
Browse files Browse the repository at this point in the history
* test: add annotations
* test: simplify imports

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 31, 2024
1 parent 936b71c commit 90c6a06
Show file tree
Hide file tree
Showing 16 changed files with 309 additions and 243 deletions.
2 changes: 1 addition & 1 deletion tests/smoke/integrate/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
from galax.integrate import _base, _builtin


def test_all():
def test_all() -> None:
"""Test the API."""
assert set(integrate.__all__) == set(_base.__all__ + _builtin.__all__)
2 changes: 1 addition & 1 deletion tests/smoke/potential/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from galax.potential import _potential


def test_all():
def test_all() -> None:
"""Test the `galax.potential` package contents."""
# Test correct dumping of contents
assert gp.__all__ == _potential.__all__
Expand Down
12 changes: 1 addition & 11 deletions tests/smoke/utils/test_package.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
import pytest

from galax import utils


def test__all__():
def test__all__() -> None:
"""Test that `galax.utils` has the expected `__all__`."""
assert utils.__all__ == [
"dataclasses",
*utils._jax.__all__,
*utils._collections.__all__,
]


@pytest.mark.skip(reason="TODO")
def test_public_modules():
"""Test which modules are publicly importable."""
# IDK how to discover all submodules of a package, even if they aren't
# imported without relying on the filesystem. The filesystem is generally
# safe, but I'd rather solve this generically. Low priority.
25 changes: 17 additions & 8 deletions tests/unit/potential/builtin/test_bar.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Any

import astropy.units as u
import jax.experimental.array_api as xp
import jax.numpy as jnp
import pytest

import galax.potential as gp
from galax.potential import BarPotential
from galax.typing import Vec3
from galax.units import UnitSystem

from ..test_core import TestAbstractPotential as AbstractPotential_Test
from .test_common import (
Expand All @@ -24,16 +27,22 @@ class TestBarPotential(
ShapeCParameterMixin,
):
@pytest.fixture(scope="class")
def pot_cls(self) -> type[gp.BarPotential]:
return gp.BarPotential
def pot_cls(self) -> type[BarPotential]:
return BarPotential

@pytest.fixture(scope="class")
def field_Omega(self) -> dict[str, Any]:
return 0

@pytest.fixture(scope="class")
def fields_(
self, field_m, field_a, field_b, field_c, field_Omega, field_units
self,
field_m: u.Quantity,
field_a: u.Quantity,
field_b: u.Quantity,
field_c: u.Quantity,
field_Omega: u.Quantity,
field_units: UnitSystem,
) -> dict[str, Any]:
return {
"m": field_m,
Expand All @@ -46,18 +55,18 @@ def fields_(

# ==========================================================================

def test_potential_energy(self, pot, x) -> None:
def test_potential_energy(self, pot: BarPotential, x: Vec3) -> None:
assert jnp.isclose(pot.potential_energy(x, t=0), xp.asarray(-0.94601574))

def test_gradient(self, pot, x):
def test_gradient(self, pot: BarPotential, x: Vec3) -> None:
assert jnp.allclose(
pot.gradient(x, t=0), xp.asarray([0.04011905, 0.08383918, 0.16552719])
)

def test_density(self, pot, x):
def test_density(self, pot: BarPotential, x: Vec3) -> None:
assert jnp.isclose(pot.density(x, t=0), 1.94669274e08)

def test_hessian(self, pot, x):
def test_hessian(self, pot: BarPotential, x: Vec3) -> None:
assert jnp.allclose(
pot.hessian(x, t=0),
xp.asarray(
Expand Down
9 changes: 5 additions & 4 deletions tests/unit/potential/builtin/test_hernquist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from galax.potential import HernquistPotential
from galax.typing import Vec3

from ..test_core import TestAbstractPotential as AbstractPotential_Test
from .test_common import MassParameterMixin, ShapeCParameterMixin
Expand All @@ -26,18 +27,18 @@ def fields_(self, field_m, field_c, field_units) -> dict[str, Any]:

# ==========================================================================

def test_potential_energy(self, pot, x) -> None:
def test_potential_energy(self, pot: HernquistPotential, x: Vec3) -> None:
assert jnp.isclose(pot.potential_energy(x, t=0), xp.asarray(-0.94871936))

def test_gradient(self, pot, x):
def test_gradient(self, pot: HernquistPotential, x: Vec3) -> None:
assert jnp.allclose(
pot.gradient(x, t=0), xp.asarray([0.05347411, 0.10694822, 0.16042233])
)

def test_density(self, pot, x):
def test_density(self, pot: HernquistPotential, x: Vec3) -> None:
assert jnp.isclose(pot.density(x, t=0), 3.989933e08)

def test_hessian(self, pot, x):
def test_hessian(self, pot: HernquistPotential, x: Vec3) -> None:
assert jnp.allclose(
pot.hessian(x, t=0),
xp.asarray(
Expand Down
10 changes: 6 additions & 4 deletions tests/unit/potential/builtin/test_isochrone.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import pytest

import galax.potential as gp
from galax.potential import IsochronePotential
from galax.typing import Vec3

from ..test_core import TestAbstractPotential as AbstractPotential_Test
from .test_common import MassParameterMixin, ShapeBParameterMixin
Expand All @@ -26,18 +28,18 @@ def fields_(self, field_m, field_b, field_units) -> dict[str, Any]:

# ==========================================================================

def test_potential_energy(self, pot, x) -> None:
def test_potential_energy(self, pot: IsochronePotential, x: Vec3) -> None:
assert jnp.isclose(pot.potential_energy(x, t=0), xp.asarray(-0.9231515))

def test_gradient(self, pot, x):
def test_gradient(self, pot: IsochronePotential, x: Vec3) -> None:
assert jnp.allclose(
pot.gradient(x, t=0), xp.asarray([0.04891392, 0.09782784, 0.14674175])
)

def test_density(self, pot, x):
def test_density(self, pot: IsochronePotential, x: Vec3) -> None:
assert jnp.isclose(pot.density(x, t=0), 5.04511665e08)

def test_hessian(self, pot, x):
def test_hessian(self, pot: IsochronePotential, x: Vec3) -> None:
assert jnp.allclose(
pot.hessian(x, t=0),
xp.asarray(
Expand Down
20 changes: 15 additions & 5 deletions tests/unit/potential/builtin/test_miyamotonagai.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Any

import astropy.units as u
import jax.experimental.array_api as xp
import jax.numpy as jnp
import pytest

import galax.potential as gp
from galax.potential import MiyamotoNagaiPotential
from galax.typing import Vec3
from galax.units import UnitSystem

from ..test_core import TestAbstractPotential as AbstractPotential_Test
from .test_common import MassParameterMixin, ShapeAParameterMixin, ShapeBParameterMixin
Expand All @@ -24,23 +28,29 @@ def pot_cls(self) -> type[gp.MiyamotoNagaiPotential]:
return gp.MiyamotoNagaiPotential

@pytest.fixture(scope="class")
def fields_(self, field_m, field_a, field_b, field_units) -> dict[str, Any]:
def fields_(
self,
field_m: u.Quantity,
field_a: u.Quantity,
field_b: u.Quantity,
field_units: UnitSystem,
) -> dict[str, Any]:
return {"m": field_m, "a": field_a, "b": field_b, "units": field_units}

# ==========================================================================

def test_potential_energy(self, pot, x) -> None:
def test_potential_energy(self, pot: MiyamotoNagaiPotential, x: Vec3) -> None:
assert jnp.isclose(pot.potential_energy(x, t=0), xp.asarray(-0.95208676))

def test_gradient(self, pot, x):
def test_gradient(self, pot: MiyamotoNagaiPotential, x: Vec3) -> None:
assert jnp.allclose(
pot.gradient(x, t=0), xp.asarray([0.04264751, 0.08529503, 0.16840152])
)

def test_density(self, pot, x):
def test_density(self, pot: MiyamotoNagaiPotential, x: Vec3) -> None:
assert jnp.isclose(pot.density(x, t=0), 1.9949418e08)

def test_hessian(self, pot, x):
def test_hessian(self, pot: MiyamotoNagaiPotential, x: Vec3) -> None:
assert jnp.allclose(
pot.hessian(x, t=0),
xp.asarray(
Expand Down
9 changes: 5 additions & 4 deletions tests/unit/potential/builtin/test_mwpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from galax.potential import MilkyWayPotential
from galax.typing import Vec3
from galax.units import galactic

from ..test_core import TestAbstractPotential
Expand Down Expand Up @@ -34,18 +35,18 @@ def test_init_units_from_args(self, pot_cls, fields_unitless):

# ==========================================================================

def test_potential_energy(self, pot, x) -> None:
def test_potential_energy(self, pot: MilkyWayPotential, x: Vec3) -> None:
assert xp.isclose(pot.potential_energy(x, t=0), xp.array(-0.19386052))

def test_gradient(self, pot, x):
def test_gradient(self, pot: MilkyWayPotential, x: Vec3) -> None:
assert xp.allclose(
pot.gradient(x, t=0), xp.array([0.00256403, 0.00512806, 0.01115272])
)

def test_density(self, pot, x):
def test_density(self, pot: MilkyWayPotential, x: Vec3) -> None:
assert xp.isclose(pot.density(x, t=0), 33_365_858.46361218)

def test_hessian(self, pot, x):
def test_hessian(self, pot: MilkyWayPotential, x: Vec3) -> None:
assert xp.allclose(
pot.hessian(x, t=0),
xp.array(
Expand Down
40 changes: 24 additions & 16 deletions tests/unit/potential/builtin/test_nfw.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from typing_extensions import override

import galax.potential as gp
from galax.potential import ConstantParameter
from galax.potential import AbstractPotential, ConstantParameter, NFWPotential
from galax.typing import Vec3
from galax.units import galactic
from galax.units import UnitSystem, galactic
from galax.utils._optional_deps import HAS_GALA

from ..param.test_field import ParameterFieldMixin
Expand All @@ -21,30 +21,36 @@
class ScaleRadiusParameterMixin(ParameterFieldMixin):
"""Test the mass parameter."""

pot_cls: type[gp.AbstractPotential]
pot_cls: type[AbstractPotential]

@pytest.fixture(scope="class")
def field_r_s(self) -> float:
return 1.0 * u.kpc

# =====================================================

def test_r_s_units(self, pot_cls, fields):
def test_r_s_units(
self, pot_cls: type[AbstractPotential], fields: dict[str, Any]
) -> None:
"""Test the mass parameter."""
fields["r_s"] = 1.0 * u.Unit(10 * u.kpc)
fields["units"] = galactic
pot = pot_cls(**fields)
assert isinstance(pot.r_s, ConstantParameter)
assert jnp.isclose(pot.r_s.value, 10)

def test_r_s_constant(self, pot_cls, fields):
def test_r_s_constant(
self, pot_cls: type[AbstractPotential], fields: dict[str, Any]
):
"""Test the mass parameter."""
fields["r_s"] = 1.0
pot = pot_cls(**fields)
assert pot.r_s(t=0) == 1.0

@pytest.mark.xfail(reason="TODO: user function doesn't have units")
def test_r_s_userfunc(self, pot_cls, fields):
def test_r_s_userfunc(
self, pot_cls: type[AbstractPotential], fields: dict[str, Any]
):
"""Test the mass parameter."""
fields["r_s"] = lambda t: t + 2
pot = pot_cls(**fields)
Expand All @@ -62,8 +68,8 @@ class TestNFWPotential(
):
@pytest.fixture(scope="class")
@override
def pot_cls(self) -> type[gp.NFWPotential]:
return gp.NFWPotential
def pot_cls(self) -> type[NFWPotential]:
return NFWPotential

@pytest.fixture(scope="class")
def field_softening_length(self) -> float:
Expand All @@ -72,7 +78,11 @@ def field_softening_length(self) -> float:
@pytest.fixture(scope="class")
@override
def fields_(
self, field_m, field_r_s, field_softening_length, field_units
self,
field_m: u.Quantity,
field_r_s: u.Quantity,
field_softening_length: float,
field_units: UnitSystem,
) -> dict[str, Any]:
return {
"m": field_m,
Expand All @@ -83,18 +93,18 @@ def fields_(

# ==========================================================================

def test_potential_energy(self, pot, x) -> None:
def test_potential_energy(self, pot: NFWPotential, x: Vec3) -> None:
assert jnp.isclose(pot.potential_energy(x, t=0), xp.asarray(-1.87117234))

def test_gradient(self, pot, x):
def test_gradient(self, pot: NFWPotential, x: Vec3) -> None:
assert jnp.allclose(
pot.gradient(x, t=0), xp.asarray([0.0658867, 0.1317734, 0.19766011])
)

def test_density(self, pot, x):
def test_density(self, pot: NFWPotential, x: Vec3) -> None:
assert jnp.isclose(pot.density(x, t=0), 9.46039849e08)

def test_hessian(self, pot, x):
def test_hessian(self, pot: NFWPotential, x: Vec3) -> None:
assert jnp.allclose(
pot.hessian(x, t=0),
xp.asarray(
Expand All @@ -110,9 +120,7 @@ def test_hessian(self, pot, x):
# I/O

@pytest.mark.skipif(not HAS_GALA, reason="requires gala")
def test_galax_to_gala_to_galax_roundtrip(
self, pot: gp.AbstractPotentialBase, x: Vec3
) -> None:
def test_galax_to_gala_to_galax_roundtrip(self, pot: NFWPotential, x: Vec3) -> None:
"""Test roundtripping ``gala_to_galax(galax_to_gala())``."""
from ..io.gala_helper import galax_to_gala

Expand Down
Loading

0 comments on commit 90c6a06

Please sign in to comment.