Skip to content

Commit

Permalink
add array_namespace
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 7, 2024
1 parent c1aeaec commit 915307e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ warn_return_any = false

[[tool.mypy.overrides]]
module = [
"array_api_jax_compat._dispatch.*", # TODO: resolve
"astropy.units.*",
"array_api.*",
"array_api_jax_compat.*",
"astropy.*",
"equinox.*",
"jax.*",
"jaxtyping.*",
Expand All @@ -139,6 +140,7 @@ ignore = [
"D213", # Multi-line docstring summary should start at the second line
"FIX002", # Line contains TODO
"ISC001", # Conflicts with formatter
"PD", # Pandas
"PLR09", # Too many <...>
"PLR2004", # Magic value used in comparison
"PYI041", # Use `complex` instead of `int | complex` <- plum is more strict
Expand Down
16 changes: 15 additions & 1 deletion src/jax_quantity/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import operator
from dataclasses import replace
from typing import Any
from typing import TYPE_CHECKING, Any

import array_api_jax_compat
import equinox as eqx
import jax
import jax.core
Expand All @@ -15,6 +16,9 @@
from quax import ArrayValue, quaxify
from typing_extensions import Self

if TYPE_CHECKING:
from array_api import ArrayAPINamespace


class Quantity(ArrayValue): # type: ignore[misc]
"""Represents an array, with each axis bound to a name."""
Expand All @@ -40,6 +44,12 @@ def aval(self) -> jax.core.ShapedArray:
def enable_materialise(self, _: bool = True) -> Self: # noqa: FBT001, FBT002
return type(self)(self.value, self.unit)

# ===============================================================
# Array API

def __array_namespace__(self, *, api_version: Any = None) -> "ArrayAPINamespace":
return array_api_jax_compat

# ===============================================================
# Quantity

Expand All @@ -62,6 +72,10 @@ def __getitem__(self, key: Any) -> "Quantity":
# __rmul__
# __matmul__
# __rmatmul__
__pow__ = quaxify(operator.pow)
__truediv__ = quaxify(operator.truediv)

# Boolean
__and__ = quaxify(operator.__and__)
__gt__ = quaxify(operator.__gt__)
__ge__ = quaxify(operator.__ge__)
Expand Down

0 comments on commit 915307e

Please sign in to comment.