Skip to content

Commit

Permalink
feat: interpolated integration
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Apr 2, 2024
1 parent b99f209 commit 3b51686
Show file tree
Hide file tree
Showing 9 changed files with 270 additions and 64 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ messages_control.disable = [
"protected-access", # ruff SLF001
"unnecessary-ellipsis",
"unnecessary-lambda-assignment", # ruff E731
"unnecessary-pass", # handled by ruff
"wrong-import-position",
"wrong-import-order", # handled by ruff
]
Expand Down
27 changes: 17 additions & 10 deletions src/galax/coordinates/_psp/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,22 @@ class PhaseSpacePositionInterpolant(Protocol):
"""Protocol for interpolating phase-space positions."""

units: AbstractUnitSystem
"""The unit system for the interpolation."""

def __call__(self, t: gt.VecTime) -> gt.BatchVecTime6:
pass
def __call__(self, t: gt.QVecTime) -> PhaseSpacePosition:
"""Evaluate the interpolation.
Parameters
----------
t : Quantity[float, (time,), 'time']
The times at which to evaluate the interpolation.
Returns
-------
:class:`galax.coordinates.PhaseSpacePosition`
The interpolated phase-space positions.
"""
...


@final
Expand Down Expand Up @@ -53,7 +66,7 @@ class InterpolatedPhaseSpacePosition(AbstractPhaseSpacePosition):
as `q` and `p`.
"""

interpolation: PhaseSpacePositionInterpolant
interpolant: PhaseSpacePositionInterpolant
"""The interpolation function."""

def __post_init__(self) -> None:
Expand All @@ -65,13 +78,7 @@ def __post_init__(self) -> None:

def __call__(self, t: gt.BatchFloatQScalar) -> PhaseSpacePosition:
"""Call the interpolation."""
qp = self.interpolation(t)
units = self.interpolation.units
return PhaseSpacePosition(
q=Quantity(qp[..., 0:3], units["length"]),
p=Quantity(qp[..., 3:6], units["speed"]),
t=t,
)
return self.interpolant(t)

# ==========================================================================
# Array properties
Expand Down
9 changes: 5 additions & 4 deletions src/galax/dynamics/_dynamics/integrate/_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
__all__ = ["Integrator"]

from typing import Any, Protocol, TypeAlias, runtime_checkable
from typing import Any, Literal, Protocol, TypeAlias, runtime_checkable

from unxt import AbstractUnitSystem

import galax.coordinates as gc
import galax.typing as gt
from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition
from galax.utils.dataclasses import _DataclassInstance

SaveT: TypeAlias = gt.BatchQVecTime | gt.QVecTime | gt.BatchVecTime | gt.VecTime
Expand Down Expand Up @@ -52,14 +52,15 @@ class Integrator(_DataclassInstance, Protocol):
def __call__(
self,
F: FCallable,
w0: AbstractPhaseSpacePosition | gt.BatchVec6,
w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6,
t0: gt.FloatQScalar | gt.FloatScalar,
t1: gt.FloatQScalar | gt.FloatScalar,
/,
savet: SaveT | None = None,
*,
units: AbstractUnitSystem,
) -> PhaseSpacePosition:
interpolated: Literal[False, True] = False,
) -> gc.PhaseSpacePosition | gc.InterpolatedPhaseSpacePosition:
"""Integrate.
Parameters
Expand Down
8 changes: 5 additions & 3 deletions src/galax/dynamics/_dynamics/integrate/_base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
__all__ = ["AbstractIntegrator"]

import abc
from typing import Literal

import equinox as eqx

from unxt import AbstractUnitSystem

import galax.coordinates as gc
import galax.typing as gt
from ._api import FCallable
from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition


class AbstractIntegrator(eqx.Module, strict=True): # type: ignore[call-arg, misc]
Expand All @@ -28,7 +29,7 @@ class AbstractIntegrator(eqx.Module, strict=True): # type: ignore[call-arg, mis
def __call__(
self,
F: FCallable,
w0: AbstractPhaseSpacePosition | gt.BatchVec6,
w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6,
t0: gt.FloatQScalar | gt.FloatScalar,
t1: gt.FloatQScalar | gt.FloatScalar,
/,
Expand All @@ -37,7 +38,8 @@ def __call__(
) = None,
*,
units: AbstractUnitSystem,
) -> PhaseSpacePosition:
interpolated: Literal[False, True] = False,
) -> gc.PhaseSpacePosition | gc.InterpolatedPhaseSpacePosition:
"""Run the integrator.
Parameters
Expand Down
Loading

0 comments on commit 3b51686

Please sign in to comment.