Skip to content

Commit

Permalink
use exact astropy method names (#1)
Browse files Browse the repository at this point in the history
* use exact astropy method names
* fix ruff
* rm quax in pre-commit mypy
* add jaxtyping to mypy
* don’t include pypy

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Dec 22, 2023
1 parent f422236 commit 9a91703
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 103 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ jobs:
python-version: ["3.10", "3.12"]
runs-on: [ubuntu-latest, macos-latest, windows-latest]

include:
- python-version: pypy-3.10
runs-on: ubuntu-latest
# include:
# - python-version: pypy-3.10
# runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
Expand Down
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ repos:
args: []
additional_dependencies:
- pytest
- quax

- repo: https://github.com/codespell-project/codespell
rev: "v2.2.6"
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ warn_return_any = false
"astropy.units.*",
"equinox.*",
"jax.*",
"jaxtyping.*",
"quax.*",
]
ignore_missing_imports = true
Expand All @@ -131,6 +132,7 @@ ignore = [
"ANN101", # Missing type annotation for `self` in method
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed
"ARG001", # Unused function argument # TODO: resolve
"COM812", # <- for ruff.format
"D103", # Missing docstring in public function # TODO: resolve
"D105", # Missing docstring in magic method
"D203", # 1 blank line required before class docstring
Expand Down
16 changes: 8 additions & 8 deletions src/jax_quantity/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Quantity(ArrayValue): # type: ignore[misc]
"""Represents an array, with each axis bound to a name."""

value: jax.Array = eqx.field(converter=jax.numpy.asarray)
units: Unit = eqx.field(static=True, converter=Unit)
unit: Unit = eqx.field(static=True, converter=Unit)

# ===============================================================
# Quax
Expand All @@ -35,21 +35,21 @@ def materialise(self) -> None:
raise RuntimeError(msg)

def aval(self) -> jax.core.ShapedArray:
return jax.core.get_aval(self.value) # type: ignore[no-untyped-call]
return jax.core.get_aval(self.value)

def enable_materialise(self, _: bool = True) -> Self: # noqa: FBT001, FBT002
return type(self)(self.value, self.units)
return type(self)(self.value, self.unit)

# ===============================================================
# Quantity

def to_units(self, units: Unit) -> "Quantity":
return type(self)(self.value * self.units.to(units), units)
def to(self, units: Unit) -> "Quantity":
return type(self)(self.value * self.unit.to(units), units)

def to_units_value(self, units: Unit) -> ArrayLike:
if units == self.units:
def to_value(self, units: Unit) -> ArrayLike:
if units == self.unit:
return self.value
return self.value * self.units.to(units)
return self.value * self.unit.to(units)

def __getitem__(self, key: Any) -> "Quantity":
return replace(self, value=self.value[key])
Expand Down
8 changes: 4 additions & 4 deletions src/jax_quantity/_register_dispatches.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def dispatcher(f: T) -> T: # TODO: figure out mypy stub issue.

@dispatcher
def empty_like(x: Quantity, /, *, dtype: Any = None, device: Any = None) -> Quantity:
out = Quantity(jnp.empty_like(x.value, dtype=dtype), units=x.units)
out = Quantity(jnp.empty_like(x.value, dtype=dtype), units=x.unit)
return jax.device_put(out, device=device)


Expand All @@ -30,17 +30,17 @@ def full_like(
dtype: Any = None,
device: Any = None,
) -> Quantity:
out = Quantity(jnp.full_like(x.value, fill_value, dtype=dtype), units=x.units)
out = Quantity(jnp.full_like(x.value, fill_value, dtype=dtype), units=x.unit)
return jax.device_put(out, device=device)


@dispatcher
def ones_like(x: Quantity, /, *, dtype: Any = None, device: Any = None) -> Quantity:
out = Quantity(jnp.ones_like(x.value, dtype=dtype), units=x.units)
out = Quantity(jnp.ones_like(x.value, dtype=dtype), units=x.unit)
return jax.device_put(out, device=device)


@dispatcher
def zeros_like(x: Quantity, /, *, dtype: Any = None, device: Any = None) -> Quantity:
out = Quantity(jnp.zeros_like(x.value, dtype=dtype), units=x.units)
out = Quantity(jnp.zeros_like(x.value, dtype=dtype), units=x.unit)
return jax.device_put(out, device=device)
Loading

0 comments on commit 9a91703

Please sign in to comment.