From 177bdb9e31bea27b409ba1aa8b61d91044296217 Mon Sep 17 00:00:00 2001 From: nstarman Date: Sat, 27 Jan 2024 20:25:46 -0500 Subject: [PATCH] feat(PhaseSpacePosition): add slicing via getitem Signed-off-by: nstarman --- src/galax/dynamics/_base.py | 15 ++++- src/galax/dynamics/_core.py | 6 +- src/galax/dynamics/_orbit.py | 12 ++-- src/galax/dynamics/_utils.py | 46 +++++++++++++++ src/galax/dynamics/mockstream/_core.py | 17 ++---- src/galax/potential/_potential/base.py | 4 +- src/galax/typing.py | 1 + tests/unit/dynamics/test_core.py | 80 ++++++++++++++++++++------ 8 files changed, 141 insertions(+), 40 deletions(-) create mode 100644 src/galax/dynamics/_utils.py diff --git a/src/galax/dynamics/_base.py b/src/galax/dynamics/_base.py index 40612213..5992311d 100644 --- a/src/galax/dynamics/_base.py +++ b/src/galax/dynamics/_base.py @@ -3,8 +3,9 @@ __all__ = ["AbstractPhaseSpacePosition"] from abc import abstractmethod +from dataclasses import replace from functools import partial -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import equinox as eqx import jax @@ -16,7 +17,11 @@ from galax.units import UnitSystem from galax.utils._shape import atleast_batched +from ._utils import getitem_time_index + if TYPE_CHECKING: + from typing import Self + from galax.potential._potential.base import AbstractPotentialBase @@ -57,6 +62,14 @@ def __len__(self) -> int: """Return the number of particles.""" return self.shape[0] + def __getitem__(self, index: Any) -> "Self": + """Return a new object with the given slice applied.""" + # This is the default implementation, but subclasses can override this + # Compute subindex + subindex = getitem_time_index(index, self.t) + # Apply slice + return replace(self, q=self.q[index], p=self.p[index], t=self.t[subindex]) + # ========================================================================== @property diff --git a/src/galax/dynamics/_core.py b/src/galax/dynamics/_core.py index c29c642c..9ee9e4f8 100644 --- a/src/galax/dynamics/_core.py +++ b/src/galax/dynamics/_core.py @@ -9,7 +9,7 @@ import jax import jax.numpy as jnp -from galax.typing import BatchFloatScalar, BroadBatchVec1, BroadBatchVec3, Vec1 +from galax.typing import BatchFloatScalar, BroadBatchFloatScalar, BroadBatchVec3, Vec1 from galax.utils._shape import batched_shape, expand_batch_dims from galax.utils.dataclasses import converter_float_array @@ -35,7 +35,7 @@ class PhaseSpacePosition(AbstractPhaseSpacePosition): This is a 3-vector with a batch shape allowing for vector inputs. """ - t: BroadBatchVec1 | Vec1 = eqx.field( + t: BroadBatchFloatScalar | Vec1 = eqx.field( default=(0.0,), converter=converter_float_array ) """The time corresponding to the positions. @@ -47,7 +47,7 @@ class PhaseSpacePosition(AbstractPhaseSpacePosition): def __post_init__(self) -> None: """Post-initialization.""" - # Need to ensure t shape is correct + # Need to ensure t shape is correct. Can be Vec0. if self.t.ndim == 0: t = expand_batch_dims(self.t, ndim=self.q.ndim) object.__setattr__(self, "t", t) diff --git a/src/galax/dynamics/_orbit.py b/src/galax/dynamics/_orbit.py index 8fa28966..cea68f5b 100644 --- a/src/galax/dynamics/_orbit.py +++ b/src/galax/dynamics/_orbit.py @@ -7,10 +7,9 @@ import equinox as eqx import jax import jax.numpy as jnp -from jaxtyping import Array, Float from galax.potential._potential.base import AbstractPotentialBase -from galax.typing import BatchFloatScalar, BatchVecTime +from galax.typing import BatchFloatScalar, BroadBatchVec3, VecTime from galax.utils._shape import batched_shape from galax.utils.dataclasses import converter_float_array @@ -25,13 +24,14 @@ class Orbit(AbstractPhaseSpacePosition): """ - q: Float[Array, "*batch time 3"] = eqx.field(converter=converter_float_array) + q: BroadBatchVec3 = eqx.field(converter=converter_float_array) """Positions (x, y, z).""" - p: Float[Array, "*batch time 3"] = eqx.field(converter=converter_float_array) + p: BroadBatchVec3 = eqx.field(converter=converter_float_array) r"""Conjugate momenta ($v_x$, $v_y$, $v_z$).""" - t: BatchVecTime = eqx.field(converter=converter_float_array) + # TODO: consider how this should be vectorized + t: VecTime = eqx.field(converter=converter_float_array) """Array of times corresponding to the positions.""" potential: AbstractPotentialBase @@ -45,7 +45,7 @@ def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]: """Batch, component shape.""" qbatch, qshape = batched_shape(self.q, expect_ndim=1) pbatch, pshape = batched_shape(self.p, expect_ndim=1) - tbatch, _ = batched_shape(self.t, expect_ndim=0) + tbatch, _ = batched_shape(self.t, expect_ndim=1) batch_shape = jnp.broadcast_shapes(qbatch, pbatch, tbatch) array_shape = qshape + pshape + (1,) return batch_shape, array_shape diff --git a/src/galax/dynamics/_utils.py b/src/galax/dynamics/_utils.py new file mode 100644 index 00000000..9d6bd752 --- /dev/null +++ b/src/galax/dynamics/_utils.py @@ -0,0 +1,46 @@ +"""galax: Galactic Dynamix in Jax.""" + +__all__: list[str] = [] + +from typing import Any, Protocol, cast, runtime_checkable + +import jax.experimental.array_api as xp +from jaxtyping import Array, Float + + +@runtime_checkable +class Shaped(Protocol): + """Protocol for a shaped object.""" + + shape: tuple[int, ...] + + +def _getitem_time_index_tuple(index: tuple[Any, ...], t: Float[Array, "..."]) -> Any: + """Get the time index from a slice.""" + if len(index) == 0: # slice is an empty tuple + return slice(None) + if t.ndim == 1: # slicing a Vec1 + return slice(None) + if len(index) >= t.ndim: + msg = f"Index {index} has too many dimensions for time array of shape {t.shape}" + raise IndexError(msg) + return index + + +def _getitem_time_index_shaped(index: Shaped, t: Float[Array, "..."]) -> Shaped: + """Get the time index from a slice.""" + if t.ndim == 1: # Vec1 + return cast(Shaped, xp.asarray([True])) + if len(index.shape) >= t.ndim: + msg = f"Index {index} has too many dimensions for time array of shape {t.shape}" + raise IndexError(msg) + return index + + +def getitem_time_index(index: Any, t: Float[Array, "..."]) -> Any: + """Get the time index from an index.""" + if isinstance(index, tuple): + return _getitem_time_index_tuple(index, t) + if isinstance(index, Shaped): + return _getitem_time_index_shaped(index, t) + return index diff --git a/src/galax/dynamics/mockstream/_core.py b/src/galax/dynamics/mockstream/_core.py index efe53480..435d2ac4 100644 --- a/src/galax/dynamics/mockstream/_core.py +++ b/src/galax/dynamics/mockstream/_core.py @@ -8,10 +8,9 @@ import equinox as eqx import jax import jax.numpy as jnp -from jaxtyping import Array, Float from galax.dynamics._base import AbstractPhaseSpacePosition -from galax.typing import BatchFloatScalar, VecTime +from galax.typing import BatchFloatScalar, BroadBatchVec3, VecTime from galax.utils._shape import batched_shape from galax.utils.dataclasses import converter_float_array @@ -32,19 +31,12 @@ class MockStream(AbstractPhaseSpacePosition): Array of times corresponding to the positions. release_time : Array[float, (*batch,)] Release time of the stream particles [Myr]. - - Todo: - ---- - - units stuff - - change this to be a collection of sub-objects: progenitor, leading arm, - trailing arm, 3-body ejecta, etc. - - GR 4-vector stuff """ - q: Float[Array, "*batch time 3"] = eqx.field(converter=converter_float_array) + q: BroadBatchVec3 = eqx.field(converter=converter_float_array) """Positions (x, y, z).""" - p: Float[Array, "*batch time 3"] = eqx.field(converter=converter_float_array) + p: BroadBatchVec3 = eqx.field(converter=converter_float_array) r"""Conjugate momenta (v_x, v_y, v_z).""" t: VecTime = eqx.field(converter=converter_float_array) @@ -53,6 +45,9 @@ class MockStream(AbstractPhaseSpacePosition): release_time: VecTime = eqx.field(converter=converter_float_array) """Release time of the stream particles [Myr].""" + # ========================================================================== + # Array properties + @property def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]: """Batch .""" diff --git a/src/galax/potential/_potential/base.py b/src/galax/potential/_potential/base.py index c57a4412..fa452956 100644 --- a/src/galax/potential/_potential/base.py +++ b/src/galax/potential/_potential/base.py @@ -411,7 +411,7 @@ def integrate_orbit( >>> orbit = potential.integrate_orbit(xv0, ts) >>> orbit Orbit( - q=f64[2,10,3], p=f64[2,10,3], t=f64[2,10], potential=KeplerPotential(...) + q=f64[2,10,3], p=f64[2,10,3], t=f64[10], potential=KeplerPotential(...) ) """ # TODO: ꜛ get NORMALIZE_WHITESPACE to work correctly so Orbit is 1 line @@ -421,4 +421,4 @@ def integrate_orbit( ws = integrator_(self._integrator_F, w0, t) # TODO: ꜛ reduce repeat dimensions of `time`. - return Orbit(q=ws[..., 0:3], p=ws[..., 3:6], t=ws[..., -1], potential=self) + return Orbit(q=ws[..., 0:3], p=ws[..., 3:6], t=t, potential=self) diff --git a/src/galax/typing.py b/src/galax/typing.py index 93dbbd87..054ffadb 100644 --- a/src/galax/typing.py +++ b/src/galax/typing.py @@ -77,6 +77,7 @@ BatchableIntLike = BatchIntScalar | IntLike +BroadBatchFloatScalar = Shaped[FloatScalar, "*#batch"] BatchFloatScalar = Shaped[FloatScalar, "*batch"] BatchableFloatLike = BatchFloatScalar | FloatLike diff --git a/tests/unit/dynamics/test_core.py b/tests/unit/dynamics/test_core.py index 5dbcd2e2..8976083f 100644 --- a/tests/unit/dynamics/test_core.py +++ b/tests/unit/dynamics/test_core.py @@ -12,19 +12,62 @@ class TestPhaseSpacePosition: """Test :class:`~galax.dynamics.PhaseSpacePosition`.""" + def test_slice(self) -> None: + """Test slicing.""" + _, *_subkeys = random.split(random.PRNGKey(0), num=9) + subkeys = iter(_subkeys) + + # Simple + x = random.uniform(next(subkeys), shape=(10, 3)) + v = random.uniform(next(subkeys), shape=(10, 3)) + o = PhaseSpacePosition(x, v) + new_o = o[:5] + assert new_o.shape == (5,) + + # 1d slice on 3d + x = random.uniform(next(subkeys), shape=(10, 8, 3)) + v = random.uniform(next(subkeys), shape=(10, 8, 3)) + o = PhaseSpacePosition(x, v) + new_o = o[:5] + assert new_o.shape == (5, 8) + + # 3d slice on 3d + o = PhaseSpacePosition(x, v) + new_o = o[:5, :4] + assert new_o.shape == (5, 4) + + # Boolean array + x = random.uniform(next(subkeys), shape=(10, 3)) + v = random.uniform(next(subkeys), shape=(10, 3)) + o = PhaseSpacePosition(x, v) + ix = xp.asarray([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]).astype(bool) + new_o = o[ix] + assert new_o.shape == (sum(ix),) + + # Integer array + x = random.uniform(next(subkeys), shape=(10, 3)) + v = random.uniform(next(subkeys), shape=(10, 3)) + o = PhaseSpacePosition(x, v) + ix = xp.asarray([0, 3, 5]) + new_o = o[ix] + assert new_o.shape == (len(ix),) + + # ------------------------------------------------------------------------ + def test_len(self) -> None: """Test length.""" - _, *subkeys = random.split(random.PRNGKey(0), num=5) + _, *_subkeys = random.split(random.PRNGKey(0), num=5) + subkeys = iter(_subkeys) # Simple - q = random.uniform(subkeys[0], shape=(10, 3)) - p = random.uniform(subkeys[1], shape=(10, 3)) + q = random.uniform(next(subkeys), shape=(10, 3)) + p = random.uniform(next(subkeys), shape=(10, 3)) psp = PhaseSpacePosition(q, p) assert len(psp) == 10 # Complex shape - q = random.uniform(subkeys[2], shape=(4, 10, 3)) - p = random.uniform(subkeys[3], shape=(4, 10, 3)) + q = random.uniform(next(subkeys), shape=(4, 10, 3)) + p = random.uniform(next(subkeys), shape=(4, 10, 3)) psp = PhaseSpacePosition(q, p) assert len(psp) == 4 @@ -32,10 +75,11 @@ def test_len(self) -> None: def test_w(self) -> None: """Test :attr:`~galax.dynamics.PhaseSpacePosition.w`.""" - _, *subkeys = random.split(random.PRNGKey(0), num=3) + _, *_subkeys = random.split(random.PRNGKey(0), num=3) + subkeys = iter(_subkeys) - q = random.uniform(subkeys[0], shape=(10, 3)) - p = random.uniform(subkeys[1], shape=(10, 3)) + q = random.uniform(next(subkeys), shape=(10, 3)) + p = random.uniform(next(subkeys), shape=(10, 3)) psp = PhaseSpacePosition(q, p) # units = None @@ -50,11 +94,12 @@ def test_w(self) -> None: # `wt()` def test_wt_notime(self) -> None: - """Test :attr:`~galax.dynamics.core.PhaseSpacePosition.wt`.""" - _, *subkeys = random.split(random.PRNGKey(0), num=3) + """Test :attr:`~galax.dynamics.PhaseSpacePosition.wt`.""" + _, *_subkeys = random.split(random.PRNGKey(0), num=3) + subkeys = iter(_subkeys) - q = random.uniform(subkeys[0], shape=(10, 3)) - p = random.uniform(subkeys[1], shape=(10, 3)) + q = random.uniform(next(subkeys), shape=(10, 3)) + p = random.uniform(next(subkeys), shape=(10, 3)) psp = PhaseSpacePosition(q, p) # units = None @@ -67,12 +112,13 @@ def test_wt_notime(self) -> None: _ = psp.wt(units=galactic) def test_wt_time(self) -> None: - """Test :attr:`~galax.dynamics.core.AbstractPhaseSpacePositionBase.wt`.""" - _, *subkeys = random.split(random.PRNGKey(0), num=4) + """Test :attr:`~galax.dynamics.PhaseSpacePosition.wt`.""" + _, *_subkeys = random.split(random.PRNGKey(0), num=4) + subkeys = iter(_subkeys) - q = random.uniform(subkeys[0], shape=(10, 3)) - p = random.uniform(subkeys[1], shape=(10, 3)) - t = random.uniform(subkeys[2], shape=(10, 1)) + q = random.uniform(next(subkeys), shape=(10, 3)) + p = random.uniform(next(subkeys), shape=(10, 3)) + t = random.uniform(next(subkeys), shape=(10, 1)) psp = PhaseSpacePosition(q, p, t=t) # units = None