Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gala -> galax conversion #79

Merged
merged 3 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ repos:
- id: mixed-line-ending
- id: name-tests-test
args: ["--pytest-test-first"]
exclude: '^tests\/.*_helper\.py$'
- id: requirements-txt-fixer
- id: trailing-whitespace

Expand Down
4 changes: 2 additions & 2 deletions src/galax/potential/_potential/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""galax: Galactic Dynamix in Jax."""


from . import base, builtin, composite, core, param, special, utils
from . import base, builtin, composite, core, io, param, special, utils
from .base import *
from .builtin import *
from .composite import *
Expand All @@ -10,7 +10,7 @@
from .special import *
from .utils import *

__all__: list[str] = []
__all__: list[str] = ["io"]
__all__ += base.__all__
__all__ += core.__all__
__all__ += composite.__all__
Expand Down
17 changes: 17 additions & 0 deletions src/galax/potential/_potential/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Input/output/conversion of potential objects.
This module contains the machinery for I/O and conversion of potential objects.
Conversion is useful for e.g. converting a
:class:`galax.potential.AbstractPotential` object to a
:class:`gala.potential.PotentialBase` object.
"""

__all__: list[str] = ["gala_to_galax"]


from galax.utils._optional_deps import HAS_GALA

if HAS_GALA:
from .gala import gala_to_galax
else:
from .gala_noop import gala_to_galax # type: ignore[assignment]
156 changes: 156 additions & 0 deletions src/galax/potential/_potential/io/gala.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""Interoperability."""

__all__ = ["gala_to_galax"]

from functools import singledispatch

import numpy as np
from gala.potential import (
CompositePotential as GalaCompositePotential,
HernquistPotential as GalaHernquistPotential,
IsochronePotential as GalaIsochronePotential,
KeplerPotential as GalaKeplerPotential,
MilkyWayPotential as GalaMilkyWayPotential,
MiyamotoNagaiPotential as GalaMiyamotoNagaiPotential,
NFWPotential as GalaNFWPotential,
NullPotential as GalaNullPotential,
PotentialBase as GalaPotentialBase,
)

from galax.potential._potential.base import AbstractPotentialBase
from galax.potential._potential.builtin import (
HernquistPotential,
IsochronePotential,
KeplerPotential,
MiyamotoNagaiPotential,
NFWPotential,
NullPotential,
)
from galax.potential._potential.composite import CompositePotential
from galax.potential._potential.special import MilkyWayPotential

##############################################################################
# GALA -> GALAX


def _static_at_origin(pot: GalaPotentialBase, /) -> bool:
return pot.R is None and np.array_equal(pot.origin, (0, 0, 0))


@singledispatch
def gala_to_galax(pot: GalaPotentialBase, /) -> AbstractPotentialBase:
"""Convert a :mod:`gala` potential to a :mod:`galax` potential.

Parameters
----------
pot : :class:`~gala.potential.PotentialBase`
:mod:`gala` potential.

Returns
-------
gala_pot : :class:`~galax.potential.AbstractPotentialBase`
:mod:`galax` potential.
"""
msg = (
"`gala_to_galax` does not have a registered function to convert "
f"{pot.__class__.__name__!r} to a `galax.AbstractPotentialBase` instance."
)
raise NotImplementedError(msg)


# -----------------------------------------------------------------------------
# General rules


@gala_to_galax.register
def _gala_to_galax_composite(pot: GalaCompositePotential, /) -> CompositePotential:
"""Convert a Gala CompositePotential to a Galax potential."""
return CompositePotential(**{k: gala_to_galax(p) for k, p in pot.items()})


# -----------------------------------------------------------------------------
# Builtin potentials


@gala_to_galax.register
def _gala_to_galax_hernquist(pot: GalaHernquistPotential, /) -> HernquistPotential:
"""Convert a Gala HernquistPotential to a Galax potential."""
if not _static_at_origin(pot):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)
params = pot.parameters
return HernquistPotential(m=params["m"], c=params["c"], units=pot.units)


@gala_to_galax.register
def _gala_to_galax_isochrone(pot: GalaIsochronePotential, /) -> IsochronePotential:
"""Convert a Gala IsochronePotential to a Galax potential."""
if not _static_at_origin(pot):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)
params = pot.parameters
return IsochronePotential(m=params["m"], b=params["b"], units=pot.units)


@gala_to_galax.register
def _gala_to_galax_kepler(pot: GalaKeplerPotential, /) -> KeplerPotential:
"""Convert a Gala KeplerPotential to a Galax potential."""
if not _static_at_origin(pot):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)
params = pot.parameters
return KeplerPotential(m=params["m"], units=pot.units)


@gala_to_galax.register
def _gala_to_galax_miyamotonagi(
pot: GalaMiyamotoNagaiPotential, /
) -> MiyamotoNagaiPotential:
"""Convert a Gala MiyamotoNagaiPotential to a Galax potential."""
if not _static_at_origin(pot):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)
params = pot.parameters
return MiyamotoNagaiPotential(
m=params["m"], a=params["a"], b=params["b"], units=pot.units
)


@gala_to_galax.register
def _gala_to_galax_nfw(pot: GalaNFWPotential, /) -> NFWPotential:
"""Convert a Gala NFWPotential to a Galax potential."""
if not _static_at_origin(pot):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)
params = pot.parameters
return NFWPotential(
m=params["m"], r_s=params["r_s"], softening_length=0, units=pot.units
)


@gala_to_galax.register
def _gala_to_galax_nullpotential(pot: GalaNullPotential, /) -> NullPotential:
"""Convert a Gala NullPotential to a Galax potential."""
if not _static_at_origin(pot):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)
return NullPotential(units=pot.units)


# -----------------------------------------------------------------------------
# MW potentials


@gala_to_galax.register
def _gala_to_galax_mwpotential(pot: GalaMilkyWayPotential, /) -> MilkyWayPotential:
"""Convert a Gala MilkyWayPotential to a Galax potential."""
if not all(_static_at_origin(p) for p in pot.values()):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)

return MilkyWayPotential(
disk={k: pot["disk"].parameters[k] for k in ("m", "a", "b")},
halo={k: pot["halo"].parameters[k] for k in ("m", "r_s")},
bulge={k: pot["bulge"].parameters[k] for k in ("m", "c")},
nucleus={k: pot["nucleus"].parameters[k] for k in ("m", "c")},
)
28 changes: 28 additions & 0 deletions src/galax/potential/_potential/io/gala_noop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Interoperability."""

