Skip to content

Commit

Permalink
feat(PhaseSpacePosition): add slicing via getitem
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 29, 2024
1 parent e86e23e commit c30a7e7
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 40 deletions.
15 changes: 14 additions & 1 deletion src/galax/dynamics/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
__all__ = ["AbstractPhaseSpacePosition"]

from abc import abstractmethod
from dataclasses import replace
from functools import partial
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import equinox as eqx
import jax
Expand All @@ -16,7 +17,11 @@
from galax.units import UnitSystem
from galax.utils._shape import atleast_batched

from ._utils import getitem_time_index

if TYPE_CHECKING:
from typing import Self

from galax.potential._potential.base import AbstractPotentialBase


Expand Down Expand Up @@ -57,6 +62,14 @@ def __len__(self) -> int:
"""Return the number of particles."""
return self.shape[0]

def __getitem__(self, index: Any) -> "Self":
"""Return a new object with the given slice applied."""
# This is the default implementation, but subclasses can override this
# Compute subindex
subindex = getitem_time_index(index, self.t)
# Apply slice
return replace(self, q=self.q[index], p=self.p[index], t=self.t[subindex])

# ==========================================================================

@property
Expand Down
6 changes: 3 additions & 3 deletions src/galax/dynamics/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jax
import jax.numpy as jnp

from galax.typing import BatchFloatScalar, BroadBatchVec1, BroadBatchVec3, Vec1
from galax.typing import BatchFloatScalar, BroadBatchFloatScalar, BroadBatchVec3, Vec1
from galax.utils._shape import batched_shape, expand_batch_dims
from galax.utils.dataclasses import converter_float_array

Expand All @@ -35,7 +35,7 @@ class PhaseSpacePosition(AbstractPhaseSpacePosition):
This is a 3-vector with a batch shape allowing for vector inputs.
"""

t: BroadBatchVec1 | Vec1 = eqx.field(
t: BroadBatchFloatScalar | Vec1 = eqx.field(
default=(0.0,), converter=converter_float_array
)
"""The time corresponding to the positions.
Expand All @@ -47,7 +47,7 @@ class PhaseSpacePosition(AbstractPhaseSpacePosition):

def __post_init__(self) -> None:
"""Post-initialization."""
# Need to ensure t shape is correct
# Need to ensure t shape is correct. Can be Vec0.
if self.t.ndim == 0:
t = expand_batch_dims(self.t, ndim=self.q.ndim)
object.__setattr__(self, "t", t)
Expand Down
12 changes: 6 additions & 6 deletions src/galax/dynamics/_orbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

from galax.potential._potential.base import AbstractPotentialBase
from galax.typing import BatchFloatScalar, BatchVecTime
from galax.typing import BatchFloatScalar, BroadBatchVec3, VecTime
from galax.utils._shape import batched_shape
from galax.utils.dataclasses import converter_float_array

Expand All @@ -25,13 +24,14 @@ class Orbit(AbstractPhaseSpacePosition):
"""

q: Float[Array, "*batch time 3"] = eqx.field(converter=converter_float_array)
q: BroadBatchVec3 = eqx.field(converter=converter_float_array)
"""Positions (x, y, z)."""

p: Float[Array, "*batch time 3"] = eqx.field(converter=converter_float_array)
p: BroadBatchVec3 = eqx.field(converter=converter_float_array)
r"""Conjugate momenta ($v_x$, $v_y$, $v_z$)."""

t: BatchVecTime = eqx.field(converter=converter_float_array)
# TODO: consider how this should be vectorized
t: VecTime = eqx.field(converter=converter_float_array)
"""Array of times corresponding to the positions."""

potential: AbstractPotentialBase
Expand All @@ -45,7 +45,7 @@ def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]:
"""Batch, component shape."""
qbatch, qshape = batched_shape(self.q, expect_ndim=1)
pbatch, pshape = batched_shape(self.p, expect_ndim=1)
tbatch, _ = batched_shape(self.t, expect_ndim=0)
tbatch, _ = batched_shape(self.t, expect_ndim=1)
batch_shape = jnp.broadcast_shapes(qbatch, pbatch, tbatch)
array_shape = qshape + pshape + (1,)
return batch_shape, array_shape
Expand Down
46 changes: 46 additions & 0 deletions src/galax/dynamics/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""galax: Galactic Dynamix in Jax."""

__all__: list[str] = []

from typing import Any, Protocol, cast, runtime_checkable

import jax.experimental.array_api as xp
from jaxtyping import Array, Float


@runtime_checkable
class Shaped(Protocol):
"""Protocol for a shaped object."""

