Skip to content

Commit

Permalink
sort the potentials (#74)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 21, 2024
1 parent b703c69 commit 03c6add
Showing 1 changed file with 57 additions and 57 deletions.
114 changes: 57 additions & 57 deletions src/galax/potential/_potential/builtin.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""galax: Galactic Dynamix in Jax."""

__all__ = [
"KeplerPotential",
"MiyamotoNagaiPotential",
"NullPotential",
"BarPotential",
"IsochronePotential",
"KeplerPotential",
"MiyamotoNagaiPotential",
"NFWPotential",
"NullPotential",
]

from dataclasses import KW_ONLY
Expand Down Expand Up @@ -35,60 +35,6 @@
# -------------------------------------------------------------------


class KeplerPotential(AbstractPotential):
r"""The Kepler potential for a point mass.
.. math::
\Phi = -\frac{G M(t)}{r}
"""

m: AbstractParameter = ParameterField(dimensions=mass) # type: ignore[assignment]

@partial_jit()
def _potential_energy(
self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike
) -> BatchFloatScalar:
r = xp.linalg.norm(q, axis=-1)
return -self._G * self.m(t) / r


# -------------------------------------------------------------------


class MiyamotoNagaiPotential(AbstractPotential):
m: AbstractParameter = ParameterField(dimensions=mass) # type: ignore[assignment]
a: AbstractParameter = ParameterField(dimensions=length) # type: ignore[assignment]
b: AbstractParameter = ParameterField(dimensions=length) # type: ignore[assignment]

@partial_jit()
def _potential_energy(
self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike
) -> BatchFloatScalar:
x, y, z = q[..., 0], q[..., 1], q[..., 2]
R2 = x**2 + y**2
return (
-self._G
* self.m(t)
/ xp.sqrt(R2 + xp.square(xp.sqrt(z**2 + self.b(t) ** 2) + self.a(t)))
)


# -------------------------------------------------------------------


class NullPotential(AbstractPotential):
"""Null potential, i.e. no potential."""

@partial_jit()
def _potential_energy(
self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike
) -> BatchFloatScalar:
return xp.zeros(q.shape[:-1], dtype=q.dtype)


# -------------------------------------------------------------------


class BarPotential(AbstractPotential):
"""Rotating bar potentil, with hard-coded rotation.
Expand Down Expand Up @@ -155,6 +101,47 @@ def _potential_energy(
# -------------------------------------------------------------------


class KeplerPotential(AbstractPotential):
r"""The Kepler potential for a point mass.
.. math::
\Phi = -\frac{G M(t)}{r}
"""

m: AbstractParameter = ParameterField(dimensions=mass) # type: ignore[assignment]

@partial_jit()
def _potential_energy(
self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike
) -> BatchFloatScalar:
r = xp.linalg.norm(q, axis=-1)
return -self._G * self.m(t) / r


# -------------------------------------------------------------------


class MiyamotoNagaiPotential(AbstractPotential):
m: AbstractParameter = ParameterField(dimensions=mass) # type: ignore[assignment]
a: AbstractParameter = ParameterField(dimensions=length) # type: ignore[assignment]
b: AbstractParameter = ParameterField(dimensions=length) # type: ignore[assignment]

@partial_jit()
def _potential_energy(
self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike
) -> BatchFloatScalar:
x, y, z = q[..., 0], q[..., 1], q[..., 2]
R2 = x**2 + y**2
return (
-self._G
* self.m(t)
/ xp.sqrt(R2 + xp.square(xp.sqrt(z**2 + self.b(t) ** 2) + self.a(t)))
)


# -------------------------------------------------------------------


class NFWPotential(AbstractPotential):
"""NFW Potential."""

Expand All @@ -171,3 +158,16 @@ def _potential_energy(
r2 = q[..., 0] ** 2 + q[..., 1] ** 2 + q[..., 2] ** 2
m = xp.sqrt(r2 + self.softening_length) / self.r_s(t)
return v_h2 * xp.log(1.0 + m) / m


# -------------------------------------------------------------------


class NullPotential(AbstractPotential):
"""Null potential, i.e. no potential."""

@partial_jit()
def _potential_energy(
self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike
) -> BatchFloatScalar:
return xp.zeros(q.shape[:-1], dtype=q.dtype)

0 comments on commit 03c6add

Please sign in to comment.