Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Orbit interpolation #212

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/galax/dynamics/_dynamics/integrate/_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ["Integrator"]

from typing import Any, Protocol, TypeAlias, runtime_checkable
from typing import Any, Literal, Protocol, TypeAlias, runtime_checkable

from unxt import AbstractUnitSystem

Expand Down Expand Up @@ -59,7 +59,8 @@ def __call__(
savet: SaveT | None = None,
*,
units: AbstractUnitSystem,
) -> gc.PhaseSpacePosition:
interpolated: Literal[False, True] = False,
) -> gc.PhaseSpacePosition | gc.InterpolatedPhaseSpacePosition:
"""Integrate.

Parameters
Expand Down
4 changes: 3 additions & 1 deletion src/galax/dynamics/_dynamics/integrate/_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = ["AbstractIntegrator"]

import abc
from typing import Literal

import equinox as eqx

Expand Down Expand Up @@ -37,7 +38,8 @@ def __call__(
) = None,
*,
units: AbstractUnitSystem,
) -> gc.PhaseSpacePosition:
interpolated: Literal[False, True] = False,
) -> gc.PhaseSpacePosition | gc.InterpolatedPhaseSpacePosition:
"""Run the integrator.

Parameters
Expand Down
169 changes: 156 additions & 13 deletions src/galax/dynamics/_dynamics/integrate/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
from collections.abc import Callable, Mapping
from dataclasses import KW_ONLY
from functools import partial
from typing import Any, ParamSpec, TypeVar, final
from typing import Any, Literal, ParamSpec, TypeVar, final

import diffrax
import equinox as eqx
import jax
import jax.numpy as jnp
from diffrax import DenseInterpolation
from jax._src.numpy.vectorize import _parse_gufunc_signature, _parse_input_dimensions
from plum import overload

import quaxed.array_api as xp
from unxt import AbstractUnitSystem, Quantity, to_units_value
from unxt import AbstractUnitSystem, Quantity, to_units_value, unitsystem

import galax.coordinates as gc
import galax.typing as gt
Expand Down Expand Up @@ -117,7 +119,10 @@
default=(("scan_kind", "bounded"),), static=True, converter=ImmutableDict
)

@partial(jax.jit, static_argnums=(0, 1))
# =====================================================
# Call

@partial(eqx.filter_jit)
def _call_implementation(
self,
F: FCallable,
Expand All @@ -126,9 +131,12 @@
t1: gt.FloatScalar,
ts: gt.BatchVecTime,
/,
) -> tuple[gt.BatchVecTime7, None]:
interpolated: Literal[False, True],
) -> tuple[gt.BatchVecTime7, DenseInterpolation | None]:
# TODO: less awkward munging of the diffrax API
kw = dict(self.diffeq_kw)
if interpolated and kw.get("max_steps") is None:
kw.pop("max_steps")

terms = diffrax.ODETerm(F)
solver = self.Solver(**self.solver_kw)
Expand All @@ -146,13 +154,13 @@
y0=w0,
dt0=None,
args=(),
saveat=diffrax.SaveAt(t0=False, t1=False, ts=ts, dense=False),
saveat=diffrax.SaveAt(t0=False, t1=False, ts=ts, dense=interpolated),
stepsize_controller=self.stepsize_controller,
**kw,
)

# Perform the integration
solution = solve_diffeq(w0, t0, t1, ts)
solution = solve_diffeq(w0, t0, t1, jnp.atleast_2d(ts))

# Parse the solution
w = jnp.concat((solution.ys, solution.ts[..., None]), axis=-1)
Expand All @@ -161,6 +169,21 @@

return w, interp

@overload
def __call__(
self,
F: FCallable,
w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6,
t0: gt.FloatQScalar | gt.FloatScalar,
t1: gt.FloatQScalar | gt.FloatScalar,
/,
savet: SaveT | None = None,
*,
units: AbstractUnitSystem,
interpolated: Literal[False] = False,
) -> gc.PhaseSpacePosition: ...

@overload
def __call__(
self,
F: FCallable,
Expand All @@ -171,7 +194,21 @@
savet: SaveT | None = None,
*,
units: AbstractUnitSystem,
) -> gc.PhaseSpacePosition:
interpolated: Literal[True],
) -> gc.InterpolatedPhaseSpacePosition: ...

