Skip to content

Commit

Permalink
fix expand_batch_dims to allow multiple dims (#26)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Dec 9, 2023
1 parent 03dab46 commit 4716971
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 35 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ ignore = [
"N80", # Naming conventions.
"PD", # pandas-vet
"PLR", # Design related pylint codes
"PYI041", # Use `float` instead of `int | float` <- beartype is more strict
"TCH00", # Move into a type-checking block
"TD002", # Missing author in TODO
"TD003", # Missing issue link on the line following this TODO
Expand Down
24 changes: 21 additions & 3 deletions src/galdynamix/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,21 @@
# TODO: Finalize variable names and make everything public.
__all__: list[str] = []

from typing import TypeAlias

import astropy.units as u
from jaxtyping import Array, Float, Integer, Shaped

# =============================================================================

Unit: TypeAlias = u.Unit | u.UnitBase | u.CompositeUnit

# =============================================================================
# Scalars

AnyScalar = Shaped[Array, ""]
"""Any scalar."""

IntScalar = Integer[Array, ""]
"""An integer scalar."""

Expand Down Expand Up @@ -64,7 +74,7 @@
BatchableFloatOrIntScalarLike = BatchFloatOrIntScalar | FloatOrIntScalarLike

# -----------------
# Shaped
# Batched

BatchVec3 = Shaped[Vec3, "*batch"]
"""Zero or more batches of 3-vectors."""
Expand All @@ -78,8 +88,16 @@
BatchVec7 = Shaped[Vec7, "*batch"]
"""Zero or more batches of 7-vectors."""

ArrayAnyShape = Float[Array, "..."]
"""An array with any shape."""
# -----------------
# Any Shape

FloatArrayAnyShape = Float[Array, "..."]
"""A float array with any shape."""

IntArrayAnyShape = Integer[Array, "..."]
"""An integer array with any shape."""

ArrayAnyShape = Shaped[Array, "..."]

# =============================================================================
# Specific Vectors
Expand Down
110 changes: 78 additions & 32 deletions src/galdynamix/utils/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@

__all__: list[str] = []

from typing import Literal, overload
from typing import Any, Literal, NoReturn, overload

import jax.numpy as xp
from jaxtyping import Array, ArrayLike

from galdynamix.typing import ArrayAnyShape, FloatLike
from galdynamix.typing import AnyScalar, ArrayAnyShape

from ._jax import partial_jit


@overload
def atleast_batched() -> tuple[Array, ...]:
def atleast_batched() -> NoReturn:
...


Expand All @@ -30,7 +30,48 @@ def atleast_batched(


@partial_jit()
def atleast_batched(*arys: ArrayLike) -> Array | tuple[Array, ...]:
def atleast_batched(*arys: Any) -> Array | tuple[Array, ...]:
"""Convert inputs to arrays with at least two dimensions.
Parameters
----------
*arys : array_like
One or more array-like sequences. Non-array inputs are converted to
arrays. Arrays that already have two or more dimensions are preserved.
Returns
-------
res : tuple
A tuple of arrays, each with ``a.ndim >= 2``. Copies are made only if
necessary.
Examples
--------
>>> from galdynamix.utils._shape import atleast_batched
>>> atleast_batched(0)
Array([[0]], dtype=int64, ...)
>>> atleast_batched([1])
Array([[1]], dtype=int64)
>>> atleast_batched([[1]])
Array([[1]], dtype=int64)
>>> atleast_batched([[[1]]])
Array([[[1]]], dtype=int64)
>>> atleast_batched([1, 2, 3])
Array([[1],
[2],
[3]], dtype=int64)
>>> import jax.numpy as jnp
>>> jnp.atleast_2d(xp.array([1, 2, 3]))
Array([[1, 2, 3]], dtype=int64)
"""
if len(arys) == 0:
msg = "atleast_batched() requires at least one argument"
raise ValueError(msg)
if len(arys) == 1:
arr = xp.asarray(arys[0])
if arr.ndim >= 2:
Expand All @@ -46,34 +87,34 @@ def atleast_batched(*arys: ArrayLike) -> Array | tuple[Array, ...]:

@overload
def batched_shape(
arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: Literal[0]
arr: ArrayAnyShape | AnyScalar, /, *, expect_ndim: Literal[0]
) -> tuple[tuple[int, ...], tuple[int, ...]]:
...


@overload
def batched_shape(
arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: Literal[1]
arr: ArrayAnyShape | AnyScalar, /, *, expect_ndim: Literal[1]
) -> tuple[tuple[int, ...], tuple[int]]:
...


@overload
def batched_shape(
arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: Literal[2]
arr: ArrayAnyShape | AnyScalar, /, *, expect_ndim: Literal[2]
) -> tuple[tuple[int, ...], tuple[int, int]]:
...


@overload
def batched_shape(
arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: int
arr: ArrayAnyShape | AnyScalar, /, *, expect_ndim: int
) -> tuple[tuple[int, ...], tuple[int, ...]]:
...


def batched_shape(
arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: int
arr: ArrayAnyShape | AnyScalar | float | int, /, *, expect_ndim: int
) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""Return the (batch_shape, arr_shape) an array.
Expand All @@ -93,34 +134,39 @@ def batched_shape(
Examples
--------
>>> import jax.numpy as jnp
>>> from galdynamix.utils._shape import batched_shape
Standard imports:
>>> import jax.numpy as xp
>>> from galdynamix.utils._shape import batched_shape
Expecting a scalar:
>>> batched_shape(0, expect_ndim=0)
((), ())
>>> batched_shape(jnp.array([1]), expect_ndim=0)
((1,), ())
>>> batched_shape(jnp.array([1, 2, 3]), expect_ndim=0)
((3,), ())
>>> batched_shape(0, expect_ndim=0)
((), ())
>>> batched_shape(xp.array([1]), expect_ndim=0)
((1,), ())
>>> batched_shape(xp.array([1, 2, 3]), expect_ndim=0)
((3,), ())
Expecting a 1D vector:
>>> batched_shape(jnp.array(0), expect_ndim=1)
((), (1,))
>>> batched_shape(jnp.array([1]), expect_ndim=1)
((), (1,))
>>> batched_shape(jnp.array([1, 2, 3]), expect_ndim=1)
((), (3,))
>>> batched_shape(jnp.array([[1, 2, 3]]), expect_ndim=1)
((1,), (3,))
>>> batched_shape(xp.array(0), expect_ndim=1)
((), ())
>>> batched_shape(xp.array([1]), expect_ndim=1)
((), (1,))
>>> batched_shape(xp.array([1, 2, 3]), expect_ndim=1)
((), (3,))
>>> batched_shape(xp.array([[1, 2, 3]]), expect_ndim=1)
((1,), (3,))
Expecting a 2D matrix:
>>> batched_shape(jnp.array([[1]]), expect_ndim=2)
((), (1, 1))
>>> batched_shape(jnp.array([[[1]]]), expect_ndim=2)
((1,), (1, 1))
>>> batched_shape(jnp.array([[[1]], [[1]]]), expect_ndim=2)
((2,), (1, 1))
>>> batched_shape(xp.array([[1]]), expect_ndim=2)
((), (1, 1))
>>> batched_shape(xp.array([[[1]]]), expect_ndim=2)
((1,), (1, 1))
>>> batched_shape(xp.array([[[1]], [[1]]]), expect_ndim=2)
((2,), (1, 1))
"""
shape: tuple[int, ...] = xp.shape(arr)
ndim = len(shape)
Expand Down Expand Up @@ -159,7 +205,7 @@ def expand_batch_dims(arr: ArrayAnyShape, /, ndim: int) -> ArrayAnyShape:
>>> expand_batch_dims(jnp.array([0, 1]), ndim=1).shape
(1, 2)
"""
return xp.expand_dims(arr, axis=tuple(0 for _ in range(ndim)))
return xp.expand_dims(arr, axis=tuple(range(ndim)))


def expand_arr_dims(arr: ArrayAnyShape, /, ndim: int) -> ArrayAnyShape:
Expand Down

0 comments on commit 4716971

Please sign in to comment.