Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: unit system refactor #142

Merged
merged 3 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@
"__init__.py" = ["F403"]
"docs/conf.py" = ["A001", "INP001"]
"noxfile.py" = ["T20"]
"tests/**" = ["ANN", "S101", "T20"]
"tests/**" = ["ANN", "S101", "SLF001", "T20"]

[tool.ruff.lint.isort]
combine-as-imports = true
Expand Down
3 changes: 1 addition & 2 deletions src/unxt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from ._unxt.quantity import *
from ._version import version as __version__
from .unitsystems import AbstractUnitSystem, UnitSystem, unitsystem
from .unitsystems import AbstractUnitSystem, unitsystem

# isort: split
from . import _interop # noqa: F401 # register interop
Expand All @@ -20,7 +20,6 @@
# units systems
"unitsystems", # module
"AbstractUnitSystem", # base class
"UnitSystem", # main user-facing class
"unitsystem", # convenience constructor
]
__all__ += quantity.__all__
Expand Down
36 changes: 31 additions & 5 deletions src/unxt/_interop/unxt_interop_gala/unitsystems.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,42 @@
)
from plum import dispatch

from unxt.unitsystems import DimensionlessUnitSystem, UnitSystem, dimensionless
from unxt.unitsystems import AbstractUnitSystem, DimensionlessUnitSystem, dimensionless


@dispatch
def unitsystem(value: GalaUnitSystem, /) -> UnitSystem:
usys = UnitSystem(*value._core_units) # noqa: SLF001
usys._registry = value._registry # noqa: SLF001
return usys
def unitsystem(value: GalaUnitSystem, /) -> AbstractUnitSystem:
"""Return a `gala.units.UnitSystem` as a `unxt.AbstractUnitSystem`.

Examples
--------
>>> from gala.units import UnitSystem
>>> import astropy.units as u
>>> usys = UnitSystem(u.km, u.s, u.Msun, u.radian)

>>> from unxt import unitsystem
>>> unitsystem(usys)
LTMAUnitSystem(length=Unit("km"), time=Unit("s"),
mass=Unit("solMass"), angle=Unit("rad"))

"""
# Create a new unit system instance, and possibly class.
return unitsystem(*value._core_units) # noqa: SLF001


@dispatch # type: ignore[no-redef]
def unitsystem(_: GalaDimensionlessUnitSystem, /) -> DimensionlessUnitSystem:
"""Return a `gala.units.DimensionlessUnitSystem` as a `unxt.DimensionlessUnitSystem`.

Examples
--------
>>> from gala.units import DimensionlessUnitSystem
>>> import astropy.units as u
>>> usys = DimensionlessUnitSystem()

>>> from unxt import unitsystem
>>> unitsystem(usys)
DimensionlessUnitSystem()

""" # noqa: E501
return dimensionless
239 changes: 156 additions & 83 deletions src/unxt/_unxt/unitsystems/base.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,168 @@
"""Tools for representing systems of units using ``astropy.units``."""

__all__ = ["AbstractUnitSystem"]
__all__ = ["AbstractUnitSystem", "UNITSYSTEMS_REGISTRY"]

from collections.abc import Iterator
from typing import ClassVar, Union, cast
from dataclasses import dataclass
from types import MappingProxyType
from typing import ClassVar, get_args, get_type_hints

import astropy.units as u
from astropy.units import PhysicalType as Dimension
from astropy.units.physical import _physical_unit_mapping

from unxt._unxt.quantity.base import AbstractQuantity
from unxt._unxt.typing_ext import Unit
from .utils import get_dimension_name, is_annotated
from unxt._unxt.typing_ext import Unit as UnitT

Unit = u.UnitBase

_UNITSYSTEMS_REGISTRY: dict[tuple[Dimension, ...], type["AbstractUnitSystem"]] = {}
UNITSYSTEMS_REGISTRY = MappingProxyType(_UNITSYSTEMS_REGISTRY)


