Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove private jit utils #104

Merged
merged 4 commits into from
Jan 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/galax/dynamics/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
__all__ = ["AbstractPhaseSpacePosition", "PhaseSpacePosition"]

from abc import abstractmethod
from functools import partial
from typing import TYPE_CHECKING, final

import equinox as eqx
import jax
import jax.experimental.array_api as xp
import jax.numpy as jnp
from jaxtyping import Array, Float

from galax.typing import BatchFloatScalar, BatchVec3, BatchVec6, BatchVec7
from galax.utils import partial_jit
from galax.utils._shape import atleast_batched, batched_shape
from galax.utils.dataclasses import converter_float_array

Expand Down Expand Up @@ -49,7 +50,7 @@ def shape(self) -> tuple[int, ...]:
# Convenience properties

@property
@partial_jit()
@partial(jax.jit)
def qp(self) -> BatchVec6:
"""Return as a single Array[float, (*batch, Q + P),]."""
batch_shape, component_shapes = self._shape_tuple
Expand All @@ -66,7 +67,7 @@ def __len__(self) -> int:
# ==========================================================================
# Dynamical quantities

@partial_jit()
@partial(jax.jit)
def kinetic_energy(self) -> BatchFloatScalar:
r"""Return the specific kinetic energy.

Expand Down Expand Up @@ -100,7 +101,7 @@ def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]:
# Convenience properties

@property
@partial_jit()
@partial(jax.jit)
def w(self) -> BatchVec7:
"""Return as a single Array[float, (*batch, Q + P + T)]."""
batch_shape, component_shapes = self._shape_tuple
Expand All @@ -112,7 +113,7 @@ def w(self) -> BatchVec7:
return xp.concat((q, p, t), axis=-1)

@property
@partial_jit()
@partial(jax.jit)
def angular_momentum(self) -> BatchVec3:
r"""Compute the angular momentum.

Expand Down Expand Up @@ -144,7 +145,7 @@ def angular_momentum(self) -> BatchVec3:
# ==========================================================================
# Dynamical quantities

@partial_jit()
@partial(jax.jit)
def potential_energy(
self, potential: "AbstractPotentialBase", /
) -> BatchFloatScalar:
Expand All @@ -166,7 +167,7 @@ def potential_energy(
"""
return potential.potential_energy(self.q, t=self.t)

@partial_jit()
@partial(jax.jit)
def energy(self, potential: "AbstractPotentialBase", /) -> BatchFloatScalar:
r"""Return the specific total energy.

Expand Down Expand Up @@ -204,7 +205,7 @@ def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]:
return batch_shape, array_shape

@property
@partial_jit()
@partial(jax.jit)
def w(self) -> BatchVec7:
"""Return as a single Array[float, (*batch, Q + P + T)]."""
batch_shape, component_shapes = self._shape_tuple
Expand Down
6 changes: 4 additions & 2 deletions src/galax/dynamics/_orbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

__all__ = ["Orbit"]

from functools import partial

import equinox as eqx
import jax
from jaxtyping import Array, Float
from typing_extensions import override

from galax.potential._potential.base import AbstractPotentialBase
from galax.typing import BatchFloatScalar, TimeVector
from galax.utils._jax import partial_jit
from galax.utils.dataclasses import converter_float_array

from ._core import AbstractPhaseSpacePosition
Expand Down Expand Up @@ -38,7 +40,7 @@ class Orbit(AbstractPhaseSpacePosition):
# Dynamical quantities

@override
@partial_jit()
@partial(jax.jit)
def potential_energy(
self, potential: AbstractPotentialBase | None = None, /
) -> BatchFloatScalar:
Expand Down
6 changes: 4 additions & 2 deletions src/galax/dynamics/mockstream/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

__all__ = ["MockStream"]

from functools import partial

import equinox as eqx
import jax
import jax.experimental.array_api as xp
import jax.numpy as jnp
from jaxtyping import Array, Float

from galax.dynamics._core import AbstractPhaseSpacePositionBase
from galax.typing import BatchVec7, TimeVector
from galax.utils import partial_jit
from galax.utils._shape import atleast_batched, batched_shape
from galax.utils.dataclasses import converter_float_array

Expand Down Expand Up @@ -44,7 +46,7 @@ def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]:
return batch_shape, qshape + pshape + (1,)

@property
@partial_jit()
@partial(jax.jit)
def w(self) -> BatchVec7:
"""Return as a single Array[float, (*batch, Q + P + T)]."""
batch_shape, component_shapes = self._shape_tuple
Expand Down
4 changes: 2 additions & 2 deletions src/galax/dynamics/mockstream/_df/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
__all__ = ["AbstractStreamDF"]

