Skip to content

Commit

Permalink
feat: Q in and out
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Mar 6, 2024
1 parent 4dc33bd commit c151305
Show file tree
Hide file tree
Showing 13 changed files with 171 additions and 224 deletions.
2 changes: 1 addition & 1 deletion src/galax/potential/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

__getattr__, __dir__, __all__ = attach_stub(__name__, __file__)

install_import_hook("galax.potential", RUNTIME_TYPECHECKER)
# install_import_hook("galax.potential", RUNTIME_TYPECHECKER) # noqa: ERA001

# Cleanup
del install_import_hook, RUNTIME_TYPECHECKER
78 changes: 37 additions & 41 deletions src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from astropy.constants import G as _G # pylint: disable=no-name-in-module
from astropy.constants import G as APYG
from astropy.coordinates import BaseRepresentation as APYRepresentation
from astropy.units import Quantity as APYQuantity
from jax import grad, hessian, jacfwd
Expand All @@ -22,15 +22,14 @@
from coordinax import Abstract3DVector, FourVector
from jax_quantity import Quantity

from .utils import _convert_from_3dvec, convert_input_to_array, convert_inputs_to_arrays
from .utils import _convert_from_3dvec, parse_to_quantity
from galax.coordinates import PhaseSpacePosition, PhaseSpaceTimePosition
from galax.coordinates._psp.psp import AbstractPhaseSpacePosition
from galax.coordinates._psp.pspt import AbstractPhaseSpaceTimePosition
from galax.potential._potential.param.attr import ParametersAttribute
from galax.potential._potential.param.utils import all_parameters
from galax.typing import (
BatchableFloatLike,
BatchableIntLike,
BatchableRealQScalar,
BatchableRealScalarLike,
BatchFloatScalar,
BatchMatrix33,
Expand All @@ -43,7 +42,9 @@
IntQScalar,
IntScalar,
Matrix33,
QVec3,
QVecTime,
RealQScalar,
RealScalar,
Vec3,
Vec6,
Expand Down Expand Up @@ -75,6 +76,8 @@
| APYQuantity
)

G = Quantity(APYG.value, APYG.unit)


class AbstractPotentialBase(eqx.Module, metaclass=ModuleMeta, strict=True): # type: ignore[misc]
"""Abstract Potential Class."""
Expand All @@ -97,8 +100,11 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
# Parsing

def _init_units(self) -> None:
G = 1 if self.units == dimensionless else _G.decompose(self.units).value
object.__setattr__(self, "_G", G)
object.__setattr__(
self,
"_G",
Quantity(1, "") if self.units == dimensionless else G.decompose(self.units),
)

from galax.potential._potential.param.field import ParameterField

Expand Down Expand Up @@ -131,8 +137,8 @@ def _init_units(self) -> None:
# @vectorize_method(signature="(3),()->()")
@abc.abstractmethod
def _potential_energy(
self, q: Vec3, t: RealScalar, /, _G: FloatScalar
) -> FloatScalar:
self, q: QVec3, t: RealQScalar, /, _G: FloatQScalar
) -> FloatQScalar:
"""Compute the potential energy at the given position(s).
This method MUST be implemented by subclasses.
Expand Down Expand Up @@ -214,10 +220,8 @@ def potential_energy(
Quantity['specific energy'](Array(-1.20227527, dtype=float64), unit='kpc2 / Myr2')
""" # noqa: E501
q = _convert_from_3dvec(pspt.q, units=self.units)
t = pspt.t.to_value(self.units["time"]) # TODO: value
return Quantity(
self._potential_energy(q, t, self._G), self.units["specific energy"]
)
t = pspt.t.to(self.units["time"]) # TODO: value
return self._potential_energy(q, t, self._G)