def __call__(
self,
F: FCallable,
w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6,
t0: gt.FloatQScalar | gt.FloatScalar,
t1: gt.FloatQScalar | gt.FloatScalar,
/,
savet: SaveT | None = None,
*,
units: AbstractUnitSystem,
interpolated: Literal[False, True] = False,
) -> gc.PhaseSpacePosition | gc.InterpolatedPhaseSpacePosition:
"""Run the integrator.

Parameters
Expand All @@ -189,6 +226,8 @@

units : `unxt.AbstractUnitSystem`
The unit system to use.
interpolated : bool, keyword-only
Whether to return an interpolated solution.

Returns
-------
Expand Down Expand Up @@ -261,6 +300,49 @@
>>> ws.shape
(2,)

A cool feature of the integrator is that it can return an interpolated
solution.

>>> w = integrator(pot._integrator_F, w0, t0, t1, savet=ts, units=usx.galactic,
... interpolated=True)
>>> type(w)
<class 'galax.coordinates...InterpolatedPhaseSpacePosition'>

The interpolated solution can be evaluated at any time in the domain to get
the phase-space position at that time:

>>> t = Quantity(xp.e, "Gyr")
>>> w(t)
PhaseSpacePosition(
q=Cartesian3DVector( ... ),
p=CartesianDifferential3D( ... ),
t=Quantity[PhysicalType('time')](value=f64[1,1], unit=Unit("Gyr"))
)

The interpolant is vectorized:

>>> t = Quantity(xp.linspace(0, 1, 100), "Gyr")
>>> w(t)
PhaseSpacePosition(
q=Cartesian3DVector( ... ),
p=CartesianDifferential3D( ... ),
t=Quantity[PhysicalType('time')](value=f64[1,100], unit=Unit("Gyr"))
)

And it works on batches:

>>> w0 = gc.PhaseSpacePosition(q=Quantity([[10., 0, 0], [11., 0, 0]], "kpc"),
... p=Quantity([[0, 200, 0], [0, 210, 0]], "km/s"))
>>> ws = integrator(pot._integrator_F, w0, t0, t1, units=usx.galactic,
... interpolated=True)
>>> ws.shape
(2,)
>>> w(t)
PhaseSpacePosition(
q=Cartesian3DVector( ... ),
p=CartesianDifferential3D( ... ),
t=Quantity[PhysicalType('time')](value=f64[1,100], unit=Unit("Gyr"))
)
"""
# Parse inputs
w0_: gt.Vec6 = (
Expand All @@ -273,12 +355,73 @@
)

# Perform the integration
w, interp = self._call_implementation(F, w0_, t0_, t1_, savet_)
w = w[..., -1, :] if savet is None else w
w, interp = self._call_implementation(F, w0_, t0_, t1_, savet_, interpolated)
w = w[..., -1, :] if savet is None else w # TODO: undo this