def parse_field_names_and_dimensions(
cls: type,
) -> tuple[tuple[str, ...], tuple[Dimension, ...]]:
# Register class with a tuple of it's dimensions.
# This requires processing the type hints, not the dataclass fields
# since those are made after the original class is defined.
type_hints = get_type_hints(cls, include_extras=True)

field_names = []
dimensions = []
for name, type_hint in type_hints.items():
# Check it's Annotated
if not is_annotated(type_hint):
continue

# Get the arguments to Annotated
origin, *f_args = get_args(type_hint)

# Check that the first argument is a UnitBase
if not issubclass(origin, Unit):
continue

# Need for one of the arguments to be a PhysicalType
f_dims = [x for x in f_args if isinstance(x, Dimension)]
if not f_dims:
msg = f"Field {name!r} must be an Annotated with a dimension."
raise TypeError(msg)
if len(f_dims) > 1:
msg = (
f"Field {name!r} must be an Annotated with only one dimension; "
f"got {f_dims}"
)
raise TypeError(msg)

field_names.append(get_dimension_name(name))
dimensions.append(f_dims[0])

if len(set(dimensions)) < len(dimensions):
msg = "Some dimensions are repeated."
raise ValueError(msg)

return tuple(field_names), tuple(dimensions)


@dataclass(frozen=True, slots=True, eq=True)
class AbstractUnitSystem:
"""Represents a system of units.

This class behaves like a dictionary with keys set by physical types (i.e. "length",
"velocity", "energy", etc.). If a unit for a particular physical type is not
specified on creation, a composite unit will be created with the base units. See the
examples below for some demonstrations.
This class behaves like a dictionary with keys set by physical types (i.e.
"length", "velocity", "energy", etc.). If a unit for a particular physical
type is not specified on creation, a composite unit will be created with the
base units. See the examples below for some demonstrations.

Examples
--------
If only base units are specified, any physical type specified as a key to
this object will be composed out of the base units::

>>> from unxt import unitsystem
>>> import astropy.units as u
>>> usys = unitsystem(u.m, u.s, u.kg, u.radian)
>>> usys
LTMAUnitSystem(length=Unit("m"), time=Unit("s"), mass=Unit("kg"), angle=Unit("rad"))

>>> usys["velocity"]
Unit("m / s")

This unit system defines energy::

>>> usys = unitsystem(u.m, u.s, u.kg, u.radian, u.erg)
>>> usys["energy"]
Unit("erg")

This is useful for Galactic dynamics where lengths and times are usually
given in terms of ``kpc`` and ``Myr``, but velocities are often specified in
``km/s``::

>>> usys = unitsystem(u.kpc, u.Myr, u.Msun, u.radian, u.km/u.s)
>>> usys["velocity"]
Unit("km / s")

Unit systems can be hashed:

>>> isinstance(hash(usys), int)
True

And iterated over:

>>> [x for x in usys]
[Unit("kpc"), Unit("Myr"), Unit("solMass"), Unit("rad"), Unit("km / s")]

With length equal to the number of base units

>>> len(usys)
5

"""

_core_units: list[Unit]
_registry: dict[u.PhysicalType, Unit]

_required_dimensions: ClassVar[list[u.PhysicalType]] # do in subclass

def __init__(
self,
units: Union[Unit, u.Quantity, AbstractQuantity, "AbstractUnitSystem"],
*args: Unit | u.Quantity | AbstractQuantity,
) -> None:
if isinstance(units, AbstractUnitSystem):
if len(args) > 0:
msg = (
"If passing in a AbstractUnitSystem, "
"cannot pass in additional units."
)
raise ValueError(msg)

self._registry = units._registry.copy() # noqa: SLF001
self._core_units = units._core_units # noqa: SLF001
return

units = (units, *args)

self._registry = {}
for unit in units:
unit_ = ( # TODO: better detection of allowed unit base classes
unit if isinstance(unit, u.UnitBase) else u.def_unit(f"{unit!s}", unit)
)
if unit_.physical_type in self._registry:
msg = f"Multiple units passed in with type {unit_.physical_type!r}"
raise ValueError(msg)
self._registry[unit_.physical_type] = unit_