import abc
from functools import partial
from typing import TypeAlias

import equinox as eqx
Expand All @@ -14,7 +15,6 @@
from galax.dynamics.mockstream._core import MockStream
from galax.potential._potential.base import AbstractPotentialBase
from galax.typing import BatchVec3, FloatScalar, IntLike, Vec3, Vec6
from galax.utils import partial_jit

Wif: TypeAlias = tuple[Vec3, Vec3, Vec3, Vec3]
Carry: TypeAlias = tuple[IntLike, Vec3, Vec3, Vec3, Vec3]
Expand All @@ -31,7 +31,7 @@ def __post_init__(self) -> None:
msg = "You must generate either leading or trailing tails (or both!)"
raise ValueError(msg)

@partial_jit(static_argnames=("seed_num",))
@partial(jax.jit, static_argnames=("seed_num",))
def sample(
self,
# <\ parts of gala's ``prog_orbit``
Expand Down
18 changes: 10 additions & 8 deletions src/galax/dynamics/mockstream/_df/fardal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
__all__ = ["FardalStreamDF"]


from functools import partial

import jax
import jax.experimental.array_api as xp
import jax.numpy as jnp
from jax import grad, random
Expand All @@ -16,7 +19,6 @@
Vec3,
Vec6,
)
from galax.utils import partial_jit

from .base import AbstractStreamDF

Expand All @@ -29,7 +31,7 @@ class FardalStreamDF(AbstractStreamDF):
https://ui.adsabs.harvard.edu/abs/2015MNRAS.452..301F/abstract
"""

@partial_jit(static_argnums=(0,), static_argnames=("seed_num",))
@partial(jax.jit, static_argnums=(0,), static_argnames=("seed_num",))
def _sample(
self,
potential: AbstractPotentialBase,
Expand Down Expand Up @@ -115,7 +117,7 @@ def _sample(
# TODO: move this to a more general location.


@partial_jit()
@partial(jax.jit)
def dphidr(potential: AbstractPotentialBase, x: Vec3, t: FloatScalar) -> Vec3:
"""Compute the derivative of the potential at a position x.

Expand All @@ -137,7 +139,7 @@ def dphidr(potential: AbstractPotentialBase, x: Vec3, t: FloatScalar) -> Vec3:
return xp.sum(potential.gradient(x, t) * r_hat)


