diff --git a/src/galax/potential/_potential/builtin.py b/src/galax/potential/_potential/builtin.py index b6be3dea..72f5a13e 100644 --- a/src/galax/potential/_potential/builtin.py +++ b/src/galax/potential/_potential/builtin.py @@ -1,12 +1,12 @@ """galax: Galactic Dynamix in Jax.""" __all__ = [ - "KeplerPotential", - "MiyamotoNagaiPotential", - "NullPotential", "BarPotential", "IsochronePotential", + "KeplerPotential", + "MiyamotoNagaiPotential", "NFWPotential", + "NullPotential", ] from dataclasses import KW_ONLY @@ -35,60 +35,6 @@ # ------------------------------------------------------------------- -class KeplerPotential(AbstractPotential): - r"""The Kepler potential for a point mass. - - .. math:: - \Phi = -\frac{G M(t)}{r} - """ - - m: AbstractParameter = ParameterField(dimensions=mass) # type: ignore[assignment] - - @partial_jit() - def _potential_energy( - self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike - ) -> BatchFloatScalar: - r = xp.linalg.norm(q, axis=-1) - return -self._G * self.m(t) / r - - -# ------------------------------------------------------------------- - - -class MiyamotoNagaiPotential(AbstractPotential): - m: AbstractParameter = ParameterField(dimensions=mass) # type: ignore[assignment] - a: AbstractParameter = ParameterField(dimensions=length) # type: ignore[assignment] - b: AbstractParameter = ParameterField(dimensions=length) # type: ignore[assignment] - - @partial_jit() - def _potential_energy( - self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike - ) -> BatchFloatScalar: - x, y, z = q[..., 0], q[..., 1], q[..., 2] - R2 = x**2 + y**2 - return ( - -self._G - * self.m(t) - / xp.sqrt(R2 + xp.square(xp.sqrt(z**2 + self.b(t) ** 2) + self.a(t))) - ) - - -# ------------------------------------------------------------------- - - -class NullPotential(AbstractPotential): - """Null potential, i.e. no potential.""" - - @partial_jit() - def _potential_energy( - self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike - ) -> BatchFloatScalar: - return xp.zeros(q.shape[:-1], dtype=q.dtype) - - -# ------------------------------------------------------------------- - - class BarPotential(AbstractPotential): """Rotating bar potentil, with hard-coded rotation. @@ -155,6 +101,47 @@ def _potential_energy( # ------------------------------------------------------------------- +class KeplerPotential(AbstractPotential): + r"""The Kepler potential for a point mass. + + .. math:: + \Phi = -\frac{G M(t)}{r} + """ + + m: AbstractParameter = ParameterField(dimensions=mass) # type: ignore[assignment] + + @partial_jit() + def _potential_energy( + self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike + ) -> BatchFloatScalar: + r = xp.linalg.norm(q, axis=-1) + return -self._G * self.m(t) / r + + +# ------------------------------------------------------------------- + + +class MiyamotoNagaiPotential(AbstractPotential): + m: AbstractParameter = ParameterField(dimensions=mass) # type: ignore[assignment] + a: AbstractParameter = ParameterField(dimensions=length) # type: ignore[assignment] + b: AbstractParameter = ParameterField(dimensions=length) # type: ignore[assignment] + + @partial_jit() + def _potential_energy( + self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike + ) -> BatchFloatScalar: + x, y, z = q[..., 0], q[..., 1], q[..., 2] + R2 = x**2 + y**2 + return ( + -self._G + * self.m(t) + / xp.sqrt(R2 + xp.square(xp.sqrt(z**2 + self.b(t) ** 2) + self.a(t))) + ) + + +# ------------------------------------------------------------------- + + class NFWPotential(AbstractPotential): """NFW Potential.""" @@ -171,3 +158,16 @@ def _potential_energy( r2 = q[..., 0] ** 2 + q[..., 1] ** 2 + q[..., 2] ** 2 m = xp.sqrt(r2 + self.softening_length) / self.r_s(t) return v_h2 * xp.log(1.0 + m) / m + + +# ------------------------------------------------------------------- + + +class NullPotential(AbstractPotential): + """Null potential, i.e. no potential.""" + + @partial_jit() + def _potential_energy( + self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike + ) -> BatchFloatScalar: + return xp.zeros(q.shape[:-1], dtype=q.dtype)