Skip to content

Commit

Permalink
feat: move units to unxt
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Mar 16, 2024
1 parent 8955ef9 commit 33079c9
Show file tree
Hide file tree
Showing 35 changed files with 48 additions and 464 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"Typing :: Typed",
]
dependencies = [
"astropy >= 5.3",
"astropy >= 6.0",
"beartype",
"coordinax @ git+https://github.com/GalacticDynamics/coordinax.git",
"diffrax",
Expand Down
3 changes: 1 addition & 2 deletions src/galax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
"coordinates",
"potential",
"dynamics",
"units",
"utils",
"typing",
]

from jax import config

from . import coordinates, dynamics, potential, typing, units, utils
from . import coordinates, dynamics, potential, typing, utils
from ._version import __version__

config.update("jax_enable_x64", True) # noqa: FBT003
Expand Down
2 changes: 0 additions & 2 deletions src/galax/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@ __all__ = [
"dynamics",
"potential",
"typing",
"units",
"utils",
]

from . import (
dynamics as dynamics,
potential as potential,
typing as typing,
units as units,
utils as utils,
)
from ._version import ( # type: ignore[attr-defined]
Expand Down
11 changes: 5 additions & 6 deletions src/galax/coordinates/_psp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
Cartesian3DVector,
represent_as as vector_represent_as,
)
from unxt import Quantity
from unxt import Quantity, unitsystem

from .utils import getitem_broadscalartime_index
from galax.typing import (
Expand All @@ -30,7 +30,6 @@
BatchVec7,
BroadBatchFloatQScalar,
)
from galax.units import unitsystem

if TYPE_CHECKING:
from typing import Self
Expand Down Expand Up @@ -188,8 +187,8 @@ def w(self, *, units: Any) -> BatchVec6:
Parameters
----------
units : `galax.units.UnitSystem`, optional keyword-only
The unit system. :func:`~galax.units.unitsystem` is used to
units : `unxt.UnitSystem`, optional keyword-only
The unit system. :func:`~unxt.unitsystem` is used to
convert the input to a unit system.
Returns
Expand Down Expand Up @@ -228,8 +227,8 @@ def wt(self, *, units: Any) -> BatchVec7:
Parameters
----------
units : `galax.units.UnitSystem`, keyword-only
The unit system. :func:`~galax.units.unitsystem` is used to
units : `unxt.UnitSystem`, keyword-only
The unit system. :func:`~unxt.unitsystem` is used to
convert the input to a unit system.
Returns
Expand Down
3 changes: 1 addition & 2 deletions src/galax/coordinates/_psp/psp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import jax.numpy as jnp

from coordinax import Abstract3DVector, Abstract3DVectorDifferential
from unxt import Quantity
from unxt import Quantity, UnitSystem

from .base import AbstractPhaseSpacePosition
from .utils import _p_converter, _q_converter
Expand All @@ -20,7 +20,6 @@
QVec1,
VecTime,
)
from galax.units import UnitSystem
from galax.utils._shape import batched_shape, expand_batch_dims, vector_batched_shape


Expand Down
3 changes: 2 additions & 1 deletion src/galax/dynamics/_dynamics/integrate/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from typing import Any, Protocol, runtime_checkable

from unxt import UnitSystem

import galax.typing as gt
from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition
from galax.units import UnitSystem
from galax.utils.dataclasses import _DataclassInstance


Expand Down
3 changes: 2 additions & 1 deletion src/galax/dynamics/_dynamics/integrate/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

import equinox as eqx

from unxt import UnitSystem

from ._api import FCallable
from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition
from galax.typing import BatchQVecTime, BatchVec6, BatchVecTime, QVecTime, VecTime
from galax.units import UnitSystem


class AbstractIntegrator(eqx.Module, strict=True): # type: ignore[call-arg, misc]
Expand Down
3 changes: 1 addition & 2 deletions src/galax/dynamics/_dynamics/integrate/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
import jax

import quaxed.array_api as xp
from unxt import Quantity
from unxt import Quantity, UnitSystem

import galax.typing as gt
from ._api import FCallable
from ._base import AbstractIntegrator
from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition
from galax.units import UnitSystem
from galax.utils import ImmutableDict
from galax.utils._jax import vectorize_method

