diff --git a/pyproject.toml b/pyproject.toml index ed3d2327..35d7b08f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,12 @@ dependencies = [ ] [project.optional-dependencies] +compat-gala = [ + "gala", +] +compat-all = [ + "galax[compat-gala]", +] test = [ "hypothesis[numpy]", "pytest >=6", @@ -57,6 +63,7 @@ docs = [ "furo>=2023.08.17", ] all = [ + "galax[compat-all]", "galax[test]", "galax[docs]", "galax[dev]", diff --git a/src/galax/potential/_potential/utils.py b/src/galax/potential/_potential/utils.py index 75f95b38..6292cf0d 100644 --- a/src/galax/potential/_potential/utils.py +++ b/src/galax/potential/_potential/utils.py @@ -10,7 +10,13 @@ from astropy.units import Quantity from jax import Array -from galax.units import UnitSystem, dimensionless, galactic, solarsystem +from galax.units import ( + DimensionlessUnitSystem, + UnitSystem, + dimensionless, + galactic, + solarsystem, +) @singledispatch @@ -141,3 +147,29 @@ def _convert_from_representation( ) ) return _convert_from_baserep(value, units=units) + + +############################################################################## +# Gala compatibility +# TODO: move this to an interoperability module + +# isort: split +from galax.utils._optional_deps import HAS_GALA # noqa: E402 + +if HAS_GALA: + from gala.units import ( + DimensionlessUnitSystem as GalaDimensionlessUnitSystem, + UnitSystem as GalaUnitSystem, + ) + + @converter_to_usys.register + def _from_gala(value: GalaUnitSystem, /) -> UnitSystem: + usys = UnitSystem(*value._core_units) # noqa: SLF001 + usys._registry = value._registry # noqa: SLF001 + return usys + + @converter_to_usys.register + def _from_gala_dimensionless( + value: GalaDimensionlessUnitSystem, / + ) -> DimensionlessUnitSystem: + return dimensionless diff --git a/src/galax/utils/_optional_deps.py b/src/galax/utils/_optional_deps.py new file mode 100644 index 00000000..60c37c14 --- /dev/null +++ b/src/galax/utils/_optional_deps.py @@ -0,0 +1,7 @@ +"""Optional dependencies.""" + +__all__ = ["HAS_GALA"] + +from importlib.util import find_spec + +HAS_GALA = find_spec("gala") is not None diff --git a/tests/unit/potential/test_utils.py b/tests/unit/potential/test_utils.py index 2a809096..09390703 100644 --- a/tests/unit/potential/test_utils.py +++ b/tests/unit/potential/test_utils.py @@ -12,6 +12,7 @@ galactic, solarsystem, ) +from galax.utils._optional_deps import HAS_GALA class TestConverterToUtils: @@ -45,6 +46,23 @@ def test_from_name(self): with pytest.raises(NotImplementedError): converter_to_usys("invalid_value") + @pytest.mark.skipif(not HAS_GALA, reason="requires gala") + def test_from_gala(self): + """Test conversion from gala.""" + # ------------------------------- + # UnitSystem + from gala.units import UnitSystem as GalaUnitSystem + + value = GalaUnitSystem(u.km, u.s, u.Msun, u.radian) + assert converter_to_usys(value) == UnitSystem(*value._core_units) + + # ------------------------------- + # DimensionlessUnitSystem + from gala.units import DimensionlessUnitSystem as GalaDimensionlessUnitSystem + + value = GalaDimensionlessUnitSystem() + assert converter_to_usys(value) == dimensionless + # ============================================================================