diff --git a/pyproject.toml b/pyproject.toml index fe4aaf31..a0c180e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,6 +150,7 @@ ignore = [ "N80", # Naming conventions. "PD", # pandas-vet "PLR", # Design related pylint codes + "PYI041", # Use `float` instead of `int | float` <- beartype is more strict "TCH00", # Move into a type-checking block "TD002", # Missing author in TODO "TD003", # Missing issue link on the line following this TODO diff --git a/src/galdynamix/typing.py b/src/galdynamix/typing.py index 8d82b54f..000b5000 100644 --- a/src/galdynamix/typing.py +++ b/src/galdynamix/typing.py @@ -3,11 +3,21 @@ # TODO: Finalize variable names and make everything public. __all__: list[str] = [] +from typing import TypeAlias + +import astropy.units as u from jaxtyping import Array, Float, Integer, Shaped +# ============================================================================= + +Unit: TypeAlias = u.Unit | u.UnitBase | u.CompositeUnit + # ============================================================================= # Scalars +AnyScalar = Shaped[Array, ""] +"""Any scalar.""" + IntScalar = Integer[Array, ""] """An integer scalar.""" @@ -64,7 +74,7 @@ BatchableFloatOrIntScalarLike = BatchFloatOrIntScalar | FloatOrIntScalarLike # ----------------- -# Shaped +# Batched BatchVec3 = Shaped[Vec3, "*batch"] """Zero or more batches of 3-vectors.""" @@ -78,8 +88,16 @@ BatchVec7 = Shaped[Vec7, "*batch"] """Zero or more batches of 7-vectors.""" -ArrayAnyShape = Float[Array, "..."] -"""An array with any shape.""" +# ----------------- +# Any Shape + +FloatArrayAnyShape = Float[Array, "..."] +"""A float array with any shape.""" + +IntArrayAnyShape = Integer[Array, "..."] +"""An integer array with any shape.""" + +ArrayAnyShape = Shaped[Array, "..."] # ============================================================================= # Specific Vectors diff --git a/src/galdynamix/utils/_shape.py b/src/galdynamix/utils/_shape.py index a280c3da..eceeaa36 100644 --- a/src/galdynamix/utils/_shape.py +++ b/src/galdynamix/utils/_shape.py @@ -2,18 +2,18 @@ __all__: list[str] = [] -from typing import Literal, overload +from typing import Any, Literal, NoReturn, overload import jax.numpy as xp from jaxtyping import Array, ArrayLike -from galdynamix.typing import ArrayAnyShape, FloatLike +from galdynamix.typing import AnyScalar, ArrayAnyShape from ._jax import partial_jit @overload -def atleast_batched() -> tuple[Array, ...]: +def atleast_batched() -> NoReturn: ... @@ -30,7 +30,48 @@ def atleast_batched( @partial_jit() -def atleast_batched(*arys: ArrayLike) -> Array | tuple[Array, ...]: +def atleast_batched(*arys: Any) -> Array | tuple[Array, ...]: + """Convert inputs to arrays with at least two dimensions. + + Parameters + ---------- + *arys : array_like + One or more array-like sequences. Non-array inputs are converted to + arrays. Arrays that already have two or more dimensions are preserved. + + Returns + ------- + res : tuple + A tuple of arrays, each with ``a.ndim >= 2``. Copies are made only if + necessary. + + Examples + -------- + >>> from galdynamix.utils._shape import atleast_batched + >>> atleast_batched(0) + Array([[0]], dtype=int64, ...) + + >>> atleast_batched([1]) + Array([[1]], dtype=int64) + + >>> atleast_batched([[1]]) + Array([[1]], dtype=int64) + + >>> atleast_batched([[[1]]]) + Array([[[1]]], dtype=int64) + + >>> atleast_batched([1, 2, 3]) + Array([[1], + [2], + [3]], dtype=int64) + + >>> import jax.numpy as jnp + >>> jnp.atleast_2d(xp.array([1, 2, 3])) + Array([[1, 2, 3]], dtype=int64) + """ + if len(arys) == 0: + msg = "atleast_batched() requires at least one argument" + raise ValueError(msg) if len(arys) == 1: arr = xp.asarray(arys[0]) if arr.ndim >= 2: @@ -46,34 +87,34 @@ def atleast_batched(*arys: ArrayLike) -> Array | tuple[Array, ...]: @overload def batched_shape( - arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: Literal[0] + arr: ArrayAnyShape | AnyScalar, /, *, expect_ndim: Literal[0] ) -> tuple[tuple[int, ...], tuple[int, ...]]: ... @overload def batched_shape( - arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: Literal[1] + arr: ArrayAnyShape | AnyScalar, /, *, expect_ndim: Literal[1] ) -> tuple[tuple[int, ...], tuple[int]]: ... @overload def batched_shape( - arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: Literal[2] + arr: ArrayAnyShape | AnyScalar, /, *, expect_ndim: Literal[2] ) -> tuple[tuple[int, ...], tuple[int, int]]: ... @overload def batched_shape( - arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: int + arr: ArrayAnyShape | AnyScalar, /, *, expect_ndim: int ) -> tuple[tuple[int, ...], tuple[int, ...]]: ... def batched_shape( - arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: int + arr: ArrayAnyShape | AnyScalar | float | int, /, *, expect_ndim: int ) -> tuple[tuple[int, ...], tuple[int, ...]]: """Return the (batch_shape, arr_shape) an array. @@ -93,34 +134,39 @@ def batched_shape( Examples -------- - >>> import jax.numpy as jnp - >>> from galdynamix.utils._shape import batched_shape + Standard imports: + + >>> import jax.numpy as xp + >>> from galdynamix.utils._shape import batched_shape Expecting a scalar: - >>> batched_shape(0, expect_ndim=0) - ((), ()) - >>> batched_shape(jnp.array([1]), expect_ndim=0) - ((1,), ()) - >>> batched_shape(jnp.array([1, 2, 3]), expect_ndim=0) - ((3,), ()) + + >>> batched_shape(0, expect_ndim=0) + ((), ()) + >>> batched_shape(xp.array([1]), expect_ndim=0) + ((1,), ()) + >>> batched_shape(xp.array([1, 2, 3]), expect_ndim=0) + ((3,), ()) Expecting a 1D vector: - >>> batched_shape(jnp.array(0), expect_ndim=1) - ((), (1,)) - >>> batched_shape(jnp.array([1]), expect_ndim=1) - ((), (1,)) - >>> batched_shape(jnp.array([1, 2, 3]), expect_ndim=1) - ((), (3,)) - >>> batched_shape(jnp.array([[1, 2, 3]]), expect_ndim=1) - ((1,), (3,)) + + >>> batched_shape(xp.array(0), expect_ndim=1) + ((), ()) + >>> batched_shape(xp.array([1]), expect_ndim=1) + ((), (1,)) + >>> batched_shape(xp.array([1, 2, 3]), expect_ndim=1) + ((), (3,)) + >>> batched_shape(xp.array([[1, 2, 3]]), expect_ndim=1) + ((1,), (3,)) Expecting a 2D matrix: - >>> batched_shape(jnp.array([[1]]), expect_ndim=2) - ((), (1, 1)) - >>> batched_shape(jnp.array([[[1]]]), expect_ndim=2) - ((1,), (1, 1)) - >>> batched_shape(jnp.array([[[1]], [[1]]]), expect_ndim=2) - ((2,), (1, 1)) + + >>> batched_shape(xp.array([[1]]), expect_ndim=2) + ((), (1, 1)) + >>> batched_shape(xp.array([[[1]]]), expect_ndim=2) + ((1,), (1, 1)) + >>> batched_shape(xp.array([[[1]], [[1]]]), expect_ndim=2) + ((2,), (1, 1)) """ shape: tuple[int, ...] = xp.shape(arr) ndim = len(shape) @@ -159,7 +205,7 @@ def expand_batch_dims(arr: ArrayAnyShape, /, ndim: int) -> ArrayAnyShape: >>> expand_batch_dims(jnp.array([0, 1]), ndim=1).shape (1, 2) """ - return xp.expand_dims(arr, axis=tuple(0 for _ in range(ndim))) + return xp.expand_dims(arr, axis=tuple(range(ndim))) def expand_arr_dims(arr: ArrayAnyShape, /, ndim: int) -> ArrayAnyShape: