Skip to content

Commit

Permalink
refactor: unxt api (#246)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Apr 2, 2024
1 parent 5f9d3b0 commit 5a3f7a7
Show file tree
Hide file tree
Showing 13 changed files with 164 additions and 45 deletions.
8 changes: 4 additions & 4 deletions docs/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ we could compute the potential energy or the acceleration at a Cartesian
position near the Sun::

>>> xyz = [-8., 0, 0] * u.kpc
>>> mw.potential_energy(xyz, t=0).to("kpc2 / Myr2")
>>> mw.potential_energy(xyz, t=0).to_units("kpc2 / Myr2")
Quantity['specific energy'](Array(-0.16440296, dtype=float64), unit='kpc2 / Myr2')
>>> mw.acceleration(xyz, t=0).to("kpc/Myr2")
>>> mw.acceleration(xyz, t=0).to_units("kpc/Myr2")
Quantity['acceleration'](Array([ 0.00702262, -0. , -0. ], dtype=float64), unit='kpc / Myr2')

The values that are returned by most methods in :mod:`galax` are provided as
Expand All @@ -77,9 +77,9 @@ with associated physical units. :class:`~astropy.units.Quantity` objects can be
re-represented in any equivalent units, so, for example, we could display the
energy or acceleration in other units::

>>> mw.potential_energy(xyz, t=0).to("kpc2/Myr2")
>>> mw.potential_energy(xyz, t=0).to_units("kpc2/Myr2")
Quantity['specific energy'](Array(-0.16440296, dtype=float64), unit='kpc2 / Myr2')
>>> mw.acceleration(xyz, t=0).to("kpc/Myr2")
>>> mw.acceleration(xyz, t=0).to_units("kpc/Myr2")
Quantity['acceleration'](Array([ 0.00702262, -0. , -0. ], dtype=float64), unit='kpc / Myr2')

Now that we have a potential model, if we want to compute an orbit, we need to
Expand Down
7 changes: 6 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ def lint(session: nox.Session) -> None:
"""Run the linter."""
session.install("pre-commit")
session.run(
"pre-commit", "run", "--all-files", "--show-diff-on-failure", *session.posargs
"pre-commit",
"run",
"--all-files",
"--show-diff-on-failure",
*session.posargs,
)


Expand Down Expand Up @@ -45,6 +49,7 @@ def doctests(session: nox.Session) -> None:
"--doctest-modules",
'--doctest-glob="*.rst"',
'--doctest-glob="*.md"',
'--doctest-glob="*.py"',
"docs",
"src/galax",
*session.posargs,
Expand Down
2 changes: 1 addition & 1 deletion src/galax/coordinates/_psp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def to_units(self, units: Any) -> "Self":
self,
q=self.q.to_units(usys),
p=self.p.to_units(usys),
t=self.t.to(usys["time"]) if self.t is not None else None,
t=self.t.to_units(usys["time"]) if self.t is not None else None,
)

