From 66de806e1d49f79bb8e52295826a877320a36a4d Mon Sep 17 00:00:00 2001 From: nstarman Date: Tue, 2 Jan 2024 11:14:33 -0800 Subject: [PATCH] standin array_namespace Until array-api-jax-compat can handle this would detailed special-casing. Signed-off-by: nstarman --- src/jax_quantity/array_namespace.py | 210 ++++++++++++++++++++++++++++ 1 file changed, 210 insertions(+) create mode 100644 src/jax_quantity/array_namespace.py diff --git a/src/jax_quantity/array_namespace.py b/src/jax_quantity/array_namespace.py new file mode 100644 index 0000000..2f378c6 --- /dev/null +++ b/src/jax_quantity/array_namespace.py @@ -0,0 +1,210 @@ +"""Copyright (c) 2023 Nathaniel Starkman. All rights reserved. + +jax-quantity: Quantities in JAX +""" + +from collections.abc import Callable, Sequence +from typing import Any, TypeAlias, TypeGuard, overload + +import array_api_jax_compat +import jax +from astropy.units import UnitBase +from jax._src.util import wraps +from jax.tree_util import tree_map + +from ._core import Quantity + + +def __getattr__(name: str) -> Any: # TODO: fuller annotation + """Forward all other attribute accesses to Quaxified JAX.""" + return getattr(array_api_jax_compat, name) + + +# ============================================================================= +# Grad +# Credit to dfm/jpu + +AuxT: TypeAlias = Any + + +def is_quantity(obj: Any) -> TypeGuard[Quantity]: + return hasattr(obj, "unit") and hasattr(obj, "value") + + +# ----------------- + + +@overload +def grad( + fun: Callable[..., Quantity], + argnums: int | Sequence[int], + *, + has_aux: bool = False, + holomorphic: bool, + allow_int: bool, + reduce_axes: Sequence[int], +) -> Callable[..., Quantity]: + ... + + +@overload +def grad( + fun: Callable[..., tuple[Quantity, AuxT]], + argnums: int | Sequence[int], + *, + has_aux: bool = True, + holomorphic: bool, + allow_int: bool, + reduce_axes: Sequence[int], +) -> Callable[..., tuple[Quantity, AuxT]]: + ... + + +def grad( + fun: Callable[..., Quantity] | Callable[..., tuple[Quantity, AuxT]], + argnums: int | Sequence[int] = 0, + *, + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: Sequence[int] = (), +): + value_and_grad_f = value_and_grad( + fun, + argnums, + has_aux=has_aux, + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes, + ) + + docstr = ( + "Gradient of {fun} with respect to positional argument(s) " + "{argnums}. Takes the same arguments as {fun} but returns the " + "gradient, which has the same shape as the arguments at " + "positions {argnums}." + ) + + @wraps(fun, docstr=docstr, argnums=argnums) + def grad_f(*args: Any, **kwargs: Any) -> Quantity: + _, g = value_and_grad_f(*args, **kwargs) + return g + + @wraps(fun, docstr=docstr, argnums=argnums) + def grad_f_aux(*args: Any, **kwargs: Any) -> tuple[Quantity, AuxT]: + (_, aux), g = value_and_grad_f(*args, **kwargs) + return g, aux + + return grad_f_aux if has_aux else grad_f + + +# ----------------- + + +@overload +def value_and_grad( + fun: Callable[..., Quantity], + argnums: int | Sequence[int], + *, + has_aux: bool = False, + holomorphic: bool, + allow_int: bool, + reduce_axes: Sequence[int], +) -> Callable[..., tuple[Quantity, Quantity]]: + ... + + +@overload +def value_and_grad( + fun: Callable[..., tuple[Quantity, AuxT]], + argnums: int | Sequence[int], + *, + has_aux: bool = True, + holomorphic: bool, + allow_int: bool, + reduce_axes: Sequence[int], +) -> Callable[..., tuple[tuple[Quantity, AuxT], Quantity]]: + ... + + +def value_and_grad( + fun: Callable[..., Quantity] | Callable[..., tuple[Quantity, AuxT]], + argnums: int | Sequence[int] = 0, + *, + has_aux: bool = False, + holomorphic: bool = False, + allow_int: bool = False, + reduce_axes: Sequence[int] = (), +) -> Callable[..., tuple[Quantity, Quantity] | tuple[tuple[Quantity, AuxT], Quantity]]: + # inspired by: https://twitter.com/shoyer/status/1531703890512490499 + docstr = ( + "Value and gradient of {fun} with respect to positional " + "argument(s) {argnums}. Takes the same arguments as {fun} but " + "returns a two-element tuple where the first element is the value " + "of {fun} and the second element is the gradient, which has the " + "same shape as the arguments at positions {argnums}." + ) + + def fun_wo_units( + *args: Any, **kwargs: Any + ) -> tuple[jax.Array, tuple[UnitBase, AuxT]]: + if has_aux: + result, aux = fun(*args, **kwargs) + else: + result = fun(*args, **kwargs) + aux = None + + if is_quantity(result): + value = result.value + unit = result.unit + else: + value = result + unit = None + + return value, (unit, aux) + + value_and_grad_fun = jax.value_and_grad( + fun_wo_units, + argnums=argnums, + has_aux=True, + holomorphic=holomorphic, + allow_int=allow_int, + reduce_axes=reduce_axes, + ) + + @wraps(fun, docstr=docstr, argnums=argnums) + def wrapped( + *args: Any, **kwargs: Any + ) -> tuple[Quantity, Quantity] | tuple[tuple[Quantity, AuxT], Quantity]: + (result_wo_units, (result_units, aux)), grad = value_and_grad_fun( + *args, **kwargs + ) + + if result_units is None: + result = result_wo_units + grad = tree_map( + lambda g: (g.value * (1 / g.unit) if is_quantity(g) else g), + grad, + is_leaf=is_quantity, + ) + + else: + result = result_wo_units * result_units + grad = tree_map( + lambda g: ( + g.value * (result_units / g.unit) + if is_quantity(g) + else g * result_units + ), + grad, + is_leaf=is_quantity, + ) + + if has_aux: + return (result, aux), grad + return result, grad + + return wrapped + +# ============================================================================= +