__all__ = ["gala_to_galax"]


from typing import TYPE_CHECKING

from galax.potential._potential.base import AbstractPotentialBase

if TYPE_CHECKING:
from gala.potential import PotentialBase as GalaPotentialBase


def gala_to_galax(pot: "GalaPotentialBase", /) -> AbstractPotentialBase:
"""Convert a :mod:`gala` potential to a :mod:`galax` potential.
Parameters
----------
pot : :class:`~gala.potential.PotentialBase`
:mod:`gala` potential.
Returns
-------
gala_pot : :class:`~galax.potential.AbstractPotentialBase`
:mod:`galax` potential.
"""
msg = "The `gala` package must be installed to use this function. "
raise ImportError(msg)
5 changes: 3 additions & 2 deletions src/galax/potential/_potential/param/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
BatchableFloatOrIntScalarLike,
FloatArrayAnyShape,
FloatOrIntScalar,
FloatOrIntScalarLike,
FloatScalar,
Unit,
)
Expand All @@ -37,12 +38,12 @@ class AbstractParameter(eqx.Module, strict=True): # type: ignore[call-arg, misc
unit: Unit = eqx.field(static=True, converter=u.Unit)

@abc.abstractmethod
def __call__(self, t: FloatScalar, **kwargs: Any) -> FloatArrayAnyShape:
def __call__(self, t: FloatOrIntScalarLike, **kwargs: Any) -> FloatArrayAnyShape:
"""Compute the parameter value at the given time(s).

Parameters
----------
t : float | Array[float, ()]
t : Array[float | int, ()] | float | int
The time(s) at which to compute the parameter value.
**kwargs : Any
Additional parameters to pass to the parameter function.
Expand Down
17 changes: 9 additions & 8 deletions tests/smoke/potential/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ def test_all():
assert gp.__all__ == _potential.__all__

# Test detailed contents (not order)
assert set(gp.__all__) == set(
_potential.base.__all__
+ _potential.builtin.__all__
+ _potential.composite.__all__
+ _potential.core.__all__
+ _potential.param.__all__
+ _potential.special.__all__
)
assert set(gp.__all__) == {
"io",
*_potential.base.__all__,
*_potential.builtin.__all__,
*_potential.composite.__all__,
*_potential.core.__all__,
*_potential.param.__all__,
*_potential.special.__all__,
}
34 changes: 33 additions & 1 deletion tests/unit/potential/builtin/test_nfw.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from dataclasses import replace
from typing import Any

import astropy.units as u
import jax.experimental.array_api as xp
import jax.numpy as jnp
import pytest
from typing_extensions import override

import galax.potential as gp
from galax.potential import ConstantParameter
from galax.typing import Vec3
from galax.units import galactic
from galax.utils._optional_deps import HAS_GALA

from ..params.test_field import ParameterFieldMixin
from ..test_core import TestAbstractPotential as AbstractPotential_Test
Expand Down Expand Up @@ -47,21 +51,26 @@ def test_r_s_userfunc(self, pot_cls, fields):
assert pot.r_s(t=0) == 2


###############################################################################


class TestNFWPotential(
AbstractPotential_Test,
# Parameters
MassParameterMixin,
ScaleRadiusParameterMixin,
):
@pytest.fixture(scope="class")
@override
def pot_cls(self) -> type[gp.NFWPotential]:
return gp.NFWPotential

@pytest.fixture(scope="class")
def field_softening_length(self) -> dict[str, Any]:
def field_softening_length(self) -> float:
return 0.001

@pytest.fixture(scope="class")
@override
def fields_(
self, field_m, field_r_s, field_softening_length, field_units
) -> dict[str, Any]:
Expand Down Expand Up @@ -96,3 +105,26 @@ def test_hessian(self, pot, x):
]
),
)

# ==========================================================================
# I/O

@pytest.mark.skipif(not HAS_GALA, reason="requires gala")
def test_galax_to_gala_to_galax_roundtrip(
self, pot: gp.AbstractPotentialBase, x: Vec3
) -> None:
"""Test roundtripping ``gala_to_galax(galax_to_gala())``."""
from ..io.gala_helper import galax_to_gala

# Base is with non-zero softening
assert pot.softening_length != 0
with pytest.raises(TypeError, match="Gala does not support softening"):
_ = galax_to_gala(pot)

# Make a copy without softening
pot = replace(pot, softening_length=0)

rpot = gp.io.gala_to_galax(galax_to_gala(pot))

# quick test that the potential energies are the same
assert jnp.array_equal(pot(x, t=0), rpot(x, t=0))
2 changes: 1 addition & 1 deletion tests/unit/potential/builtin/test_null.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..test_core import TestAbstractPotential as AbstractPotential_Test


class TestBarPotential(AbstractPotential_Test):
class TestNullPotential(AbstractPotential_Test):
@pytest.fixture(scope="class")
def pot_cls(self) -> type[gp.NullPotential]:
return gp.NullPotential
Expand Down
Empty file.
Loading
Loading