From 625988b98280eb521f45bf028530b435c14f5106 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Tue, 5 Dec 2023 13:33:54 -0500 Subject: [PATCH] PhaseSpacePosition (#6) * PhaseSpacePosition Signed-off-by: nstarman * convenience methods and docstrings Signed-off-by: nstarman * convenience functions Signed-off-by: nstarman --------- Signed-off-by: nstarman Co-authored-by: Jacob Nibauer --- src/galdynamix/dynamics/__init__.py | 4 +- src/galdynamix/dynamics/_core.py | 160 ++++++++++++++++++ src/galdynamix/dynamics/_orbit.py | 90 +++++----- src/galdynamix/potential/_potential/base.py | 115 ++++++++++++- .../potential/_potential/composite.py | 52 +++++- src/galdynamix/utils/_collections.py | 17 ++ 6 files changed, 384 insertions(+), 54 deletions(-) create mode 100644 src/galdynamix/dynamics/_core.py diff --git a/src/galdynamix/dynamics/__init__.py b/src/galdynamix/dynamics/__init__.py index f7001a91..cd1f19c1 100644 --- a/src/galdynamix/dynamics/__init__.py +++ b/src/galdynamix/dynamics/__init__.py @@ -2,10 +2,12 @@ from __future__ import annotations -from . import _orbit, mockstream +from . import _core, _orbit, mockstream +from ._core import * from ._orbit import * from .mockstream import * __all__: list[str] = [] +__all__ += _core.__all__ __all__ += _orbit.__all__ __all__ += mockstream.__all__ diff --git a/src/galdynamix/dynamics/_core.py b/src/galdynamix/dynamics/_core.py new file mode 100644 index 00000000..20a7edb5 --- /dev/null +++ b/src/galdynamix/dynamics/_core.py @@ -0,0 +1,160 @@ +"""galdynamix: Galactic Dynamix in Jax.""" + +from __future__ import annotations + +__all__ = ["PhaseSpacePosition"] + +from typing import TYPE_CHECKING, cast + +import equinox as eqx +import jax.numpy as xp +import jax.typing as jt + +from galdynamix.utils._jax import partial_jit + +if TYPE_CHECKING: + from galdynamix.potential._potential.base import AbstractPotentialBase + + +class PhaseSpacePosition(eqx.Module): # type: ignore[misc] + """Orbit. + + Todo: + ---- + - Units stuff + - GR stuff + """ + + q: jt.Array + """Position of the stream particles (x, y, z) [kpc].""" + + p: jt.Array + """Position of the stream particles (x, y, z) [kpc/Myr].""" + + t: jt.Array + """Array of times [Myr].""" + + @property + @partial_jit() + def qp(self) -> jt.Array: + """Return as a single Array[(N, Q + P),].""" + # Determine output shape + qd = self.q.shape[1] # dimensionality of q + shape = (self.q.shape[0], qd + self.p.shape[1]) + # Create output array (jax will fuse these ops) + out = xp.empty(shape) + out = out.at[:, :qd].set(self.q) + out = out.at[:, qd:].set(self.p) + return out # noqa: RET504 + + @property + @partial_jit() + def w(self) -> jt.Array: + """Return as a single Array[(N, Q + P + T),].""" + qp = self.qp + qpd = qp.shape[1] # dimensionality of qp + # Reshape t to (N, 1) if necessary + t = self.t[:, None] if self.t.ndim == 1 else self.t + # Determine output shape + shape = (qp.shape[0], qpd + t.shape[1]) + # Create output array (jax will fuse these ops) + out = xp.empty(shape) + out = out.at[:, :qpd].set(qp) + out = out.at[:, qpd:].set(t) + return out # noqa: RET504 + + # ========================================================================== + # Array stuff + + @property + def shape(self) -> tuple[int, ...]: + """Shape of the position and velocity arrays.""" + return cast( + tuple[int, ...], + xp.broadcast_shapes(self.q.shape, self.p.shape, self.t.shape), + ) + + # ========================================================================== + # Dynamical quantities + + @partial_jit() + def kinetic_energy(self) -> jt.Array: + r"""Return the specific kinetic energy. + + .. math:: + + E_K = \frac{1}{2} \\, |\boldsymbol{v}|^2 + + Returns + ------- + E : :class:`~astropy.units.Quantity` + The kinetic energy. + """ + # TODO: use a ``norm`` function + return 0.5 * xp.sum(self.p**2, axis=-1) + + @partial_jit() + def potential_energy(self, potential: AbstractPotentialBase, /) -> jt.Array: + r"""Return the specific potential energy. + + .. math:: + + E_\Phi = \Phi(\boldsymbol{q}) + + Parameters + ---------- + potential : `galdynamix.potential.AbstractPotentialBase` + The potential object to compute the energy from. + + Returns + ------- + E : :class:`~jax.Array` + The specific potential energy. + """ + return potential.potential_energy(self, self.t) + + @partial_jit() + def energy(self, potential: AbstractPotentialBase, /) -> jt.Array: + r"""Return the specific total energy. + + .. math:: + + E_K = \frac{1}{2} \\, |\boldsymbol{v}|^2 + E_\Phi = \Phi(\boldsymbol{q}) + E = E_K + E_\Phi + + Returns + ------- + E : :class:`~astropy.units.Quantity` + The kinetic energy. + """ + return self.kinetic_energy() + self.potential_energy(potential) + + @partial_jit() + def angular_momentum(self) -> jt.Array: + r"""Compute the angular momentum. + + .. math:: + + \boldsymbol{{L}} = \boldsymbol{{q}} \times \boldsymbol{{p}} + + See :ref:`shape-conventions` for more information about the shapes of + input and output objects. + + Returns + ------- + L : :class:`~astropy.units.Quantity` + Array of angular momentum vectors. + + Examples + -------- + >>> import numpy as np + >>> import astropy.units as u + >>> pos = np.array([1., 0, 0]) * u.au + >>> vel = np.array([0, 2*np.pi, 0]) * u.au/u.yr + >>> w = PhaseSpacePosition(pos, vel) + >>> w.angular_momentum() # doctest: +FLOAT_CMP + + """ + # TODO: when q, p are not Cartesian. + return xp.cross(self.q, self.p) diff --git a/src/galdynamix/dynamics/_orbit.py b/src/galdynamix/dynamics/_orbit.py index 9d63fe4d..cdf0181d 100644 --- a/src/galdynamix/dynamics/_orbit.py +++ b/src/galdynamix/dynamics/_orbit.py @@ -4,61 +4,65 @@ __all__ = ["Orbit"] - -import equinox as eqx -import jax.numpy as xp import jax.typing as jt from galdynamix.potential._potential.base import AbstractPotentialBase from galdynamix.utils._jax import partial_jit +from ._core import PhaseSpacePosition -class Orbit(eqx.Module): # type: ignore[misc] - """Orbit. - - Todo: - ---- - - Units stuff - - GR stuff - """ - q: jt.Array - """Position of the stream particles (x, y, z) [kpc].""" +class Orbit(PhaseSpacePosition): + """Represents an orbit. - p: jt.Array - """Position of the stream particles (x, y, z) [kpc/Myr].""" + Represents an orbit: positions and velocities (conjugate momenta) as a + function of time. - t: jt.Array - """Array of times [Myr].""" + """ potential: AbstractPotentialBase """Potential in which the orbit was integrated.""" - @property + # ========================================================================== + # Dynamical quantities + @partial_jit() - def qp(self) -> jt.Array: - """Return as a single Array[(N, Q + P),].""" - # Determine output shape - qd = self.q.shape[1] # dimensionality of q - shape = (self.q.shape[0], qd + self.p.shape[1]) - # Create output array (jax will fuse these ops) - out = xp.empty(shape) - out = out.at[:, :qd].set(self.q) - out = out.at[:, qd:].set(self.p) - return out # noqa: RET504 - - @property + def potential_energy( + self, potential: AbstractPotentialBase | None = None, / + ) -> jt.Array: + r"""Return the specific potential energy. + + .. math:: + + E_\Phi = \Phi(\boldsymbol{q}) + + Parameters + ---------- + potential : `galdynamix.potential.AbstractPotentialBase` + The potential object to compute the energy from. + + Returns + ------- + E : :class:`~jax.Array` + The specific potential energy. + """ + if potential is None: + return self.potential.potential_energy(self, self.t) + return potential.potential_energy(self, self.t) + @partial_jit() - def w(self) -> jt.Array: - """Return as a single Array[(N, Q + P + T),].""" - qp = self.qp - qpd = qp.shape[1] # dimensionality of qp - # Reshape t to (N, 1) if necessary - t = self.t[:, None] if self.t.ndim == 1 else self.t - # Determine output shape - shape = (qp.shape[0], qpd + t.shape[1]) - # Create output array (jax will fuse these ops) - out = xp.empty(shape) - out = out.at[:, :qpd].set(qp) - out = out.at[:, qpd:].set(t) - return out # noqa: RET504 + def energy(self, potential: AbstractPotentialBase | None = None, /) -> jt.Array: + r"""Return the specific total energy. + + .. math:: + + E_K = \frac{1}{2} \\, |\boldsymbol{v}|^2 + E_\Phi = \Phi(\boldsymbol{q}) + E = E_K + E_\Phi + + Returns + ------- + E : :class:`~astropy.units.Quantity` + The kinetic energy. + """ + return self.kinetic_energy() + self.potential_energy(potential) diff --git a/src/galdynamix/potential/_potential/base.py b/src/galdynamix/potential/_potential/base.py index 68759cc2..7ffdae1f 100644 --- a/src/galdynamix/potential/_potential/base.py +++ b/src/galdynamix/potential/_potential/base.py @@ -3,6 +3,7 @@ __all__ = ["AbstractPotentialBase", "AbstractPotential"] import abc +import uuid from dataclasses import KW_ONLY, fields from typing import TYPE_CHECKING, Any @@ -20,6 +21,7 @@ if TYPE_CHECKING: from galdynamix.integrate._base import AbstractIntegrator + from galdynamix.potential._potential.composite import CompositePotential class AbstractPotentialBase(eqx.Module): # type: ignore[misc] @@ -32,7 +34,20 @@ class AbstractPotentialBase(eqx.Module): # type: ignore[misc] @abc.abstractmethod def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array: - """Compute the potential energy at the given position(s).""" + """Compute the potential energy at the given position(s). + + Parameters + ---------- + q : :class:`~jax.Array` + The position to compute the value of the potential. + t : :class:`~jax.Array` + The time at which to compute the value of the potential. + + Returns + ------- + E : :class:`~jax.Array` + The potential energy per unit mass or value of the potential. + """ raise NotImplementedError ########################################################################### @@ -65,11 +80,53 @@ def _init_units(self) -> None: @partial_jit() def __call__(self, q: jt.Array, /, t: jt.Array) -> jt.Array: - """Compute the potential energy at the given position(s).""" + """Compute the potential energy at the given position(s). + + See Also + -------- + potential_energy + """ return self.potential_energy(q, t) @partial_jit() - def gradient(self, q: jt.Array, /, t: jt.Array) -> jt.Array: + def gradient(self, q: jt.Array, /, t: jt.Array) -> jt.Array: phasespaceposition + """Compute the gradient of the potential at the given position(s). + + Parameters + ---------- + q : :class:`~jax.Array` + The position to compute the value of the potential. If the + input position object has no units (i.e. is an `~numpy.ndarray`), + it is assumed to be in the same unit system as the potential. + t : :class:`~jax.Array` + The time at which to compute the value of the potential. + + Returns + ------- + :class:`~jax.Array` + The gradient of the potential. + """ + return jax.grad(self.potential_energy)(q, t) + + @partial_jit() + def density(self, q: jt.Array, /, t: jt.Array) -> jt.Array: + """Compute the density value at the given position(s). + + Parameters + ---------- + q : :class:`~jax.Array` + The position to compute the value of the potential. If the + input position object has no units (i.e. is an `~numpy.ndarray`), + it is assumed to be in the same unit system as the potential. + t : :class:`~jax.Array` + The time at which to compute the value of the potential. + + Returns + ------- + :class:`~jax.Array` + The potential energy or value of the potential. + """ + lap = xp.trace(jax.hessian(self.potential_energy)(q, t)) """Compute the gradient.""" return jax.grad(self.potential_energy, argnums=0)(q, t) @@ -77,18 +134,52 @@ def gradient(self, q: jt.Array, /, t: jt.Array) -> jt.Array: def density(self, q: jt.Array, /, t: jt.Array) -> jt.Array: # Note: trace(jacobian(gradient)) is faster than trace(hessian(energy)) lap = xp.trace(jax.jacfwd(self.gradient)(q, t)) + return lap / (4 * xp.pi * self._G) @partial_jit() def hessian(self, q: jt.Array, /, t: jt.Array) -> jt.Array: + """Compute the Hessian of the potential at the given position(s). + + Parameters + ---------- + q : :class:`~jax.Array` + The position to compute the value of the potential. If the + input position object has no units (i.e. is an `~numpy.ndarray`), + it is assumed to be in the same unit system as the potential. + t : :class:`~jax.Array` + The time at which to compute the value of the potential. + + Returns + ------- + :class:`~jax.Array` + The Hessian matrix of second derivatives of the potential. + """ return jax.hessian(self.potential_energy)(q, t) + ########################################################################### + # Convenience methods + @partial_jit() def acceleration(self, q: jt.Array, /, t: jt.Array) -> jt.Array: + """Compute the acceleration due to the potential at the given position(s). + + Parameters + ---------- + q : :class:`~jax.Array` + Position to compute the acceleration at. + t : :class:`~jax.Array` + Time at which to compute the acceleration. + + Returns + ------- + :class:`~jax.Array` + The acceleration. Will have the same shape as the input + position array, ``q``. + """ return -self.gradient(q, t) - ########################################################################### - # Convenience methods + # ========================================================================= @partial_jit() def _integrator_F( @@ -113,6 +204,20 @@ def integrate_orbit( ws = integrator.run(w0, t0, t1, ts) return Orbit(q=ws[:, :3], p=ws[:, 3:-1], t=ws[:, -1], potential=self) + ########################################################################### + # Composite potentials + + def __add__(self, other: Any) -> CompositePotential: + if not isinstance(other, AbstractPotentialBase): + return NotImplemented + + from galdynamix.potential._potential.composite import CompositePotential + + if isinstance(other, CompositePotential): + return other.__ror__(self) + + return CompositePotential({str(uuid.uuid4()): self, str(uuid.uuid4()): other}) + # =========================================================================== diff --git a/src/galdynamix/potential/_potential/composite.py b/src/galdynamix/potential/_potential/composite.py index 3310e8e8..94d37a5b 100644 --- a/src/galdynamix/potential/_potential/composite.py +++ b/src/galdynamix/potential/_potential/composite.py @@ -3,8 +3,9 @@ __all__ = ["CompositePotential"] +import uuid from dataclasses import KW_ONLY -from typing import TypeVar, final +from typing import Any, TypeVar, final import equinox as eqx import jax.numpy as xp @@ -26,7 +27,9 @@ class CompositePotential(ImmutableDict[AbstractPotentialBase], AbstractPotential _data: dict[str, AbstractPotentialBase] _: KW_ONLY units: UnitSystem = eqx.field( - static=True, converter=lambda x: dimensionless if x is None else UnitSystem(x) + init=False, + static=True, + converter=lambda x: dimensionless if x is None else UnitSystem(x), ) _G: float = eqx.field(init=False, static=True) @@ -35,13 +38,20 @@ def __init__( potentials: dict[str, AbstractPotentialBase] | tuple[tuple[str, AbstractPotentialBase], ...] = (), /, - units: UnitSystem | None = None, **kwargs: AbstractPotentialBase, ) -> None: super().__init__(potentials, **kwargs) # type: ignore[arg-type] - self.units = self.__dataclass_fields__["units"].metadata["converter"](units) - # TODO: check unit systems of contained potentials to make sure they match. + self.__post_init__() + def __post_init__(self) -> None: + # Check that all potentials have the same unit system + units = next(iter(self.values())).units + if not all(p.units == units for p in self.values()): + msg = "all potentials must have the same unit system" + raise ValueError(msg) + object.__setattr__(self, "units", units) + + # Apply the unit system to any parameters. self._init_units() # === Potential === @@ -53,3 +63,35 @@ def potential_energy( t: jt.Array, ) -> jt.Array: return xp.sum(xp.array([p.potential_energy(q, t) for p in self.values()])) + + ########################################################################### + # Composite potentials + + def __or__(self, other: Any) -> CompositePotential: + if not isinstance(other, AbstractPotentialBase): + return NotImplemented + + return CompositePotential( # combine the two dictionaries + self._data + | ( # make `other` into a compatible dictionary. + other._data + if isinstance(other, CompositePotential) + else {str(uuid.uuid4()): other} + ) + ) + + def __ror__(self, other: Any) -> CompositePotential: + if not isinstance(other, AbstractPotentialBase): + return NotImplemented + + return CompositePotential( # combine the two dictionaries + ( # make `other` into a compatible dictionary. + other._data + if isinstance(other, CompositePotential) + else {str(uuid.uuid4()): other} + ) + | self._data + ) + + def __add__(self, other: AbstractPotentialBase) -> CompositePotential: + return self | other diff --git a/src/galdynamix/utils/_collections.py b/src/galdynamix/utils/_collections.py index a6ffc190..eb2f63e9 100644 --- a/src/galdynamix/utils/_collections.py +++ b/src/galdynamix/utils/_collections.py @@ -17,6 +17,23 @@ @register_pytree_node_class class ImmutableDict(Mapping[str, V]): + """Immutable string-keyed dictionary. + + Parameters + ---------- + *args : tuple[str, V] + Key-value pairs. + **kwargs : V + Key-value pairs. + + Examples + -------- + >>> from galdynamix.utils import ImmutableDict + >>> d = ImmutableDict(a=1, b=2) + >>> d + ImmutableDict({'a': 1, 'b': 2}) + """ + def __init__(self, /, *args: tuple[str, V], **kwargs: V) -> None: self._data: dict[str, V] = dict(*args, **kwargs)