diff --git a/pyproject.toml b/pyproject.toml index e9e4cc1..c121f5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,6 +115,7 @@ warn_return_any = false "astropy.units.*", "equinox.*", "jax.*", + "jaxtyping.*", "quax.*", ] ignore_missing_imports = true diff --git a/src/jax_quantity/_core.py b/src/jax_quantity/_core.py index 05d4bf1..9ee3441 100644 --- a/src/jax_quantity/_core.py +++ b/src/jax_quantity/_core.py @@ -35,7 +35,7 @@ def materialise(self) -> None: raise RuntimeError(msg) def aval(self) -> jax.core.ShapedArray: - return jax.core.get_aval(self.value) # type: ignore[no-untyped-call] + return jax.core.get_aval(self.value) def enable_materialise(self, _: bool = True) -> Self: # noqa: FBT001, FBT002 return type(self)(self.value, self.unit) diff --git a/src/jax_quantity/_register_primitives.py b/src/jax_quantity/_register_primitives.py index f1028f6..18453bc 100644 --- a/src/jax_quantity/_register_primitives.py +++ b/src/jax_quantity/_register_primitives.py @@ -388,7 +388,7 @@ def _convert_element_type_p( @register(lax.copy_p) def _copy_p(x: Quantity) -> Quantity: - return replace(x, value=lax.copy_p.bind(x.value)) # type: ignore[no-untyped-call] + return replace(x, value=lax.copy_p.bind(x.value)) # ==============================================================================