Skip to content

Commit

Permalink
galax <-> gala potential conversion
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>

optional dependency

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 25, 2024
1 parent b563c9b commit 1339f2f
Show file tree
Hide file tree
Showing 12 changed files with 449 additions and 17 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ repos:
- id: mixed-line-ending
- id: name-tests-test
args: ["--pytest-test-first"]
exclude: '^tests\/.*_helper\.py$'
- id: requirements-txt-fixer
- id: trailing-whitespace

Expand Down
4 changes: 2 additions & 2 deletions src/galax/potential/_potential/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""galax: Galactic Dynamix in Jax."""


from . import base, builtin, composite, core, param, special, utils
from . import base, builtin, composite, core, io, param, special, utils
from .base import *
from .builtin import *
from .composite import *
Expand All @@ -10,7 +10,7 @@
from .special import *
from .utils import *

__all__: list[str] = []
__all__: list[str] = ["io"]
__all__ += base.__all__
__all__ += core.__all__
__all__ += composite.__all__
Expand Down
17 changes: 17 additions & 0 deletions src/galax/potential/_potential/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Input/output/conversion of potential objects.
This module contains the machinery for I/O and conversion of potential objects.
Conversion is useful for e.g. converting a
:class:`galax.potential.AbstractPotential` object to a
:class:`gala.potential.PotentialBase` object.
"""

__all__: list[str] = ["gala_to_galax"]


from galax.utils._optional_deps import HAS_GALA

if HAS_GALA:
from .gala import gala_to_galax
else:
from .gala_noop import gala_to_galax # type: ignore[assignment]
127 changes: 127 additions & 0 deletions src/galax/potential/_potential/io/gala.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Interoperability."""

__all__ = ["gala_to_galax"]

from functools import singledispatch

import numpy as np
from gala.potential import (
CompositePotential as GalaCompositePotential,
HernquistPotential as GalaHernquistPotential,
IsochronePotential as GalaIsochronePotential,
KeplerPotential as GalaKeplerPotential,
MiyamotoNagaiPotential as GalaMiyamotoNagaiPotential,
NFWPotential as GalaNFWPotential,
NullPotential as GalaNullPotential,
PotentialBase as GalaPotentialBase,
)

from galax.potential._potential.base import AbstractPotentialBase
from galax.potential._potential.builtin import (
HernquistPotential,
IsochronePotential,
KeplerPotential,
MiyamotoNagaiPotential,
NFWPotential,
NullPotential,
)
from galax.potential._potential.composite import CompositePotential

##############################################################################
# GALA -> GALAX


def _static_at_origin(pot: GalaPotentialBase, /) -> bool:
return pot.R is None and np.array_equal(pot.origin, (0, 0, 0))


@singledispatch
def gala_to_galax(pot: GalaPotentialBase, /) -> AbstractPotentialBase:
"""Convert a :mod:`gala` potential to a :mod:`galax` potential.
Parameters
----------
pot : :class:`~gala.potential.PotentialBase`
:mod:`gala` potential.
Returns
-------
gala_pot : :class:`~galax.potential.AbstractPotentialBase`
:mod:`galax` potential.
"""
msg = (
"`gala_to_galax` does not have a registered function to convert "
f"{pot.__class__.__name__!r} to a `galax.AbstractPotentialBase` instance."
)
raise NotImplementedError(msg)


@gala_to_galax.register
def _gala_to_galax_composite(pot: GalaCompositePotential, /) -> CompositePotential:
"""Convert a Gala CompositePotential to a Galax potential."""
return CompositePotential(**{k: gala_to_galax(p) for k, p in pot.items()})


@gala_to_galax.register
def _gala_to_galax_hernquist(pot: GalaHernquistPotential, /) -> HernquistPotential:
"""Convert a Gala HernquistPotential to a Galax potential."""
if not _static_at_origin(pot):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)
params = pot.parameters
return HernquistPotential(m=params["m"], c=params["c"], units=pot.units)


@gala_to_galax.register
def _gala_to_galax_isochrone(pot: GalaIsochronePotential, /) -> IsochronePotential:
"""Convert a Gala IsochronePotential to a Galax potential."""
if not _static_at_origin(pot):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)
params = pot.parameters
return IsochronePotential(m=params["m"], b=params["b"], units=pot.units)


@gala_to_galax.register
def _gala_to_galax_kepler(pot: GalaKeplerPotential, /) -> KeplerPotential:
"""Convert a Gala KeplerPotential to a Galax potential."""
if not _static_at_origin(pot):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)
params = pot.parameters
return KeplerPotential(m=params["m"], units=pot.units)


@gala_to_galax.register
def _gala_to_galax_miyamotonagi(
pot: GalaMiyamotoNagaiPotential, /
) -> MiyamotoNagaiPotential:
"""Convert a Gala MiyamotoNagaiPotential to a Galax potential."""
if not _static_at_origin(pot):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)
params = pot.parameters
return MiyamotoNagaiPotential(
m=params["m"], a=params["a"], b=params["b"], units=pot.units
)


@gala_to_galax.register
def _gala_to_galax_nfw(pot: GalaNFWPotential, /) -> NFWPotential:
"""Convert a Gala NFWPotential to a Galax potential."""
if not _static_at_origin(pot):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)
params = pot.parameters
return NFWPotential(
m=params["m"], r_s=params["r_s"], softening_length=0, units=pot.units
)


@gala_to_galax.register
def _gala_to_galax_nullpotential(pot: GalaNullPotential, /) -> NullPotential:
"""Convert a Gala NullPotential to a Galax potential."""
if not _static_at_origin(pot):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)
return NullPotential(units=pot.units)
28 changes: 28 additions & 0 deletions src/galax/potential/_potential/io/gala_noop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Interoperability."""

