Skip to content

Commit

Permalink
feat: interpolated integration
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Apr 2, 2024
1 parent 5a3f7a7 commit 08aa512
Show file tree
Hide file tree
Showing 10 changed files with 279 additions and 125 deletions.
3 changes: 2 additions & 1 deletion src/galax/coordinates/_psp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@

from . import operator_compat # noqa: F401
from .base import AbstractPhaseSpacePosition
from .psp import InterpolatedPhaseSpacePosition, PhaseSpacePosition
from .interpolated import InterpolatedPhaseSpacePosition
from .psp import PhaseSpacePosition
85 changes: 85 additions & 0 deletions src/galax/coordinates/_psp/interpolated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""galax: Galactic Dynamix in Jax."""

__all__ = ["InterpolatedPhaseSpacePosition"]

from typing import Protocol, final, runtime_checkable

import equinox as eqx
import jax.numpy as jnp

from coordinax import Abstract3DVector, Abstract3DVectorDifferential
from unxt import AbstractUnitSystem, Quantity

import galax.typing as gt
from .base import AbstractPhaseSpacePosition
from .utils import _p_converter, _q_converter
from galax.utils._shape import batched_shape, expand_batch_dims, vector_batched_shape


@runtime_checkable
class Interpolation(Protocol):
"""Protocol for interpolating phase-space positions."""

units: AbstractUnitSystem

def __call__(self, t: gt.VecTime) -> gt.BatchVecTime6:
pass


@final
class InterpolatedPhaseSpacePosition(AbstractPhaseSpacePosition):
"""Interpolated phase-space position."""

q: Abstract3DVector = eqx.field(converter=_q_converter)
"""Positions, e.g Cartesian3DVector.
This is a 3-vector with a batch shape allowing for vector inputs.
"""

p: Abstract3DVectorDifferential = eqx.field(converter=_p_converter)
r"""Conjugate momenta, e.g. CartesianDifferential3D.
This is a 3-vector with a batch shape allowing for vector inputs.
"""

t: gt.BroadBatchFloatQScalar | gt.QVec1 = eqx.field(
converter=Quantity["time"].constructor
)
"""The time corresponding to the positions.
This is a Quantity with the same batch shape as the positions and
velocities. If `t` is a scalar it will be broadcast to the same batch shape
as `q` and `p`.
"""

interpolation: Interpolation
"""The interpolation function."""

def __post_init__(self) -> None:
"""Post-initialization."""
# Need to ensure t shape is correct. Can be Vec0.
if self.t.ndim in (0, 1):
t = expand_batch_dims(self.t, ndim=self.q.ndim - self.t.ndim)
object.__setattr__(self, "t", t)

def __call__(self, t: gt.BatchFloatQScalar) -> PhaseSpacePosition:
"""Call the interpolation."""
qp = self.interpolation(t)
units = self.interpolation.units
return PhaseSpacePosition(
q=Quantity(qp[..., 0:3], units["length"]),
p=Quantity(qp[..., 3:6], units["speed"]),
t=t,
)

# ==========================================================================
# Array properties

@property
def _shape_tuple(self) -> tuple[tuple[int, ...], ComponentShapeTuple]:
"""Batch, component shape."""
qbatch, qshape = vector_batched_shape(self.q)
pbatch, pshape = vector_batched_shape(self.p)
tbatch, _ = batched_shape(self.t, expect_ndim=0)
batch_shape = jnp.broadcast_shapes(qbatch, pbatch, tbatch)
return batch_shape, ComponentShapeTuple(q=qshape, p=pshape, t=1)
85 changes: 5 additions & 80 deletions src/galax/coordinates/_psp/psp.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
"""galax: Galactic Dynamix in Jax."""

__all__ = ["PhaseSpacePosition", "InterpolatedPhaseSpacePosition"]
__all__ = ["PhaseSpacePosition"]

from typing import Any, NamedTuple, Protocol, TypeAlias, final, runtime_checkable
from typing import Any, NamedTuple, final

import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Shaped

from coordinax import Abstract3DVector, Abstract3DVectorDifferential
from unxt import AbstractUnitSystem, Quantity
from unxt import Quantity

import galax.typing as gt
from .base import AbstractPhaseSpacePosition
Expand Down Expand Up @@ -134,8 +133,8 @@ class PhaseSpacePosition(AbstractPhaseSpacePosition):
def __post_init__(self) -> None:
"""Post-initialization."""
# Need to ensure t shape is correct. Can be Vec0.
if self.t is not None and self.t.ndim in (0, 1):
t = expand_batch_dims(self.t, ndim=self.q.ndim - self.t.ndim)
if (t := self.t) is not None and t.ndim in (0, 1):
t = expand_batch_dims(t, ndim=self.q.ndim - t.ndim)
object.__setattr__(self, "t", t)

# ==========================================================================
Expand Down Expand Up @@ -163,77 +162,3 @@ def wt(self, *, units: Any) -> gt.BatchVec7:
self.t, self.t is None, "No time defined for phase-space position"
)
return super().wt(units=units)


# ============================================================================

BatchVecTime6: TypeAlias = Shaped[gt.VecTime6, "*batch"]


@runtime_checkable
class Interpolation(Protocol):
"""Protocol for interpolating phase-space positions."""

units: AbstractUnitSystem

def __call__(self, t: gt.VecTime) -> BatchVecTime6:
pass


@final
class InterpolatedPhaseSpacePosition(AbstractPhaseSpacePosition):
"""Interpolated phase-space position."""

q: Abstract3DVector = eqx.field(converter=_q_converter)
"""Positions, e.g Cartesian3DVector.
This is a 3-vector with a batch shape allowing for vector inputs.
"""

p: Abstract3DVectorDifferential = eqx.field(converter=_p_converter)
r"""Conjugate momenta, e.g. CartesianDifferential3D.
This is a 3-vector with a batch shape allowing for vector inputs.
"""

t: gt.BroadBatchFloatQScalar | gt.QVec1 = eqx.field(
converter=Quantity["time"].constructor
)
"""The time corresponding to the positions.
This is a Quantity with the same batch shape as the positions and
velocities. If `t` is a scalar it will be broadcast to the same batch shape
as `q` and `p`.
"""

interpolation: Interpolation
"""The interpolation function."""

def __post_init__(self) -> None:
"""Post-initialization."""
# Need to ensure t shape is correct. Can be Vec0.
if self.t.ndim in (0, 1):
t = expand_batch_dims(self.t, ndim=self.q.ndim - self.t.ndim)
object.__setattr__(self, "t", t)

def __call__(self, t: gt.BatchFloatQScalar) -> PhaseSpacePosition:
"""Call the interpolation."""
qp = self.interpolation(t)
units = self.interpolation.units
return PhaseSpacePosition(
q=Quantity(qp[..., 0:3], units["length"]),
p=Quantity(qp[..., 3:6], units["speed"]),
t=t,
)

# ==========================================================================
# Array properties

@property
def _shape_tuple(self) -> tuple[tuple[int, ...], ComponentShapeTuple]:
"""Batch, component shape."""
qbatch, qshape = vector_batched_shape(self.q)
pbatch, pshape = vector_batched_shape(self.p)
tbatch, _ = batched_shape(self.t, expect_ndim=0)
batch_shape = jnp.broadcast_shapes(qbatch, pbatch, tbatch)
return batch_shape, ComponentShapeTuple(q=qshape, p=pshape, t=1)
9 changes: 5 additions & 4 deletions src/galax/dynamics/_dynamics/integrate/_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
__all__ = ["Integrator"]

from typing import Any, Protocol, TypeAlias, runtime_checkable
from typing import Any, Literal, Protocol, TypeAlias, runtime_checkable

from unxt import AbstractUnitSystem

import galax.coordinates as gc
import galax.typing as gt
from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition
from galax.utils.dataclasses import _DataclassInstance

SaveT: TypeAlias = gt.BatchQVecTime | gt.QVecTime | gt.BatchVecTime | gt.VecTime
Expand Down Expand Up @@ -52,14 +52,15 @@ class Integrator(_DataclassInstance, Protocol):
def __call__(
self,
F: FCallable,
w0: AbstractPhaseSpacePosition | gt.BatchVec6,
w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6,
t0: gt.FloatQScalar | gt.FloatScalar,
t1: gt.FloatQScalar | gt.FloatScalar,
/,
savet: SaveT | None = None,
*,
units: AbstractUnitSystem,
) -> PhaseSpacePosition:
interpolated: Literal[False, True] = False,
) -> gc.PhaseSpacePosition | gc.InterpolatedPhaseSpacePosition:
"""Integrate.
Parameters
Expand Down
8 changes: 5 additions & 3 deletions src/galax/dynamics/_dynamics/integrate/_base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
__all__ = ["AbstractIntegrator"]

import abc
from typing import Literal

import equinox as eqx

from unxt import AbstractUnitSystem

import galax.coordinates as gc
import galax.typing as gt
from ._api import FCallable
from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition


class AbstractIntegrator(eqx.Module, strict=True): # type: ignore[call-arg, misc]
Expand All @@ -28,7 +29,7 @@ class AbstractIntegrator(eqx.Module, strict=True): # type: ignore[call-arg, mis
def __call__(
self,
F: FCallable,
w0: AbstractPhaseSpacePosition | gt.BatchVec6,
w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6,
t0: gt.FloatQScalar | gt.FloatScalar,
t1: gt.FloatQScalar | gt.FloatScalar,
/,
Expand All @@ -37,7 +38,8 @@ def __call__(
) = None,
*,
units: AbstractUnitSystem,
) -> PhaseSpacePosition:
interpolated: Literal[False, True] = False,
) -> gc.PhaseSpacePosition | gc.InterpolatedPhaseSpacePosition:
"""Run the integrator.
Parameters
Expand Down
Loading

0 comments on commit 08aa512

Please sign in to comment.