@partial_jit()
@partial(jax.jit)
def d2phidr2(
potential: AbstractPotentialBase, x: Vec3, /, t: FloatOrIntScalarLike
) -> FloatScalar:
Expand Down Expand Up @@ -172,7 +174,7 @@ def d2phidr2(
return xp.sum(grad(dphi_dr_func)(x) * r_hat)


@partial_jit()
@partial(jax.jit)
def orbital_angular_velocity(x: Vec3, v: Vec3, /) -> Vec3:
"""Compute the orbital angular velocity about the origin.

Expand All @@ -199,7 +201,7 @@ def orbital_angular_velocity(x: Vec3, v: Vec3, /) -> Vec3:
return jnp.cross(x, v) / r**2


@partial_jit()
@partial(jax.jit)
def orbital_angular_velocity_mag(x: Vec3, v: Vec3, /) -> FloatScalar:
"""Compute the magnitude of the angular momentum in the simulation frame.

Expand All @@ -225,7 +227,7 @@ def orbital_angular_velocity_mag(x: Vec3, v: Vec3, /) -> FloatScalar:
return xp.linalg.vector_norm(orbital_angular_velocity(x, v))


@partial_jit()
@partial(jax.jit)
def tidal_radius(
potential: AbstractPotentialBase,
x: Vec3,
Expand Down Expand Up @@ -271,7 +273,7 @@ def tidal_radius(
) ** (1.0 / 3.0)


@partial_jit()
@partial(jax.jit)
def lagrange_points(
potential: AbstractPotentialBase,
x: Vec3,
Expand Down
8 changes: 4 additions & 4 deletions src/galax/dynamics/mockstream/_mockstream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
__all__ = ["MockStreamGenerator"]

from dataclasses import KW_ONLY
from functools import partial
from typing import Any, TypeAlias

import equinox as eqx
Expand All @@ -23,7 +24,6 @@
Vec6,
VecN,
)
from galax.utils import partial_jit
from galax.utils._collections import ImmutableDict

from ._core import MockStream
Expand Down Expand Up @@ -59,7 +59,7 @@ class MockStreamGenerator(eqx.Module): # type: ignore[misc]

# ==========================================================================

@partial_jit()
@partial(jax.jit)
def _run_scan( # TODO: output shape depends on the input shape
self, ts: TimeVector, mock0_lead: MockStream, mock0_trail: MockStream
) -> tuple[BatchVec6, BatchVec6]:
Expand Down Expand Up @@ -92,7 +92,7 @@ def integ_ics(ics: Vec6) -> VecN:

return lead_arm_qp, trail_arm_qp

@partial_jit()
@partial(jax.jit)
def _run_vmap( # TODO: output shape depends on the input shape
self, ts: TimeVector, mock0_lead: MockStream, mock0_trail: MockStream
) -> tuple[BatchVec6, BatchVec6]:
Expand Down Expand Up @@ -123,7 +123,7 @@ def single_particle_integrate(
lead_arm_qp, trail_arm_qp = integrator(particle_ids, qp0_lead, mock0_trail.qp)
return lead_arm_qp, trail_arm_qp

@partial_jit(static_argnames=("seed_num", "vmapped"))
@partial(jax.jit, static_argnames=("seed_num", "vmapped"))
def run(
self,
ts: TimeVector,
Expand Down
20 changes: 11 additions & 9 deletions src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import abc
from dataclasses import KW_ONLY, fields, replace
from functools import partial
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, ClassVar

import equinox as eqx
import jax
import jax.experimental.array_api as xp
import jax.numpy as jnp
from astropy.constants import G as _G # pylint: disable=no-name-in-module
Expand All @@ -31,7 +33,7 @@
Vec6,
)
from galax.units import UnitSystem, dimensionless
from galax.utils import partial_jit, vectorize_method
from galax.utils._jax import vectorize_method
from galax.utils._shape import batched_shape, expand_arr_dims, expand_batch_dims
from galax.utils.dataclasses import ModuleMeta

Expand Down Expand Up @@ -64,7 +66,7 @@ def __init_subclass__(cls) -> None:
###########################################################################
# Abstract methods that must be implemented by subclasses

# @partial_jit()
# @partial(jax.jit)
# @vectorize_method(signature="(3),()->()")
@abc.abstractmethod
def _potential_energy(self, q: Vec3, /, t: FloatOrIntScalar) -> FloatScalar:
Expand Down Expand Up @@ -133,7 +135,7 @@ def potential_energy(
q, t = convert_inputs_to_arrays(q, t, units=self.units, no_differentials=True)
return self._potential_energy(q, t)

@partial_jit()
@partial(jax.jit)
def __call__(
self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike
) -> BatchFloatScalar:
Expand All @@ -160,7 +162,7 @@ def __call__(
# ---------------------------------------
# Gradient

@partial_jit()
@partial(jax.jit)
@vectorize_method(signature="(3),()->(3)")
def _gradient(self, q: Vec3, /, t: FloatOrIntScalar) -> Vec3:
"""See ``gradient``."""
Expand Down Expand Up @@ -194,7 +196,7 @@ def gradient(
# ---------------------------------------
# Density

@partial_jit()
@partial(jax.jit)
@vectorize_method(signature="(3),()->()")
def _density(self, q: Vec3, /, t: FloatOrIntScalar) -> FloatScalar:
"""See ``density``."""
Expand Down Expand Up @@ -227,7 +229,7 @@ def density(
# ---------------------------------------
# Hessian

@partial_jit()
@partial(jax.jit)
@vectorize_method(signature="(3),()->(3,3)")
def _hessian(self, q: Vec3, /, t: FloatOrIntScalar) -> Matrix33:
"""See ``hessian``."""
Expand Down Expand Up @@ -285,7 +287,7 @@ def acceleration(
q, t = convert_inputs_to_arrays(q, t, units=self.units, no_differentials=True)
return -self._gradient(q, t)

@partial_jit()
@partial(jax.jit)
def tidal_tensor(
self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike
) -> BatchMatrix33:
Expand Down Expand Up @@ -322,7 +324,7 @@ def tidal_tensor(
# =========================================================================
# Integrating orbits

@partial_jit()
@partial(jax.jit)
def _integrator_F(
self,
t: FloatScalar,
Expand All @@ -332,7 +334,7 @@ def _integrator_F(
"""Return the derivative of the phase-space position."""
return jnp.hstack([qp[3:6], self.acceleration(qp[0:3], t)]) # v, a

@partial_jit(static_argnames=("integrator",))
@partial(jax.jit, static_argnames=("integrator",))
def integrate_orbit(
self,
qp0: BatchVec6,
Expand Down
Loading
Loading