diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dbe995d6..41307fcb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,6 +27,7 @@ repos: - id: mixed-line-ending - id: name-tests-test args: ["--pytest-test-first"] + exclude: '^tests\/.*_helper\.py$' - id: requirements-txt-fixer - id: trailing-whitespace diff --git a/src/galax/potential/_potential/__init__.py b/src/galax/potential/_potential/__init__.py index 40af68cc..e0d3ac8f 100644 --- a/src/galax/potential/_potential/__init__.py +++ b/src/galax/potential/_potential/__init__.py @@ -1,7 +1,7 @@ """galax: Galactic Dynamix in Jax.""" -from . import base, builtin, composite, core, param, special, utils +from . import base, builtin, composite, core, io, param, special, utils from .base import * from .builtin import * from .composite import * @@ -10,7 +10,7 @@ from .special import * from .utils import * -__all__: list[str] = [] +__all__: list[str] = ["io"] __all__ += base.__all__ __all__ += core.__all__ __all__ += composite.__all__ diff --git a/src/galax/potential/_potential/io/__init__.py b/src/galax/potential/_potential/io/__init__.py new file mode 100644 index 00000000..5671a317 --- /dev/null +++ b/src/galax/potential/_potential/io/__init__.py @@ -0,0 +1,17 @@ +"""Input/output/conversion of potential objects. + +This module contains the machinery for I/O and conversion of potential objects. +Conversion is useful for e.g. converting a +:class:`galax.potential.AbstractPotential` object to a +:class:`gala.potential.PotentialBase` object. +""" + +__all__: list[str] = ["gala_to_galax"] + + +from galax.utils._optional_deps import HAS_GALA + +if HAS_GALA: + from .gala import gala_to_galax +else: + from .gala_noop import gala_to_galax # type: ignore[assignment] diff --git a/src/galax/potential/_potential/io/gala.py b/src/galax/potential/_potential/io/gala.py new file mode 100644 index 00000000..f6dd64fa --- /dev/null +++ b/src/galax/potential/_potential/io/gala.py @@ -0,0 +1,156 @@ +"""Interoperability.""" + +__all__ = ["gala_to_galax"] + +from functools import singledispatch + +import numpy as np +from gala.potential import ( + CompositePotential as GalaCompositePotential, + HernquistPotential as GalaHernquistPotential, + IsochronePotential as GalaIsochronePotential, + KeplerPotential as GalaKeplerPotential, + MilkyWayPotential as GalaMilkyWayPotential, + MiyamotoNagaiPotential as GalaMiyamotoNagaiPotential, + NFWPotential as GalaNFWPotential, + NullPotential as GalaNullPotential, + PotentialBase as GalaPotentialBase, +) + +from galax.potential._potential.base import AbstractPotentialBase +from galax.potential._potential.builtin import ( + HernquistPotential, + IsochronePotential, + KeplerPotential, + MiyamotoNagaiPotential, + NFWPotential, + NullPotential, +) +from galax.potential._potential.composite import CompositePotential +from galax.potential._potential.special import MilkyWayPotential + +############################################################################## +# GALA -> GALAX + + +def _static_at_origin(pot: GalaPotentialBase, /) -> bool: + return pot.R is None and np.array_equal(pot.origin, (0, 0, 0)) + + +@singledispatch +def gala_to_galax(pot: GalaPotentialBase, /) -> AbstractPotentialBase: + """Convert a :mod:`gala` potential to a :mod:`galax` potential. + + Parameters + ---------- + pot : :class:`~gala.potential.PotentialBase` + :mod:`gala` potential. + + Returns + ------- + gala_pot : :class:`~galax.potential.AbstractPotentialBase` + :mod:`galax` potential. + """ + msg = ( + "`gala_to_galax` does not have a registered function to convert " + f"{pot.__class__.__name__!r} to a `galax.AbstractPotentialBase` instance." + ) + raise NotImplementedError(msg) + + +# ----------------------------------------------------------------------------- +# General rules + + +@gala_to_galax.register +def _gala_to_galax_composite(pot: GalaCompositePotential, /) -> CompositePotential: + """Convert a Gala CompositePotential to a Galax potential.""" + return CompositePotential(**{k: gala_to_galax(p) for k, p in pot.items()}) + + +# ----------------------------------------------------------------------------- +# Builtin potentials + + +@gala_to_galax.register +def _gala_to_galax_hernquist(pot: GalaHernquistPotential, /) -> HernquistPotential: + """Convert a Gala HernquistPotential to a Galax potential.""" + if not _static_at_origin(pot): + msg = "Galax does not support rotating or offset potentials." + raise TypeError(msg) + params = pot.parameters + return HernquistPotential(m=params["m"], c=params["c"], units=pot.units) + + +@gala_to_galax.register +def _gala_to_galax_isochrone(pot: GalaIsochronePotential, /) -> IsochronePotential: + """Convert a Gala IsochronePotential to a Galax potential.""" + if not _static_at_origin(pot): + msg = "Galax does not support rotating or offset potentials." + raise TypeError(msg) + params = pot.parameters + return IsochronePotential(m=params["m"], b=params["b"], units=pot.units) + + +@gala_to_galax.register +def _gala_to_galax_kepler(pot: GalaKeplerPotential, /) -> KeplerPotential: + """Convert a Gala KeplerPotential to a Galax potential.""" + if not _static_at_origin(pot): + msg = "Galax does not support rotating or offset potentials." + raise TypeError(msg) + params = pot.parameters + return KeplerPotential(m=params["m"], units=pot.units) + + +@gala_to_galax.register +def _gala_to_galax_miyamotonagi( + pot: GalaMiyamotoNagaiPotential, / +) -> MiyamotoNagaiPotential: + """Convert a Gala MiyamotoNagaiPotential to a Galax potential.""" + if not _static_at_origin(pot): + msg = "Galax does not support rotating or offset potentials." + raise TypeError(msg) + params = pot.parameters + return MiyamotoNagaiPotential( + m=params["m"], a=params["a"], b=params["b"], units=pot.units + ) + + +@gala_to_galax.register +def _gala_to_galax_nfw(pot: GalaNFWPotential, /) -> NFWPotential: + """Convert a Gala NFWPotential to a Galax potential.""" + if not _static_at_origin(pot): + msg = "Galax does not support rotating or offset potentials." + raise TypeError(msg) + params = pot.parameters + return NFWPotential( + m=params["m"], r_s=params["r_s"], softening_length=0, units=pot.units + ) + + +@gala_to_galax.register +def _gala_to_galax_nullpotential(pot: GalaNullPotential, /) -> NullPotential: + """Convert a Gala NullPotential to a Galax potential.""" + if not _static_at_origin(pot): + msg = "Galax does not support rotating or offset potentials." + raise TypeError(msg) + return NullPotential(units=pot.units) + + +# ----------------------------------------------------------------------------- +# MW potentials + + +@gala_to_galax.register +def _gala_to_galax_mwpotential(pot: GalaMilkyWayPotential, /) -> MilkyWayPotential: + """Convert a Gala MilkyWayPotential to a Galax potential.""" + if not all(_static_at_origin(p) for p in pot.values()): + msg = "Galax does not support rotating or offset potentials." + raise TypeError(msg) + + return MilkyWayPotential( + disk={k: pot["disk"].parameters[k] for k in ("m", "a", "b")}, + halo={k: pot["halo"].parameters[k] for k in ("m", "r_s")}, + bulge={k: pot["bulge"].parameters[k] for k in ("m", "c")}, + nucleus={k: pot["nucleus"].parameters[k] for k in ("m", "c")}, + ) diff --git a/src/galax/potential/_potential/io/gala_noop.py b/src/galax/potential/_potential/io/gala_noop.py new file mode 100644 index 00000000..aa151d97 --- /dev/null +++ b/src/galax/potential/_potential/io/gala_noop.py @@ -0,0 +1,28 @@ +"""Interoperability.""" + +__all__ = ["gala_to_galax"] + + +from typing import TYPE_CHECKING + +from galax.potential._potential.base import AbstractPotentialBase + +if TYPE_CHECKING: + from gala.potential import PotentialBase as GalaPotentialBase + + +def gala_to_galax(pot: "GalaPotentialBase", /) -> AbstractPotentialBase: + """Convert a :mod:`gala` potential to a :mod:`galax` potential. + + Parameters + ---------- + pot : :class:`~gala.potential.PotentialBase` + :mod:`gala` potential. + + Returns + ------- + gala_pot : :class:`~galax.potential.AbstractPotentialBase` + :mod:`galax` potential. + """ + msg = "The `gala` package must be installed to use this function. " + raise ImportError(msg) diff --git a/src/galax/potential/_potential/param/core.py b/src/galax/potential/_potential/param/core.py index 487a37e0..38c6d55a 100644 --- a/src/galax/potential/_potential/param/core.py +++ b/src/galax/potential/_potential/param/core.py @@ -13,6 +13,7 @@ BatchableFloatOrIntScalarLike, FloatArrayAnyShape, FloatOrIntScalar, + FloatOrIntScalarLike, FloatScalar, Unit, ) @@ -37,12 +38,12 @@ class AbstractParameter(eqx.Module, strict=True): # type: ignore[call-arg, misc unit: Unit = eqx.field(static=True, converter=u.Unit) @abc.abstractmethod - def __call__(self, t: FloatScalar, **kwargs: Any) -> FloatArrayAnyShape: + def __call__(self, t: FloatOrIntScalarLike, **kwargs: Any) -> FloatArrayAnyShape: """Compute the parameter value at the given time(s). Parameters ---------- - t : float | Array[float, ()] + t : Array[float | int, ()] | float | int The time(s) at which to compute the parameter value. **kwargs : Any Additional parameters to pass to the parameter function. diff --git a/tests/smoke/potential/test_package.py b/tests/smoke/potential/test_package.py index d881be0b..5ffb9f5b 100644 --- a/tests/smoke/potential/test_package.py +++ b/tests/smoke/potential/test_package.py @@ -8,11 +8,12 @@ def test_all(): assert gp.__all__ == _potential.__all__ # Test detailed contents (not order) - assert set(gp.__all__) == set( - _potential.base.__all__ - + _potential.builtin.__all__ - + _potential.composite.__all__ - + _potential.core.__all__ - + _potential.param.__all__ - + _potential.special.__all__ - ) + assert set(gp.__all__) == { + "io", + *_potential.base.__all__, + *_potential.builtin.__all__, + *_potential.composite.__all__, + *_potential.core.__all__, + *_potential.param.__all__, + *_potential.special.__all__, + } diff --git a/tests/unit/potential/builtin/test_nfw.py b/tests/unit/potential/builtin/test_nfw.py index 7c3528f5..30398255 100644 --- a/tests/unit/potential/builtin/test_nfw.py +++ b/tests/unit/potential/builtin/test_nfw.py @@ -1,13 +1,17 @@ +from dataclasses import replace from typing import Any import astropy.units as u import jax.experimental.array_api as xp import jax.numpy as jnp import pytest +from typing_extensions import override import galax.potential as gp from galax.potential import ConstantParameter +from galax.typing import Vec3 from galax.units import galactic +from galax.utils._optional_deps import HAS_GALA from ..params.test_field import ParameterFieldMixin from ..test_core import TestAbstractPotential as AbstractPotential_Test @@ -47,6 +51,9 @@ def test_r_s_userfunc(self, pot_cls, fields): assert pot.r_s(t=0) == 2 +############################################################################### + + class TestNFWPotential( AbstractPotential_Test, # Parameters @@ -54,14 +61,16 @@ class TestNFWPotential( ScaleRadiusParameterMixin, ): @pytest.fixture(scope="class") + @override def pot_cls(self) -> type[gp.NFWPotential]: return gp.NFWPotential @pytest.fixture(scope="class") - def field_softening_length(self) -> dict[str, Any]: + def field_softening_length(self) -> float: return 0.001 @pytest.fixture(scope="class") + @override def fields_( self, field_m, field_r_s, field_softening_length, field_units ) -> dict[str, Any]: @@ -96,3 +105,26 @@ def test_hessian(self, pot, x): ] ), ) + + # ========================================================================== + # I/O + + @pytest.mark.skipif(not HAS_GALA, reason="requires gala") + def test_galax_to_gala_to_galax_roundtrip( + self, pot: gp.AbstractPotentialBase, x: Vec3 + ) -> None: + """Test roundtripping ``gala_to_galax(galax_to_gala())``.""" + from ..io.gala_helper import galax_to_gala + + # Base is with non-zero softening + assert pot.softening_length != 0 + with pytest.raises(TypeError, match="Gala does not support softening"): + _ = galax_to_gala(pot) + + # Make a copy without softening + pot = replace(pot, softening_length=0) + + rpot = gp.io.gala_to_galax(galax_to_gala(pot)) + + # quick test that the potential energies are the same + assert jnp.array_equal(pot(x, t=0), rpot(x, t=0)) diff --git a/tests/unit/potential/builtin/test_null.py b/tests/unit/potential/builtin/test_null.py index a9e052d2..fc27ac4a 100644 --- a/tests/unit/potential/builtin/test_null.py +++ b/tests/unit/potential/builtin/test_null.py @@ -9,7 +9,7 @@ from ..test_core import TestAbstractPotential as AbstractPotential_Test -class TestBarPotential(AbstractPotential_Test): +class TestNullPotential(AbstractPotential_Test): @pytest.fixture(scope="class") def pot_cls(self) -> type[gp.NullPotential]: return gp.NullPotential diff --git a/tests/unit/potential/io/__init__.py b/tests/unit/potential/io/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/potential/io/gala_helper.py b/tests/unit/potential/io/gala_helper.py new file mode 100644 index 00000000..37c6fb52 --- /dev/null +++ b/tests/unit/potential/io/gala_helper.py @@ -0,0 +1,183 @@ +"""Interoperability.""" + +__all__ = ["galax_to_gala"] + +from functools import singledispatch + +from gala.potential import ( + CompositePotential as GalaCompositePotential, + HernquistPotential as GalaHernquistPotential, + IsochronePotential as GalaIsochronePotential, + KeplerPotential as GalaKeplerPotential, + MilkyWayPotential as GalaMilkyWayPotential, + MiyamotoNagaiPotential as GalaMiyamotoNagaiPotential, + NFWPotential as GalaNFWPotential, + NullPotential as GalaNullPotential, + PotentialBase as GalaPotentialBase, +) +from gala.units import UnitSystem as GalaUnitSystem, dimensionless as gala_dimensionless + +from galax.potential._potential.base import AbstractPotentialBase +from galax.potential._potential.builtin import ( + BarPotential, + HernquistPotential, + IsochronePotential, + KeplerPotential, + MiyamotoNagaiPotential, + NFWPotential, + NullPotential, +) +from galax.potential._potential.composite import CompositePotential +from galax.potential._potential.param import ConstantParameter +from galax.potential._potential.special import MilkyWayPotential +from galax.units import DimensionlessUnitSystem, UnitSystem + +############################################################################## +# UnitSystem + + +def galax_to_gala_units(units: UnitSystem, /) -> GalaUnitSystem: + if isinstance(units, DimensionlessUnitSystem): + return gala_dimensionless + return GalaUnitSystem(units) + + +############################################################################## +# GALAX -> GALA + + +# TODO: this can be removed when AbstractPotential gets a `parameters` +# attribute that is a dict whose keys are the names of the parameters. +def _all_constant_parameters( + pot: "AbstractPotentialBase", + *params: str, +) -> bool: + return all(isinstance(getattr(pot, name), ConstantParameter) for name in params) + + +# TODO: add an argument to specify how to handle time-dependent parameters. +# Gala potentials are not time-dependent, so we need to specify how to +# handle time-dependent Galax parameters. +@singledispatch +def galax_to_gala(pot: AbstractPotentialBase, /) -> GalaPotentialBase: + """Convert a Galax potential to a Gala potential. + + Parameters + ---------- + pot : :class:`~galax.potential.AbstractPotentialBase` + Galax potential. + + Returns + ------- + gala_pot : :class:`~gala.potential.PotentialBase` + Gala potential. + """ + msg = ( + "`galax_to_gala` does not have a registered function to convert " + f"{pot.__class__.__name__!r} to a `gala.PotentialBase` instance." + ) + raise NotImplementedError(msg) + + +@galax_to_gala.register +def _galax_to_gala_composite(pot: CompositePotential, /) -> GalaCompositePotential: + """Convert a Galax CompositePotential to a Gala potential.""" + return GalaCompositePotential(**{k: galax_to_gala(p) for k, p in pot.items()}) + + +@galax_to_gala.register +def _galax_to_gala_bar(pot: BarPotential, /) -> GalaPotentialBase: + """Convert a Galax BarPotential to a Gala potential.""" + raise NotImplementedError # TODO: implement + + +@galax_to_gala.register +def _galax_to_gala_hernquist(pot: HernquistPotential, /) -> GalaHernquistPotential: + """Convert a Galax HernquistPotential to a Gala potential.""" + if not _all_constant_parameters(pot, "m", "c"): + msg = "Gala does not support time-dependent parameters." + raise TypeError(msg) + + return GalaHernquistPotential( + m=pot.m(0) * pot.units["mass"], + c=pot.c(0) * pot.units["length"], + units=galax_to_gala_units(pot.units), + ) + + +@galax_to_gala.register +def _galax_to_gala_isochrone(pot: IsochronePotential, /) -> GalaIsochronePotential: + """Convert a Galax IsochronePotential to a Gala potential.""" + if not _all_constant_parameters(pot, "m", "b"): + msg = "Gala does not support time-dependent parameters." + raise TypeError(msg) + + return GalaIsochronePotential( + m=pot.m(0) * pot.units["mass"], + b=pot.b(0) * pot.units["length"], # TODO: fix the mismatch + units=galax_to_gala_units(pot.units), + ) + + +@galax_to_gala.register +def _galax_to_gala_kepler(pot: KeplerPotential, /) -> GalaKeplerPotential: + """Convert a Galax KeplerPotential to a Gala potential.""" + if not _all_constant_parameters(pot, "m"): + msg = "Gala does not support time-dependent parameters." + raise TypeError(msg) + + return GalaKeplerPotential( + m=pot.m(0) * pot.units["mass"], units=galax_to_gala_units(pot.units) + ) + + +@galax_to_gala.register +def _galax_to_gala_miyamotonagi( + pot: MiyamotoNagaiPotential, / +) -> GalaMiyamotoNagaiPotential: + """Convert a Galax MiyamotoNagaiPotential to a Gala potential.""" + if not _all_constant_parameters(pot, "m", "a", "b"): + msg = "Gala does not support time-dependent parameters." + raise TypeError(msg) + + return GalaMiyamotoNagaiPotential( + m=pot.m(0) * pot.units["mass"], + a=pot.a(0) * pot.units["length"], + b=pot.b(0) * pot.units["length"], + units=galax_to_gala_units(pot.units), + ) + + +@galax_to_gala.register +def _galax_to_gala_nfw(pot: NFWPotential, /) -> GalaNFWPotential: + """Convert a Galax NFWPotential to a Gala potential.""" + if not _all_constant_parameters(pot, "m", "r_s"): + msg = "Gala does not support time-dependent parameters." + raise TypeError(msg) + + if pot.softening_length != 0: + msg = "Gala does not support softening." + raise TypeError(msg) + + return GalaNFWPotential( + m=pot.m(0) * pot.units["mass"], + r_s=pot.r_s(0) * pot.units["length"], + units=galax_to_gala_units(pot.units), + ) + + +@galax_to_gala.register +def _galax_to_gala_nullpotential(pot: NullPotential, /) -> GalaNullPotential: + """Convert a Galax NullPotential to a Gala potential.""" + return GalaNullPotential(units=galax_to_gala_units(pot.units)) + + +@galax_to_gala.register +def _gala_to_galax_mwpotential(pot: MilkyWayPotential, /) -> GalaMilkyWayPotential: + """Convert a Gala MilkyWayPotential to a Galax potential.""" + return GalaMilkyWayPotential( + disk={k: getattr(pot["disk"], k)(0) for k in ("m", "a", "b")}, + halo={k: getattr(pot["halo"], k)(0) for k in ("m", "r_s")}, + bulge={k: getattr(pot["bulge"], k)(0) for k in ("m", "c")}, + nucleus={k: getattr(pot["nucleus"], k)(0) for k in ("m", "c")}, + ) diff --git a/tests/unit/potential/io/test_gala.py b/tests/unit/potential/io/test_gala.py new file mode 100644 index 00000000..c51cb9c7 --- /dev/null +++ b/tests/unit/potential/io/test_gala.py @@ -0,0 +1,45 @@ +"""Testing the gala potential I/O module.""" + +from inspect import get_annotations +from typing import ClassVar + +import jax.numpy as xp +import pytest + +import galax.potential as gp +from galax.typing import Vec3 +from galax.utils._optional_deps import HAS_GALA + + +class GalaIOMixin: + """Mixin for testing gala potential I/O. + + This is mixed into the ``TestAbstractPotentialBase`` class. + """ + + # All the Gala-mapped potentials + _GALA_CAN_MAP_TO: ClassVar = ( + [ + get_annotations(pot)["return"] + for pot in gp.io.gala_to_galax.registry.values() + ] + if HAS_GALA + else [] + ) + + @pytest.mark.skipif(not HAS_GALA, reason="requires gala") + def test_galax_to_gala_to_galax_roundtrip( + self, pot: gp.AbstractPotentialBase, x: Vec3 + ) -> None: + """Test roundtripping ``gala_to_galax(galax_to_gala())``.""" + from .gala_helper import galax_to_gala + + # First we need to check that the potential is gala-compatible + if type(pot) not in self._GALA_CAN_MAP_TO: + pytest.skip(f"potential {pot} cannot be mapped to from gala") + + # TODO: a more robust test + rpot = gp.io.gala_to_galax(galax_to_gala(pot)) + + # quick test that the potential energies are the same + assert xp.array_equal(pot(x, t=0), rpot(x, t=0)) diff --git a/tests/unit/potential/test_base.py b/tests/unit/potential/test_base.py index 2d34f576..fa11d885 100644 --- a/tests/unit/potential/test_base.py +++ b/tests/unit/potential/test_base.py @@ -9,12 +9,19 @@ import galax.dynamics as gd import galax.potential as gp -from galax.typing import BatchableFloatOrIntScalarLike, BatchFloatScalar, BatchVec3 +from galax.typing import ( + BatchableFloatOrIntScalarLike, + BatchFloatScalar, + BatchVec3, + Vec3, +) from galax.units import UnitSystem, dimensionless from galax.utils import partial_jit, vectorize_method +from .io.test_gala import GalaIOMixin -class TestAbstractPotentialBase: + +class TestAbstractPotentialBase(GalaIOMixin): """Test the `galax.potential.AbstractPotentialBase` class.""" @pytest.fixture(scope="class") @@ -53,17 +60,17 @@ def pot( # --------------------------------- @pytest.fixture(scope="class") - def x(self) -> Float[Array, "3"]: + def x(self) -> Vec3: """Create a position vector for testing.""" return xp.asarray([1, 2, 3], dtype=float) @pytest.fixture(scope="class") - def v(self) -> Float[Array, "3"]: + def v(self) -> Vec3: """Create a velocity vector for testing.""" return xp.asarray([4, 5, 6], dtype=float) @pytest.fixture(scope="class") - def xv(self, x: Float[Array, "3"], v: Float[Array, "3"]) -> Float[Array, "6"]: + def xv(self, x: Vec3, v: Vec3) -> Float[Array, "6"]: """Create a phase-space vector for testing.""" return xp.concat([x, v]) @@ -92,10 +99,14 @@ def _potential_energy(self, q, t): # ========================================================================= + # --------------------------------- + def test_potential_energy(self, pot, x): """Test the `AbstractPotentialBase.potential_energy` method.""" assert pot.potential_energy(x, t=0) == 6 + # --------------------------------- + def test_call(self, pot, x): """Test the `AbstractPotentialBase.__call__` method.""" assert xp.equal(pot(x, t=0), pot.potential_energy(x, t=0))