shape: tuple[int, ...]


def _getitem_time_index_tuple(index: tuple[Any, ...], t: Float[Array, "..."]) -> Any:
"""Get the time index from a slice."""
if len(index) == 0: # slice is an empty tuple
return slice(None)
if t.ndim == 1: # slicing a Vec1
return slice(None)
if len(index) >= t.ndim:
msg = f"Index {index} has too many dimensions for time array of shape {t.shape}"
raise IndexError(msg)
return index


def _getitem_time_index_shaped(index: Shaped, t: Float[Array, "..."]) -> Shaped:
"""Get the time index from a slice."""
if t.ndim == 1: # Vec1
return cast(Shaped, xp.asarray([True]))
if len(index.shape) >= t.ndim:
msg = f"Index {index} has too many dimensions for time array of shape {t.shape}"
raise IndexError(msg)
return index


def getitem_time_index(index: Any, t: Float[Array, "..."]) -> Any:
"""Get the time index from an index."""
if isinstance(index, tuple):
return _getitem_time_index_tuple(index, t)
if isinstance(index, Shaped):
return _getitem_time_index_shaped(index, t)
return index
17 changes: 6 additions & 11 deletions src/galax/dynamics/mockstream/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

from galax.dynamics._base import AbstractPhaseSpacePosition
from galax.typing import BatchFloatScalar, VecTime
from galax.typing import BatchFloatScalar, BroadBatchVec3, VecTime
from galax.utils._shape import batched_shape
from galax.utils.dataclasses import converter_float_array