self._core_units = []
for phys_type in self._required_dimensions:
if phys_type not in self._registry:
msg = f"You must specify a unit for the physical type {phys_type!r}"
raise ValueError(msg)
self._core_units.append(self._registry[phys_type])

def __getitem__(self, key: str | u.PhysicalType) -> u.UnitBase:
# ===============================================================
# Class-level

_base_field_names: ClassVar[tuple[str, ...]]
_base_dimensions: ClassVar[tuple[Dimension, ...]]

def __init_subclass__(cls) -> None:
# Register class with a tuple of it's dimensions.
# This requires processing the type hints, not the dataclass fields
# since those are made after the original class is defined.
field_names, dimensions = parse_field_names_and_dimensions(cls)

# Check the unitsystem is not already registered
# If `make_dataclass(slots=True)` then the class is made twice, the
# second time adding the `__slots__` attribute
if dimensions in _UNITSYSTEMS_REGISTRY and "__slots__" not in cls.__dict__:
msg = f"Unit system with dimensions {dimensions} already exists."
raise ValueError(msg)

# Add attributes to the class
cls._base_field_names = tuple(field_names)
cls._base_dimensions = dimensions

_UNITSYSTEMS_REGISTRY[dimensions] = cls

# ===============================================================
# Instance-level

def __post_init__(self) -> None:
pass

@property # TODO: classproperty
def base_dimensions(self) -> tuple[Dimension, ...]:
"""Dimensions required for the unit system."""
return self._base_dimensions

@property
def base_units(self) -> tuple[UnitT, ...]:
"""List of core units."""
return tuple(getattr(self, k) for k in self._base_field_names)

def __getitem__(self, key: Dimension | str) -> UnitT:
key = u.get_physical_type(key)
if key in self._required_dimensions:
return self._registry[key]
if key in self.base_dimensions:
return getattr(self, get_dimension_name(key))

unit = None
for k, v in _physical_unit_mapping.items():
Expand All @@ -78,41 +174,18 @@
msg = f"Physical type '{key}' doesn't exist in unit registry."
raise ValueError(msg)

unit = unit.decompose(self._core_units)
unit = unit.decompose(self.base_units)
unit._scale = 1.0 # noqa: SLF001
return unit

def __len__(self) -> int:
# Note: This is required for q.decompose(usys) to work, where q is a Quantity
return len(self._core_units)

def __iter__(self) -> Iterator[Unit]:
yield from self._core_units
return len(self.base_dimensions)

def __repr__(self) -> str:
return f"{type(self).__name__}({', '.join(map(str, self._core_units))})"
# TODO: should this be changed to _base_field_names -> Iterator[str]?
def __iter__(self) -> Iterator[UnitT]:
yield from self.base_units

def __eq__(self, other: object) -> bool:
if not isinstance(other, AbstractUnitSystem):
return NotImplemented
return bool(self._registry == other._registry)

def __ne__(self, other: object) -> bool:
return not self.__eq__(other)

def __hash__(self) -> int:
"""Hash the unit system."""
return hash(tuple(self._core_units) + tuple(self._required_dimensions))

def preferred(self, key: str | u.PhysicalType) -> Unit:
"""Return the preferred unit for a given physical type."""
key = u.get_physical_type(key)
if key in self._registry:
return self._registry[key]
return self[key]

def as_preferred(self, quantity: AbstractQuantity | u.Quantity) -> AbstractQuantity:
"""Convert a quantity to the preferred unit for this unit system."""
unit = self.preferred(quantity.unit.physical_type)
# Note that it's necessary to
return cast(AbstractQuantity, quantity.to(unit))
def __str__(self) -> str:
fs = ", ".join(map(str, self._base_field_names))
return f"{type(self).__name__}({fs})"

Check warning on line 191 in src/unxt/_unxt/unitsystems/base.py

View check run for this annotation

Codecov / codecov/patch

src/unxt/_unxt/unitsystems/base.py#L190-L191

Added lines #L190 - L191 were not covered by tests
Loading