# ==========================================================================
Expand Down
2 changes: 1 addition & 1 deletion src/galax/coordinates/_psp/operator_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def call(
>>> newpsp.q.x
Quantity['length'](Array(2., dtype=float64), unit='kpc')
>>> newpsp.t.to("Myr")
>>> newpsp.t.to_units("Myr")
Quantity['time'](Array(6.52312732, dtype=float64), unit='Myr')
This spatial translation is time independent.
Expand Down
14 changes: 9 additions & 5 deletions src/galax/dynamics/_dynamics/integrate/_api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
__all__ = ["Integrator"]

from typing import Any, Protocol, runtime_checkable
from typing import Any, Protocol, TypeAlias, runtime_checkable

from unxt import AbstractUnitSystem

import galax.typing as gt
from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition
from galax.utils.dataclasses import _DataclassInstance

SaveT: TypeAlias = gt.BatchQVecTime | gt.QVecTime | gt.BatchVecTime | gt.VecTime


@runtime_checkable
class FCallable(Protocol):
Expand Down Expand Up @@ -39,7 +41,11 @@ class Integrator(_DataclassInstance, Protocol):
The integrators are classes that are used to integrate the equations of
motion.
They must not be stateful since they are used in a functional way.
.. note::
Integrators should NOT be stateful (i.e., they must not have attributes
that change).
"""

# TODO: shape hint of the return type
Expand All @@ -50,9 +56,7 @@ def __call__(
t0: gt.FloatQScalar | gt.FloatScalar,
t1: gt.FloatQScalar | gt.FloatScalar,
/,
savet: (
gt.BatchQVecTime | gt.QVecTime | gt.BatchVecTime | gt.VecTime | None
) = None,
savet: SaveT | None = None,
*,
units: AbstractUnitSystem,
) -> PhaseSpacePosition:
Expand Down
6 changes: 5 additions & 1 deletion src/galax/dynamics/_dynamics/integrate/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ def __call__(
savet : (Quantity | Array)[float, (T,)] | None, optional
Times to return the computation. If `None`, the computation is
returned at the final time.
returned only at the final time.
units : `unxt.AbstractUnitSystem`
The unit system to use.
interpolated : bool, keyword-only
Whether to return an interpolated solution.
Returns
-------
Expand All @@ -81,6 +83,8 @@ def __call__(
>>> w0 = gc.PhaseSpacePosition(q=Quantity([10., 0., 0.], "kpc"),
... p=Quantity([0., 200., 0.], "km/s"))
(Note that the ``t`` attribute is not used.)
Now we can integrate the phase-space position for 1 Gyr, getting the
final position. The integrator accepts any function for the equations
of motion. Here we will reproduce what happens with orbit integrations.
Expand Down
140 changes: 123 additions & 17 deletions src/galax/dynamics/_dynamics/integrate/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,44 @@
import diffrax
import equinox as eqx
import jax
from jaxtyping import Array, Float, Shaped

import quaxed.array_api as xp
from unxt import AbstractUnitSystem, Quantity
from unxt import AbstractUnitSystem, Quantity, to_units_value

import galax.coordinates as gc
import galax.typing as gt
from ._api import FCallable
from ._base import AbstractIntegrator
from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition
from galax.utils import ImmutableDict
from galax.utils._jax import vectorize_method


def _to_value(
x: Shaped[Quantity, "*shape"] | Float[Array, "*shape"], unit: gt.Unit, /
) -> Float[Array, "*shape"]:
return x.to_value(unit) if isinstance(x, Quantity) else x


@final
class DiffraxIntegrator(AbstractIntegrator):
"""Thin wrapper around ``diffrax.diffeqsolve``."""
"""Integrator using :func:`diffrax.diffeqsolve`.
This integrator uses the :func:`diffrax.diffeqsolve` function to integrate
the equations of motion. :func:`diffrax.diffeqsolve` supports a wide range
of solvers and options. See the documentation of :func:`diffrax.diffeqsolve`
for more information.
Parameters
----------
Solver : type[diffrax.AbstractSolver], optional
The solver to use. Default is :class:`diffrax.Dopri5`.
stepsize_controller : diffrax.AbstractStepSizeController, optional
The stepsize controller to use. Default is a PID controller with
relative and absolute tolerances of 1e-7.
diffeq_kw : Mapping[str, Any], optional
Keyword arguments to pass to :func:`diffrax.diffeqsolve`. Default is
``{"max_steps": None, "discrete_terminating_event": None}``. The
``"max_steps"`` key is removed if ``interpolated=True`` in the
:meth`DiffraxIntegrator.__call__` method.
solver_kw : Mapping[str, Any], optional
Keyword arguments to pass to the solver. Default is ``{"scan_kind":
"bounded"}``.
"""

_: KW_ONLY
Solver: type[diffrax.AbstractSolver] = eqx.field(
Expand Down Expand Up @@ -76,7 +92,7 @@ def _call_implementation(
def __call__(
self,
F: FCallable,
w0: AbstractPhaseSpacePosition | gt.BatchVec6,
w0: gc.AbstractPhaseSpacePosition | gt.BatchVec6,
t0: gt.FloatQScalar | gt.FloatScalar,
t1: gt.FloatQScalar | gt.FloatScalar,
/,
Expand All @@ -85,22 +101,112 @@ def __call__(
) = None,
*,
units: AbstractUnitSystem,
) -> PhaseSpacePosition:
) -> gc.PhaseSpacePosition:
"""Run the integrator.
Parameters
----------
F : FCallable, positional-only
The function to integrate.
w0 : AbstractPhaseSpacePosition | Array[float, (6,)], positional-only
Initial conditions ``[q, p]``.
t0, t1 : Quantity, positional-only
Initial and final times.
savet : (Quantity | Array)[float, (T,)] | None, optional
Times to return the computation. If `None`, the computation is
returned only at the final time.
units : `unxt.AbstractUnitSystem`
The unit system to use.
Returns
-------
PhaseSpacePosition[float, (time, 7)]
The solution of the integrator [q, p, t], where q, p are the
generalized 3-coordinates.
Examples
--------
For this example, we will use the
:class:`~galax.integrate.DiffraxIntegrator`
First some imports:
>>> import quaxed.array_api as xp
>>> from unxt import Quantity
>>> import unxt.unitsystems as usx
>>> import galax.coordinates as gc
>>> import galax.dynamics as gd
>>> import galax.potential as gp
Then we define initial conditions:
>>> w0 = gc.PhaseSpacePosition(q=Quantity([10., 0., 0.], "kpc"),
... p=Quantity([0., 200., 0.], "km/s"))
(Note that the ``t`` attribute is not used.)
Now we can integrate the phase-space position for 1 Gyr, getting the
final position. The integrator accepts any function for the equations
of motion. Here we will reproduce what happens with orbit integrations.
>>> pot = gp.HernquistPotential(m=Quantity(1e12, "Msun"), c=Quantity(5, "kpc"),
... units="galactic")
>>> integrator = gd.integrate.DiffraxIntegrator()
>>> t0, t1 = Quantity(0, "Gyr"), Quantity(1, "Gyr")
>>> w = integrator(pot._integrator_F, w0, t0, t1, units=usx.galactic)
>>> w
PhaseSpacePosition(
q=Cartesian3DVector( ... ),
p=CartesianDifferential3D( ... ),
t=Quantity[...](value=f64[], unit=Unit("Myr"))
)
>>> w.shape
()
We can also request the orbit at specific times:
>>> ts = Quantity(xp.linspace(0, 1, 10), "Myr") # 10 steps
>>> ws = integrator(pot._integrator_F, w0, t0, t1, savet=ts, units=usx.galactic)
>>> ws
PhaseSpacePosition(
q=Cartesian3DVector( ... ),
p=CartesianDifferential3D( ... ),
t=Quantity[...](value=f64[10], unit=Unit("Myr"))
)
>>> ws.shape
(10,)
The integrator can also be used to integrate a batch of initial
conditions at once, returning a batch of final conditions (or a batch
of conditions at the requested times):
>>> w0 = gc.PhaseSpacePosition(q=Quantity([[10., 0, 0], [10., 0, 0]], "kpc"),
... p=Quantity([[0, 200, 0], [0, 200, 0]], "km/s"))
>>> ws = integrator(pot._integrator_F, w0, t0, t1, units=usx.galactic)
>>> ws.shape
(2,)
"""
# Parse inputs
t0_: gt.VecTime = _to_value(t0, units["time"])
t1_: gt.VecTime = _to_value(t1, units["time"])
savet_ = xp.asarray([t1_]) if savet is None else _to_value(savet, units["time"])
t0_: gt.VecTime = to_units_value(t0, units["time"])
t1_: gt.VecTime = to_units_value(t1, units["time"])
savet_ = (
xp.asarray([t1_]) if savet is None else to_units_value(savet, units["time"])
)

w0_: gt.Vec6 = (
w0.w(units=units) if isinstance(w0, AbstractPhaseSpacePosition) else w0
w0.w(units=units) if isinstance(w0, gc.AbstractPhaseSpacePosition) else w0
)

# Perform the integration
w = self._call_implementation(F, w0_, t0_, t1_, savet_)
w = w[..., -1, :] if savet is None else w

# Return
return PhaseSpacePosition(
return gc.PhaseSpacePosition(
q=Quantity(w[..., 0:3], units["length"]),
p=Quantity(w[..., 3:6], units["speed"]),
t=Quantity(w[..., -1], units["time"]),
Expand Down
16 changes: 8 additions & 8 deletions src/galax/dynamics/_dynamics/mockstream/df/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,16 @@ def scan_fn(carry: Carry, t: gt.FloatQScalar) -> tuple[Carry, Wif]:
x_lead, x_trail, v_lead, v_trail = jax.lax.scan(scan_fn, init_carry, ts)[1]

mock_lead = MockStream(
q=x_lead.to(pot.units["length"]),
p=v_lead.to(pot.units["speed"]),
t=ts.to(pot.units["time"]),
release_time=ts.to(pot.units["time"]),
q=x_lead.to_units(pot.units["length"]),
p=v_lead.to_units(pot.units["speed"]),
t=ts.to_units(pot.units["time"]),
release_time=ts.to_units(pot.units["time"]),
)
mock_trail = MockStream(
q=x_trail.to(pot.units["length"]),
p=v_trail.to(pot.units["speed"]),
t=ts.to(pot.units["time"]),
release_time=ts.to(pot.units["time"]),
q=x_trail.to_units(pot.units["length"]),
p=v_trail.to_units(pot.units["speed"]),
t=ts.to_units(pot.units["time"]),
release_time=ts.to_units(pot.units["time"]),
)

return mock_lead, mock_trail
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def run(
w0 = PhaseSpacePosition(
q=Quantity(prog_w0[0:3], self.units["length"]),
p=Quantity(prog_w0[3:6], self.units["speed"]),
t=ts[0].to(self.potential.units["time"]),
t=ts[0].to_units(self.potential.units["time"]),
)
w0 = eqx.error_if(w0, w0.ndim > 0, "prog_w0 must be scalar")

Expand Down
4 changes: 2 additions & 2 deletions src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _init_units(self) -> None:
elif "dimensions" in f.metadata:
value = getattr(self, f.name)
if isinstance(value, APYQuantity):
value = value.to_value(
value = value.to_units_value(
self.units[f.metadata.get("dimensions")],
equivalencies=f.metadata.get("equivalencies", None),
)
Expand Down Expand Up @@ -1887,7 +1887,7 @@ def _integrator_F(
args: tuple[Any, ...], # noqa: ARG002
) -> gt.Vec6:
"""Return the derivative of the phase-space position."""
a = self.acceleration(w[0:3], t).to_value(self.units["acceleration"])
a = self.acceleration(w[0:3], t).to_units_value(self.units["acceleration"])
return jnp.hstack([w[3:6], a]) # v, a

def evaluate_orbit(
Expand Down
4 changes: 2 additions & 2 deletions src/galax/potential/_potential/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class PotentialFrame(AbstractPotentialBase):
Now we define a triaxial Hernquist potential with a time-dependent mass:
>>> mfunc = gp.UserParameter(lambda t: 1e12 * (1 + t.to_value("Gyr") / 10), unit="Msun")
>>> mfunc = gp.UserParameter(lambda t: 1e12 * (1 + t.to_units_value("Gyr") / 10), unit="Msun")
>>> pot = gp.TriaxialHernquistPotential(m=mfunc, c=Quantity(1, "kpc"),
... q1=1, q2=0.5, units="galactic")
Expand Down Expand Up @@ -88,7 +88,7 @@ class PotentialFrame(AbstractPotentialBase):
We can also apply a time translation to the potential:
>>> op2 = cxo.GalileanTranslationOperator(Quantity([1_000, 0, 0, 0], "kpc"))
>>> op2.translation.t.to("Myr")
>>> op2.translation.t.to_units("Myr")
Quantity['time'](Array(3.26156366, dtype=float64), unit='Myr')
>>> framedpot2 = gp.PotentialFrame(potential=pot, operator=op2)
Expand Down
2 changes: 1 addition & 1 deletion src/galax/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def converter_float_array(x: Any, /) -> Float[Array, "*shape"]:
@converter_float_array.register
def _converter_float_quantity(x: Quantity, /) -> Float[Array, "*shape"]:
"""Convert to a batched vector."""
return converter_float_array(x.to_value(u.dimensionless_unscaled))
return converter_float_array(x.to_units_value(u.dimensionless_unscaled))


##############################################################################
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/potential/param/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def param_cls(self) -> type[T]:
@pytest.fixture(scope="class")
def field_func(self) -> ParameterCallable:
def func(t: Quantity["time"], **kwargs: Any) -> Any:
return Quantity(t.to_value("Gyr"), "kpc")
return Quantity(t.to_units_value("Gyr"), "kpc")

return func

Expand Down

0 comments on commit 5a3f7a7

Please sign in to comment.