Expand All @@ -32,19 +31,12 @@ class MockStream(AbstractPhaseSpacePosition):
Array of times corresponding to the positions.
release_time : Array[float, (*batch,)]
Release time of the stream particles [Myr].
Todo:
----
- units stuff
- change this to be a collection of sub-objects: progenitor, leading arm,
trailing arm, 3-body ejecta, etc.
- GR 4-vector stuff
"""

q: Float[Array, "*batch time 3"] = eqx.field(converter=converter_float_array)
q: BroadBatchVec3 = eqx.field(converter=converter_float_array)
"""Positions (x, y, z)."""

p: Float[Array, "*batch time 3"] = eqx.field(converter=converter_float_array)
p: BroadBatchVec3 = eqx.field(converter=converter_float_array)
r"""Conjugate momenta (v_x, v_y, v_z)."""

t: VecTime = eqx.field(converter=converter_float_array)
Expand All @@ -53,6 +45,9 @@ class MockStream(AbstractPhaseSpacePosition):
release_time: VecTime = eqx.field(converter=converter_float_array)
"""Release time of the stream particles [Myr]."""

# ==========================================================================
# Array properties

@property
def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]:
"""Batch ."""
Expand Down
4 changes: 2 additions & 2 deletions src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def integrate_orbit(
>>> orbit = potential.integrate_orbit(xv0, ts)
>>> orbit
Orbit(
q=f64[2,10,3], p=f64[2,10,3], t=f64[2,10], potential=KeplerPotential(...)
q=f64[2,10,3], p=f64[2,10,3], t=f64[10], potential=KeplerPotential(...)
)
"""
# TODO: ꜛ get NORMALIZE_WHITESPACE to work correctly so Orbit is 1 line
Expand All @@ -421,4 +421,4 @@ def integrate_orbit(

ws = integrator_(self._integrator_F, w0, t)
# TODO: ꜛ reduce repeat dimensions of `time`.
return Orbit(q=ws[..., 0:3], p=ws[..., 3:6], t=ws[..., -1], potential=self)
return Orbit(q=ws[..., 0:3], p=ws[..., 3:6], t=t, potential=self)
1 change: 1 addition & 0 deletions src/galax/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@

BatchableIntLike = BatchIntScalar | IntLike

BroadBatchFloatScalar = Shaped[FloatScalar, "*#batch"]
BatchFloatScalar = Shaped[FloatScalar, "*batch"]

BatchableFloatLike = BatchFloatScalar | FloatLike
Expand Down
80 changes: 63 additions & 17 deletions tests/unit/dynamics/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,74 @@
class TestPhaseSpacePosition:
"""Test :class:`~galax.dynamics.PhaseSpacePosition`."""

def test_slice(self) -> None:
"""Test slicing."""
_, *_subkeys = random.split(random.PRNGKey(0), num=9)
subkeys = iter(_subkeys)

# Simple
x = random.uniform(next(subkeys), shape=(10, 3))
v = random.uniform(next(subkeys), shape=(10, 3))
o = PhaseSpacePosition(x, v)
new_o = o[:5]
assert new_o.shape == (5,)

# 1d slice on 3d
x = random.uniform(next(subkeys), shape=(10, 8, 3))
v = random.uniform(next(subkeys), shape=(10, 8, 3))
o = PhaseSpacePosition(x, v)
new_o = o[:5]
assert new_o.shape == (5, 8)

# 3d slice on 3d
o = PhaseSpacePosition(x, v)
new_o = o[:5, :4]
assert new_o.shape == (5, 4)

# Boolean array
x = random.uniform(next(subkeys), shape=(10, 3))
v = random.uniform(next(subkeys), shape=(10, 3))
o = PhaseSpacePosition(x, v)
ix = xp.asarray([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]).astype(bool)
new_o = o[ix]
assert new_o.shape == (sum(ix),)

# Integer array
x = random.uniform(next(subkeys), shape=(10, 3))
v = random.uniform(next(subkeys), shape=(10, 3))
o = PhaseSpacePosition(x, v)
ix = xp.asarray([0, 3, 5])
new_o = o[ix]
assert new_o.shape == (len(ix),)

# ------------------------------------------------------------------------

def test_len(self) -> None:
"""Test length."""
_, *subkeys = random.split(random.PRNGKey(0), num=5)
_, *_subkeys = random.split(random.PRNGKey(0), num=5)
subkeys = iter(_subkeys)

# Simple
q = random.uniform(subkeys[0], shape=(10, 3))
p = random.uniform(subkeys[1], shape=(10, 3))
q = random.uniform(next(subkeys), shape=(10, 3))
p = random.uniform(next(subkeys), shape=(10, 3))
psp = PhaseSpacePosition(q, p)
assert len(psp) == 10

# Complex shape
q = random.uniform(subkeys[2], shape=(4, 10, 3))
p = random.uniform(subkeys[3], shape=(4, 10, 3))
q = random.uniform(next(subkeys), shape=(4, 10, 3))
p = random.uniform(next(subkeys), shape=(4, 10, 3))
psp = PhaseSpacePosition(q, p)
assert len(psp) == 4

# ------------------------------------------------------------------------

def test_w(self) -> None:
"""Test :attr:`~galax.dynamics.PhaseSpacePosition.w`."""
_, *subkeys = random.split(random.PRNGKey(0), num=3)
_, *_subkeys = random.split(random.PRNGKey(0), num=3)
subkeys = iter(_subkeys)

q = random.uniform(subkeys[0], shape=(10, 3))
p = random.uniform(subkeys[1], shape=(10, 3))
q = random.uniform(next(subkeys), shape=(10, 3))
p = random.uniform(next(subkeys), shape=(10, 3))
psp = PhaseSpacePosition(q, p)

# units = None
Expand All @@ -50,11 +94,12 @@ def test_w(self) -> None:
# `wt()`

def test_wt_notime(self) -> None:
"""Test :attr:`~galax.dynamics.core.PhaseSpacePosition.wt`."""
_, *subkeys = random.split(random.PRNGKey(0), num=3)
"""Test :attr:`~galax.dynamics.PhaseSpacePosition.wt`."""
_, *_subkeys = random.split(random.PRNGKey(0), num=3)
subkeys = iter(_subkeys)

q = random.uniform(subkeys[0], shape=(10, 3))
p = random.uniform(subkeys[1], shape=(10, 3))
q = random.uniform(next(subkeys), shape=(10, 3))
p = random.uniform(next(subkeys), shape=(10, 3))
psp = PhaseSpacePosition(q, p)

# units = None
Expand All @@ -67,12 +112,13 @@ def test_wt_notime(self) -> None:
_ = psp.wt(units=galactic)

def test_wt_time(self) -> None:
"""Test :attr:`~galax.dynamics.core.AbstractPhaseSpacePositionBase.wt`."""
_, *subkeys = random.split(random.PRNGKey(0), num=4)
"""Test :attr:`~galax.dynamics.PhaseSpacePosition.wt`."""
_, *_subkeys = random.split(random.PRNGKey(0), num=4)
subkeys = iter(_subkeys)

q = random.uniform(subkeys[0], shape=(10, 3))
p = random.uniform(subkeys[1], shape=(10, 3))
t = random.uniform(subkeys[2], shape=(10, 1))
q = random.uniform(next(subkeys), shape=(10, 3))
p = random.uniform(next(subkeys), shape=(10, 3))
t = random.uniform(next(subkeys), shape=(10, 1))
psp = PhaseSpacePosition(q, p, t=t)

# units = None
Expand Down

0 comments on commit c30a7e7

Please sign in to comment.