From 96df8b3c308dfe52c9ffe4fcc9149ff78464605f Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 28 Jan 2024 16:46:40 -0500 Subject: [PATCH 1/4] refactor(jax-utils): remove partial_jit Signed-off-by: nstarman --- src/galax/dynamics/_core.py | 17 ++++----- src/galax/dynamics/_orbit.py | 6 ++-- src/galax/dynamics/mockstream/_core.py | 6 ++-- src/galax/dynamics/mockstream/_df/base.py | 4 +-- src/galax/dynamics/mockstream/_df/fardal.py | 18 +++++----- .../mockstream/_mockstream_generator.py | 8 ++--- src/galax/potential/_potential/base.py | 20 ++++++----- src/galax/potential/_potential/builtin.py | 18 +++++----- src/galax/potential/_potential/composite.py | 6 ++-- src/galax/potential/_potential/param/core.py | 10 +++--- src/galax/utils/_jax.py | 35 +------------------ src/galax/utils/_shape.py | 6 ++-- tests/unit/potential/test_base.py | 6 ++-- tests/unit/potential/test_core.py | 6 ++-- tests/unit/utils/test_jax.py | 22 +----------- 15 files changed, 77 insertions(+), 111 deletions(-) diff --git a/src/galax/dynamics/_core.py b/src/galax/dynamics/_core.py index fa519452..940f6841 100644 --- a/src/galax/dynamics/_core.py +++ b/src/galax/dynamics/_core.py @@ -3,15 +3,16 @@ __all__ = ["AbstractPhaseSpacePosition", "PhaseSpacePosition"] from abc import abstractmethod +from functools import partial from typing import TYPE_CHECKING, final import equinox as eqx +import jax import jax.experimental.array_api as xp import jax.numpy as jnp from jaxtyping import Array, Float from galax.typing import BatchFloatScalar, BatchVec3, BatchVec6, BatchVec7 -from galax.utils import partial_jit from galax.utils._shape import atleast_batched, batched_shape from galax.utils.dataclasses import converter_float_array @@ -49,7 +50,7 @@ def shape(self) -> tuple[int, ...]: # Convenience properties @property - @partial_jit() + @partial(jax.jit) def qp(self) -> BatchVec6: """Return as a single Array[float, (*batch, Q + P),].""" batch_shape, component_shapes = self._shape_tuple @@ -66,7 +67,7 @@ def __len__(self) -> int: # ========================================================================== # Dynamical quantities - @partial_jit() + @partial(jax.jit) def kinetic_energy(self) -> BatchFloatScalar: r"""Return the specific kinetic energy. @@ -100,7 +101,7 @@ def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]: # Convenience properties @property - @partial_jit() + @partial(jax.jit) def w(self) -> BatchVec7: """Return as a single Array[float, (*batch, Q + P + T)].""" batch_shape, component_shapes = self._shape_tuple @@ -112,7 +113,7 @@ def w(self) -> BatchVec7: return xp.concat((q, p, t), axis=-1) @property - @partial_jit() + @partial(jax.jit) def angular_momentum(self) -> BatchVec3: r"""Compute the angular momentum. @@ -144,7 +145,7 @@ def angular_momentum(self) -> BatchVec3: # ========================================================================== # Dynamical quantities - @partial_jit() + @partial(jax.jit) def potential_energy( self, potential: "AbstractPotentialBase", / ) -> BatchFloatScalar: @@ -166,7 +167,7 @@ def potential_energy( """ return potential.potential_energy(self.q, t=self.t) - @partial_jit() + @partial(jax.jit) def energy(self, potential: "AbstractPotentialBase", /) -> BatchFloatScalar: r"""Return the specific total energy. @@ -204,7 +205,7 @@ def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]: return batch_shape, array_shape @property - @partial_jit() + @partial(jax.jit) def w(self) -> BatchVec7: """Return as a single Array[float, (*batch, Q + P + T)].""" batch_shape, component_shapes = self._shape_tuple diff --git a/src/galax/dynamics/_orbit.py b/src/galax/dynamics/_orbit.py index b8bac673..4c975e0c 100644 --- a/src/galax/dynamics/_orbit.py +++ b/src/galax/dynamics/_orbit.py @@ -2,13 +2,15 @@ __all__ = ["Orbit"] +from functools import partial + import equinox as eqx +import jax from jaxtyping import Array, Float from typing_extensions import override from galax.potential._potential.base import AbstractPotentialBase from galax.typing import BatchFloatScalar, TimeVector -from galax.utils._jax import partial_jit from galax.utils.dataclasses import converter_float_array from ._core import AbstractPhaseSpacePosition @@ -38,7 +40,7 @@ class Orbit(AbstractPhaseSpacePosition): # Dynamical quantities @override - @partial_jit() + @partial(jax.jit) def potential_energy( self, potential: AbstractPotentialBase | None = None, / ) -> BatchFloatScalar: diff --git a/src/galax/dynamics/mockstream/_core.py b/src/galax/dynamics/mockstream/_core.py index eecffd96..21e4aca4 100644 --- a/src/galax/dynamics/mockstream/_core.py +++ b/src/galax/dynamics/mockstream/_core.py @@ -2,14 +2,16 @@ __all__ = ["MockStream"] +from functools import partial + import equinox as eqx +import jax import jax.experimental.array_api as xp import jax.numpy as jnp from jaxtyping import Array, Float from galax.dynamics._core import AbstractPhaseSpacePositionBase from galax.typing import BatchVec7, TimeVector -from galax.utils import partial_jit from galax.utils._shape import atleast_batched, batched_shape from galax.utils.dataclasses import converter_float_array @@ -44,7 +46,7 @@ def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]: return batch_shape, qshape + pshape + (1,) @property - @partial_jit() + @partial(jax.jit) def w(self) -> BatchVec7: """Return as a single Array[float, (*batch, Q + P + T)].""" batch_shape, component_shapes = self._shape_tuple diff --git a/src/galax/dynamics/mockstream/_df/base.py b/src/galax/dynamics/mockstream/_df/base.py index 949b2255..611d0f7e 100644 --- a/src/galax/dynamics/mockstream/_df/base.py +++ b/src/galax/dynamics/mockstream/_df/base.py @@ -4,6 +4,7 @@ __all__ = ["AbstractStreamDF"] import abc +from functools import partial from typing import TypeAlias import equinox as eqx @@ -14,7 +15,6 @@ from galax.dynamics.mockstream._core import MockStream from galax.potential._potential.base import AbstractPotentialBase from galax.typing import BatchVec3, FloatScalar, IntLike, Vec3, Vec6 -from galax.utils import partial_jit Wif: TypeAlias = tuple[Vec3, Vec3, Vec3, Vec3] Carry: TypeAlias = tuple[IntLike, Vec3, Vec3, Vec3, Vec3] @@ -31,7 +31,7 @@ def __post_init__(self) -> None: msg = "You must generate either leading or trailing tails (or both!)" raise ValueError(msg) - @partial_jit(static_argnames=("seed_num",)) + @partial(jax.jit, static_argnames=("seed_num",)) def sample( self, # <\ parts of gala's ``prog_orbit`` diff --git a/src/galax/dynamics/mockstream/_df/fardal.py b/src/galax/dynamics/mockstream/_df/fardal.py index 4832df2e..728daa49 100644 --- a/src/galax/dynamics/mockstream/_df/fardal.py +++ b/src/galax/dynamics/mockstream/_df/fardal.py @@ -4,6 +4,9 @@ __all__ = ["FardalStreamDF"] +from functools import partial + +import jax import jax.experimental.array_api as xp import jax.numpy as jnp from jax import grad, random @@ -16,7 +19,6 @@ Vec3, Vec6, ) -from galax.utils import partial_jit from .base import AbstractStreamDF @@ -29,7 +31,7 @@ class FardalStreamDF(AbstractStreamDF): https://ui.adsabs.harvard.edu/abs/2015MNRAS.452..301F/abstract """ - @partial_jit(static_argnums=(0,), static_argnames=("seed_num",)) + @partial(jax.jit, static_argnums=(0,), static_argnames=("seed_num",)) def _sample( self, potential: AbstractPotentialBase, @@ -115,7 +117,7 @@ def _sample( # TODO: move this to a more general location. -@partial_jit() +@partial(jax.jit) def dphidr(potential: AbstractPotentialBase, x: Vec3, t: FloatScalar) -> Vec3: """Compute the derivative of the potential at a position x. @@ -137,7 +139,7 @@ def dphidr(potential: AbstractPotentialBase, x: Vec3, t: FloatScalar) -> Vec3: return xp.sum(potential.gradient(x, t) * r_hat) -@partial_jit() +@partial(jax.jit) def d2phidr2( potential: AbstractPotentialBase, x: Vec3, /, t: FloatOrIntScalarLike ) -> FloatScalar: @@ -172,7 +174,7 @@ def d2phidr2( return xp.sum(grad(dphi_dr_func)(x) * r_hat) -@partial_jit() +@partial(jax.jit) def orbital_angular_velocity(x: Vec3, v: Vec3, /) -> Vec3: """Compute the orbital angular velocity about the origin. @@ -199,7 +201,7 @@ def orbital_angular_velocity(x: Vec3, v: Vec3, /) -> Vec3: return jnp.cross(x, v) / r**2 -@partial_jit() +@partial(jax.jit) def orbital_angular_velocity_mag(x: Vec3, v: Vec3, /) -> FloatScalar: """Compute the magnitude of the angular momentum in the simulation frame. @@ -225,7 +227,7 @@ def orbital_angular_velocity_mag(x: Vec3, v: Vec3, /) -> FloatScalar: return xp.linalg.vector_norm(orbital_angular_velocity(x, v)) -@partial_jit() +@partial(jax.jit) def tidal_radius( potential: AbstractPotentialBase, x: Vec3, @@ -271,7 +273,7 @@ def tidal_radius( ) ** (1.0 / 3.0) -@partial_jit() +@partial(jax.jit) def lagrange_points( potential: AbstractPotentialBase, x: Vec3, diff --git a/src/galax/dynamics/mockstream/_mockstream_generator.py b/src/galax/dynamics/mockstream/_mockstream_generator.py index 0e1c11e8..393de9a8 100644 --- a/src/galax/dynamics/mockstream/_mockstream_generator.py +++ b/src/galax/dynamics/mockstream/_mockstream_generator.py @@ -3,6 +3,7 @@ __all__ = ["MockStreamGenerator"] from dataclasses import KW_ONLY +from functools import partial from typing import Any, TypeAlias import equinox as eqx @@ -23,7 +24,6 @@ Vec6, VecN, ) -from galax.utils import partial_jit from galax.utils._collections import ImmutableDict from ._core import MockStream @@ -59,7 +59,7 @@ class MockStreamGenerator(eqx.Module): # type: ignore[misc] # ========================================================================== - @partial_jit() + @partial(jax.jit) def _run_scan( # TODO: output shape depends on the input shape self, ts: TimeVector, mock0_lead: MockStream, mock0_trail: MockStream ) -> tuple[BatchVec6, BatchVec6]: @@ -92,7 +92,7 @@ def integ_ics(ics: Vec6) -> VecN: return lead_arm_qp, trail_arm_qp - @partial_jit() + @partial(jax.jit) def _run_vmap( # TODO: output shape depends on the input shape self, ts: TimeVector, mock0_lead: MockStream, mock0_trail: MockStream ) -> tuple[BatchVec6, BatchVec6]: @@ -123,7 +123,7 @@ def single_particle_integrate( lead_arm_qp, trail_arm_qp = integrator(particle_ids, qp0_lead, mock0_trail.qp) return lead_arm_qp, trail_arm_qp - @partial_jit(static_argnames=("seed_num", "vmapped")) + @partial(jax.jit, static_argnames=("seed_num", "vmapped")) def run( self, ts: TimeVector, diff --git a/src/galax/potential/_potential/base.py b/src/galax/potential/_potential/base.py index b14316f3..d1fa6b53 100644 --- a/src/galax/potential/_potential/base.py +++ b/src/galax/potential/_potential/base.py @@ -2,10 +2,12 @@ import abc from dataclasses import KW_ONLY, fields, replace +from functools import partial from types import MappingProxyType from typing import TYPE_CHECKING, Any, ClassVar import equinox as eqx +import jax import jax.experimental.array_api as xp import jax.numpy as jnp from astropy.constants import G as _G # pylint: disable=no-name-in-module @@ -31,7 +33,7 @@ Vec6, ) from galax.units import UnitSystem, dimensionless -from galax.utils import partial_jit, vectorize_method +from galax.utils import vectorize_method from galax.utils._shape import batched_shape, expand_arr_dims, expand_batch_dims from galax.utils.dataclasses import ModuleMeta @@ -64,7 +66,7 @@ def __init_subclass__(cls) -> None: ########################################################################### # Abstract methods that must be implemented by subclasses - # @partial_jit() + # @partial(jax.jit) # @vectorize_method(signature="(3),()->()") @abc.abstractmethod def _potential_energy(self, q: Vec3, /, t: FloatOrIntScalar) -> FloatScalar: @@ -133,7 +135,7 @@ def potential_energy( q, t = convert_inputs_to_arrays(q, t, units=self.units, no_differentials=True) return self._potential_energy(q, t) - @partial_jit() + @partial(jax.jit) def __call__( self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike ) -> BatchFloatScalar: @@ -160,7 +162,7 @@ def __call__( # --------------------------------------- # Gradient - @partial_jit() + @partial(jax.jit) @vectorize_method(signature="(3),()->(3)") def _gradient(self, q: Vec3, /, t: FloatOrIntScalar) -> Vec3: """See ``gradient``.""" @@ -194,7 +196,7 @@ def gradient( # --------------------------------------- # Density - @partial_jit() + @partial(jax.jit) @vectorize_method(signature="(3),()->()") def _density(self, q: Vec3, /, t: FloatOrIntScalar) -> FloatScalar: """See ``density``.""" @@ -227,7 +229,7 @@ def density( # --------------------------------------- # Hessian - @partial_jit() + @partial(jax.jit) @vectorize_method(signature="(3),()->(3,3)") def _hessian(self, q: Vec3, /, t: FloatOrIntScalar) -> Matrix33: """See ``hessian``.""" @@ -285,7 +287,7 @@ def acceleration( q, t = convert_inputs_to_arrays(q, t, units=self.units, no_differentials=True) return -self._gradient(q, t) - @partial_jit() + @partial(jax.jit) def tidal_tensor( self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike ) -> BatchMatrix33: @@ -322,7 +324,7 @@ def tidal_tensor( # ========================================================================= # Integrating orbits - @partial_jit() + @partial(jax.jit) def _integrator_F( self, t: FloatScalar, @@ -332,7 +334,7 @@ def _integrator_F( """Return the derivative of the phase-space position.""" return jnp.hstack([qp[3:6], self.acceleration(qp[0:3], t)]) # v, a - @partial_jit(static_argnames=("integrator",)) + @partial(jax.jit, static_argnames=("integrator",)) def integrate_orbit( self, qp0: BatchVec6, diff --git a/src/galax/potential/_potential/builtin.py b/src/galax/potential/_potential/builtin.py index 67c3b784..2dd46d39 100644 --- a/src/galax/potential/_potential/builtin.py +++ b/src/galax/potential/_potential/builtin.py @@ -11,9 +11,11 @@ ] from dataclasses import KW_ONLY +from functools import partial from typing import final import astropy.units as u +import jax import jax.experimental.array_api as xp from galax.potential._potential.core import AbstractPotential @@ -27,7 +29,7 @@ FloatScalar, Vec3, ) -from galax.utils import partial_jit, vectorize_method +from galax.utils import vectorize_method from galax.utils.dataclasses import field mass = u.get_physical_type("mass") @@ -51,7 +53,7 @@ class BarPotential(AbstractPotential): c: AbstractParameter = ParameterField(dimensions=length) # type: ignore[assignment] Omega: AbstractParameter = ParameterField(dimensions=frequency) # type: ignore[assignment] - @partial_jit() + @partial(jax.jit) @vectorize_method(signature="(3),()->()") def _potential_energy(self, q: Vec3, /, t: FloatOrIntScalarLike) -> FloatScalar: ## First take the simulation frame coordinates and rotate them by Omega*t @@ -95,7 +97,7 @@ class HernquistPotential(AbstractPotential): m: AbstractParameter = ParameterField(dimensions=mass) # type: ignore[assignment] c: AbstractParameter = ParameterField(dimensions=length) # type: ignore[assignment] - @partial_jit() + @partial(jax.jit) def _potential_energy( self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike ) -> BatchFloatScalar: @@ -113,7 +115,7 @@ class IsochronePotential(AbstractPotential): m: AbstractParameter = ParameterField(dimensions=mass) # type: ignore[assignment] b: AbstractParameter = ParameterField(dimensions=length) # type: ignore[assignment] - @partial_jit() + @partial(jax.jit) def _potential_energy( self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike ) -> BatchFloatScalar: @@ -135,7 +137,7 @@ class KeplerPotential(AbstractPotential): m: AbstractParameter = ParameterField(dimensions=mass) # type: ignore[assignment] - @partial_jit() + @partial(jax.jit) def _potential_energy( self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike ) -> BatchFloatScalar: @@ -154,7 +156,7 @@ class MiyamotoNagaiPotential(AbstractPotential): a: AbstractParameter = ParameterField(dimensions=length) # type: ignore[assignment] b: AbstractParameter = ParameterField(dimensions=length) # type: ignore[assignment] - @partial_jit() + @partial(jax.jit) def _potential_energy( self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike ) -> BatchFloatScalar: @@ -179,7 +181,7 @@ class NFWPotential(AbstractPotential): _: KW_ONLY softening_length: FloatLike = field(default=0.001, static=True, dimensions=length) - @partial_jit() + @partial(jax.jit) def _potential_energy( self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike ) -> BatchFloatScalar: @@ -196,7 +198,7 @@ def _potential_energy( class NullPotential(AbstractPotential): """Null potential, i.e. no potential.""" - @partial_jit() + @partial(jax.jit) def _potential_energy( self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike ) -> BatchFloatScalar: diff --git a/src/galax/potential/_potential/composite.py b/src/galax/potential/_potential/composite.py index 7638b7eb..a37753d3 100644 --- a/src/galax/potential/_potential/composite.py +++ b/src/galax/potential/_potential/composite.py @@ -3,9 +3,11 @@ import uuid from dataclasses import KW_ONLY +from functools import partial from typing import Any, TypeVar, final import equinox as eqx +import jax import jax.experimental.array_api as xp from galax.typing import ( @@ -14,7 +16,7 @@ BatchVec3, ) from galax.units import UnitSystem -from galax.utils import ImmutableDict, partial_jit +from galax.utils import ImmutableDict from galax.utils._misc import first from .base import AbstractPotentialBase @@ -30,7 +32,7 @@ class AbstractCompositePotential( ): # === Potential === - @partial_jit() + @partial(jax.jit) def _potential_energy( self, q: BatchVec3, /, t: BatchableFloatOrIntScalarLike ) -> BatchFloatScalar: diff --git a/src/galax/potential/_potential/param/core.py b/src/galax/potential/_potential/param/core.py index a4681ca1..1f29a47e 100644 --- a/src/galax/potential/_potential/param/core.py +++ b/src/galax/potential/_potential/param/core.py @@ -4,10 +4,12 @@ import abc from dataclasses import KW_ONLY +from functools import partial from typing import Any, Protocol, runtime_checkable import astropy.units as u import equinox as eqx +import jax from galax.typing import ( BatchableFloatOrIntScalarLike, @@ -17,7 +19,7 @@ FloatScalar, Unit, ) -from galax.utils import partial_jit, vectorize_method +from galax.utils import vectorize_method from galax.utils.dataclasses import converter_float_array @@ -66,12 +68,12 @@ class ConstantParameter(AbstractParameter): unit: Unit = eqx.field(static=True, converter=u.Unit) # This is a workaround since vectorized methods don't support kwargs. - @partial_jit() + @partial(jax.jit) @vectorize_method(signature="()->()") def _call_helper(self, _: FloatOrIntScalar) -> FloatArrayAnyShape: return self.value - @partial_jit() + @partial(jax.jit) def __call__( self, t: BatchableFloatOrIntScalarLike = 0, **kwargs: Any ) -> FloatArrayAnyShape: @@ -137,6 +139,6 @@ class UserParameter(AbstractParameter): _: KW_ONLY unit: Unit = eqx.field(static=True, converter=u.Unit) - @partial_jit() + @partial(jax.jit) def __call__(self, t: FloatOrIntScalar, **kwargs: Any) -> FloatArrayAnyShape: return self.func(t, **kwargs) diff --git a/src/galax/utils/_jax.py b/src/galax/utils/_jax.py index feb35e27..f13d8234 100644 --- a/src/galax/utils/_jax.py +++ b/src/galax/utils/_jax.py @@ -2,13 +2,12 @@ __all__ = [ - "partial_jit", "partial_vmap", "partial_vectorize", "vectorize_method", ] -from collections.abc import Callable, Hashable, Iterable, Sequence +from collections.abc import Callable, Hashable, Sequence from functools import partial from typing import Any, NotRequired, TypedDict, TypeVar @@ -19,38 +18,6 @@ R = TypeVar("R") -class JITKwargs(TypedDict): - """Keyword arguments for :func:`jax.jit`.""" - - in_shardings: NotRequired[Any] - out_shardings: NotRequired[Any] - static_argnums: NotRequired[int | Sequence[int] | None] - static_argnames: NotRequired[str | Iterable[str] | None] - donate_argnums: NotRequired[int | Sequence[int] | None] - donate_argnames: NotRequired[str | Iterable[str] | None] - keep_unused: NotRequired[bool] - device: NotRequired[jax.Device | None] - backend: NotRequired[str | None] - inline: NotRequired[bool] - - -def partial_jit(**kw: Unpack[JITKwargs]) -> Callable[[Callable[P, R]], Callable[P, R]]: - """Decorate a function with :func:`jax.jit`. - - Parameters - ---------- - **kw : Unpack[JITKwargs] - Keyword arguments for :func:`jax.jit`. - See :func:`jax.jit` for more information. - - Returns - ------- - :class:`~functools.partial` - A partial function to :func:`jax.jit` a function. - """ - return partial(jax.jit, **kw) - - # TODO: nest the definitions properly class VMapKwargs(TypedDict): """Keyword arguments for :func:`jax.vmap`.""" diff --git a/src/galax/utils/_shape.py b/src/galax/utils/_shape.py index 9b9b8f79..9fa3b478 100644 --- a/src/galax/utils/_shape.py +++ b/src/galax/utils/_shape.py @@ -2,16 +2,16 @@ __all__: list[str] = [] +from functools import partial from typing import Any, Literal, NoReturn, overload +import jax import jax.experimental.array_api as xp import jax.numpy as jnp from jaxtyping import Array, ArrayLike from galax.typing import AnyScalar, ArrayAnyShape -from ._jax import partial_jit - @overload def atleast_batched() -> NoReturn: @@ -30,7 +30,7 @@ def atleast_batched( ... -@partial_jit() +@partial(jax.jit) def atleast_batched(*arys: Any) -> Array | tuple[Array, ...]: """Convert inputs to arrays with at least two dimensions. diff --git a/tests/unit/potential/test_base.py b/tests/unit/potential/test_base.py index fa11d885..417187ac 100644 --- a/tests/unit/potential/test_base.py +++ b/tests/unit/potential/test_base.py @@ -1,7 +1,9 @@ import copy +from functools import partial from typing import Any import equinox as eqx +import jax import jax.experimental.array_api as xp import pytest from jax.numpy import array_equal @@ -16,7 +18,7 @@ Vec3, ) from galax.units import UnitSystem, dimensionless -from galax.utils import partial_jit, vectorize_method +from galax.utils import vectorize_method from .io.test_gala import GalaIOMixin @@ -33,7 +35,7 @@ class TestPotential(gp.AbstractPotentialBase): def __post_init__(self): object.__setattr__(self, "_G", 1.0) - @partial_jit() + @partial(jax.jit) @vectorize_method(signature="(3),()->()") def _potential_energy( self, q: BatchVec3, t: BatchableFloatOrIntScalarLike diff --git a/tests/unit/potential/test_core.py b/tests/unit/potential/test_core.py index fb6fa9e7..ff70dd80 100644 --- a/tests/unit/potential/test_core.py +++ b/tests/unit/potential/test_core.py @@ -1,7 +1,9 @@ from dataclasses import field +from functools import partial from typing import Any import equinox as eqx +import jax import jax.experimental.array_api as xp import pytest @@ -9,7 +11,7 @@ from galax.potential._potential.utils import converter_to_usys from galax.typing import BatchableFloatOrIntScalarLike, BatchFloatScalar, BatchVec3 from galax.units import UnitSystem, dimensionless, galactic -from galax.utils import partial_jit, vectorize_method +from galax.utils import vectorize_method from .test_base import TestAbstractPotentialBase as AbstractPotentialBase_Test from .test_utils import FieldUnitSystemMixin @@ -29,7 +31,7 @@ class TestPotential(gp.AbstractPotentialBase): def __post_init__(self): object.__setattr__(self, "_G", 1.0) - @partial_jit() + @partial(jax.jit) @vectorize_method(signature="(3),()->()") def _potential_energy( self, q: BatchVec3, t: BatchableFloatOrIntScalarLike diff --git a/tests/unit/utils/test_jax.py b/tests/unit/utils/test_jax.py index 94ea7a98..acf46f58 100644 --- a/tests/unit/utils/test_jax.py +++ b/tests/unit/utils/test_jax.py @@ -2,27 +2,7 @@ import jax.numpy as jnp from jaxtyping import Array, Float -from galax.utils import ( - partial_jit, - partial_vectorize, - partial_vmap, - vectorize_method, -) - - -def test_partial_jit(): - """Test the partial_jit function.""" - - def func(x, y): - return x + y - - jit_func = partial_jit()(func) - assert jit_func(1, 2) == 3 - - # The real test is comparing this to the output of `jax.jit`. - assert jit_func(1, 2) == jax.jit(func)(1, 2) - - # TODO: test all the kwarg options. +from galax.utils import partial_vectorize, partial_vmap, vectorize_method def test_partial_vmap(): From 795bc3a60d2b0cb076521215c571cc5e608fdadb Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 28 Jan 2024 16:51:07 -0500 Subject: [PATCH 2/4] refactor(jax-utils): remove partial_vmap --- src/galax/utils/_jax.py | 40 +++--------------------------------- tests/unit/utils/test_jax.py | 17 +-------------- 2 files changed, 4 insertions(+), 53 deletions(-) diff --git a/src/galax/utils/_jax.py b/src/galax/utils/_jax.py index f13d8234..ee8fb1c7 100644 --- a/src/galax/utils/_jax.py +++ b/src/galax/utils/_jax.py @@ -1,15 +1,11 @@ """galax: Galactic Dynamix in Jax.""" -__all__ = [ - "partial_vmap", - "partial_vectorize", - "vectorize_method", -] +__all__ = ["partial_vectorize", "vectorize_method"] -from collections.abc import Callable, Hashable, Sequence +from collections.abc import Callable, Sequence from functools import partial -from typing import Any, NotRequired, TypedDict, TypeVar +from typing import NotRequired, TypedDict, TypeVar import jax from typing_extensions import ParamSpec, Unpack @@ -18,36 +14,6 @@ R = TypeVar("R") -# TODO: nest the definitions properly -class VMapKwargs(TypedDict): - """Keyword arguments for :func:`jax.vmap`.""" - - in_axes: NotRequired[int | Sequence[Any] | dict[str, Any] | None] - out_axes: NotRequired[Any] - axis_name: NotRequired[Hashable | None] - axis_size: NotRequired[int | None] - spmd_axis_name: NotRequired[Hashable | tuple[Hashable, ...] | None] - - -def partial_vmap( - **kw: Unpack[VMapKwargs], -) -> Callable[[Callable[P, R]], Callable[P, R]]: - """Decorate a function with :func:`jax.vmap`. - - Parameters - ---------- - **kw : Unpack[VMapKwargs] - Keyword arguments for :func:`jax.vmap`. - See :func:`jax.vmap` for more information. - - Returns - ------- - :class:`~functools.partial` - A partial function to :func:`jax.vmap` a function. - """ - return partial(jax.vmap, **kw) - - class VectorizeKwargs(TypedDict): """Keyword arguments for :func:`jax.numpy.vectorize`.""" diff --git a/tests/unit/utils/test_jax.py b/tests/unit/utils/test_jax.py index acf46f58..5d43fc8f 100644 --- a/tests/unit/utils/test_jax.py +++ b/tests/unit/utils/test_jax.py @@ -1,22 +1,7 @@ -import jax import jax.numpy as jnp from jaxtyping import Array, Float -from galax.utils import partial_vectorize, partial_vmap, vectorize_method - - -def test_partial_vmap(): - """Test the partial_vmap function.""" - - def func(x: Float[Array, "batch N"]) -> Float[Array, "batch"]: - return jnp.sum(x) - - vmap_func = partial_vmap(in_axes=0)(func) - x = jnp.array([[1, 2, 3]]) - assert vmap_func(x) == 6 - - # The real test is comparing this to the output of `jax.vmap`. - assert vmap_func(x) == jax.vmap(func, in_axes=0)(x) +from galax.utils import partial_vectorize, vectorize_method def test_partial_vectorize(): From 6fcc6a8a44fe929c3d7d7bcf8b9c1bdbabf01ab3 Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 28 Jan 2024 16:54:20 -0500 Subject: [PATCH 3/4] refactor(jax-utils): remove partial_vectorize --- src/galax/utils/_jax.py | 35 ++++++++--------------------------- tests/unit/utils/test_jax.py | 16 +--------------- 2 files changed, 9 insertions(+), 42 deletions(-) diff --git a/src/galax/utils/_jax.py b/src/galax/utils/_jax.py index ee8fb1c7..aeb25edb 100644 --- a/src/galax/utils/_jax.py +++ b/src/galax/utils/_jax.py @@ -1,7 +1,7 @@ """galax: Galactic Dynamix in Jax.""" -__all__ = ["partial_vectorize", "vectorize_method"] +__all__ = ["vectorize_method"] from collections.abc import Callable, Sequence from functools import partial @@ -21,45 +21,26 @@ class VectorizeKwargs(TypedDict): signature: NotRequired[str | None] -def partial_vectorize( - **kw: Unpack[VectorizeKwargs], -) -> Callable[[Callable[P, R]], Callable[P, R]]: - """Decorate a function with :func:`jax.numpy.vectorize`. - - Parameters - ---------- - **kw : Unpack[VMapKwargs] - Keyword arguments for :func:`jax.numpy.vectorize`. - See :func:`jax.numpy.vectorize` for more information. - - Returns - ------- - :class:`~functools.partial` - A partial function to :func:`jax.numpy.vectorize` a function. - """ - return partial(jax.numpy.vectorize, **kw) - - def vectorize_method( **kw: Unpack[VectorizeKwargs], ) -> Callable[[Callable[P, R]], Callable[P, R]]: """:func:`jax.numpy.vectorize` a class' method. This is a wrapper around :func:`jax.numpy.vectorize` that vectorizes a - class' method by returning a :class:`functools.partial`. It is equivalent to - :func:`partial_vectorize`, except that ``excluded`` is set to exclude the - 0th argument (``self``). As a result, the ``excluded`` tuple should start - at 0 to exclude the first 'real' argument (proceeding ``self``). + class' method by returning a :class:`functools.partial`. ``excluded`` is + set to exclude the 0th argument (``self``). As a result, the ``excluded`` + tuple should start at 0 to exclude the first 'real' argument (proceeding + ``self``). Parameters ---------- **kw : Unpack[VMapKwargs] - Keyword arguments for :func:`jax.numpy.vectorize`. - See :func:`jax.numpy.vectorize` for more information. + Keyword arguments for :func:`jax.numpy.vectorize`. See + :func:`jax.numpy.vectorize` for more information. """ # Prepend 0 to excluded to exclude the first argument (self) excluded = tuple(kw.get("excluded") or (-1,)) # (None -> (0,)) excluded = excluded if excluded[0] == -1 else (-1, *excluded) kw["excluded"] = tuple(i + 1 for i in excluded) - return partial_vectorize(**kw) + return partial(jax.numpy.vectorize, **kw) diff --git a/tests/unit/utils/test_jax.py b/tests/unit/utils/test_jax.py index 5d43fc8f..60a9c43b 100644 --- a/tests/unit/utils/test_jax.py +++ b/tests/unit/utils/test_jax.py @@ -1,21 +1,7 @@ import jax.numpy as jnp from jaxtyping import Array, Float -from galax.utils import partial_vectorize, vectorize_method - - -def test_partial_vectorize(): - """Test the partial_vectorize function.""" - - def func(x: Float[Array, "batch N"]) -> Float[Array, "batch"]: - return jnp.sum(x) - - vectorize_func = partial_vectorize(signature="(3)->()")(func) - assert vectorize_func(jnp.array([1, 2, 3])) == 6 - - # The real test is comparing this to the output of `jax.vectorize`. - x = jnp.array([1, 2, 3]) - assert vectorize_func(x) == jnp.vectorize(func, signature="(3)->()")(x) +from galax.utils import vectorize_method def test_vectorize_method(): From d53cb007dc65c8da93d6f166bf1b2f8766489f4e Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 28 Jan 2024 16:59:11 -0500 Subject: [PATCH 4/4] refactor(jax-utils): ensure vectorize_method is private Signed-off-by: nstarman --- src/galax/potential/_potential/base.py | 2 +- src/galax/potential/_potential/builtin.py | 2 +- src/galax/potential/_potential/param/core.py | 2 +- src/galax/utils/_jax.py | 2 +- tests/unit/potential/test_base.py | 2 +- tests/unit/potential/test_core.py | 2 +- tests/unit/utils/test_jax.py | 4 ++-- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/galax/potential/_potential/base.py b/src/galax/potential/_potential/base.py index d1fa6b53..2dd9e010 100644 --- a/src/galax/potential/_potential/base.py +++ b/src/galax/potential/_potential/base.py @@ -33,7 +33,7 @@ Vec6, ) from galax.units import UnitSystem, dimensionless -from galax.utils import vectorize_method +from galax.utils._jax import vectorize_method from galax.utils._shape import batched_shape, expand_arr_dims, expand_batch_dims from galax.utils.dataclasses import ModuleMeta diff --git a/src/galax/potential/_potential/builtin.py b/src/galax/potential/_potential/builtin.py index 2dd46d39..7d6b2d7f 100644 --- a/src/galax/potential/_potential/builtin.py +++ b/src/galax/potential/_potential/builtin.py @@ -29,7 +29,7 @@ FloatScalar, Vec3, ) -from galax.utils import vectorize_method +from galax.utils._jax import vectorize_method from galax.utils.dataclasses import field mass = u.get_physical_type("mass") diff --git a/src/galax/potential/_potential/param/core.py b/src/galax/potential/_potential/param/core.py index 1f29a47e..c9b06700 100644 --- a/src/galax/potential/_potential/param/core.py +++ b/src/galax/potential/_potential/param/core.py @@ -19,7 +19,7 @@ FloatScalar, Unit, ) -from galax.utils import vectorize_method +from galax.utils._jax import vectorize_method from galax.utils.dataclasses import converter_float_array diff --git a/src/galax/utils/_jax.py b/src/galax/utils/_jax.py index aeb25edb..5f6283d5 100644 --- a/src/galax/utils/_jax.py +++ b/src/galax/utils/_jax.py @@ -1,7 +1,7 @@ """galax: Galactic Dynamix in Jax.""" -__all__ = ["vectorize_method"] +__all__: list[str] = [] from collections.abc import Callable, Sequence from functools import partial diff --git a/tests/unit/potential/test_base.py b/tests/unit/potential/test_base.py index 417187ac..a331eb75 100644 --- a/tests/unit/potential/test_base.py +++ b/tests/unit/potential/test_base.py @@ -18,7 +18,7 @@ Vec3, ) from galax.units import UnitSystem, dimensionless -from galax.utils import vectorize_method +from galax.utils._jax import vectorize_method from .io.test_gala import GalaIOMixin diff --git a/tests/unit/potential/test_core.py b/tests/unit/potential/test_core.py index ff70dd80..c42860a1 100644 --- a/tests/unit/potential/test_core.py +++ b/tests/unit/potential/test_core.py @@ -11,7 +11,7 @@ from galax.potential._potential.utils import converter_to_usys from galax.typing import BatchableFloatOrIntScalarLike, BatchFloatScalar, BatchVec3 from galax.units import UnitSystem, dimensionless, galactic -from galax.utils import vectorize_method +from galax.utils._jax import vectorize_method from .test_base import TestAbstractPotentialBase as AbstractPotentialBase_Test from .test_utils import FieldUnitSystemMixin diff --git a/tests/unit/utils/test_jax.py b/tests/unit/utils/test_jax.py index 60a9c43b..cc9982b9 100644 --- a/tests/unit/utils/test_jax.py +++ b/tests/unit/utils/test_jax.py @@ -1,10 +1,10 @@ import jax.numpy as jnp from jaxtyping import Array, Float -from galax.utils import vectorize_method +from galax.utils._jax import vectorize_method -def test_vectorize_method(): +def test_vectorize_method() -> None: """Test the vectorize_method function.""" class A: