From 548a99b4e4869d11b1ab056c93a9d311644695a5 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Fri, 8 Dec 2023 22:31:26 -0500 Subject: [PATCH] Shape utils (#22) * simplify batched_shape * Add batch, array shape utils * type batch matrix 3,3 Signed-off-by: nstarman --- src/galdynamix/dynamics/_core.py | 11 +- src/galdynamix/dynamics/mockstream/_core.py | 8 +- src/galdynamix/potential/_potential/base.py | 8 +- src/galdynamix/typing.py | 3 + src/galdynamix/utils/_shape.py | 166 ++++++++++++++++++-- 5 files changed, 173 insertions(+), 23 deletions(-) diff --git a/src/galdynamix/dynamics/_core.py b/src/galdynamix/dynamics/_core.py index c2bd3bb4..353523a2 100644 --- a/src/galdynamix/dynamics/_core.py +++ b/src/galdynamix/dynamics/_core.py @@ -94,11 +94,12 @@ class AbstractPhaseSpacePosition(AbstractPhaseSpacePositionBase): @property def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]: """Batch .""" - qbatch, qshape = batched_shape(self.q, expect_scalar=False) - pbatch, pshape = batched_shape(self.p, expect_scalar=False) - tbatch, tshape = batched_shape(self.t, expect_scalar=True) - batch_shape = xp.broadcast_shapes(qbatch, pbatch, tbatch) - return batch_shape, (qshape, pshape, tshape) + qbatch, qshape = batched_shape(self.q, expect_ndim=1) + pbatch, pshape = batched_shape(self.p, expect_ndim=1) + tbatch, _ = batched_shape(self.t, expect_ndim=0) + batch_shape: tuple[int, ...] = xp.broadcast_shapes(qbatch, pbatch, tbatch) + array_shape: tuple[int, int, int] = qshape + pshape + (1,) + return batch_shape, array_shape # ========================================================================== # Convenience properties diff --git a/src/galdynamix/dynamics/mockstream/_core.py b/src/galdynamix/dynamics/mockstream/_core.py index 6d3caf8f..7d6eeb3a 100644 --- a/src/galdynamix/dynamics/mockstream/_core.py +++ b/src/galdynamix/dynamics/mockstream/_core.py @@ -28,11 +28,11 @@ class MockStream(AbstractPhaseSpacePositionBase): @property def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]: """Batch .""" - qbatch, qshape = batched_shape(self.q, expect_scalar=False) - pbatch, pshape = batched_shape(self.p, expect_scalar=False) - tbatch, tshape = batched_shape(self.release_time, expect_scalar=True) + qbatch, qshape = batched_shape(self.q, expect_ndim=1) + pbatch, pshape = batched_shape(self.p, expect_ndim=1) + tbatch, _ = batched_shape(self.release_time, expect_ndim=0) batch_shape = xp.broadcast_shapes(qbatch, pbatch, tbatch) - return batch_shape, (qshape, pshape, tshape) + return batch_shape, qshape + pshape + (1,) @property @partial_jit() diff --git a/src/galdynamix/potential/_potential/base.py b/src/galdynamix/potential/_potential/base.py index 58d9000d..98d69677 100644 --- a/src/galdynamix/potential/_potential/base.py +++ b/src/galdynamix/potential/_potential/base.py @@ -17,8 +17,10 @@ from galdynamix.typing import ( BatchableFloatLike, BatchFloatScalar, + BatchMatrix33, BatchVec3, FloatScalar, + Matrix33, Vec3, Vec6, ) @@ -187,13 +189,11 @@ def density(self, q: BatchVec3, /, t: BatchableFloatLike) -> BatchFloatScalar: @partial_jit() @vectorize_method(signature="(3),()->(3,3)") - def _hessian(self, q: Vec3, /, t: FloatScalar) -> Float[Array, "3 3"]: + def _hessian(self, q: Vec3, /, t: FloatScalar) -> Matrix33: """See ``hessian``.""" return hessian(self.potential_energy)(q, t) - def hessian( - self, q: BatchVec3, /, t: BatchableFloatLike - ) -> Float[Array, "*batch 3 3"]: + def hessian(self, q: BatchVec3, /, t: BatchableFloatLike) -> BatchMatrix33: """Compute the Hessian of the potential at the given position(s). Parameters diff --git a/src/galdynamix/typing.py b/src/galdynamix/typing.py index a0087a88..ea501835 100644 --- a/src/galdynamix/typing.py +++ b/src/galdynamix/typing.py @@ -61,6 +61,9 @@ BatchVec3 = Shaped[Vec3, "*batch"] """Zero or more batches of 3-vectors.""" +BatchMatrix33 = Shaped[Matrix33, "*batch"] +"""Zero or more batches of 3x3 matrices.""" + BatchVec6 = Shaped[Vec6, "*batch"] """Zero or more batches of 6-vectors.""" diff --git a/src/galdynamix/utils/_shape.py b/src/galdynamix/utils/_shape.py index 7593f3f0..a280c3da 100644 --- a/src/galdynamix/utils/_shape.py +++ b/src/galdynamix/utils/_shape.py @@ -2,12 +2,12 @@ __all__: list[str] = [] -from typing import overload +from typing import Literal, overload import jax.numpy as xp from jaxtyping import Array, ArrayLike -from galdynamix.typing import ArrayAnyShape +from galdynamix.typing import ArrayAnyShape, FloatLike from ._jax import partial_jit @@ -41,12 +41,158 @@ def atleast_batched(*arys: ArrayLike) -> Array | tuple[Array, ...]: return tuple(atleast_batched(arr) for arr in arys) +# ============================================================================= + + +@overload +def batched_shape( + arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: Literal[0] +) -> tuple[tuple[int, ...], tuple[int, ...]]: + ... + + +@overload def batched_shape( - arr: ArrayAnyShape, /, *, expect_scalar: bool -) -> tuple[tuple[int, ...], int]: - """Return the shape of the batch dimensions of an array.""" - if arr.ndim == 0: - raise NotImplementedError - if arr.ndim == 1: - return (arr.shape, 1) if expect_scalar else ((), arr.shape[0]) - return arr.shape[:-1], arr.shape[-1] + arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: Literal[1] +) -> tuple[tuple[int, ...], tuple[int]]: + ... + + +@overload +def batched_shape( + arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: Literal[2] +) -> tuple[tuple[int, ...], tuple[int, int]]: + ... + + +@overload +def batched_shape( + arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: int +) -> tuple[tuple[int, ...], tuple[int, ...]]: + ... + + +def batched_shape( + arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: int +) -> tuple[tuple[int, ...], tuple[int, ...]]: + """Return the (batch_shape, arr_shape) an array. + + Parameters + ---------- + arr : array-like + The array to get the shape of. + expect_ndim : int + The expected dimensionality of the array. + + Returns + ------- + batch_shape : tuple[int, ...] + The shape of the batch. + arr_shape : tuple[int, ...] + The shape of the array. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from galdynamix.utils._shape import batched_shape + + Expecting a scalar: + >>> batched_shape(0, expect_ndim=0) + ((), ()) + >>> batched_shape(jnp.array([1]), expect_ndim=0) + ((1,), ()) + >>> batched_shape(jnp.array([1, 2, 3]), expect_ndim=0) + ((3,), ()) + + Expecting a 1D vector: + >>> batched_shape(jnp.array(0), expect_ndim=1) + ((), (1,)) + >>> batched_shape(jnp.array([1]), expect_ndim=1) + ((), (1,)) + >>> batched_shape(jnp.array([1, 2, 3]), expect_ndim=1) + ((), (3,)) + >>> batched_shape(jnp.array([[1, 2, 3]]), expect_ndim=1) + ((1,), (3,)) + + Expecting a 2D matrix: + >>> batched_shape(jnp.array([[1]]), expect_ndim=2) + ((), (1, 1)) + >>> batched_shape(jnp.array([[[1]]]), expect_ndim=2) + ((1,), (1, 1)) + >>> batched_shape(jnp.array([[[1]], [[1]]]), expect_ndim=2) + ((2,), (1, 1)) + """ + shape: tuple[int, ...] = xp.shape(arr) + ndim = len(shape) + return shape[: ndim - expect_ndim], shape[ndim - expect_ndim :] + + +def expand_batch_dims(arr: ArrayAnyShape, /, ndim: int) -> ArrayAnyShape: + """Expand the batch dimensions of an array. + + Parameters + ---------- + arr : array-like + The array to expand the batch dimensions of. + ndim : int + The number of batch dimensions to expand. + + Returns + ------- + arr : array-like + The array with expanded batch dimensions. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from galdynamix.utils._shape import expand_batch_dims + + >>> expand_batch_dims(jnp.array(0), ndim=0).shape + () + + >>> expand_batch_dims(jnp.array([0]), ndim=0).shape + (1,) + + >>> expand_batch_dims(jnp.array(0), ndim=1).shape + (1,) + + >>> expand_batch_dims(jnp.array([0, 1]), ndim=1).shape + (1, 2) + """ + return xp.expand_dims(arr, axis=tuple(0 for _ in range(ndim))) + + +def expand_arr_dims(arr: ArrayAnyShape, /, ndim: int) -> ArrayAnyShape: + """Expand the array dimensions of an array. + + Parameters + ---------- + arr : array-like + The array to expand the array dimensions of. + ndim : int + The number of array dimensions to expand. + + Returns + ------- + arr : array-like + The array with expanded array dimensions. + + Examples + -------- + >>> import jax.numpy as jnp + >>> from galdynamix.utils._shape import expand_arr_dims + + >>> expand_arr_dims(jnp.array(0), ndim=0).shape + () + + >>> expand_arr_dims(jnp.array([0]), ndim=0).shape + (1,) + + >>> expand_arr_dims(jnp.array(0), ndim=1).shape + (1,) + + >>> expand_arr_dims(jnp.array([0, 0]), ndim=1).shape + (2, 1) + """ + nbatch = len(arr.shape) + return xp.expand_dims(arr, axis=tuple(nbatch + i for i in range(ndim)))