From 915307ed5dc51052dc4e8c33bfdfbfe6c6658335 Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 7 Jan 2024 18:02:51 -0500 Subject: [PATCH] add array_namespace Signed-off-by: nstarman --- pyproject.toml | 6 ++++-- src/jax_quantity/_core.py | 16 +++++++++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c121f5a..ed0ef39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,8 +111,9 @@ warn_return_any = false [[tool.mypy.overrides]] module = [ - "array_api_jax_compat._dispatch.*", # TODO: resolve - "astropy.units.*", + "array_api.*", + "array_api_jax_compat.*", + "astropy.*", "equinox.*", "jax.*", "jaxtyping.*", @@ -139,6 +140,7 @@ ignore = [ "D213", # Multi-line docstring summary should start at the second line "FIX002", # Line contains TODO "ISC001", # Conflicts with formatter + "PD", # Pandas "PLR09", # Too many <...> "PLR2004", # Magic value used in comparison "PYI041", # Use `complex` instead of `int | complex` <- plum is more strict diff --git a/src/jax_quantity/_core.py b/src/jax_quantity/_core.py index 9ee3441..65aa816 100644 --- a/src/jax_quantity/_core.py +++ b/src/jax_quantity/_core.py @@ -5,8 +5,9 @@ import operator from dataclasses import replace -from typing import Any +from typing import TYPE_CHECKING, Any +import array_api_jax_compat import equinox as eqx import jax import jax.core @@ -15,6 +16,9 @@ from quax import ArrayValue, quaxify from typing_extensions import Self +if TYPE_CHECKING: + from array_api import ArrayAPINamespace + class Quantity(ArrayValue): # type: ignore[misc] """Represents an array, with each axis bound to a name.""" @@ -40,6 +44,12 @@ def aval(self) -> jax.core.ShapedArray: def enable_materialise(self, _: bool = True) -> Self: # noqa: FBT001, FBT002 return type(self)(self.value, self.unit) + # =============================================================== + # Array API + + def __array_namespace__(self, *, api_version: Any = None) -> "ArrayAPINamespace": + return array_api_jax_compat + # =============================================================== # Quantity @@ -62,6 +72,10 @@ def __getitem__(self, key: Any) -> "Quantity": # __rmul__ # __matmul__ # __rmatmul__ + __pow__ = quaxify(operator.pow) + __truediv__ = quaxify(operator.truediv) + + # Boolean __and__ = quaxify(operator.__and__) __gt__ = quaxify(operator.__gt__) __ge__ = quaxify(operator.__ge__)