Expand Down
2 changes: 1 addition & 1 deletion src/galax/dynamics/_dynamics/integrate/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def evaluate_orbit(
>>> import quaxed.array_api as xp # preferred over `jax.numpy`
>>> import galax.coordinates as gc
>>> import galax.potential as gp
>>> from galax.units import galactic
>>> from unxt.unitsystems import galactic
We can then create the point-mass' potential, with galactic units:
Expand Down
4 changes: 2 additions & 2 deletions src/galax/dynamics/_dynamics/mockstream/df/fardal.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def d2phidr2(
Examples
--------
>>> from galax.potential import NFWPotential
>>> from galax.units import galactic
>>> from unxt.unitsystems import galactic
>>> pot = NFWPotential(m=1e12, r_s=20.0, units=galactic)
>>> d2phidr2(pot, xp.asarray([8.0, 0.0, 0.0]), t=0)
Array(-0.00017469, dtype=float64)
Expand Down Expand Up @@ -249,7 +249,7 @@ def tidal_radius(
Examples
--------
>>> from galax.potential import NFWPotential
>>> from galax.units import galactic
>>> from unxt.unitsystems import galactic
>>> pot = NFWPotential(m=1e12, r_s=20.0, units=galactic)
>>> x=xp.asarray([8.0, 0.0, 0.0])
>>> v=xp.asarray([8.0, 0.0, 0.0])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from jax.lib.xla_bridge import get_backend

import quaxed.array_api as xp
from unxt import Quantity
from unxt import Quantity, UnitSystem

from .core import MockStream
from .df import AbstractStreamDF
Expand All @@ -24,7 +24,6 @@
from galax.dynamics._dynamics.integrate._funcs import evaluate_orbit
from galax.potential._potential.base import AbstractPotentialBase
from galax.typing import BatchVec6, FloatScalar, IntScalar, QVecTime, Vec6, VecN
from galax.units import UnitSystem

Carry: TypeAlias = tuple[IntScalar, VecN, VecN]

Expand Down
2 changes: 1 addition & 1 deletion src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
import quaxed.array_api as xp
from coordinax import Abstract3DVector, FourVector
from unxt import Quantity
from unxt.unitsystems import UnitSystem, dimensionless

import galax.typing as gt
from .utils import _convert_from_3dvec, convert_input_to_array, convert_inputs_to_arrays
from galax.coordinates import AbstractPhaseSpacePosition, PhaseSpacePosition
from galax.potential._potential.param.attr import ParametersAttribute
from galax.potential._potential.param.utils import all_parameters
from galax.units import UnitSystem, dimensionless
from galax.utils._collections import ImmutableDict
from galax.utils._jax import vectorize_method
from galax.utils._shape import batched_shape, expand_arr_dims, expand_batch_dims
Expand Down
9 changes: 4 additions & 5 deletions src/galax/potential/_potential/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
import jax

import quaxed.array_api as xp
from unxt import Quantity
from unxt import Quantity, UnitSystem, unitsystem

import galax.typing as gt
from galax.potential._potential.base import default_constants
from galax.potential._potential.core import AbstractPotential
from galax.potential._potential.param import AbstractParameter, ParameterField
from galax.units import UnitSystem, unitsystem
from galax.utils import ImmutableDict
from galax.utils._jax import vectorize_method
from galax.utils.dataclasses import field
Expand Down Expand Up @@ -269,10 +268,10 @@ class TriaxialHernquistPotential(AbstractPotential):
or constant, like a Quantity. See
:class:`~galax.potential.ParameterField` for details.
units : :class:`~galax.units.UnitSystem`, keyword-only
units : :class:`~unxt.UnitSystem`, keyword-only
The unit system to use for the potential. This parameter accepts a
:class:`~galax.units.UnitSystem` or anything that can be converted to a
:class:`~galax.units.UnitSystem` using :func:`~galax.units.unitsystem`.
:class:`~unxt.UnitSystem` or anything that can be converted to a
:class:`~unxt.UnitSystem` using :func:`~unxt.unitsystem`.
Examples
--------
Expand Down
3 changes: 1 addition & 2 deletions src/galax/potential/_potential/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
import jax

import quaxed.array_api as xp
from unxt import Quantity
from unxt import Quantity, UnitSystem, unitsystem

from .base import AbstractPotentialBase, default_constants
from galax.typing import BatchableRealScalarLike, BatchFloatScalar, BatchVec3
from galax.units import UnitSystem, unitsystem
from galax.utils import ImmutableDict
from galax.utils._misc import first

Expand Down
3 changes: 1 addition & 2 deletions src/galax/potential/_potential/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@

import equinox as eqx

from unxt import Quantity
from unxt import Quantity, UnitSystem, unitsystem

from .base import AbstractPotentialBase, default_constants
from .composite import CompositePotential
from galax.typing import FloatScalar, RealScalar, Vec3
from galax.units import UnitSystem, unitsystem
from galax.utils import ImmutableDict


Expand Down
3 changes: 1 addition & 2 deletions src/galax/potential/_potential/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import equinox as eqx

from coordinax.operators import OperatorSequence, simplify_op
from unxt import Quantity
from unxt import Quantity, UnitSystem

from galax.potential._potential.base import AbstractPotentialBase
from galax.typing import (
Expand All @@ -18,7 +18,6 @@
BatchVec3,
RealScalar,
)
from galax.units import UnitSystem
from galax.utils import ImmutableDict


Expand Down
4 changes: 2 additions & 2 deletions src/galax/potential/_potential/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import equinox as eqx

from unxt import Quantity
from unxt.unitsystems import UnitSystem, dimensionless, galactic, unitsystem

from .base import AbstractPotentialBase, default_constants
from .builtin import HernquistPotential, MiyamotoNagaiPotential, NFWPotential
from .composite import AbstractCompositePotential
from galax.units import UnitSystem, dimensionless, galactic, unitsystem
from galax.utils import ImmutableDict

T = TypeVar("T", bound=AbstractPotentialBase)
Expand Down Expand Up @@ -53,7 +53,7 @@ class MilkyWayPotential(AbstractCompositePotential):
Parameters
----------
units : `~galax.units.UnitSystem` (optional)
units : `~unxt.UnitSystem` (optional)
Set of non-reducable units that specify (at minimum) the
length, mass, time, and angle units.
disk : dict (optional)
Expand Down
3 changes: 1 addition & 2 deletions src/galax/potential/_potential/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

from coordinax import Abstract3DVector, Cartesian3DVector
from unxt import Quantity

from galax.units import DimensionlessUnitSystem, UnitSystem, dimensionless
from unxt.unitsystems import DimensionlessUnitSystem, UnitSystem, dimensionless


def convert_inputs_to_arrays(
Expand Down
Loading

0 comments on commit 33079c9

Please sign in to comment.