Skip to content

Commit

Permalink
fix vectorize_method to always exclude self (#59)
Browse files Browse the repository at this point in the history
* fix vectorize_method to always exclude self
* Skip self in indexing

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 18, 2024
1 parent 0676a67 commit 5047e80
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/galax/utils/_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,17 @@ def partial_vectorize(
def vectorize_method(
**kwargs: Unpack[VectorizeKwargs],
) -> Callable[[Callable[P, R]], Callable[P, R]]:
kwargs.setdefault("excluded", (0, *tuple(kwargs.get("excluded") or ())))
"""Decorate a method to :func:`jax.numpy.vectorize`.
This is a wrapper around :func:`jax.numpy.vectorize` that vectorizes a
class' method by returning a :class:`functools.partial`. It is equivalent to
:func:`partial_vectorize`, except that ``excluded`` is set to exclude the
0th argument (``self``). As a result, the ``excluded`` tuple should start
at 0 to exclude the first 'real' argument (proceeding ``self``).
"""
# Prepend 0 to excluded to exclude the first argument (self)
excluded = tuple(kwargs.get("excluded") or (-1,)) # (None -> (0,))
excluded = excluded if excluded[0] == -1 else (-1, *excluded)
kwargs["excluded"] = tuple(i + 1 for i in excluded)

return partial_vectorize(**kwargs)

0 comments on commit 5047e80

Please sign in to comment.