From 758bf2951d82fa54e3693099254460207d6570ce Mon Sep 17 00:00:00 2001 From: nstarman Date: Wed, 13 Mar 2024 15:26:01 -0400 Subject: [PATCH] feat: interpolated integration Signed-off-by: nstarman --- .../dynamics/_dynamics/integrate/_api.py | 5 +- .../dynamics/_dynamics/integrate/_base.py | 4 +- .../dynamics/_dynamics/integrate/_builtin.py | 169 ++++++++++++++++-- .../dynamics/_dynamics/integrate/_funcs.py | 32 +++- src/galax/potential/_potential/base.py | 15 +- 5 files changed, 199 insertions(+), 26 deletions(-) diff --git a/src/galax/dynamics/_dynamics/integrate/_api.py b/src/galax/dynamics/_dynamics/integrate/_api.py index b233cc32..4b1e29fe 100644 --- a/src/galax/dynamics/_dynamics/integrate/_api.py +++ b/src/galax/dynamics/_dynamics/integrate/_api.py @@ -1,6 +1,6 @@ __all__ = ["Integrator"] -from typing import Any, Protocol, TypeAlias, runtime_checkable +from typing import Any, Literal, Protocol, TypeAlias, runtime_checkable from unxt import AbstractUnitSystem @@ -59,7 +59,8 @@ def __call__( savet: SaveT | None = None, *, units: AbstractUnitSystem, - ) -> gc.PhaseSpacePosition: + interpolated: Literal[False, True] = False, + ) -> gc.PhaseSpacePosition | gc.InterpolatedPhaseSpacePosition: """Integrate. Parameters diff --git a/src/galax/dynamics/_dynamics/integrate/_base.py b/src/galax/dynamics/_dynamics/integrate/_base.py index ce54479b..102cc66e 100644 --- a/src/galax/dynamics/_dynamics/integrate/_base.py +++ b/src/galax/dynamics/_dynamics/integrate/_base.py @@ -1,6 +1,7 @@ __all__ = ["AbstractIntegrator"] import abc +from typing import Literal import equinox as eqx @@ -37,7 +38,8 @@ def __call__( ) = None, *, units: AbstractUnitSystem, - ) -> gc.PhaseSpacePosition: + interpolated: Literal[False, True] = False, + ) -> gc.PhaseSpacePosition | gc.InterpolatedPhaseSpacePosition: """Run the integrator. Parameters diff --git a/src/galax/dynamics/_dynamics/integrate/_builtin.py b/src/galax/dynamics/_dynamics/integrate/_builtin.py index 8d2a7929..2f506d9d 100644 --- a/src/galax/dynamics/_dynamics/integrate/_builtin.py +++ b/src/galax/dynamics/_dynamics/integrate/_builtin.py @@ -4,16 +4,18 @@ from collections.abc import Callable, Mapping from dataclasses import KW_ONLY from functools import partial -from typing import Any, ParamSpec, TypeVar, final +from typing import Any, Literal, ParamSpec, TypeVar, final import diffrax import equinox as eqx import jax import jax.numpy as jnp +from diffrax import DenseInterpolation from jax._src.numpy.vectorize import _parse_gufunc_signature, _parse_input_dimensions +from plum import overload import quaxed.array_api as xp -from unxt import AbstractUnitSystem, Quantity, to_units_value +from unxt import AbstractUnitSystem, Quantity, to_units_value, unitsystem import galax.coordinates as gc import galax.typing as gt @@ -117,7 +119,10 @@ class DiffraxIntegrator(AbstractIntegrator): default=(("scan_kind", "bounded"),), static=True, converter=ImmutableDict ) - @partial(jax.jit, static_argnums=(0, 1)) + # ===================================================== + # Call + + @partial(eqx.filter_jit) def _call_implementation( self, F: FCallable, @@ -126,9 +131,12 @@ def _call_implementation( t1: gt.FloatScalar, ts: gt.BatchVecTime, /, - ) -> tuple[gt.BatchVecTime7, None]: + interpolated: Literal[False, True], + ) -> tuple[gt.BatchVecTime7, DenseInterpolation | None]: # TODO: less awkward munging of the diffrax API kw = dict(self.diffeq_kw) + if interpolated and kw.get("max_steps") is None: + kw.pop("max_steps") terms = diffrax.ODETerm(F) solver = self.Solver(**self.solver_kw) @@ -146,13 +154,13 @@ def solve_diffeq( y0=w0, dt0=None, args=(), - saveat=diffrax.SaveAt(t0=False, t1=False, ts=ts, dense=False), + saveat=diffrax.SaveAt(t0=False, t1=False, ts=ts, dense=interpolated), stepsize_controller=self.stepsize_controller, **kw, ) # Perform the integration - solution = solve_diffeq(w0, t0, t1, ts) + solution = solve_diffeq(w0, t0, t1, jnp.atleast_2d(ts)) # Parse the solution w = jnp.concat((solution.ys, solution.ts[..., None]), axis=-1) @@ -161,6 +169,21 @@ def solve_diffeq( return w, interp + @overload + def __call__( + self, + F: FCallable, + w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6, + t0: gt.FloatQScalar | gt.FloatScalar, + t1: gt.FloatQScalar | gt.FloatScalar, + /, + savet: SaveT | None = None, + *, + units: AbstractUnitSystem, + interpolated: Literal[False] = False, + ) -> gc.PhaseSpacePosition: ... + + @overload def __call__( self, F: FCallable, @@ -171,7 +194,21 @@ def __call__( savet: SaveT | None = None, *, units: AbstractUnitSystem, - ) -> gc.PhaseSpacePosition: + interpolated: Literal[True], + ) -> gc.InterpolatedPhaseSpacePosition: ... + + def __call__( + self, + F: FCallable, + w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6, + t0: gt.FloatQScalar | gt.FloatScalar, + t1: gt.FloatQScalar | gt.FloatScalar, + /, + savet: SaveT | None = None, + *, + units: AbstractUnitSystem, + interpolated: Literal[False, True] = False, + ) -> gc.PhaseSpacePosition | gc.InterpolatedPhaseSpacePosition: """Run the integrator. Parameters @@ -189,6 +226,8 @@ def __call__( units : `unxt.AbstractUnitSystem` The unit system to use. + interpolated : bool, keyword-only + Whether to return an interpolated solution. Returns ------- @@ -261,6 +300,49 @@ def __call__( >>> ws.shape (2,) + A cool feature of the integrator is that it can return an interpolated + solution. + + >>> w = integrator(pot._integrator_F, w0, t0, t1, savet=ts, units=usx.galactic, + ... interpolated=True) + >>> type(w) + + + The interpolated solution can be evaluated at any time in the domain to get + the phase-space position at that time: + + >>> t = Quantity(xp.e, "Gyr") + >>> w(t) + PhaseSpacePosition( + q=Cartesian3DVector( ... ), + p=CartesianDifferential3D( ... ), + t=Quantity[PhysicalType('time')](value=f64[1,1], unit=Unit("Gyr")) + ) + + The interpolant is vectorized: + + >>> t = Quantity(xp.linspace(0, 1, 100), "Gyr") + >>> w(t) + PhaseSpacePosition( + q=Cartesian3DVector( ... ), + p=CartesianDifferential3D( ... ), + t=Quantity[PhysicalType('time')](value=f64[1,100], unit=Unit("Gyr")) + ) + + And it works on batches: + + >>> w0 = gc.PhaseSpacePosition(q=Quantity([[10., 0, 0], [11., 0, 0]], "kpc"), + ... p=Quantity([[0, 200, 0], [0, 210, 0]], "km/s")) + >>> ws = integrator(pot._integrator_F, w0, t0, t1, units=usx.galactic, + ... interpolated=True) + >>> ws.shape + (2,) + >>> w(t) + PhaseSpacePosition( + q=Cartesian3DVector( ... ), + p=CartesianDifferential3D( ... ), + t=Quantity[PhysicalType('time')](value=f64[1,100], unit=Unit("Gyr")) + ) """ # Parse inputs w0_: gt.Vec6 = ( @@ -273,12 +355,73 @@ def __call__( ) # Perform the integration - w, interp = self._call_implementation(F, w0_, t0_, t1_, savet_) - w = w[..., -1, :] if savet is None else w + w, interp = self._call_implementation(F, w0_, t0_, t1_, savet_, interpolated) + w = w[..., -1, :] if savet is None else w # TODO: undo this # Return - return gc.PhaseSpacePosition( # shape = (*batch, T) - q=Quantity(w[..., 0:3], units["length"]), - p=Quantity(w[..., 3:6], units["speed"]), - t=Quantity(w[..., -1], units["time"]), + if interpolated: + # Determine if an extra dimension was added to the output + added_ndim = int(w0_.shape[:-1] == () or w0_.shape[0] == 1) + # If one was, then the interpolant must be reshaped since the input + # was squeezed beforehand and the dimension must be added back. + if added_ndim == 1: + arr, narr = eqx.partition(interp, eqx.is_array) + arr = jax.tree_util.tree_map(lambda x: x[None], arr) + interp = eqx.combine(arr, narr) + + out = gc.InterpolatedPhaseSpacePosition( # shape = (*batch, T) + q=Quantity(w[..., 0:3], units["length"]), + p=Quantity(w[..., 3:6], units["speed"]), + t=Quantity(savet_, units["time"]), + interpolant=DiffraxInterpolant( + interp, units=units, added_ndim=added_ndim + ), + ) + else: + out = gc.PhaseSpacePosition( # shape = (*batch, T) + q=Quantity(w[..., 0:3], units["length"]), + p=Quantity(w[..., 3:6], units["speed"]), + t=Quantity(w[..., -1], units["time"]), + ) + + return out + + +class DiffraxInterpolant(eqx.Module): # type: ignore[misc]# + """Wrapper for ``diffrax.DenseInterpolation``.""" + + interpolant: DenseInterpolation + """:class:`diffrax.DenseInterpolation` object. + + This object is the result of the integration and can be used to evaluate the + interpolated solution at any time. However it does not understand units, so + the input is the time in ``units["time"]``. The output is a 6-vector of + (q, p) values in the units of the integrator. + """ + + units: AbstractUnitSystem = eqx.field(static=True, converter=unitsystem) + """The :class:`unxt.AbstractUnitSystem`. + + This is used to convert the time input to the interpolant and the phase-space + position output. + """ + + added_ndim: tuple[int, ...] = eqx.field(static=True) + """The number of dimensions added to the output of the interpolation. + + This is used to reshape the output of the interpolation to match the batch + shape of the input to the integrator. The means of vectorizing the + interpolation means that the input must always be a batched array, resulting + in an extra dimension when the integration was on a scalar input. + """ + + def __call__(self, t: gt.QVecTime, **_: Any) -> gc.PhaseSpacePosition: + """Evaluate the interpolation.""" + t_ = jnp.atleast_1d(t.to_units_value(self.units["time"])) + ys = jax.vmap(lambda s: jax.vmap(s.evaluate)(t_))(self.interpolant) + ys = ys[(0,) * (ys.ndim - 3 + self.added_ndim)] + return gc.PhaseSpacePosition( + q=Quantity(ys[..., 0:3], self.units["length"]), + p=Quantity(ys[..., 3:6], self.units["speed"]), + t=t, ) diff --git a/src/galax/dynamics/_dynamics/integrate/_funcs.py b/src/galax/dynamics/_dynamics/integrate/_funcs.py index 7721dddf..abc2dddb 100644 --- a/src/galax/dynamics/_dynamics/integrate/_funcs.py +++ b/src/galax/dynamics/_dynamics/integrate/_funcs.py @@ -4,6 +4,7 @@ from dataclasses import replace from functools import partial +from typing import Literal import jax import jax.numpy as jnp @@ -16,7 +17,7 @@ from ._api import Integrator from ._builtin import DiffraxIntegrator from galax.coordinates import PhaseSpacePosition -from galax.dynamics._dynamics.orbit import Orbit +from galax.dynamics._dynamics.orbit import InterpolatedOrbit, Orbit from galax.potential._potential.base import AbstractPotentialBase ############################################################################## @@ -29,20 +30,21 @@ _select_w0 = jnp.vectorize(jax.lax.select, signature="(),(6),(6)->(6)") -@partial(jax.jit, static_argnames=("integrator",)) +@partial(jax.jit, static_argnames=("integrator", "interpolated")) def evaluate_orbit( pot: AbstractPotentialBase, w0: PhaseSpacePosition | gt.BatchVec6, t: gt.QVecTime | gt.VecTime | APYQuantity, *, integrator: Integrator | None = None, -) -> Orbit: + interpolated: Literal[True, False] = False, +) -> Orbit | InterpolatedOrbit: """Compute an orbit in a potential. - :class:`~galax.coordinates.PhaseSpacePosition` includes a time in - addition to the position (and velocity) information, enabling the orbit to - be evaluated over a time range that is different from the initial time of - the position. + :class:`~galax.coordinates.PhaseSpacePosition` includes a time in addition + to the position (and velocity) information, enabling the orbit to be + evaluated over a time range that is different from the initial time of the + position. Parameters ---------- @@ -82,6 +84,10 @@ def evaluate_orbit( is used twice: once to integrate from `w0.t` to `t[0]` and then from `t[0]` to `t[1]`. + interpolated: bool, optional keyword-only + If `True`, return an interpolated orbit. If `False`, return the orbit + at the requested times. Default is `False`. + Returns ------- orbit : :class:`~galax.dynamics.Orbit` @@ -225,7 +231,17 @@ def evaluate_orbit( t[-1], savet=t, units=units, + interpolated=interpolated, ) + wt = t # Construct the orbit object - return Orbit(q=ws.q, p=ws.p, t=t, potential=pot) + # TODO: easier construction from the (Interpolated)PhaseSpacePosition + if interpolated: + out = InterpolatedOrbit( + q=ws.q, p=ws.p, t=wt, interpolant=ws.interpolant, potential=pot + ) + else: + out = Orbit(q=ws.q, p=ws.p, t=wt, potential=pot) + + return out diff --git a/src/galax/potential/_potential/base.py b/src/galax/potential/_potential/base.py index 2075f255..14710141 100644 --- a/src/galax/potential/_potential/base.py +++ b/src/galax/potential/_potential/base.py @@ -4,7 +4,7 @@ from dataclasses import KW_ONLY, fields from functools import partial from types import MappingProxyType -from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeAlias, cast import equinox as eqx import jax @@ -1896,6 +1896,7 @@ def evaluate_orbit( t: gt.QVecTime | gt.VecTime | APYQuantity, # TODO: must be a Quantity *, integrator: "Integrator | None" = None, + interpolated: Literal[True, False] = False, ) -> "Orbit": """Compute an orbit in a potential. @@ -1943,6 +1944,11 @@ def evaluate_orbit( Integrator to use. If `None`, the default integrator :class:`~galax.integrator.DiffraxIntegrator` is used. + interpolated: bool, optional keyword-only + If `True`, return an interpolated orbit. If `False`, return the orbit + at the requested times. Default is `False`. + + Returns ------- orbit : :class:`~galax.dynamics.Orbit` @@ -1956,4 +1962,9 @@ def evaluate_orbit( """ from galax.dynamics import evaluate_orbit - return cast("Orbit", evaluate_orbit(self, w0, t, integrator=integrator)) + return cast( + "Orbit", + evaluate_orbit( + self, w0, t, integrator=integrator, interpolated=interpolated + ), + )