Skip to content

Commit

Permalink
integrator instance, not class
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 2, 2024
1 parent bef3275 commit 4a41c15
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 62 deletions.
13 changes: 12 additions & 1 deletion src/galax/dynamics/_orbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

__all__ = ["Orbit"]

import equinox as eqx
from typing_extensions import override

from galax.potential._potential.base import AbstractPotentialBase
from galax.typing import BatchFloatScalar
from galax.typing import BatchFloatScalar, BatchVec3
from galax.utils._jax import partial_jit
from galax.utils.dataclasses import converter_float_array

from ._core import AbstractPhaseSpacePosition

Expand All @@ -19,6 +21,15 @@ class Orbit(AbstractPhaseSpacePosition):
"""

q: BatchVec3 = eqx.field(converter=converter_float_array)
"""Positions (x, y, z)."""

p: BatchVec3 = eqx.field(converter=converter_float_array)
r"""Conjugate momenta (v_x, v_y, v_z)."""

t: BatchFloatScalar = eqx.field(converter=converter_float_array)
"""Array of times."""

potential: AbstractPotentialBase
"""Potential in which the orbit was integrated."""

Expand Down
55 changes: 12 additions & 43 deletions src/galax/dynamics/mockstream/_mockstream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,14 @@ class MockStreamGenerator(eqx.Module): # type: ignore[misc]
"""Potential in which the progenitor orbits and creates a stream."""

_: KW_ONLY
progenitor_integrator_cls: type[AbstractIntegrator] = eqx.field(
default=DiffraxIntegrator, static=True
progenitor_integrator: AbstractIntegrator = eqx.field(
default=DiffraxIntegrator(), static=True
)
"""Integrator class for integrating the progenitor orbit."""
"""Integrator for the progenitor orbit."""

progenitor_integrator_kw: Mapping[str, Any] | None = eqx.field(
default=None, static=True, converter=_converter_immutabledict_or_none
)
"""Keyword arguments for the progenitor integrator."""

stream_integrator_cls: type[AbstractIntegrator] = eqx.field(
default=DiffraxIntegrator, static=True
stream_integrator: AbstractIntegrator = eqx.field(
default=DiffraxIntegrator(), static=True
)
"""Integrator class for integrating the stream."""

stream_integrator_kw: Mapping[str, Any] | None = eqx.field(
default=None, static=True, converter=_converter_immutabledict_or_none
Expand All @@ -77,12 +71,7 @@ def _run_scan(
"""
# Integrate the progenitor orbit
prog_o = self.potential.integrate_orbit(
prog_w0,
xp.min(ts),
xp.max(ts),
ts,
Integrator=self.progenitor_integrator_cls,
integrator_kw=self.progenitor_integrator_kw,
prog_w0, xp.min(ts), xp.max(ts), ts, integrator=self.progenitor_integrator
)

# Generate stream initial conditions along the integrated progenitor orbit
Expand All @@ -99,12 +88,7 @@ def scan_fn(carry: Carry, idx: IntScalar) -> tuple[Carry, tuple[VecN, VecN]]:

def integ_ics(ics: Vec6) -> VecN:
return self.potential.integrate_orbit(
ics,
t_i,
t_f,
None,
Integrator=self.stream_integrator_cls,
integrator_kw=self.stream_integrator_kw,
ics, t_i, t_f, None, integrator=self.stream_integrator
).qp[0]