@dispatch
def potential_energy(
Expand Down Expand Up @@ -276,11 +280,9 @@ def potential_energy(
>>> pot.potential_energy(q, t)
Quantity['specific energy'](Array([-1.20227527, -0.5126519 ], dtype=float64), unit='kpc2 / Myr2')
""" # noqa: E501
q = convert_input_to_array(q, units=self.units, no_differentials=True)
t = Quantity.constructor(t, self.units["time"]).value # TODO: value
return Quantity(
self._potential_energy(q, t, self._G), self.units["specific energy"]
)
q = parse_to_quantity(q, unit=self.units["length"])
t = Quantity.constructor(t, self.units["time"])
return self._potential_energy(q, t, self._G)

@dispatch
def potential_energy(
Expand Down Expand Up @@ -365,11 +367,9 @@ def potential_energy(
>>> pot.potential_energy(q, t)
Quantity['specific energy'](Array([-1.20227527, -0.5126519 ], dtype=float64), unit='kpc2 / Myr2')
""" # noqa: E501
q = convert_input_to_array(q, units=self.units, no_differentials=True)
t = Quantity.constructor(t, self.units["time"]).value # TODO: value
return Quantity(
self._potential_energy(q, t, self._G), self.units["specific energy"]
)
q = parse_to_quantity(q, unit=self.units["length"])
t = Quantity.constructor(t, self.units["time"])
return self._potential_energy(q, t, self._G)

@dispatch
def potential_energy(
Expand Down Expand Up @@ -402,15 +402,9 @@ def potential_energy(
""" # noqa: E501
return self.potential_energy(q, t)

@quaxify # type: ignore[misc]
@partial(jax.jit)
def __call__(
self,
q: Shaped[Array, "*batch 3"], # TODO: enable more inputs
/,
t: ( # TODO: enable more inputs
BatchableRealScalarLike | BatchableFloatLike | BatchableIntLike
),
self, q: Shaped[Quantity, "*batch 3"], /, t: BatchableRealQScalar
) -> Float[Quantity["specific energy"], "*batch"]:
"""Compute the potential energy at the given position(s).
Expand All @@ -435,11 +429,12 @@ def __call__(
# ---------------------------------------
# Gradient

@quaxify # type: ignore[misc]
@partial(jax.jit)
@vectorize_method(signature="(3),()->(3)")
def _gradient(self, q: Vec3, /, t: RealScalar) -> Vec3:
def _gradient(self, q: QVec3, /, t: RealQScalar) -> QVec3:
"""See ``gradient``."""
return grad(self._potential_energy)(q, t, self._G)
return quaxify(grad(self._potential_energy))(q, t, self._G)

@dispatch
def gradient(
Expand Down Expand Up @@ -567,7 +562,7 @@ def gradient(
[0.02663127, 0.03328908, 0.0399469 ]], dtype=float64),
unit='kpc / Myr2')
""" # noqa: E501
q = convert_input_to_array(q, units=self.units, no_differentials=True)
q = parse_to_quantity(q, unit=self.units["length"])
t = Quantity.constructor(t, self.units["time"]).value # TODO: value
return Quantity(self._gradient(q, t), self.units["acceleration"])

Expand Down Expand Up @@ -706,7 +701,7 @@ def gradient(
[0.02663127, 0.03328908, 0.0399469 ]], dtype=float64),
unit='kpc / Myr2')
""" # noqa: E501
q = convert_input_to_array(q, units=self.units, no_differentials=True)
q = parse_to_quantity(q, unit=self.units["length"])
t = Quantity.constructor(t, self.units["time"]).value # TODO: value
return Quantity(self._gradient(q, t), self.units["acceleration"])

Expand Down Expand Up @@ -866,7 +861,7 @@ def laplacian(
>>> pot.laplacian(q, t)
Quantity['diffusivity'](Array([2.77555756e-17, 0.00000000e+00], dtype=float64), unit='kpc2 / Myr')
""" # noqa: E501
q = convert_input_to_array(q, units=self.units, no_differentials=True)
q = parse_to_quantity(q, unit=self.units["length"])
t = Quantity.constructor(t, self.units["time"]).value # TODO: value
return Quantity(self._laplacian(q, t), self.units["kinematic viscosity"])

Expand Down Expand Up @@ -993,7 +988,7 @@ def laplacian(
>>> pot.laplacian(q, t)
Quantity['diffusivity'](Array([2.77555756e-17, 0.00000000e+00], dtype=float64), unit='kpc2 / Myr')
""" # noqa: E501
q = convert_input_to_array(q, units=self.units, no_differentials=True)
q = parse_to_quantity(q, unit=self.units["length"])
t = Quantity.constructor(t, self.units["time"]).value # TODO: value
return Quantity(self._laplacian(q, t), self.units["kinematic viscosity"])

Expand Down Expand Up @@ -1152,7 +1147,7 @@ def density(
>>> pot.density(q, t)
Quantity['mass density'](Array([4.90989768e-07, 0.00000000e+00], dtype=float64), unit='solMass / kpc3')
""" # noqa: E501
q = convert_input_to_array(q, units=self.units, no_differentials=True)
q = parse_to_quantity(q, unit=self.units["length"])
t = Quantity.constructor(t, self.units["time"]).value # TODO: value
return Quantity(self._density(q, t), self.units["mass density"])

Expand Down Expand Up @@ -1241,7 +1236,7 @@ def density(
Quantity['mass density'](Array([4.90989768e-07, 0.00000000e+00], dtype=float64),
unit='solMass / kpc3')
""" # noqa: E501
q = convert_input_to_array(q, units=self.units, no_differentials=True)
q = parse_to_quantity(q, unit=self.units["length"])
t = Quantity.constructor(t, self.units["time"]).value # TODO: value
return Quantity(self._density(q, t), self.units["mass density"])

Expand Down Expand Up @@ -1427,7 +1422,7 @@ def hessian(
[-0.00518791, 0.00017293, -0.00778186],
[-0.00622549, -0.00778186, -0.00268042]]], dtype=float64)
""" # noqa: E501
q = convert_input_to_array(q, units=self.units, no_differentials=True)
q = parse_to_quantity(q, unit=self.units["length"])
t = Quantity.constructor(t, self.units["time"]).value # TODO: value
return self._hessian(q, t)

Expand Down Expand Up @@ -1527,7 +1522,8 @@ def hessian(
[-0.00518791, 0.00017293, -0.00778186],
[-0.00622549, -0.00778186, -0.00268042]]], dtype=float64)
"""
q, t = convert_inputs_to_arrays(q, t, units=self.units, no_differentials=True)
q = parse_to_quantity(q, unit=self.units["length"])
t = Quantity.constructor(t, self.units["time"])
return self._hessian(q, t)

@dispatch
Expand Down Expand Up @@ -1669,7 +1665,7 @@ def acceleration(
[-0.02663127, -0.03328908, -0.0399469 ]], dtype=float64),
unit='kpc / Myr2')
""" # noqa: E501
q = convert_input_to_array(q, units=self.units, no_differentials=True)
q = parse_to_quantity(q, unit=self.units["length"])
t = Quantity.constructor(t, self.units["time"]).value # TODO: value
return Quantity(-self._gradient(q, t), self.units["acceleration"])

Expand Down Expand Up @@ -1808,7 +1804,7 @@ def acceleration(
[-0.02663127, -0.03328908, -0.0399469 ]], dtype=float64),
unit='kpc / Myr2')
""" # noqa: E501
q = convert_input_to_array(q, units=self.units, no_differentials=True)
q = parse_to_quantity(q, unit=self.units["length"])
t = Quantity.constructor(t, self.units["time"]).value # TODO: value
return Quantity(-self._gradient(q, t), self.units["acceleration"])

Expand Down
Loading

0 comments on commit c151305

Please sign in to comment.