Skip to content

Commit

Permalink
feat: device_put_p kwargs (#126)
Browse files Browse the repository at this point in the history
* feat: device_put_p kwargs
* test: add jax ignore
* docs: fix np 2.0 docs change
* docs: skip tests in Windows until fixed numpy 2.0+

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jun 19, 2024
1 parent ddd147d commit c5f7c64
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
21 changes: 14 additions & 7 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
"""Doctest configuration."""

from __future__ import annotations

import platform
from doctest import ELLIPSIS, NORMALIZE_WHITESPACE

from sybil import Sybil
from sybil.parsers.rest import DocTestParser, PythonCodeBlockParser, SkipParser

# TODO: stop skipping doctests on Windows when there is uniform support for
# numpy 2.0+ scalar repr. On windows it is printed as 1.0 instead of
# `np.float64(1.0)`.
parsers = (
[DocTestParser(optionflags=ELLIPSIS | NORMALIZE_WHITESPACE)]
if platform.system() != "Windows"
else []
) + [
PythonCodeBlockParser(),
SkipParser(),
]

pytest_collect_file = Sybil(
parsers=[
DocTestParser(optionflags=ELLIPSIS | NORMALIZE_WHITESPACE),
PythonCodeBlockParser(),
SkipParser(),
],
parsers=parsers,
patterns=["*.rst", "*.py"],
).pytest()
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@
filterwarnings = [
"error",
"ignore:ast\\.Str is deprecated and will be removed in Python 3\\.14:DeprecationWarning",
# jax
"ignore:jax\\.core\\.pp_eqn_rules is deprecated:DeprecationWarning",
]
log_cli_level = "INFO"
minversion = "6.0"
Expand Down
2 changes: 1 addition & 1 deletion src/unxt/_quantity/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def to_units_value(value: AstropyQuantity, units: Unit | str, /) -> ArrayLike:
>>> q = u.Quantity(1, "m")
>>> to_units_value(q, "cm")
100.0
np.float64(100.0)
"""
return value.to_value(units)
4 changes: 2 additions & 2 deletions src/unxt/_quantity/register_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,7 @@ def _cumsum_p(


@register(lax.device_put_p)
def _device_put_p(x: AbstractQuantity, *, device: Any, src: Any) -> AbstractQuantity:
def _device_put_p(x: AbstractQuantity, **kwargs: Any) -> AbstractQuantity:
"""Put a quantity on a device.
Examples
Expand All @@ -1418,7 +1418,7 @@ def _device_put_p(x: AbstractQuantity, *, device: Any, src: Any) -> AbstractQuan
Quantity['length'](Array(1, dtype=int32, ...), unit='m')
"""
return replace(x, value=jax.device_put(x.value, device=device, src=src))
return replace(x, value=jax.device_put(x.value, **kwargs))


# ==============================================================================
Expand Down

0 comments on commit c5f7c64

Please sign in to comment.