# vmap over leading and trailing arm
Expand All @@ -127,12 +111,7 @@ def _run_vmap(
"""
# Integrate the progenitor orbit
prog_o = self.potential.integrate_orbit(
prog_w0,
xp.min(ts),
xp.max(ts),
ts,
Integrator=self.progenitor_integrator_cls,
integrator_kw=self.progenitor_integrator_kw,
prog_w0, xp.min(ts), xp.max(ts), ts, integrator=self.progenitor_integrator
)

# Generate stream initial conditions along the integrated progenitor orbit
Expand All @@ -149,21 +128,11 @@ def single_particle_integrate(
i: int, qp0_lead_i: Vec6, qp0_trail_i: Vec6
) -> tuple[Vec6, Vec6]:
t_i = ts[i]
qp_lead = self.integrate_orbit(
qp0_lead_i,
t_i,
t_f,
None,
Integrator=self.stream_integrator_cls,
integrator_kw=self.stream_integrator_kw,
qp_lead = self.potential.integrate_orbit(
qp0_lead_i, t_i, t_f, None, integrator=self.stream_integrator
).qp[0]
qp_trail = self.integrate_orbit(
qp0_trail_i,
t_i,
t_f,
None,
Integrator=self.stream_integrator_cls,
integrator_kw=self.stream_integrator_kw,
qp_trail = self.potential.integrate_orbit(
qp0_trail_i, t_i, t_f, None, integrator=self.stream_integrator
).qp[0]
return qp_lead, qp_trail

Expand Down
24 changes: 16 additions & 8 deletions src/galax/integrate/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import equinox as eqx
from jaxtyping import Array, Float

from galax.typing import FloatScalar, Vec6, Vec7
from galax.typing import FloatScalar, Vec6


@runtime_checkable
class FCallable(Protocol):
def __call__(self, t: FloatScalar, qp: Vec6, args: tuple[Any, ...]) -> Vec7:
def __call__(self, t: FloatScalar, qp: Vec6, args: tuple[Any, ...]) -> Vec6:
"""Integration function.
Parameters
Expand All @@ -25,22 +25,28 @@ def __call__(self, t: FloatScalar, qp: Vec6, args: tuple[Any, ...]) -> Vec7:
Returns
-------
Array[float, (7,)]
[qp, t].
Array[float, (6,)]
[v (3,), a (3,)].
"""
...


class AbstractIntegrator(eqx.Module): # type: ignore[misc]
"""Integrator Class."""
"""Integrator Class.
F: FCallable
"""The function to integrate."""
# TODO: should this be moved to be the first argument of the run method?
The integrators are classes that are used to integrate the equations of
motion.
They must not be stateful since they are used in a functional way.
"""

# F: FCallable
# """The function to integrate."""
# # TODO: should this be moved to be the first argument of the run method?

@abc.abstractmethod
def run(
self,
F: FCallable,
qp0: Vec6,
t0: FloatScalar,
t1: FloatScalar,
Expand All @@ -54,6 +60,8 @@ def run(
Parameters
----------
F : FCallable
The function to integrate.
qp0 : Array[float, (6,)]
Initial conditions ``[q, p]``.
t0 : float
Expand Down
5 changes: 4 additions & 1 deletion src/galax/integrate/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from galax.typing import FloatScalar, Vec6
from galax.utils import ImmutableDict

from ._base import FCallable


class DiffraxIntegrator(AbstractIntegrator):
"""Thin wrapper around ``diffrax.diffeqsolve``."""
Expand All @@ -45,13 +47,14 @@ class DiffraxIntegrator(AbstractIntegrator):

def run(
self,
F: FCallable,
qp0: Vec6,
t0: FloatScalar,
t1: FloatScalar,
ts: Float[Array, "T"] | None,
) -> Float[Array, "R 7"]:
solution = diffeqsolve(
terms=ODETerm(self.F),
terms=ODETerm(F),
solver=self.Solver(**self.solver_kw),
t0=t0,
t1=t1,
Expand Down
18 changes: 9 additions & 9 deletions src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
__all__ = ["AbstractPotentialBase"]

import abc
from collections.abc import Mapping
from dataclasses import fields
from dataclasses import fields, replace
from typing import TYPE_CHECKING, Any

import astropy.units as u
Expand Down Expand Up @@ -34,6 +33,9 @@
from galax.dynamics._orbit import Orbit


default_integrator = DiffraxIntegrator()


class AbstractPotentialBase(eqx.Module, metaclass=ModuleMeta): # type: ignore[misc]
"""Potential Class."""

Expand Down Expand Up @@ -289,23 +291,21 @@ def tidal_tensor(

@partial_jit()
def _integrator_F(self, t: FloatScalar, xv: Vec6, args: tuple[Any, ...]) -> Vec6:
return xp.hstack([xv[3:6], self.acceleration(xv[:3], t)])
return xp.hstack([xv[3:6], self.acceleration(xv[0:3], t)]) # v, a

@partial_jit(static_argnames=("Integrator", "integrator_kw"))
def integrate_orbit(
self,
qp0: Vec6,
t0: FloatScalar,
t0: FloatScalar, # TODO: better time parsing
t1: FloatScalar,
ts: Float[Array, "time"] | None,
*,
Integrator: type[AbstractIntegrator] | None = None,
integrator_kw: Mapping[str, Any] | None = None,
integrator: AbstractIntegrator | None = None,
) -> "Orbit":
from galax.dynamics._orbit import Orbit

integrator_cls = Integrator if Integrator is not None else DiffraxIntegrator
integrator_ = default_integrator if integrator is None else replace(integrator)

integrator = integrator_cls(self._integrator_F, **(integrator_kw or {}))
ws = integrator.run(qp0, t0, t1, ts)
ws = integrator_.run(self._integrator_F, qp0, t0, t1, ts) # type: ignore[arg-type]
return Orbit(q=ws[:, :3], p=ws[:, 3:-1], t=ws[:, -1], potential=self)

0 comments on commit 4a41c15

Please sign in to comment.