# Return
return gc.PhaseSpacePosition( # shape = (*batch, T)
q=Quantity(w[..., 0:3], units["length"]),
p=Quantity(w[..., 3:6], units["speed"]),
t=Quantity(w[..., -1], units["time"]),
if interpolated:
# Determine if an extra dimension was added to the output
added_ndim = int(w0_.shape[:-1] == () or w0_.shape[0] == 1)
# If one was, then the interpolant must be reshaped since the input
# was squeezed beforehand and the dimension must be added back.
if added_ndim == 1:
arr, narr = eqx.partition(interp, eqx.is_array)
arr = jax.tree_util.tree_map(lambda x: x[None], arr)
interp = eqx.combine(arr, narr)

Check warning on line 370 in src/galax/dynamics/_dynamics/integrate/_builtin.py

View check run for this annotation

Codecov / codecov/patch

src/galax/dynamics/_dynamics/integrate/_builtin.py#L368-L370

Added lines #L368 - L370 were not covered by tests

out = gc.InterpolatedPhaseSpacePosition( # shape = (*batch, T)
q=Quantity(w[..., 0:3], units["length"]),
p=Quantity(w[..., 3:6], units["speed"]),
t=Quantity(savet_, units["time"]),
interpolant=DiffraxInterpolant(
interp, units=units, added_ndim=added_ndim
),
)
else:
out = gc.PhaseSpacePosition( # shape = (*batch, T)
q=Quantity(w[..., 0:3], units["length"]),
p=Quantity(w[..., 3:6], units["speed"]),
t=Quantity(w[..., -1], units["time"]),
)

return out


class DiffraxInterpolant(eqx.Module): # type: ignore[misc]#
"""Wrapper for ``diffrax.DenseInterpolation``."""

interpolant: DenseInterpolation
""":class:`diffrax.DenseInterpolation` object.

This object is the result of the integration and can be used to evaluate the
interpolated solution at any time. However it does not understand units, so
the input is the time in ``units["time"]``. The output is a 6-vector of
(q, p) values in the units of the integrator.
"""

units: AbstractUnitSystem = eqx.field(static=True, converter=unitsystem)
"""The :class:`unxt.AbstractUnitSystem`.

This is used to convert the time input to the interpolant and the phase-space
position output.
"""

added_ndim: tuple[int, ...] = eqx.field(static=True)
"""The number of dimensions added to the output of the interpolation.

This is used to reshape the output of the interpolation to match the batch
shape of the input to the integrator. The means of vectorizing the
interpolation means that the input must always be a batched array, resulting
in an extra dimension when the integration was on a scalar input.
"""

def __call__(self, t: gt.QVecTime, **_: Any) -> gc.PhaseSpacePosition:
"""Evaluate the interpolation."""
t_ = jnp.atleast_1d(t.to_units_value(self.units["time"]))
ys = jax.vmap(lambda s: jax.vmap(s.evaluate)(t_))(self.interpolant)
ys = ys[(0,) * (ys.ndim - 3 + self.added_ndim)]
return gc.PhaseSpacePosition(
q=Quantity(ys[..., 0:3], self.units["length"]),
p=Quantity(ys[..., 3:6], self.units["speed"]),
t=t,
)
32 changes: 24 additions & 8 deletions src/galax/dynamics/_dynamics/integrate/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from dataclasses import replace
from functools import partial
from typing import Literal

import jax
import jax.numpy as jnp
Expand All @@ -16,7 +17,7 @@
from ._api import Integrator
from ._builtin import DiffraxIntegrator
from galax.coordinates import PhaseSpacePosition
from galax.dynamics._dynamics.orbit import Orbit
from galax.dynamics._dynamics.orbit import InterpolatedOrbit, Orbit
from galax.potential._potential.base import AbstractPotentialBase

##############################################################################
Expand All @@ -29,20 +30,21 @@
_select_w0 = jnp.vectorize(jax.lax.select, signature="(),(6),(6)->(6)")


@partial(jax.jit, static_argnames=("integrator",))
@partial(jax.jit, static_argnames=("integrator", "interpolated"))
def evaluate_orbit(
pot: AbstractPotentialBase,
w0: PhaseSpacePosition | gt.BatchVec6,
t: gt.QVecTime | gt.VecTime | APYQuantity,
*,
integrator: Integrator | None = None,
) -> Orbit:
interpolated: Literal[True, False] = False,
) -> Orbit | InterpolatedOrbit:
"""Compute an orbit in a potential.

:class:`~galax.coordinates.PhaseSpacePosition` includes a time in
addition to the position (and velocity) information, enabling the orbit to
be evaluated over a time range that is different from the initial time of
the position.
:class:`~galax.coordinates.PhaseSpacePosition` includes a time in addition
to the position (and velocity) information, enabling the orbit to be
evaluated over a time range that is different from the initial time of the
position.

Parameters
----------
Expand Down Expand Up @@ -82,6 +84,10 @@
is used twice: once to integrate from `w0.t` to `t[0]` and then from
`t[0]` to `t[1]`.

interpolated: bool, optional keyword-only
If `True`, return an interpolated orbit. If `False`, return the orbit
at the requested times. Default is `False`.

Returns
-------
orbit : :class:`~galax.dynamics.Orbit`
Expand Down Expand Up @@ -225,7 +231,17 @@
t[-1],
savet=t,
units=units,
interpolated=interpolated,
)
wt = t

# Construct the orbit object
return Orbit(q=ws.q, p=ws.p, t=t, potential=pot)
# TODO: easier construction from the (Interpolated)PhaseSpacePosition
if interpolated:
out = InterpolatedOrbit(

Check warning on line 241 in src/galax/dynamics/_dynamics/integrate/_funcs.py

View check run for this annotation

Codecov / codecov/patch

src/galax/dynamics/_dynamics/integrate/_funcs.py#L241

Added line #L241 was not covered by tests
q=ws.q, p=ws.p, t=wt, interpolant=ws.interpolant, potential=pot
)
else:
out = Orbit(q=ws.q, p=ws.p, t=wt, potential=pot)

return out
15 changes: 13 additions & 2 deletions src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import KW_ONLY, fields
from functools import partial
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, cast
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeAlias, cast

import equinox as eqx
import jax
Expand Down Expand Up @@ -1896,6 +1896,7 @@ def evaluate_orbit(
t: gt.QVecTime | gt.VecTime | APYQuantity, # TODO: must be a Quantity
*,
integrator: "Integrator | None" = None,
interpolated: Literal[True, False] = False,
) -> "Orbit":
"""Compute an orbit in a potential.

Expand Down Expand Up @@ -1943,6 +1944,11 @@ def evaluate_orbit(
Integrator to use. If `None`, the default integrator
:class:`~galax.integrator.DiffraxIntegrator` is used.

interpolated: bool, optional keyword-only
If `True`, return an interpolated orbit. If `False`, return the orbit
at the requested times. Default is `False`.


Returns
-------
orbit : :class:`~galax.dynamics.Orbit`
Expand All @@ -1956,4 +1962,9 @@ def evaluate_orbit(
"""
from galax.dynamics import evaluate_orbit

return cast("Orbit", evaluate_orbit(self, w0, t, integrator=integrator))
return cast(
"Orbit",
evaluate_orbit(
self, w0, t, integrator=integrator, interpolated=interpolated
),
)