Skip to content

Commit

Permalink
galax <-> gala potential conversion
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 21, 2024
1 parent b025e83 commit 3be4fd6
Show file tree
Hide file tree
Showing 10 changed files with 429 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/galax/potential/_potential/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""galax: Galactic Dynamix in Jax."""


from . import base, builtin, composite, core, param
from . import base, builtin, composite, core, io, param
from .base import *
from .builtin import *
from .composite import *
from .core import *
from .param import *

__all__: list[str] = []
__all__: list[str] = ["io"]
__all__ += base.__all__
__all__ += core.__all__
__all__ += composite.__all__
Expand Down
17 changes: 17 additions & 0 deletions src/galax/potential/_potential/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Input/output/conversion of potential objects.
This module contains the machinery for I/O and conversion of potential objects.
Conversion is useful for e.g. converting a
:class:`galax.potential.AbstractPotential` object to a
:class:`gala.potential.PotentialBase` object.
"""

__all__: list[str] = ["galax_to_gala", "gala_to_galax"]


from galax.utils._optional_deps import HAS_GALA

if HAS_GALA:
from .gala import gala_to_galax, galax_to_gala
else:
from .gala_noop import gala_to_galax, galax_to_gala # type: ignore[assignment]
257 changes: 257 additions & 0 deletions src/galax/potential/_potential/io/gala.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
"""Interoperability."""

__all__ = ["galax_to_gala", "gala_to_galax"]

from functools import singledispatch

import numpy as np
from gala.potential import (
CompositePotential as GalaCompositePotential,
HernquistPotential as GalaHernquistPotential,
IsochronePotential as GalaIsochronePotential,
KeplerPotential as GalaKeplerPotential,
MiyamotoNagaiPotential as GalaMiyamotoNagaiPotential,
NFWPotential as GalaNFWPotential,
NullPotential as GalaNullPotential,
PotentialBase as GalaPotentialBase,
)

from galax.potential._potential.base import AbstractPotentialBase
from galax.potential._potential.builtin import (
BarPotential,
HernquistPotential,
IsochronePotential,
KeplerPotential,
MiyamotoNagaiPotential,
NFWPotential,
NullPotential,
)
from galax.potential._potential.composite import CompositePotential
from galax.potential._potential.param import ConstantParameter

##############################################################################
# GALAX -> GALA


# TODO: this can be removed when AbstractPotential gets a `parameters`
# attribute that is a dict whose keys are the names of the parameters.
def _all_constant_parameters(
pot: "AbstractPotentialBase",
*params: str,
) -> bool:
return all(isinstance(getattr(pot, name), ConstantParameter) for name in params)


# TODO: add an argument to specify how to handle time-dependent parameters.
# Gala potentials are not time-dependent, so we need to specify how to
# handle time-dependent Galax parameters.
@singledispatch
def galax_to_gala(pot: AbstractPotentialBase, /) -> GalaPotentialBase:
"""Convert a Galax potential to a Gala potential.
Parameters
----------
pot : :class:`~galax.potential.AbstractPotentialBase`
Galax potential.
Returns
-------
gala_pot : :class:`~gala.potential.PotentialBase`
Gala potential.
"""
msg = (
"`galax_to_gala` does not have a registered function to convert "
f"{pot.__class__.__name__!r} to a `gala.PotentialBase` instance."
)
raise NotImplementedError(msg)


@galax_to_gala.register
def _galax_to_gala_composite(pot: CompositePotential, /) -> GalaCompositePotential:
"""Convert a Galax CompositePotential to a Gala potential."""
return GalaCompositePotential(**{k: galax_to_gala(p) for k, p in pot.items()})


@galax_to_gala.register
def _galax_to_gala_bar(pot: BarPotential, /) -> GalaPotentialBase:
"""Convert a Galax BarPotential to a Gala potential."""
raise NotImplementedError # TODO: implement


@galax_to_gala.register
def _galax_to_gala_hernquist(pot: HernquistPotential, /) -> GalaHernquistPotential:
"""Convert a Galax HernquistPotential to a Gala potential."""
if not _all_constant_parameters(pot, "m", "c"):
msg = "Gala does not support time-dependent parameters."
raise TypeError(msg)

