diff --git a/src/jax_quantity/array_namespace.py b/src/jax_quantity/array_namespace.py index 040a867..57fc8bb 100644 --- a/src/jax_quantity/array_namespace.py +++ b/src/jax_quantity/array_namespace.py @@ -22,9 +22,9 @@ def __getattr__(name: str) -> Any: # TODO: fuller annotation # ============================================================================= # Grad -# Credit to dfm/jpu +# Lightly modified from dfm/jpu -AuxT: TypeAlias = Any +Aux: TypeAlias = Any def is_quantity(obj: Any) -> TypeGuard[Quantity]: @@ -49,26 +49,26 @@ def grad( @overload def grad( - fun: Callable[..., tuple[Quantity, AuxT]], + fun: Callable[..., tuple[Quantity, Aux]], argnums: int | Sequence[int], *, has_aux: bool = True, holomorphic: bool, allow_int: bool, reduce_axes: Sequence[int], -) -> Callable[..., tuple[Quantity, AuxT]]: +) -> Callable[..., tuple[Quantity, Aux]]: ... def grad( - fun: Callable[..., Quantity] | Callable[..., tuple[Quantity, AuxT]], + fun: Callable[..., Quantity] | Callable[..., tuple[Quantity, Aux]], argnums: int | Sequence[int] = 0, *, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[int] = (), -) -> Callable[..., Quantity] | Callable[..., tuple[Quantity, AuxT]]: +) -> Callable[..., Quantity] | Callable[..., tuple[Quantity, Aux]]: value_and_grad_f = value_and_grad( fun, argnums, @@ -91,7 +91,7 @@ def grad_f(*args: Any, **kwargs: Any) -> Quantity: return g @wraps(fun, docstr=docstr, argnums=argnums) # type: ignore[misc] # untyped decorator - def grad_f_aux(*args: Any, **kwargs: Any) -> tuple[Quantity, AuxT]: + def grad_f_aux(*args: Any, **kwargs: Any) -> tuple[Quantity, Aux]: (_, aux), g = value_and_grad_f(*args, **kwargs) return g, aux @@ -116,26 +116,26 @@ def value_and_grad( @overload def value_and_grad( - fun: Callable[..., tuple[Quantity, AuxT]], + fun: Callable[..., tuple[Quantity, Aux]], argnums: int | Sequence[int], *, has_aux: bool = True, holomorphic: bool, allow_int: bool, reduce_axes: Sequence[int], -) -> Callable[..., tuple[tuple[Quantity, AuxT], Quantity]]: +) -> Callable[..., tuple[tuple[Quantity, Aux], Quantity]]: ... def value_and_grad( - fun: Callable[..., Quantity] | Callable[..., tuple[Quantity, AuxT]], + fun: Callable[..., Quantity] | Callable[..., tuple[Quantity, Aux]], argnums: int | Sequence[int] = 0, *, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[int] = (), -) -> Callable[..., tuple[Quantity, Quantity] | tuple[tuple[Quantity, AuxT], Quantity]]: +) -> Callable[..., tuple[Quantity, Quantity] | tuple[tuple[Quantity, Aux], Quantity]]: # inspired by: https://twitter.com/shoyer/status/1531703890512490499 docstr = ( "Value and gradient of {fun} with respect to positional " @@ -147,7 +147,7 @@ def value_and_grad( def fun_wo_units( *args: Any, **kwargs: Any - ) -> tuple[jax.Array, tuple[UnitBase, AuxT]]: + ) -> tuple[jax.Array, tuple[UnitBase, Aux]]: if has_aux: result, aux = fun(*args, **kwargs) else: @@ -175,7 +175,7 @@ def fun_wo_units( @wraps(fun, docstr=docstr, argnums=argnums) # type: ignore[misc] # untyped decorator def wrapped( *args: Any, **kwargs: Any - ) -> tuple[Quantity, Quantity] | tuple[tuple[Quantity, AuxT], Quantity]: + ) -> tuple[Quantity, Quantity] | tuple[tuple[Quantity, Aux], Quantity]: (result_wo_units, (result_units, aux)), grad = value_and_grad_fun( *args, **kwargs )