Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: split integrator into separate module #118

Merged
merged 2 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from jaxtyping import Array, Shaped

from galax.dynamics._dynamics.orbit import Orbit
from galax.integrate._base import Integrator
from galax.integrate._api import Integrator
from galax.integrate._builtin import DiffraxIntegrator
from galax.potential._potential.base import AbstractPotentialBase
from galax.typing import BatchVec6, FloatScalar, IntScalar, Vec6, VecN, VecTime
Expand Down
4 changes: 3 additions & 1 deletion src/galax/integrate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""galax: Galactic Dynamix in Jax."""


from galax.integrate import _base, _builtin
from galax.integrate import _api, _base, _builtin
from galax.integrate._api import *
from galax.integrate._base import *
from galax.integrate._builtin import *

__all__: list[str] = []
__all__ += _api.__all__
__all__ += _base.__all__
__all__ += _builtin.__all__

Expand Down
63 changes: 63 additions & 0 deletions src/galax/integrate/_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
__all__ = ["Integrator"]

from typing import Any, Protocol, runtime_checkable

from galax.typing import FloatScalar, Vec6, VecTime, VecTime7
from galax.utils.dataclasses import _DataclassInstance


@runtime_checkable
class FCallable(Protocol):
"""Protocol for the integration callable."""

def __call__(self, t: FloatScalar, w: Vec6, args: tuple[Any, ...]) -> Vec6:
"""Integration function.

Parameters
----------
t : float
The time. This is the integration variable.
w : Array[float, (6,)]
The position and velocity.
args : tuple[Any, ...]
Additional arguments.

Returns
-------
Array[float, (6,)]
Velocity and acceleration [v (3,), a (3,)].
"""
...


@runtime_checkable
class Integrator(_DataclassInstance, Protocol):
""":class:`typing.Protocol` for integrators.

The integrators are classes that are used to integrate the equations of
motion.
They must not be stateful since they are used in a functional way.
"""

def __call__(self, F: FCallable, w0: Vec6, /, ts: VecTime) -> VecTime7:
"""Integrate.

Parameters
----------
F : FCallable, positional-only
The function to integrate.
(t, w, args) -> (v, a).
w0 : Array[float, (6,)], positional-only
Initial conditions ``[q, p]``.

ts : Array[float, (T,)]
Times to return the computation.
It's necessary to at least provide the initial and final times.

Returns
-------
Array[float, (T, 7)]
The solution of the integrator [q, p, t], where q, p are the
generalized 3-coordinates.
"""
...
78 changes: 6 additions & 72 deletions src/galax/integrate/_base.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,12 @@
__all__ = ["Integrator", "AbstractIntegrator"]
__all__ = ["AbstractIntegrator"]

import abc
from typing import Any, Protocol, runtime_checkable

import equinox as eqx
from jaxtyping import Array, Float

from galax.typing import FloatScalar, Vec6
from galax.utils.dataclasses import _DataclassInstance
from galax.typing import Vec6, VecTime, VecTime7


@runtime_checkable
class FCallable(Protocol):
"""Protocol for the integration callable."""

def __call__(self, t: FloatScalar, w: Vec6, args: tuple[Any, ...]) -> Vec6:
"""Integration function.

Parameters
----------
t : float
The time.
w : Array[float, (6,)]
The position and velocity.
args : tuple
Additional arguments.

Returns
-------
Array[float, (6,)]
[v (3,), a (3,)].
"""
...


@runtime_checkable
class Integrator(_DataclassInstance, Protocol):
""":class:`typing.Protocol` for integrators.

The integrators are classes that are used to integrate the equations of
motion.
They must not be stateful since they are used in a functional way.
"""

def __call__(
self, F: FCallable, w0: Vec6, /, ts: Float[Array, "T"] | None
) -> Float[Array, "R 7"]:
"""Integrate.

Parameters
----------
F : FCallable, positional-only
The function to integrate.
(t, w, args) -> (v, a).
w0 : Array[float, (6,)], positional-only
Initial conditions ``[q, p]``.

ts : Array[float, (T,)] | None
Times to return the computation.
It's necessary to at least provide the initial and final times.

Returns
-------
Array[float, (R, 7)]
The solution of the integrator [q, p, t], where q, p are the
generalized 3-coordinates.
"""
...
from ._api import FCallable


class AbstractIntegrator(eqx.Module, strict=True): # type: ignore[call-arg, misc]
Expand All @@ -82,13 +22,7 @@ class AbstractIntegrator(eqx.Module, strict=True): # type: ignore[call-arg, mis
"""

@abc.abstractmethod
def __call__(
self,
F: FCallable,
w0: Vec6,
/,
ts: Float[Array, "T"],
) -> Float[Array, "T 7"]:
def __call__(self, F: FCallable, w0: Vec6, /, ts: VecTime) -> VecTime7:
"""Run the integrator.

Parameters
Expand All @@ -98,13 +32,13 @@ def __call__(
w0 : Array[float, (6,)], positional-only
Initial conditions ``[q, p]``.

ts : Array[float, (T,)] | None
ts : Array[float, (time,)]
Times to return the computation.
It's necessary to at least provide the initial and final times.

Returns
-------
Array[float, (R, 7)]
Array[float, (time, 7)]
The solution of the integrator [q, p, t], where q, p are the
generalized 3-coordinates.
"""
Expand Down
2 changes: 1 addition & 1 deletion src/galax/integrate/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from galax.utils import ImmutableDict
from galax.utils._jax import vectorize_method

from ._base import FCallable
from ._api import FCallable


@final
Expand Down
2 changes: 1 addition & 1 deletion src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from jax import grad, hessian, jacfwd
from jaxtyping import Array, Float

from galax.integrate._base import Integrator
from galax.integrate._api import Integrator
from galax.integrate._builtin import DiffraxIntegrator
from galax.potential._potential.param.attr import ParametersAttribute
from galax.potential._potential.param.utils import all_parameters
Expand Down
1 change: 1 addition & 0 deletions src/galax/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
# Specific

VecTime = Float[Array, "time"]
VecTime7 = Float[Array, "time 7"]
"""A time vector."""

# -----------------------------------------------------------------------------
Expand Down
6 changes: 4 additions & 2 deletions tests/smoke/integrate/test_package.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Test the :mod:`galax.integrate` module."""

from galax import integrate
from galax.integrate import _base, _builtin
from galax.integrate import _api, _base, _builtin


def test_all() -> None:
"""Test the API."""
assert set(integrate.__all__) == set(_base.__all__ + _builtin.__all__)
assert set(integrate.__all__) == set(
_api.__all__ + _base.__all__ + _builtin.__all__
)
Loading