diff --git a/src/galax/utils/_jax.py b/src/galax/utils/_jax.py index 31bb013c..6601049b 100644 --- a/src/galax/utils/_jax.py +++ b/src/galax/utils/_jax.py @@ -69,5 +69,17 @@ def partial_vectorize( def vectorize_method( **kwargs: Unpack[VectorizeKwargs], ) -> Callable[[Callable[P, R]], Callable[P, R]]: - kwargs.setdefault("excluded", (0, *tuple(kwargs.get("excluded") or ()))) + """Decorate a method to :func:`jax.numpy.vectorize`. + + This is a wrapper around :func:`jax.numpy.vectorize` that vectorizes a + class' method by returning a :class:`functools.partial`. It is equivalent to + :func:`partial_vectorize`, except that ``excluded`` is set to exclude the + 0th argument (``self``). As a result, the ``excluded`` tuple should start + at 0 to exclude the first 'real' argument (proceeding ``self``). + """ + # Prepend 0 to excluded to exclude the first argument (self) + excluded = tuple(kwargs.get("excluded") or (-1,)) # (None -> (0,)) + excluded = excluded if excluded[0] == -1 else (-1, *excluded) + kwargs["excluded"] = tuple(i + 1 for i in excluded) + return partial_vectorize(**kwargs)