__all__ = ["gala_to_galax"]


from typing import TYPE_CHECKING

from galax.potential._potential.base import AbstractPotentialBase

if TYPE_CHECKING:
from gala.potential import PotentialBase as GalaPotentialBase


def gala_to_galax(pot: "GalaPotentialBase", /) -> AbstractPotentialBase:
"""Convert a :mod:`gala` potential to a :mod:`galax` potential.
Parameters
----------
pot : :class:`~gala.potential.PotentialBase`
:mod:`gala` potential.
Returns
-------
gala_pot : :class:`~galax.potential.AbstractPotentialBase`
:mod:`galax` potential.
"""
msg = "The `gala` package must be installed to use this function. "
raise ImportError(msg)
17 changes: 9 additions & 8 deletions tests/smoke/potential/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ def test_all():
assert gp.__all__ == _potential.__all__

# Test detailed contents (not order)
assert set(gp.__all__) == set(
_potential.base.__all__
+ _potential.builtin.__all__
+ _potential.composite.__all__
+ _potential.core.__all__
+ _potential.param.__all__
+ _potential.special.__all__
)
assert set(gp.__all__) == {
"io",
*_potential.base.__all__,
*_potential.builtin.__all__,
*_potential.composite.__all__,
*_potential.core.__all__,
*_potential.param.__all__,
*_potential.special.__all__,
}
34 changes: 33 additions & 1 deletion tests/unit/potential/builtin/test_nfw.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from dataclasses import replace
from typing import Any

import astropy.units as u
import jax.experimental.array_api as xp
import jax.numpy as jnp
import pytest
from typing_extensions import override

import galax.potential as gp
from galax.potential import ConstantParameter
from galax.typing import Vec3
from galax.units import galactic
from galax.utils._optional_deps import HAS_GALA

from ..params.test_field import ParameterFieldMixin
from ..test_core import TestAbstractPotential as AbstractPotential_Test
Expand Down Expand Up @@ -47,21 +51,26 @@ def test_r_s_userfunc(self, pot_cls, fields):
assert pot.r_s(t=0) == 2


###############################################################################


class TestNFWPotential(
AbstractPotential_Test,
# Parameters
MassParameterMixin,
ScaleRadiusParameterMixin,
):
@pytest.fixture(scope="class")
@override
def pot_cls(self) -> type[gp.NFWPotential]:
return gp.NFWPotential

@pytest.fixture(scope="class")
def field_softening_length(self) -> dict[str, Any]:
def field_softening_length(self) -> float:
return 0.001

@pytest.fixture(scope="class")
@override
def fields_(
self, field_m, field_r_s, field_softening_length, field_units
) -> dict[str, Any]:
Expand Down Expand Up @@ -96,3 +105,26 @@ def test_hessian(self, pot, x):
]
),
)

# ==========================================================================
# I/O

@pytest.mark.skipif(not HAS_GALA, reason="requires gala")
def test_galax_to_gala_to_galax_roundtrip(
self, pot: gp.AbstractPotentialBase, x: Vec3
) -> None:
"""Test roundtripping ``gala_to_galax(galax_to_gala())``."""
from ..io.gala_helper import galax_to_gala

# Base is with non-zero softening
assert pot.softening_length != 0
with pytest.raises(TypeError, match="Gala does not support softening"):
_ = galax_to_gala(pot)

# Make a copy without softening
pot = replace(pot, softening_length=0)

rpot = gp.io.gala_to_galax(galax_to_gala(pot))

# quick test that the potential energies are the same
assert jnp.array_equal(pot(x, t=0), rpot(x, t=0))
2 changes: 1 addition & 1 deletion tests/unit/potential/builtin/test_null.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..test_core import TestAbstractPotential as AbstractPotential_Test


class TestBarPotential(AbstractPotential_Test):
class TestNullPotential(AbstractPotential_Test):
@pytest.fixture(scope="class")
def pot_cls(self) -> type[gp.NullPotential]:
return gp.NullPotential
Expand Down
Empty file.
Loading

0 comments on commit 1339f2f

Please sign in to comment.