From bef32751b2ad2c0e00f9c920579cc57b97a82724 Mon Sep 17 00:00:00 2001 From: Adrian Price-Whelan Date: Tue, 2 Jan 2024 18:45:56 -0500 Subject: [PATCH] Some enhancements to the unit system stuff (#45) * remove gala license - blessed by @adrn * add docstring and type hints for input * docstring fixes * Add nox command to run doctests * implementation of preferred() * ignore ruff telling me pickle is unsafe in the tests * add as_preferred and unit tests of units * remove no type check --- noxfile.py | 17 ++++++- pyproject.toml | 2 +- src/galax/units.py | 104 ++++++++++++++++++++++++++------------- tests/unit/test_units.py | 91 ++++++++++++++++++++++++++++++++++ 4 files changed, 178 insertions(+), 36 deletions(-) create mode 100644 tests/unit/test_units.py diff --git a/noxfile.py b/noxfile.py index 227ff1e0..4be50f7a 100644 --- a/noxfile.py +++ b/noxfile.py @@ -8,7 +8,7 @@ DIR = Path(__file__).parent.resolve() -nox.options.sessions = ["lint", "tests"] +nox.options.sessions = ["lint", "tests", "doctests"] @nox.session @@ -38,6 +38,21 @@ def tests(session: nox.Session) -> None: session.run("pytest", *session.posargs) +@nox.session +def doctests(session: nox.Session) -> None: + """Run the regular tests and doctests.""" + session.install(".[test]") + session.run( + "pytest", + "--doctest-modules", + '--doctest-glob="*.rst"', + '--doctest-glob="*.md"', + "docs", + "src/galax", + *session.posargs, + ) + + @nox.session(reuse_venv=True) def docs(session: nox.Session) -> None: """Build the docs. Pass "--serve" to serve. Pass "-b linkcheck" to check links.""" diff --git a/pyproject.toml b/pyproject.toml index 09f77296..1a0f2ae5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,7 +168,7 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F403"] "tests/**" = [ - "ANN", "D10", "E731", "INP001", "S101", "SLF001", "T20", + "ANN", "D10", "E731", "INP001", "S101", "S301", "SLF001", "T20", "TID252", # Relative imports from parent modules are banned ] "noxfile.py" = ["ERA001", "T20"] diff --git a/src/galax/units.py b/src/galax/units.py index 849c6702..09809901 100644 --- a/src/galax/units.py +++ b/src/galax/units.py @@ -1,31 +1,4 @@ -"""Paired down UnitSystem class from gala. - -See gala's license below. - -``` -The MIT License (MIT) - -Copyright (c) 2012-2023 Adrian M. Price-Whelan - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -``` -""" +"""Tools for representing systems of units using ``astropy.units``.""" __all__ = [ "UnitSystem", @@ -36,15 +9,58 @@ ] from collections.abc import Iterator -from typing import Any, ClassVar, no_type_check +from typing import ClassVar, Union import astropy.units as u from astropy.units.physical import _physical_unit_mapping -@no_type_check # TODO: get beartype working with this class UnitSystem: - """Represents a system of units.""" + """Represents a system of units. + + At minimum, this consists of a set of length, time, mass, and angle units, but may + also contain preferred representations for composite units. For example, the base + unit system could be ``{kpc, Myr, Msun, radian}``, but you can also specify a + preferred velocity unit, such as ``km/s``. + + This class behaves like a dictionary with keys set by physical types (i.e. "length", + "velocity", "energy", etc.). If a unit for a particular physical type is not + specified on creation, a composite unit will be created with the base units. See the + examples below for some demonstrations. + + Parameters + ---------- + *units, **units + The units that define the unit system. At minimum, this must contain length, + time, mass, and angle units. If passing in keyword arguments, the keys must be + valid :mod:`astropy.units` physical types. + + Examples + -------- + If only base units are specified, any physical type specified as a key + to this object will be composed out of the base units:: + + >>> usys = UnitSystem(u.m, u.s, u.kg, u.radian) + >>> usys["velocity"] + Unit("m / s") + + However, preferred representations for composite units can also be specified:: + + >>> usys = UnitSystem(u.m, u.s, u.kg, u.radian, u.erg) + >>> usys["energy"] + Unit("m2 kg / s2") + >>> usys.preferred("energy") + Unit("erg") + + This is useful for Galactic dynamics where lengths and times are usually given in + terms of ``kpc`` and ``Myr``, but velocities are often specified in ``km/s``:: + + >>> usys = UnitSystem(u.kpc, u.Myr, u.Msun, u.radian, u.km/u.s) + >>> usys["velocity"] + Unit("kpc / Myr") + >>> usys.preferred("velocity") + Unit("km / s") + """ _core_units: list[u.UnitBase] _registry: dict[u.PhysicalType, u.UnitBase] @@ -56,8 +72,15 @@ class UnitSystem: u.get_physical_type("angle"), ] - # TODO: type hint `units` - def __init__(self, units: Any, *args: u.UnitBase) -> None: + def __init__( + self, + units: Union[ + u.UnitBase, + u.Quantity, + "UnitSystem", + ], + *args: u.UnitBase | u.Quantity, + ) -> None: if isinstance(units, UnitSystem): if len(args) > 0: msg = "If passing in a UnitSystem, cannot pass in additional units." @@ -87,7 +110,8 @@ def __init__(self, units: Any, *args: u.UnitBase) -> None: self._core_units.append(self._registry[phys_type]) def __getitem__(self, key: str | u.PhysicalType) -> u.UnitBase: - if key in self._registry: + key = u.get_physical_type(key) + if key in self._required_dimensions: return self._registry[key] unit = None @@ -105,6 +129,7 @@ def __getitem__(self, key: str | u.PhysicalType) -> u.UnitBase: return unit def __len__(self) -> int: + # Note: This is required for q.decompose(usys) to work, where q is a Quantity return len(self._core_units) def __iter__(self) -> Iterator[u.UnitBase]: @@ -125,6 +150,17 @@ def __hash__(self) -> int: """Hash the unit system.""" return hash(tuple(self._core_units) + tuple(self._required_dimensions)) + def preferred(self, key: str | u.PhysicalType) -> u.UnitBase: + """Return the preferred unit for a given physical type.""" + key = u.get_physical_type(key) + if key in self._registry: + return self._registry[key] + return self[key] + + def as_preferred(self, quantity: u.Quantity) -> u.Quantity: + """Convert a quantity to the preferred unit for this unit system.""" + return quantity.to(self.preferred(quantity.unit.physical_type)) + class DimensionlessUnitSystem(UnitSystem): """A unit system with only dimensionless units.""" diff --git a/tests/unit/test_units.py b/tests/unit/test_units.py new file mode 100644 index 00000000..3b940e87 --- /dev/null +++ b/tests/unit/test_units.py @@ -0,0 +1,91 @@ +# Standard library +import pickle + +# Third party +import astropy.units as u +import numpy as np +import pytest + +# This package +from galax.units import UnitSystem, dimensionless + + +def test_init(): + usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun) + + with pytest.raises( + ValueError, match="must specify a unit for the physical type .*mass" + ): + UnitSystem(u.kpc, u.Myr, u.radian) # no mass + + with pytest.raises( + ValueError, match="must specify a unit for the physical type .*angle" + ): + UnitSystem(u.kpc, u.Myr, u.Msun) + + with pytest.raises( + ValueError, match="must specify a unit for the physical type .*time" + ): + UnitSystem(u.kpc, u.radian, u.Msun) + + with pytest.raises( + ValueError, match="must specify a unit for the physical type .*length" + ): + UnitSystem(u.Myr, u.radian, u.Msun) + + usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun) + usys = UnitSystem(usys) + + +def test_quantity_init(): + usys = UnitSystem(5 * u.kpc, 50 * u.Myr, 1e5 * u.Msun, u.rad) + assert np.isclose((8 * u.Myr).decompose(usys).value, 8 / 50) + + +def test_preferred(): + usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.km / u.s) + q = 15.0 * u.km / u.s + assert usys.preferred("velocity") == u.km / u.s + assert q.decompose(usys).unit == u.kpc / u.Myr + assert usys.as_preferred(q).unit == u.km / u.s + + +def test_dimensionless(): + assert dimensionless["dimensionless"] == u.one + assert dimensionless["length"] == u.one + + with pytest.raises(ValueError, match="can not be decomposed into"): + (15 * u.kpc).decompose(dimensionless) + + with pytest.raises(ValueError, match="are not convertible"): + dimensionless.as_preferred(15 * u.kpc) + + +def test_compare(): + usys1 = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.mas / u.yr) + usys1_clone = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.mas / u.yr) + + usys2 = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun, u.kiloarcsecond / u.yr) + usys3 = UnitSystem(u.kpc, u.Myr, u.radian, u.kg, u.mas / u.yr) + + assert usys1 == usys1_clone + assert usys1_clone == usys1 + + assert usys1 != usys2 + assert usys2 != usys1 + + assert usys1 != usys3 + assert usys3 != usys1 + + +def test_pickle(tmpdir): + usys = UnitSystem(u.kpc, u.Myr, u.radian, u.Msun) + + path = tmpdir / "test.pkl" + with path.open(mode="wb") as f: + pickle.dump(usys, f) + + with path.open(mode="rb") as f: + usys2 = pickle.load(f) + + assert usys == usys2