From c5f7c64700d95378ab9abb5470512ae56a95aa8e Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Wed, 19 Jun 2024 13:25:54 -0400 Subject: [PATCH] feat: device_put_p kwargs (#126) * feat: device_put_p kwargs * test: add jax ignore * docs: fix np 2.0 docs change * docs: skip tests in Windows until fixed numpy 2.0+ Signed-off-by: nstarman --- conftest.py | 21 ++++++++++++++------- pyproject.toml | 2 ++ src/unxt/_quantity/functional.py | 2 +- src/unxt/_quantity/register_primitives.py | 4 ++-- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/conftest.py b/conftest.py index fc6f007..677ccdd 100644 --- a/conftest.py +++ b/conftest.py @@ -1,17 +1,24 @@ """Doctest configuration.""" -from __future__ import annotations - +import platform from doctest import ELLIPSIS, NORMALIZE_WHITESPACE from sybil import Sybil from sybil.parsers.rest import DocTestParser, PythonCodeBlockParser, SkipParser +# TODO: stop skipping doctests on Windows when there is uniform support for +# numpy 2.0+ scalar repr. On windows it is printed as 1.0 instead of +# `np.float64(1.0)`. +parsers = ( + [DocTestParser(optionflags=ELLIPSIS | NORMALIZE_WHITESPACE)] + if platform.system() != "Windows" + else [] +) + [ + PythonCodeBlockParser(), + SkipParser(), +] + pytest_collect_file = Sybil( - parsers=[ - DocTestParser(optionflags=ELLIPSIS | NORMALIZE_WHITESPACE), - PythonCodeBlockParser(), - SkipParser(), - ], + parsers=parsers, patterns=["*.rst", "*.py"], ).pytest() diff --git a/pyproject.toml b/pyproject.toml index da07d24..cbf6e98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,8 @@ filterwarnings = [ "error", "ignore:ast\\.Str is deprecated and will be removed in Python 3\\.14:DeprecationWarning", + # jax + "ignore:jax\\.core\\.pp_eqn_rules is deprecated:DeprecationWarning", ] log_cli_level = "INFO" minversion = "6.0" diff --git a/src/unxt/_quantity/functional.py b/src/unxt/_quantity/functional.py index 1e2552b..79dbe34 100644 --- a/src/unxt/_quantity/functional.py +++ b/src/unxt/_quantity/functional.py @@ -156,7 +156,7 @@ def to_units_value(value: AstropyQuantity, units: Unit | str, /) -> ArrayLike: >>> q = u.Quantity(1, "m") >>> to_units_value(q, "cm") - 100.0 + np.float64(100.0) """ return value.to_value(units) diff --git a/src/unxt/_quantity/register_primitives.py b/src/unxt/_quantity/register_primitives.py index 6df96d8..cd06fc6 100644 --- a/src/unxt/_quantity/register_primitives.py +++ b/src/unxt/_quantity/register_primitives.py @@ -1400,7 +1400,7 @@ def _cumsum_p( @register(lax.device_put_p) -def _device_put_p(x: AbstractQuantity, *, device: Any, src: Any) -> AbstractQuantity: +def _device_put_p(x: AbstractQuantity, **kwargs: Any) -> AbstractQuantity: """Put a quantity on a device. Examples @@ -1418,7 +1418,7 @@ def _device_put_p(x: AbstractQuantity, *, device: Any, src: Any) -> AbstractQuan Quantity['length'](Array(1, dtype=int32, ...), unit='m') """ - return replace(x, value=jax.device_put(x.value, device=device, src=src)) + return replace(x, value=jax.device_put(x.value, **kwargs)) # ==============================================================================