Skip to content

Commit

Permalink
refactor: gather astropy interop (#134)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jul 9, 2024
1 parent f678ef1 commit bd9a7d2
Show file tree
Hide file tree
Showing 15 changed files with 318 additions and 148 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@
[tool.pylint]
ignore-paths = [".*/_version.py", ".*/_compat.py"]
messages_control.disable = [
"cyclic-import", # TODO: resolve
"design",
"fixme",
"function-redefined", # plum-dispatch
Expand Down
3 changes: 3 additions & 0 deletions src/unxt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from ._version import version as __version__
from .unitsystems import AbstractUnitSystem, UnitSystem, unitsystem

# isort: split
from . import _unxt_interop_astropy # noqa: F401

if HAS_GALA:
from . import _unxt_interop_gala # noqa: F401

Expand Down
4 changes: 1 addition & 3 deletions src/unxt/_quantity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
Copyright (c) 2023 Galactic Dynamics. All rights reserved.
"""

from . import base, base_parametric, compat, core, distance, fast, functional, utils
from . import base, base_parametric, core, distance, fast, functional, utils
from .base import *
from .base_parametric import *
from .compat import *
from .core import *
from .distance import *
from .fast import *
Expand All @@ -25,4 +24,3 @@
__all__ += fast.__all__
__all__ += functional.__all__
__all__ += utils.__all__
__all__ += compat.__all__
55 changes: 7 additions & 48 deletions src/unxt/_quantity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@
import jax.core
import jax.numpy as jnp
import numpy as np
from astropy.units import (
CompositeUnit,
Quantity as AstropyQuantity,
Unit,
UnitConversionError,
)
from astropy.units import CompositeUnit, UnitConversionError
from jax.numpy import dtype as DType # noqa: N812
from jaxtyping import Array, ArrayLike, Shaped
from plum import add_promotion_rule
Expand All @@ -28,6 +23,8 @@
import quaxed.operator as qoperator
from quaxed.array_api._dispatch import dispatcher

from unxt._units import Unit

if TYPE_CHECKING:
from array_api import ArrayAPINamespace

Expand Down Expand Up @@ -486,17 +483,6 @@ def __hash__(self) -> int:
"""
return hash(tuple(getattr(self, f.name) for f in fields(self)))

# ===============================================================
# I/O

def convert_to(self, format: type[FMT], /) -> FMT:
"""Convert to a type."""
if format is AstropyQuantity:
return AstropyQuantity(self.value, self.unit)

msg = f"Unknown format {format}."
raise TypeError(msg)


# -----------------------------------------------
# Register additional constructors
Expand Down Expand Up @@ -543,33 +529,6 @@ def constructor(
return cls(value.value, unit)


@AbstractQuantity.constructor._f.register # type: ignore[no-redef] # noqa: SLF001
def constructor(
cls: type[AbstractQuantity], value: AstropyQuantity, /, *, dtype: Any = None
) -> AbstractQuantity:
"""Construct a `Quantity` from another `Quantity`.
The `value` is converted to the new `unit`.
"""
return cls(xp.asarray(value.value, dtype=dtype), value.unit)


@AbstractQuantity.constructor._f.register # type: ignore[no-redef] # noqa: SLF001
def constructor(
cls: type[AbstractQuantity],
value: AstropyQuantity,
unit: Any,
/,
*,
dtype: Any = None,
) -> AbstractQuantity:
"""Construct a `Quantity` from another `Quantity`.
The `value` is converted to the new `unit`.
"""
return cls(xp.asarray(value.to_value(unit), dtype=dtype), unit)


# -----------------------------------------------
# Promotion rules

Expand All @@ -579,14 +538,14 @@ def constructor(
# ===============================================================


def can_convert_unit(from_: AbstractQuantity | Unit, to: Unit) -> bool:
def can_convert_unit(from_unit: AbstractQuantity | Unit, to_unit: Unit, /) -> bool:
"""Check if a unit can be converted to another unit.
Parameters
----------
from_ : :clas:`unxt.AbstractQuantity` | Unit
from_unit : :clas:`unxt.AbstractQuantity` | Unit
The unit to convert from.
to : Unit
to_unit : Unit
The unit to convert to.
Returns
Expand All @@ -596,7 +555,7 @@ def can_convert_unit(from_: AbstractQuantity | Unit, to: Unit) -> bool:
"""
try:
from_.to(to)
from_unit.to(to_unit)
except UnitConversionError:
return False
return True
5 changes: 3 additions & 2 deletions src/unxt/_quantity/base_parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
import equinox as eqx
import jax
import jax.core
from astropy.units import PhysicalType, Unit, UnitBase, get_physical_type
from astropy.units import PhysicalType, Unit, get_physical_type
from jaxtyping import Array, ArrayLike, Shaped
from plum import parametric

from quaxed.array_api._dispatch import dispatcher

from .base import AbstractQuantity
from unxt._typing import Unit as UnitTypes


@parametric
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init_type_parameter__(cls, dimensions: str) -> tuple[PhysicalType]:

@classmethod # type: ignore[no-redef]
@dispatcher
def __init_type_parameter__(cls, unit: UnitBase) -> tuple[PhysicalType]:
def __init_type_parameter__(cls, unit: UnitTypes) -> tuple[PhysicalType]:
"""Infer the type parameter from the arguments."""
if unit.physical_type != "unknown":
return (unit.physical_type,)
Expand Down
31 changes: 0 additions & 31 deletions src/unxt/_quantity/compat.py

This file was deleted.

6 changes: 2 additions & 4 deletions src/unxt/_quantity/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import astropy.units as u
import equinox as eqx
import jax.numpy as jnp
from plum import add_conversion_method, add_promotion_rule
from plum import add_promotion_rule, conversion_method

import quaxed.array_api as xp
import quaxed.numpy as qnp
Expand Down Expand Up @@ -57,14 +57,12 @@ def distance_modulus(self) -> Quantity:
add_promotion_rule(AbstractDistance, Quantity, Quantity)


@conversion_method(type_from=AbstractDistance, type_to=Quantity) # type: ignore[misc]
def _convert_distance_to_quantity(x: AbstractDistance) -> Quantity:
"""Convert a distance to a quantity."""
return Quantity(x.value, x.unit)


add_conversion_method(AbstractDistance, Quantity, _convert_distance_to_quantity)


##############################################################################


Expand Down
43 changes: 0 additions & 43 deletions src/unxt/_quantity/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from typing import Any

from astropy.units import Quantity as AstropyQuantity
from jax.typing import ArrayLike
from plum import dispatch

Expand Down Expand Up @@ -66,27 +65,6 @@ def to_units(value: ArrayLike, units: Unit | str, /) -> Quantity:
return Quantity.constructor(value, units)


# ---------------------------
# Compat


@dispatch # type: ignore[no-redef]
def to_units(value: AstropyQuantity, units: Unit, /) -> Quantity:
"""Convert an Astropy Quantity to the given units.
Examples
--------
>>> from unxt import to_units
>>> import astropy.units as u
>>> q = u.Quantity(1, "m")
>>> to_units(q, "cm")
Quantity['length'](Array(100., dtype=float32), unit='cm')
"""
return Quantity.constructor(value, units)


# ============================================================================
# to_units_value

Expand Down Expand Up @@ -139,24 +117,3 @@ def to_units_value(value: ArrayLike, units: Unit | str, /) -> ArrayLike:
"""
return value


# ---------------------------
# Compat


@dispatch # type: ignore[no-redef]
def to_units_value(value: AstropyQuantity, units: Unit | str, /) -> ArrayLike:
"""Convert an Astropy Quantity to an array with the given units.
Examples
--------
>>> from unxt import to_units_value
>>> import astropy.units as u
>>> q = u.Quantity(1, "m")
>>> to_units_value(q, "cm")
np.float64(100.0)
"""
return value.to_value(units)
4 changes: 1 addition & 3 deletions src/unxt/_quantity/register_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import equinox as eqx
import jax
from astropy.units import ( # pylint: disable=no-name-in-module
Unit,
UnitBase,
dimensionless_unscaled as one,
radian,
)
Expand All @@ -30,11 +28,11 @@
from .core import Quantity
from .distance import AbstractDistance
from .utils import type_unparametrized as type_np
from unxt._units import Unit

T = TypeVar("T")

Axes: TypeAlias = tuple[int, ...]
UnitClasses: TypeAlias = UnitBase


def register(primitive: Primitive) -> Callable[[T], T]:
Expand Down
10 changes: 10 additions & 0 deletions src/unxt/_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# pylint: disable=import-error

"""unxt: Quantities in JAX.
Copyright (c) 2023 Galactic Dynamics. All rights reserved.
"""

__all__: list[str] = ["Unit"]

from astropy.units import Unit
8 changes: 8 additions & 0 deletions src/unxt/_unxt_interop_astropy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Tools for representing systems of units using ``astropy.units``."""

__all__: list[str] = []

from . import ( # noqa: F401
quantity,
unitsystems,
)
Loading

0 comments on commit bd9a7d2

Please sign in to comment.