diff --git a/src/galax/potential/_potential/builtin.py b/src/galax/potential/_potential/builtin.py index 72f5a13e..f68671de 100644 --- a/src/galax/potential/_potential/builtin.py +++ b/src/galax/potential/_potential/builtin.py @@ -2,6 +2,7 @@ __all__ = [ "BarPotential", + "HernquistPotential", "IsochronePotential", "KeplerPotential", "MiyamotoNagaiPotential", @@ -85,6 +86,23 @@ def _potential_energy(self, q: Vec3, /, t: FloatOrIntScalarLike) -> FloatScalar: # ------------------------------------------------------------------- +class HernquistPotential(AbstractPotential): + """Hernquist Potential.""" + + m: AbstractParameter = ParameterField(dimensions=mass) # type: ignore[assignment] + a: AbstractParameter = ParameterField(dimensions=length) # type: ignore[assignment] + + @partial_jit() + def _potential_energy( + self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike + ) -> BatchFloatScalar: + r = xp.linalg.norm(q, axis=-1) + return -self._G * self.m(t) / (r + self.a(t)) + + +# ------------------------------------------------------------------- + + class IsochronePotential(AbstractPotential): m: AbstractParameter = ParameterField(dimensions=mass) # type: ignore[assignment] a: AbstractParameter = ParameterField(dimensions=length) # type: ignore[assignment] diff --git a/tests/unit/potential/builtin/test_hernquist.py b/tests/unit/potential/builtin/test_hernquist.py new file mode 100644 index 00000000..19aa7b71 --- /dev/null +++ b/tests/unit/potential/builtin/test_hernquist.py @@ -0,0 +1,52 @@ +from typing import Any + +import jax.numpy as xp +import pytest + +from galax.potential import HernquistPotential + +from ..test_core import TestAbstractPotential +from .test_common import MassParameterMixin, ShapeAParameterMixin + + +class TestHernquistPotential( + TestAbstractPotential, + # Parameters + MassParameterMixin, + ShapeAParameterMixin, +): + @pytest.fixture(scope="class") + def pot_cls(self) -> type[HernquistPotential]: + return HernquistPotential + + @pytest.fixture(scope="class") + def fields_(self, field_m, field_a, field_units) -> dict[str, Any]: + return {"m": field_m, "a": field_a, "units": field_units} + + # ========================================================================== + + def test_potential_energy(self, pot, x) -> None: + assert xp.isclose(pot.potential_energy(x, t=0), xp.array(-0.94871936)) + + def test_gradient(self, pot, x): + assert xp.allclose( + pot.gradient(x, t=0), xp.array([0.05347411, 0.10694822, 0.16042233]) + ) + + def test_density(self, pot, x): + assert xp.isclose(pot.density(x, t=0), 3.989933e08) + + def test_hessian(self, pot, x): + assert xp.allclose( + pot.hessian(x, t=0), + xp.array( + [ + [0.04362645, -0.01969533, -0.02954299], + [-0.01969533, 0.01408345, -0.05908599], + [-0.02954299, -0.05908599, -0.03515487], + ] + ), + ) + + def test_acceleration(self, pot, x): + assert xp.allclose(pot.acceleration(x, t=0), -pot.gradient(x, t=0))