Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some enhancements to the unit system stuff #45

Merged
merged 12 commits into from
Jan 2, 2024
15 changes: 15 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,21 @@ def tests(session: nox.Session) -> None:
session.run("pytest", *session.posargs)


@nox.session
adrn marked this conversation as resolved.
Show resolved Hide resolved
def doctests(session: nox.Session) -> None:
"""Run the regular tests and doctests."""
session.install(".[test]")
session.run(
"pytest",
"--doctest-modules",
'--doctest-glob="*.rst"',
'--doctest-glob="*.md"',
"docs",
"src/galax",
*session.posargs,
)


@nox.session(reuse_venv=True)
def docs(session: nox.Session) -> None:
"""Build the docs. Pass "--serve" to serve. Pass "-b linkcheck" to check links."""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ ignore = [
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F403"]
"tests/**" = [
"ANN", "D10", "E731", "INP001", "S101", "SLF001", "T20",
"ANN", "D10", "E731", "INP001", "S101", "S301", "SLF001", "T20",
"TID252", # Relative imports from parent modules are banned
]
"noxfile.py" = ["ERA001", "T20"]
Expand Down
104 changes: 70 additions & 34 deletions src/galax/units.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,4 @@
"""Paired down UnitSystem class from gala.

See gala's license below.

```
The MIT License (MIT)

Copyright (c) 2012-2023 Adrian M. Price-Whelan

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
```
"""
"""Tools for representing systems of units using ``astropy.units``."""

__all__ = [
"UnitSystem",
Expand All @@ -36,15 +9,58 @@
]

from collections.abc import Iterator
from typing import Any, ClassVar, no_type_check
from typing import ClassVar, Union

import astropy.units as u
from astropy.units.physical import _physical_unit_mapping


@no_type_check # TODO: get beartype working with this
adrn marked this conversation as resolved.
Show resolved Hide resolved
class UnitSystem:
"""Represents a system of units."""
"""Represents a system of units.

At minimum, this consists of a set of length, time, mass, and angle units, but may
also contain preferred representations for composite units. For example, the base
unit system could be ``{kpc, Myr, Msun, radian}``, but you can also specify a
preferred velocity unit, such as ``km/s``.

This class behaves like a dictionary with keys set by physical types (i.e. "length",
"velocity", "energy", etc.). If a unit for a particular physical type is not
specified on creation, a composite unit will be created with the base units. See the
examples below for some demonstrations.

Parameters
----------
*units, **units
The units that define the unit system. At minimum, this must contain length,
time, mass, and angle units. If passing in keyword arguments, the keys must be
valid :mod:`astropy.units` physical types.

Examples
--------
If only base units are specified, any physical type specified as a key
to this object will be composed out of the base units::

>>> usys = UnitSystem(u.m, u.s, u.kg, u.radian)
>>> usys["velocity"]
Unit("m / s")

However, preferred representations for composite units can also be specified::

>>> usys = UnitSystem(u.m, u.s, u.kg, u.radian, u.erg)
>>> usys["energy"]
Unit("m2 kg / s2")
>>> usys.preferred("energy")
Unit("erg")

This is useful for Galactic dynamics where lengths and times are usually given in
terms of ``kpc`` and ``Myr``, but velocities are often specified in ``km/s``::

>>> usys = UnitSystem(u.kpc, u.Myr, u.Msun, u.radian, u.km/u.s)
>>> usys["velocity"]
Unit("kpc / Myr")
>>> usys.preferred("velocity")
Unit("km / s")
"""

_core_units: list[u.UnitBase]
_registry: dict[u.PhysicalType, u.UnitBase]
Expand All @@ -56,8 +72,15 @@ class UnitSystem:
u.get_physical_type("angle"),
]

