Skip to content

Commit

Permalink
Add HernquistPotential (#75)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 21, 2024
1 parent 03c6add commit f440431
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/galax/potential/_potential/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

__all__ = [
"BarPotential",
"HernquistPotential",
"IsochronePotential",
"KeplerPotential",
"MiyamotoNagaiPotential",
Expand Down Expand Up @@ -85,6 +86,23 @@ def _potential_energy(self, q: Vec3, /, t: FloatOrIntScalarLike) -> FloatScalar:
# -------------------------------------------------------------------


class HernquistPotential(AbstractPotential):
"""Hernquist Potential."""

m: AbstractParameter = ParameterField(dimensions=mass) # type: ignore[assignment]
a: AbstractParameter = ParameterField(dimensions=length) # type: ignore[assignment]

@partial_jit()
def _potential_energy(
self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike
) -> BatchFloatScalar:
r = xp.linalg.norm(q, axis=-1)
return -self._G * self.m(t) / (r + self.a(t))


# -------------------------------------------------------------------


class IsochronePotential(AbstractPotential):
m: AbstractParameter = ParameterField(dimensions=mass) # type: ignore[assignment]
a: AbstractParameter = ParameterField(dimensions=length) # type: ignore[assignment]
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/potential/builtin/test_hernquist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Any

import jax.numpy as xp
import pytest

from galax.potential import HernquistPotential

from ..test_core import TestAbstractPotential
from .test_common import MassParameterMixin, ShapeAParameterMixin


class TestHernquistPotential(
TestAbstractPotential,
# Parameters
MassParameterMixin,
ShapeAParameterMixin,
):
@pytest.fixture(scope="class")
def pot_cls(self) -> type[HernquistPotential]:
return HernquistPotential

@pytest.fixture(scope="class")
def fields_(self, field_m, field_a, field_units) -> dict[str, Any]:
return {"m": field_m, "a": field_a, "units": field_units}

# ==========================================================================

def test_potential_energy(self, pot, x) -> None:
assert xp.isclose(pot.potential_energy(x, t=0), xp.array(-0.94871936))

def test_gradient(self, pot, x):
assert xp.allclose(
pot.gradient(x, t=0), xp.array([0.05347411, 0.10694822, 0.16042233])
)

def test_density(self, pot, x):
assert xp.isclose(pot.density(x, t=0), 3.989933e08)

def test_hessian(self, pot, x):
assert xp.allclose(
pot.hessian(x, t=0),
xp.array(
[
[0.04362645, -0.01969533, -0.02954299],
[-0.01969533, 0.01408345, -0.05908599],
[-0.02954299, -0.05908599, -0.03515487],
]
),
)

def test_acceleration(self, pot, x):
assert xp.allclose(pot.acceleration(x, t=0), -pot.gradient(x, t=0))

0 comments on commit f440431

Please sign in to comment.