Skip to content

Commit

Permalink
pylint cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 5, 2024
1 parent 7ca75f0 commit 55b51d4
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions src/jax_quantity/array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def __getattr__(name: str) -> Any: # TODO: fuller annotation

# =============================================================================
# Grad
# Credit to dfm/jpu
# Lightly modified from dfm/jpu

AuxT: TypeAlias = Any
Aux: TypeAlias = Any


def is_quantity(obj: Any) -> TypeGuard[Quantity]:
Expand All @@ -49,26 +49,26 @@ def grad(

@overload
def grad(
fun: Callable[..., tuple[Quantity, AuxT]],
fun: Callable[..., tuple[Quantity, Aux]],
argnums: int | Sequence[int],
*,
has_aux: bool = True,
holomorphic: bool,
allow_int: bool,
reduce_axes: Sequence[int],
) -> Callable[..., tuple[Quantity, AuxT]]:
) -> Callable[..., tuple[Quantity, Aux]]:
...


def grad(
fun: Callable[..., Quantity] | Callable[..., tuple[Quantity, AuxT]],
fun: Callable[..., Quantity] | Callable[..., tuple[Quantity, Aux]],
argnums: int | Sequence[int] = 0,
*,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[int] = (),
) -> Callable[..., Quantity] | Callable[..., tuple[Quantity, AuxT]]:
) -> Callable[..., Quantity] | Callable[..., tuple[Quantity, Aux]]:
value_and_grad_f = value_and_grad(
fun,
argnums,
Expand All @@ -91,7 +91,7 @@ def grad_f(*args: Any, **kwargs: Any) -> Quantity:
return g

@wraps(fun, docstr=docstr, argnums=argnums) # type: ignore[misc] # untyped decorator
def grad_f_aux(*args: Any, **kwargs: Any) -> tuple[Quantity, AuxT]:
def grad_f_aux(*args: Any, **kwargs: Any) -> tuple[Quantity, Aux]:
(_, aux), g = value_and_grad_f(*args, **kwargs)
return g, aux

Expand All @@ -116,26 +116,26 @@ def value_and_grad(

@overload
def value_and_grad(
fun: Callable[..., tuple[Quantity, AuxT]],
fun: Callable[..., tuple[Quantity, Aux]],
argnums: int | Sequence[int],
*,
has_aux: bool = True,
holomorphic: bool,
allow_int: bool,
reduce_axes: Sequence[int],
) -> Callable[..., tuple[tuple[Quantity, AuxT], Quantity]]:
) -> Callable[..., tuple[tuple[Quantity, Aux], Quantity]]:
...


def value_and_grad(
fun: Callable[..., Quantity] | Callable[..., tuple[Quantity, AuxT]],
fun: Callable[..., Quantity] | Callable[..., tuple[Quantity, Aux]],
argnums: int | Sequence[int] = 0,
*,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[int] = (),
) -> Callable[..., tuple[Quantity, Quantity] | tuple[tuple[Quantity, AuxT], Quantity]]:
) -> Callable[..., tuple[Quantity, Quantity] | tuple[tuple[Quantity, Aux], Quantity]]:
# inspired by: https://twitter.com/shoyer/status/1531703890512490499
docstr = (
"Value and gradient of {fun} with respect to positional "
Expand All @@ -147,7 +147,7 @@ def value_and_grad(

def fun_wo_units(
*args: Any, **kwargs: Any
) -> tuple[jax.Array, tuple[UnitBase, AuxT]]:
) -> tuple[jax.Array, tuple[UnitBase, Aux]]:
if has_aux:
result, aux = fun(*args, **kwargs)
else:
Expand Down Expand Up @@ -175,7 +175,7 @@ def fun_wo_units(
@wraps(fun, docstr=docstr, argnums=argnums) # type: ignore[misc] # untyped decorator
def wrapped(
*args: Any, **kwargs: Any
) -> tuple[Quantity, Quantity] | tuple[tuple[Quantity, AuxT], Quantity]:
) -> tuple[Quantity, Quantity] | tuple[tuple[Quantity, Aux], Quantity]:
(result_wo_units, (result_units, aux)), grad = value_and_grad_fun(
*args, **kwargs
)
Expand Down

0 comments on commit 55b51d4

Please sign in to comment.