# TODO: type hint `units`
def __init__(self, units: Any, *args: u.UnitBase) -> None:
def __init__(
self,
units: Union[
u.UnitBase,
u.Quantity,
"UnitSystem",
],
*args: u.UnitBase | u.Quantity,
) -> None:
if isinstance(units, UnitSystem):
if len(args) > 0:
msg = "If passing in a UnitSystem, cannot pass in additional units."
Expand Down Expand Up @@ -87,7 +110,8 @@ def __init__(self, units: Any, *args: u.UnitBase) -> None:
self._core_units.append(self._registry[phys_type])

def __getitem__(self, key: str | u.PhysicalType) -> u.UnitBase:
if key in self._registry:
key = u.get_physical_type(key)
if key in self._required_dimensions:
return self._registry[key]

unit = None
Expand All @@ -105,6 +129,7 @@ def __getitem__(self, key: str | u.PhysicalType) -> u.UnitBase:
return unit

def __len__(self) -> int:
# Note: This is required for q.decompose(usys) to work, where q is a Quantity
return len(self._core_units)

def __iter__(self) -> Iterator[u.UnitBase]:
Expand All @@ -125,6 +150,17 @@ def __hash__(self) -> int:
"""Hash the unit system."""
return hash(tuple(self._core_units) + tuple(self._required_dimensions))

def preferred(self, key: str | u.PhysicalType) -> u.UnitBase:
"""Return the preferred unit for a given physical type."""
key = u.get_physical_type(key)
if key in self._registry:
return self._registry[key]
return self[key]

def as_preferred(self, quantity: u.Quantity) -> u.Quantity:
"""Convert a quantity to the preferred unit for this unit system."""
return quantity.to(self.preferred(quantity.unit.physical_type))


class DimensionlessUnitSystem(UnitSystem):
"""A unit system with only dimensionless units."""
Expand Down
91 changes: 91 additions & 0 deletions tests/unit/test_units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Standard library
import pickle

# Third party
import astropy.units as u
import numpy as np
import pytest

# This package
from galax.units import UnitSystem, dimensionless


def test_init():
usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun)

with pytest.raises(
ValueError, match="must specify a unit for the physical type .*mass"
):
UnitSystem(u.kpc, u.Myr, u.radian) # no mass

with pytest.raises(
ValueError, match="must specify a unit for the physical type .*angle"
):
UnitSystem(u.kpc, u.Myr, u.Msun)

with pytest.raises(
ValueError, match="must specify a unit for the physical type .*time"
):
UnitSystem(u.kpc, u.radian, u.Msun)

with pytest.raises(
ValueError, match="must specify a unit for the physical type .*length"
):
UnitSystem(u.Myr, u.radian, u.Msun)

usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun)
usys = UnitSystem(usys)


def test_quantity_init():
usys = UnitSystem(5 * u.kpc, 50 * u.Myr, 1e5 * u.Msun, u.rad)
assert np.isclose((8 * u.Myr).decompose(usys).value, 8 / 50)


def test_preferred():
usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.km / u.s)
q = 15.0 * u.km / u.s
assert usys.preferred("velocity") == u.km / u.s
assert q.decompose(usys).unit == u.kpc / u.Myr
assert usys.as_preferred(q).unit == u.km / u.s


def test_dimensionless():
assert dimensionless["dimensionless"] == u.one
assert dimensionless["length"] == u.one

with pytest.raises(ValueError, match="can not be decomposed into"):
(15 * u.kpc).decompose(dimensionless)

with pytest.raises(ValueError, match="are not convertible"):
dimensionless.as_preferred(15 * u.kpc)


def test_compare():
usys1 = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.mas / u.yr)
usys1_clone = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.mas / u.yr)

usys2 = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.kiloarcsecond / u.yr)
usys3 = UnitSystem(u.kpc, u.Myr, u.radian, u.kg, u.mas / u.yr)

assert usys1 == usys1_clone
assert usys1_clone == usys1

assert usys1 != usys2
assert usys2 != usys1

assert usys1 != usys3
assert usys3 != usys1


def test_pickle(tmpdir):
usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun)

path = tmpdir / "test.pkl"
with path.open(mode="wb") as f:
pickle.dump(usys, f)

with path.open(mode="rb") as f:
usys2 = pickle.load(f)

assert usys == usys2