Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 7, 2024
1 parent c1aeaec commit 1f05303
Show file tree
Hide file tree
Showing 7 changed files with 1,868 additions and 65 deletions.
11 changes: 11 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ ci:
autofix_commit_msg: "style: pre-commit fixes"
autoupdate_schedule: "monthly"

default_install_hook_types: [pre-commit, pre-push, commit-msg]

repos:
- repo: https://github.com/adamchainz/blacken-docs
rev: "1.16.0"
Expand Down Expand Up @@ -85,3 +87,12 @@ repos:
- id: check-dependabot
- id: check-github-workflows
- id: check-readthedocs

- repo: local
hooks:
- id: pylint
name: pylint
entry: pylint
language: system
types: [python]
stages: [pre-push, manual]
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [

[project.optional-dependencies]
test = [
"astropy",
"pytest >=6",
"pytest-cov >=3",
]
Expand Down Expand Up @@ -111,8 +112,9 @@ warn_return_any = false

[[tool.mypy.overrides]]
module = [
"array_api_jax_compat._dispatch.*", # TODO: resolve
"astropy.units.*",
"array_api.*",
"array_api_jax_compat.*",
"astropy.*",
"equinox.*",
"jax.*",
"jaxtyping.*",
Expand All @@ -139,6 +141,7 @@ ignore = [
"D213", # Multi-line docstring summary should start at the second line
"FIX002", # Line contains TODO
"ISC001", # Conflicts with formatter
"PD", # Pandas
"PLR09", # Too many <...>
"PLR2004", # Magic value used in comparison
"PYI041", # Use `complex` instead of `int | complex` <- plum is more strict
Expand Down Expand Up @@ -167,5 +170,6 @@ messages_control.disable = [
"missing-function-docstring", # TODO: resolve
"missing-module-docstring",
"redefined-builtin", # handled by ruff
"unused-argument", # handled by ruff
"wrong-import-position",
]
16 changes: 15 additions & 1 deletion src/jax_quantity/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import operator
from dataclasses import replace
from typing import Any
from typing import TYPE_CHECKING, Any

import array_api_jax_compat
import equinox as eqx
import jax
import jax.core
Expand All @@ -15,6 +16,9 @@
from quax import ArrayValue, quaxify
from typing_extensions import Self

if TYPE_CHECKING:
from array_api import ArrayAPINamespace


class Quantity(ArrayValue): # type: ignore[misc]
"""Represents an array, with each axis bound to a name."""
Expand All @@ -40,6 +44,12 @@ def aval(self) -> jax.core.ShapedArray:
def enable_materialise(self, _: bool = True) -> Self: # noqa: FBT001, FBT002
return type(self)(self.value, self.unit)

# ===============================================================
# Array API

def __array_namespace__(self, *, api_version: Any = None) -> "ArrayAPINamespace":
return array_api_jax_compat

# ===============================================================
# Quantity

Expand All @@ -62,6 +72,10 @@ def __getitem__(self, key: Any) -> "Quantity":
# __rmul__
# __matmul__
# __rmatmul__
__pow__ = quaxify(operator.pow)
__truediv__ = quaxify(operator.truediv)

# Boolean
__and__ = quaxify(operator.__and__)
__gt__ = quaxify(operator.__gt__)
__ge__ = quaxify(operator.__ge__)
Expand Down
58 changes: 53 additions & 5 deletions src/jax_quantity/_register_dispatches.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import jax
import jax.core
import jax.numpy as jnp
import jax.experimental.array_api as jax_xp
from array_api_jax_compat._dispatch import dispatcher as dispatcher_
from array_api_jax_compat._types import DType
from jax import Device

from ._core import Quantity

Expand All @@ -15,9 +17,31 @@ def dispatcher(f: T) -> T: # TODO: figure out mypy stub issue.
return dispatcher_(f)


@dispatcher
def arange(
start: Quantity,
stop: Quantity | None = None,
step: Quantity | None = None,
*,
dtype: Any = None,
device: Any = None,
) -> Quantity:
unit = start.unit
return Quantity(
jax_xp.arange(
start.value,
stop=stop.to_value(unit) if stop is not None else None,
step=step.to_value(unit) if step is not None else None,
dtype=dtype,
device=device,
),
unit=unit,
)


@dispatcher
def empty_like(x: Quantity, /, *, dtype: Any = None, device: Any = None) -> Quantity:
out = Quantity(jnp.empty_like(x.value, dtype=dtype), unit=x.unit)
out = Quantity(jax_xp.empty_like(x.value, dtype=dtype), unit=x.unit)
return jax.device_put(out, device=device)


Expand All @@ -30,17 +54,41 @@ def full_like(
dtype: Any = None,
device: Any = None,
) -> Quantity:
out = Quantity(jnp.full_like(x.value, fill_value, dtype=dtype), unit=x.unit)
out = Quantity(jax_xp.full_like(x.value, fill_value, dtype=dtype), unit=x.unit)
return jax.device_put(out, device=device)


@dispatcher
def linspace(
start: Quantity,
stop: Quantity,
num: int,
*,
dtype: DType | None = None,
device: Device | None = None,
endpoint: bool = True,
) -> Quantity:
unit = start.unit
return Quantity(
jax_xp.linspace(
start.to_value(unit),
stop.to_value(unit),
num=num,
dtype=dtype,
device=device,
endpoint=endpoint,
),
unit=unit,
)


@dispatcher
def ones_like(x: Quantity, /, *, dtype: Any = None, device: Any = None) -> Quantity:
out = Quantity(jnp.ones_like(x.value, dtype=dtype), unit=x.unit)
out = Quantity(jax_xp.ones_like(x.value, dtype=dtype), unit=x.unit)
return jax.device_put(out, device=device)


@dispatcher
def zeros_like(x: Quantity, /, *, dtype: Any = None, device: Any = None) -> Quantity:
out = Quantity(jnp.zeros_like(x.value, dtype=dtype), unit=x.unit)
out = Quantity(jax_xp.zeros_like(x.value, dtype=dtype), unit=x.unit)
return jax.device_put(out, device=device)
Loading

0 comments on commit 1f05303

Please sign in to comment.