return GalaHernquistPotential(
m=pot.m(0) * pot.units["mass"],
c=pot.c(0) * pot.units["length"],
units=pot.units,
)


@galax_to_gala.register
def _galax_to_gala_isochrone(pot: IsochronePotential, /) -> GalaIsochronePotential:
"""Convert a Galax IsochronePotential to a Gala potential."""
if not _all_constant_parameters(pot, "m", "a"):
msg = "Gala does not support time-dependent parameters."
raise TypeError(msg)

return GalaIsochronePotential(
m=pot.m(0) * pot.units["mass"],
b=pot.a(0) * pot.units["length"], # TODO: fix the mismatch
units=pot.units,
)


@galax_to_gala.register
def _galax_to_gala_kepler(pot: KeplerPotential, /) -> GalaKeplerPotential:
"""Convert a Galax KeplerPotential to a Gala potential."""
if not _all_constant_parameters(pot, "m"):
msg = "Gala does not support time-dependent parameters."
raise TypeError(msg)

return GalaKeplerPotential(m=pot.m(0) * pot.units["mass"], units=pot.units)


@galax_to_gala.register
def _galax_to_gala_miyamotonagi(
pot: MiyamotoNagaiPotential, /
) -> GalaMiyamotoNagaiPotential:
"""Convert a Galax MiyamotoNagaiPotential to a Gala potential."""
if not _all_constant_parameters(pot, "m", "a", "b"):
msg = "Gala does not support time-dependent parameters."
raise TypeError(msg)

return GalaMiyamotoNagaiPotential(
m=pot.m(0) * pot.units["mass"],
a=pot.a(0) * pot.units["length"],
b=pot.b(0) * pot.units["length"],
units=pot.units,
)


@galax_to_gala.register
def _galax_to_gala_nfw(pot: NFWPotential, /) -> GalaNFWPotential:
"""Convert a Galax NFWPotential to a Gala potential."""
if not _all_constant_parameters(pot, "m", "r_s"):
msg = "Gala does not support time-dependent parameters."
raise TypeError(msg)

if pot.softening_length != 0:
msg = "Gala does not support softening."
raise TypeError(msg)

return GalaNFWPotential(
m=pot.m(0) * pot.units["mass"],
r_s=pot.r_s(0) * pot.units["length"],
units=pot.units,
)


@galax_to_gala.register
def _galax_to_gala_nullpotential(pot: NullPotential, /) -> GalaNullPotential:
"""Convert a Galax NullPotential to a Gala potential."""
return GalaNullPotential(units=pot.units)


##############################################################################
# GALA -> GALAX


def _static_at_origin(pot: GalaPotentialBase, /) -> bool:
return pot.R is None and np.array_equal(pot.origin, (0, 0, 0))


@singledispatch
def gala_to_galax(pot: GalaPotentialBase, /) -> AbstractPotentialBase:
"""Convert a :mod:`gala` potential to a :mod:`galax` potential.
Parameters
----------
pot : :class:`~gala.potential.PotentialBase`
:mod:`gala` potential.
Returns
-------
gala_pot : :class:`~galax.potential.AbstractPotentialBase`
:mod:`galax` potential.
"""
msg = (
"`gala_to_galax` does not have a registered function to convert "
f"{pot.__class__.__name__!r} to a `galax.AbstractPotentialBase` instance."
)
raise NotImplementedError(msg)


@gala_to_galax.register
def _gala_to_galax_composite(pot: GalaCompositePotential, /) -> CompositePotential:
"""Convert a Gala CompositePotential to a Galax potential."""
return CompositePotential(**{k: gala_to_galax(p) for k, p in pot.items()})


@gala_to_galax.register
def _gala_to_galax_hernquist(pot: GalaHernquistPotential, /) -> HernquistPotential:
"""Convert a Gala HernquistPotential to a Galax potential."""
if not _static_at_origin(pot):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)
params = pot.parameters
return HernquistPotential(m=params["m"], c=params["c"], units=pot.units)


