diff --git a/conftest.py b/conftest.py index fc6f007..677ccdd 100644 --- a/conftest.py +++ b/conftest.py @@ -1,17 +1,24 @@ """Doctest configuration.""" -from __future__ import annotations - +import platform from doctest import ELLIPSIS, NORMALIZE_WHITESPACE from sybil import Sybil from sybil.parsers.rest import DocTestParser, PythonCodeBlockParser, SkipParser +# TODO: stop skipping doctests on Windows when there is uniform support for +# numpy 2.0+ scalar repr. On windows it is printed as 1.0 instead of +# `np.float64(1.0)`. +parsers = ( + [DocTestParser(optionflags=ELLIPSIS | NORMALIZE_WHITESPACE)] + if platform.system() != "Windows" + else [] +) + [ + PythonCodeBlockParser(), + SkipParser(), +] + pytest_collect_file = Sybil( - parsers=[ - DocTestParser(optionflags=ELLIPSIS | NORMALIZE_WHITESPACE), - PythonCodeBlockParser(), - SkipParser(), - ], + parsers=parsers, patterns=["*.rst", "*.py"], ).pytest() diff --git a/pyproject.toml b/pyproject.toml index da07d24..cbf6e98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,8 @@ filterwarnings = [ "error", "ignore:ast\\.Str is deprecated and will be removed in Python 3\\.14:DeprecationWarning", + # jax + "ignore:jax\\.core\\.pp_eqn_rules is deprecated:DeprecationWarning", ] log_cli_level = "INFO" minversion = "6.0" diff --git a/src/unxt/_quantity/functional.py b/src/unxt/_quantity/functional.py index 1e2552b..79dbe34 100644 --- a/src/unxt/_quantity/functional.py +++ b/src/unxt/_quantity/functional.py @@ -156,7 +156,7 @@ def to_units_value(value: AstropyQuantity, units: Unit | str, /) -> ArrayLike: >>> q = u.Quantity(1, "m") >>> to_units_value(q, "cm") - 100.0 + np.float64(100.0) """ return value.to_value(units) diff --git a/src/unxt/_quantity/register_primitives.py b/src/unxt/_quantity/register_primitives.py index 6df96d8..cd06fc6 100644 --- a/src/unxt/_quantity/register_primitives.py +++ b/src/unxt/_quantity/register_primitives.py @@ -1400,7 +1400,7 @@ def _cumsum_p( @register(lax.device_put_p) -def _device_put_p(x: AbstractQuantity, *, device: Any, src: Any) -> AbstractQuantity: +def _device_put_p(x: AbstractQuantity, **kwargs: Any) -> AbstractQuantity: """Put a quantity on a device. Examples @@ -1418,7 +1418,7 @@ def _device_put_p(x: AbstractQuantity, *, device: Any, src: Any) -> AbstractQuan Quantity['length'](Array(1, dtype=int32, ...), unit='m') """ - return replace(x, value=jax.device_put(x.value, device=device, src=src)) + return replace(x, value=jax.device_put(x.value, **kwargs)) # ==============================================================================