Skip to content

Commit

Permalink
PhaseSpacePosition (#6)
Browse files Browse the repository at this point in the history
* PhaseSpacePosition

Signed-off-by: nstarman <nstarman@users.noreply.github.com>

* convenience methods and docstrings

Signed-off-by: nstarman <nstarman@users.noreply.github.com>

* convenience functions

Signed-off-by: nstarman <nstarman@users.noreply.github.com>

---------

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
Co-authored-by: Jacob Nibauer <jnibauer@princeton.edu>
  • Loading branch information
nstarman and jnibauer committed Dec 5, 2023
1 parent e3eaf24 commit 625988b
Show file tree
Hide file tree
Showing 6 changed files with 384 additions and 54 deletions.
4 changes: 3 additions & 1 deletion src/galdynamix/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from __future__ import annotations

from . import _orbit, mockstream
from . import _core, _orbit, mockstream
from ._core import *
from ._orbit import *
from .mockstream import *

__all__: list[str] = []
__all__ += _core.__all__
__all__ += _orbit.__all__
__all__ += mockstream.__all__
160 changes: 160 additions & 0 deletions src/galdynamix/dynamics/_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""galdynamix: Galactic Dynamix in Jax."""

from __future__ import annotations

__all__ = ["PhaseSpacePosition"]

from typing import TYPE_CHECKING, cast

import equinox as eqx
import jax.numpy as xp
import jax.typing as jt

from galdynamix.utils._jax import partial_jit

if TYPE_CHECKING:
from galdynamix.potential._potential.base import AbstractPotentialBase


class PhaseSpacePosition(eqx.Module): # type: ignore[misc]
"""Orbit.
Todo:
----
- Units stuff
- GR stuff
"""

q: jt.Array
"""Position of the stream particles (x, y, z) [kpc]."""

p: jt.Array
"""Position of the stream particles (x, y, z) [kpc/Myr]."""

t: jt.Array
"""Array of times [Myr]."""

@property
@partial_jit()
def qp(self) -> jt.Array:
"""Return as a single Array[(N, Q + P),]."""
# Determine output shape
qd = self.q.shape[1] # dimensionality of q
shape = (self.q.shape[0], qd + self.p.shape[1])
# Create output array (jax will fuse these ops)
out = xp.empty(shape)
out = out.at[:, :qd].set(self.q)
out = out.at[:, qd:].set(self.p)
return out # noqa: RET504

@property
@partial_jit()
def w(self) -> jt.Array:
"""Return as a single Array[(N, Q + P + T),]."""
qp = self.qp
qpd = qp.shape[1] # dimensionality of qp
# Reshape t to (N, 1) if necessary
t = self.t[:, None] if self.t.ndim == 1 else self.t
# Determine output shape
shape = (qp.shape[0], qpd + t.shape[1])
# Create output array (jax will fuse these ops)
out = xp.empty(shape)
out = out.at[:, :qpd].set(qp)
out = out.at[:, qpd:].set(t)
return out # noqa: RET504

# ==========================================================================
# Array stuff

@property
def shape(self) -> tuple[int, ...]:
"""Shape of the position and velocity arrays."""
return cast(
tuple[int, ...],
xp.broadcast_shapes(self.q.shape, self.p.shape, self.t.shape),
)

# ==========================================================================
# Dynamical quantities

@partial_jit()
def kinetic_energy(self) -> jt.Array:
r"""Return the specific kinetic energy.
.. math::
E_K = \frac{1}{2} \\, |\boldsymbol{v}|^2
Returns
-------
E : :class:`~astropy.units.Quantity`
The kinetic energy.
"""
# TODO: use a ``norm`` function
return 0.5 * xp.sum(self.p**2, axis=-1)

@partial_jit()
def potential_energy(self, potential: AbstractPotentialBase, /) -> jt.Array:
r"""Return the specific potential energy.
.. math::
E_\Phi = \Phi(\boldsymbol{q})
Parameters
----------
potential : `galdynamix.potential.AbstractPotentialBase`
The potential object to compute the energy from.
Returns
-------
E : :class:`~jax.Array`
The specific potential energy.
"""
return potential.potential_energy(self, self.t)

@partial_jit()
def energy(self, potential: AbstractPotentialBase, /) -> jt.Array:
r"""Return the specific total energy.
.. math::
E_K = \frac{1}{2} \\, |\boldsymbol{v}|^2
E_\Phi = \Phi(\boldsymbol{q})
E = E_K + E_\Phi
Returns
-------
E : :class:`~astropy.units.Quantity`
The kinetic energy.
"""
return self.kinetic_energy() + self.potential_energy(potential)

@partial_jit()
def angular_momentum(self) -> jt.Array:
r"""Compute the angular momentum.
.. math::
\boldsymbol{{L}} = \boldsymbol{{q}} \times \boldsymbol{{p}}
See :ref:`shape-conventions` for more information about the shapes of
input and output objects.
Returns
-------
L : :class:`~astropy.units.Quantity`
Array of angular momentum vectors.
Examples
--------
>>> import numpy as np
>>> import astropy.units as u
>>> pos = np.array([1., 0, 0]) * u.au
>>> vel = np.array([0, 2*np.pi, 0]) * u.au/u.yr
>>> w = PhaseSpacePosition(pos, vel)
>>> w.angular_momentum() # doctest: +FLOAT_CMP
<Quantity [0. ,0. ,6.28318531] AU2 / yr>
"""
# TODO: when q, p are not Cartesian.
return xp.cross(self.q, self.p)
90 changes: 47 additions & 43 deletions src/galdynamix/dynamics/_orbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,61 +4,65 @@

__all__ = ["Orbit"]


import equinox as eqx
import jax.numpy as xp
import jax.typing as jt

from galdynamix.potential._potential.base import AbstractPotentialBase
from galdynamix.utils._jax import partial_jit

from ._core import PhaseSpacePosition

class Orbit(eqx.Module): # type: ignore[misc]
"""Orbit.
Todo:
----
- Units stuff
- GR stuff
"""

q: jt.Array
"""Position of the stream particles (x, y, z) [kpc]."""
class Orbit(PhaseSpacePosition):
"""Represents an orbit.
p: jt.Array
"""Position of the stream particles (x, y, z) [kpc/Myr]."""
Represents an orbit: positions and velocities (conjugate momenta) as a
function of time.
t: jt.Array
"""Array of times [Myr]."""
"""

potential: AbstractPotentialBase
"""Potential in which the orbit was integrated."""

@property
# ==========================================================================
# Dynamical quantities

@partial_jit()
def qp(self) -> jt.Array:
"""Return as a single Array[(N, Q + P),]."""
# Determine output shape
qd = self.q.shape[1] # dimensionality of q
shape = (self.q.shape[0], qd + self.p.shape[1])
# Create output array (jax will fuse these ops)
out = xp.empty(shape)
out = out.at[:, :qd].set(self.q)
out = out.at[:, qd:].set(self.p)
return out # noqa: RET504

@property
def potential_energy(
self, potential: AbstractPotentialBase | None = None, /
) -> jt.Array:
r"""Return the specific potential energy.
.. math::
E_\Phi = \Phi(\boldsymbol{q})
Parameters
----------
potential : `galdynamix.potential.AbstractPotentialBase`
The potential object to compute the energy from.
Returns
-------
E : :class:`~jax.Array`
The specific potential energy.
"""
if potential is None:
return self.potential.potential_energy(self, self.t)
return potential.potential_energy(self, self.t)

@partial_jit()
def w(self) -> jt.Array:
"""Return as a single Array[(N, Q + P + T),]."""
qp = self.qp
qpd = qp.shape[1] # dimensionality of qp
# Reshape t to (N, 1) if necessary
t = self.t[:, None] if self.t.ndim == 1 else self.t
# Determine output shape
shape = (qp.shape[0], qpd + t.shape[1])
# Create output array (jax will fuse these ops)
out = xp.empty(shape)
out = out.at[:, :qpd].set(qp)
out = out.at[:, qpd:].set(t)
return out # noqa: RET504
def energy(self, potential: AbstractPotentialBase | None = None, /) -> jt.Array:
r"""Return the specific total energy.
.. math::
E_K = \frac{1}{2} \\, |\boldsymbol{v}|^2
E_\Phi = \Phi(\boldsymbol{q})
E = E_K + E_\Phi
Returns
-------
E : :class:`~astropy.units.Quantity`
The kinetic energy.
"""
return self.kinetic_energy() + self.potential_energy(potential)
Loading

0 comments on commit 625988b

Please sign in to comment.