@gala_to_galax.register
def _gala_to_galax_isochrone(pot: GalaIsochronePotential, /) -> IsochronePotential:
"""Convert a Gala IsochronePotential to a Galax potential."""
if not _static_at_origin(pot):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)
params = pot.parameters
return IsochronePotential(m=params["m"], a=params["a"], units=pot.units)


@gala_to_galax.register
def _gala_to_galax_kepler(pot: GalaKeplerPotential, /) -> KeplerPotential:
"""Convert a Gala KeplerPotential to a Galax potential."""
if not _static_at_origin(pot):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)
params = pot.parameters
return KeplerPotential(m=params["m"], units=pot.units)


@gala_to_galax.register
def _gala_to_galax_miyamotonagi(
pot: GalaMiyamotoNagaiPotential, /
) -> MiyamotoNagaiPotential:
"""Convert a Gala MiyamotoNagaiPotential to a Galax potential."""
if not _static_at_origin(pot):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)
params = pot.parameters
return MiyamotoNagaiPotential(
m=params["m"], a=params["a"], b=params["b"], units=pot.units
)


@gala_to_galax.register
def _gala_to_galax_mnfw(pot: GalaNFWPotential, /) -> NFWPotential:
"""Convert a Gala NFWPotential to a Galax potential."""
if not _static_at_origin(pot):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)
params = pot.parameters
return NFWPotential(
m=params["m"], r_s=params["r_s"], softening_length=0, units=pot.units
)


@gala_to_galax.register
def _gala_to_galax_nullpotential(pot: GalaNullPotential, /) -> NullPotential:
"""Convert a Gala NullPotential to a Galax potential."""
if not _static_at_origin(pot):
msg = "Galax does not support rotating or offset potentials."
raise TypeError(msg)
return NullPotential(units=pot.units)
54 changes: 54 additions & 0 deletions src/galax/potential/_potential/io/gala_noop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Interoperability."""

__all__ = ["galax_to_gala", "gala_to_galax"]


from typing import TYPE_CHECKING

from galax.potential._potential.base import AbstractPotentialBase

if TYPE_CHECKING:
from gala.potential import PotentialBase as GalaPotentialBase

##############################################################################
# GALAX -> GALA


def galax_to_gala(pot: AbstractPotentialBase, /) -> "GalaPotentialBase":
"""Convert a Galax potential to a Gala potential.
Parameters
----------
pot : :class:`~galax.potential.AbstractPotentialBase`
Galax potential.
Returns
-------
gala_pot : :class:`~gala.potential.PotentialBase`
Gala potential.
"""
msg = (
"`galax_to_gala` does not have a registered function to convert "
f"{pot.__class__.__name__!r} to a `gala.PotentialBase` instance."
)
raise NotImplementedError(msg)


def gala_to_galax(pot: "GalaPotentialBase", /) -> AbstractPotentialBase:
"""Convert a :mod:`gala` potential to a :mod:`galax` potential.
Parameters
----------
pot : :class:`~gala.potential.PotentialBase`
:mod:`gala` potential.
Returns
-------
gala_pot : :class:`~galax.potential.AbstractPotentialBase`
:mod:`galax` potential.
"""
msg = (
"`gala_to_galax` does not have a registered function to convert "
f"{pot.__class__.__name__!r} to a `gala.PotentialBase` instance."
)
raise NotImplementedError(msg)
7 changes: 7 additions & 0 deletions src/galax/utils/_optional_deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Optional dependencies."""

__all__ = ["HAS_GALA"]

from importlib.util import find_spec

HAS_GALA = find_spec("gala") is not None
15 changes: 8 additions & 7 deletions tests/smoke/potential/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ def test_all():
assert gp.__all__ == _potential.__all__

# Test detailed contents (not order)
assert set(gp.__all__) == set(
_potential.base.__all__
+ _potential.builtin.__all__
+ _potential.composite.__all__
+ _potential.core.__all__
+ _potential.param.__all__
)
assert set(gp.__all__) == {
"io",
*_potential.base.__all__,
*_potential.builtin.__all__,
*_potential.composite.__all__,
*_potential.core.__all__,
*_potential.param.__all__,
}
Empty file.
Loading

0 comments on commit 3be4fd6

Please sign in to comment.