Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify time api #67

Merged
merged 3 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions src/galax/dynamics/mockstream/_mockstream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ def _run_scan( # TODO: output shape depends on the input shape
def scan_fn(carry: Carry, idx: IntScalar) -> tuple[Carry, tuple[VecN, VecN]]:
i, qp0_lead_i, qp0_trail_i = carry
qp0_lead_trail = xp.vstack([qp0_lead_i, qp0_trail_i])
t_i, t_f = ts[i], ts[-1]
tstep = xp.asarray([ts[i], ts[-1]])

def integ_ics(ics: Vec6) -> VecN:
return self.potential.integrate_orbit(
ics, t_i, t_f, None, integrator=self.stream_integrator
ics, tstep, integrator=self.stream_integrator
).qp[0]

# vmap over leading and trailing arm
Expand Down Expand Up @@ -106,12 +106,12 @@ def _run_vmap( # TODO: output shape depends on the input shape
def single_particle_integrate(
i: IntScalar, qp0_lead_i: Vec6, qp0_trail_i: Vec6
) -> tuple[Vec6, Vec6]:
t_i = ts[i]
tstep = xp.asarray([ts[i], t_f])
qp_lead = self.potential.integrate_orbit(
qp0_lead_i, t_i, t_f, None, integrator=self.stream_integrator
qp0_lead_i, tstep, integrator=self.stream_integrator
).qp[0]
qp_trail = self.potential.integrate_orbit(
qp0_trail_i, t_i, t_f, None, integrator=self.stream_integrator
qp0_trail_i, tstep, integrator=self.stream_integrator
).qp[0]
return qp_lead, qp_trail

Expand All @@ -130,9 +130,39 @@ def run(
seed_num: int,
vmapped: bool = False,
) -> tuple[tuple[MockStream, MockStream], Orbit]:
# Integrate the progenitor orbit
"""Generate mock stellar stream.

Parameters
----------
ts : Array[float, (time,)]
Stripping times.
prog_w0 : Array[float, (6,)]
Initial conditions of the progenitor.
prog_mass : float
Mass of the progenitor.

seed_num : int, keyword-only
Seed number for the random number generator.

:todo: a better way to handle PRNG

vmapped : bool, optional keyword-only
Whether to use `jax.vmap` (`True`) or `jax.lax.scan` (`False`) to
parallelize the integration. ``vmapped=True`` is recommended for GPU
usage, while ``vmapped=False`` is recommended for CPU usage.

Returns
-------
lead_arm, trail_arm : tuple[MockStream, MockStream]
Leading and trailing arms of the mock stream.
prog_o : Orbit
Orbit of the progenitor.
"""
# TODO: a discussion about the stripping times

# Integrate the progenitor orbit to the stripping times
prog_o = self.potential.integrate_orbit(
prog_w0, xp.min(ts), xp.max(ts), ts, integrator=self.progenitor_integrator
prog_w0, ts, integrator=self.progenitor_integrator
)

# Generate stream initial conditions along the integrated progenitor orbit
Expand Down
16 changes: 3 additions & 13 deletions src/galax/integrate/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,28 +44,18 @@ def run(
self,
F: FCallable,
qp0: Vec6,
t0: FloatScalar,
t1: FloatScalar,
ts: Float[Array, "T"] | None,
) -> Float[Array, "R 7"]:
ts: Float[Array, "T"],
) -> Float[Array, "T 7"]:
"""Run the integrator.

.. todo::

Have a better time parser.

Parameters
----------
F : FCallable
The function to integrate.
qp0 : Array[float, (6,)]
Initial conditions ``[q, p]``.
t0 : float
Initial time.
t1 : float
Final time.
ts : Array[float, (T,)] | None
Times for the computation.
Times to return the computation.

Returns
-------
Expand Down
10 changes: 4 additions & 6 deletions src/galax/integrate/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from jaxtyping import Array, Float

from galax.integrate._base import AbstractIntegrator
from galax.typing import FloatScalar, Vec6
from galax.typing import Vec6
from galax.utils import ImmutableDict

from ._base import FCallable
Expand Down Expand Up @@ -49,15 +49,13 @@ def run(
self,
F: FCallable,
qp0: Vec6,
t0: FloatScalar,
t1: FloatScalar,
ts: Float[Array, "T"] | None,
ts: Float[Array, "T"],
) -> Float[Array, "R 7"]:
solution = diffeqsolve(
terms=ODETerm(F),
solver=self.Solver(**self.solver_kw),
t0=t0,
t1=t1,
t0=ts[0],
t1=ts[-1],
y0=qp0,
dt0=None,
args=(),
Expand Down
76 changes: 72 additions & 4 deletions src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,21 +291,89 @@ def tidal_tensor(

@partial_jit()
def _integrator_F(self, t: FloatScalar, qp: Vec6, args: tuple[Any, ...]) -> Vec6:
"""Return the derivative of the phase-space position."""
return xp.hstack([qp[3:6], self.acceleration(qp[0:3], t)]) # v, a

@partial_jit(static_argnames=("integrator"))
def integrate_orbit(
self,
qp0: Vec6,
t0: FloatScalar, # TODO: better time parsing
t1: FloatScalar,
ts: Float[Array, "time"] | None,
t: Float[Array, "time"],
*,
integrator: AbstractIntegrator | None = None,
) -> "Orbit":
"""Integrate an orbit in the potential.

Parameters
----------
qp0 : Array[float, (6,)]
Initial position and velocity.
t: Array[float, (T,)]
Array of times at which to compute the orbit. The first element
should be the initial time and the last element should be the final
time and the array should be monotonically moving from the first to
final time. See the Examples section for options when constructing
this argument.

.. note::

To integrate backwards in time, ...

.. warning::

This is NOT the timesteps to use for integration, which are
controlled by the `integrator`; the default integrator
:class:`~galax.integrator.DiffraxIntegrator` uses adaptive
timesteps.

integrator : AbstractIntegrator, keyword-only
Integrator to use. If `None`, the default integrator
:class:`~galax.integrator.DiffraxIntegrator` is used.

Returns
-------
orbit : Orbit
The integrated orbit evaluated at the given times.

Examples
--------
We start by integrating a single orbit in the potential of a point mass.
A few standard imports are needed:

>>> import astropy.units as u
>>> import jax.experimental.array_api as xp # preferred over `jax.numpy`
>>> import galax.potential as gp
>>> from galax.units import galactic

We can then create the point-mass' potential, with galactic units:

>>> potential = gp.KeplerPotential(m=1e12 * u.Msun, units=galactic)

We can then integrate an initial phase-space position in this potential
to get an orbit:

>>> xv0 = xp.asarray([10., 0., 0., 0., 0.1, 0.]) # (x, v) galactic units
>>> ts = xp.linspace(0., 1000, 4) # (1 Gyr, 4 steps)
>>> orbit = potential.integrate_orbit(xv0, ts)
>>> orbit
Orbit(
q=f64[4,3], p=f64[4,3], t=f64[4], potential=KeplerPotential(...)
)

Note how there are 4 points in the orbit, corresponding to the 4 steps.
Changing the number of steps is easy:

>>> ts = xp.linspace(0., 1000, 10) # (1 Gyr, 4 steps)
>>> orbit = potential.integrate_orbit(xv0, ts)
>>> orbit
Orbit(
q=f64[10,3], p=f64[10,3], t=f64[10], potential=KeplerPotential(...)
)
"""
# TODO: ꜛ get NORMALIZE_WHITESPACE to work correctly so Orbit is 1 line
from galax.dynamics._orbit import Orbit

integrator_ = default_integrator if integrator is None else replace(integrator)

ws = integrator_.run(self._integrator_F, qp0, t0, t1, ts)
ws = integrator_.run(self._integrator_F, qp0, t)
return Orbit(q=ws[:, :3], p=ws[:, 3:-1], t=ws[:, -1], potential=self)
2 changes: 1 addition & 1 deletion tests/unit/potential/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_integrate_orbit(self, pot, xv):
"""Test the `AbstractPotentialBase.integrate_orbit` method."""
ts = xp.linspace(0.0, 1.0, 100)

orbit = pot.integrate_orbit(xv, t0=min(ts), t1=max(ts), ts=ts)
orbit = pot.integrate_orbit(xv, ts)
assert isinstance(orbit, gd.Orbit)
assert orbit.shape == (len(ts), 7)
assert xp.array_equal(orbit.t, ts)