Skip to content

Commit

Permalink
Shape utils (#22)
Browse files Browse the repository at this point in the history
* simplify batched_shape
* Add batch, array shape utils
* type batch matrix 3,3

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Dec 9, 2023
1 parent 50d3eac commit 548a99b
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 23 deletions.
11 changes: 6 additions & 5 deletions src/galdynamix/dynamics/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ class AbstractPhaseSpacePosition(AbstractPhaseSpacePositionBase):
@property
def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]:
"""Batch ."""
qbatch, qshape = batched_shape(self.q, expect_scalar=False)
pbatch, pshape = batched_shape(self.p, expect_scalar=False)
tbatch, tshape = batched_shape(self.t, expect_scalar=True)
batch_shape = xp.broadcast_shapes(qbatch, pbatch, tbatch)
return batch_shape, (qshape, pshape, tshape)
qbatch, qshape = batched_shape(self.q, expect_ndim=1)
pbatch, pshape = batched_shape(self.p, expect_ndim=1)
tbatch, _ = batched_shape(self.t, expect_ndim=0)
batch_shape: tuple[int, ...] = xp.broadcast_shapes(qbatch, pbatch, tbatch)
array_shape: tuple[int, int, int] = qshape + pshape + (1,)
return batch_shape, array_shape

# ==========================================================================
# Convenience properties
Expand Down
8 changes: 4 additions & 4 deletions src/galdynamix/dynamics/mockstream/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ class MockStream(AbstractPhaseSpacePositionBase):
@property
def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]:
"""Batch ."""
qbatch, qshape = batched_shape(self.q, expect_scalar=False)
pbatch, pshape = batched_shape(self.p, expect_scalar=False)
tbatch, tshape = batched_shape(self.release_time, expect_scalar=True)
qbatch, qshape = batched_shape(self.q, expect_ndim=1)
pbatch, pshape = batched_shape(self.p, expect_ndim=1)
tbatch, _ = batched_shape(self.release_time, expect_ndim=0)
batch_shape = xp.broadcast_shapes(qbatch, pbatch, tbatch)
return batch_shape, (qshape, pshape, tshape)
return batch_shape, qshape + pshape + (1,)

@property
@partial_jit()
Expand Down
8 changes: 4 additions & 4 deletions src/galdynamix/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from galdynamix.typing import (
BatchableFloatLike,
BatchFloatScalar,
BatchMatrix33,
BatchVec3,
FloatScalar,
Matrix33,
Vec3,
Vec6,
)
Expand Down Expand Up @@ -187,13 +189,11 @@ def density(self, q: BatchVec3, /, t: BatchableFloatLike) -> BatchFloatScalar:

@partial_jit()
@vectorize_method(signature="(3),()->(3,3)")
def _hessian(self, q: Vec3, /, t: FloatScalar) -> Float[Array, "3 3"]:
def _hessian(self, q: Vec3, /, t: FloatScalar) -> Matrix33:
"""See ``hessian``."""
return hessian(self.potential_energy)(q, t)

def hessian(
self, q: BatchVec3, /, t: BatchableFloatLike
) -> Float[Array, "*batch 3 3"]:
def hessian(self, q: BatchVec3, /, t: BatchableFloatLike) -> BatchMatrix33:
"""Compute the Hessian of the potential at the given position(s).
Parameters
Expand Down
3 changes: 3 additions & 0 deletions src/galdynamix/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@
BatchVec3 = Shaped[Vec3, "*batch"]
"""Zero or more batches of 3-vectors."""

BatchMatrix33 = Shaped[Matrix33, "*batch"]
"""Zero or more batches of 3x3 matrices."""

BatchVec6 = Shaped[Vec6, "*batch"]
"""Zero or more batches of 6-vectors."""

Expand Down
166 changes: 156 additions & 10 deletions src/galdynamix/utils/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

__all__: list[str] = []

from typing import overload
from typing import Literal, overload

import jax.numpy as xp
from jaxtyping import Array, ArrayLike

from galdynamix.typing import ArrayAnyShape
from galdynamix.typing import ArrayAnyShape, FloatLike

from ._jax import partial_jit

Expand Down Expand Up @@ -41,12 +41,158 @@ def atleast_batched(*arys: ArrayLike) -> Array | tuple[Array, ...]:
return tuple(atleast_batched(arr) for arr in arys)


# =============================================================================


@overload
def batched_shape(
arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: Literal[0]
) -> tuple[tuple[int, ...], tuple[int, ...]]:
...


@overload
def batched_shape(
arr: ArrayAnyShape, /, *, expect_scalar: bool
) -> tuple[tuple[int, ...], int]:
"""Return the shape of the batch dimensions of an array."""
if arr.ndim == 0:
raise NotImplementedError
if arr.ndim == 1:
return (arr.shape, 1) if expect_scalar else ((), arr.shape[0])
return arr.shape[:-1], arr.shape[-1]
arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: Literal[1]
) -> tuple[tuple[int, ...], tuple[int]]:
...


@overload
def batched_shape(
arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: Literal[2]
) -> tuple[tuple[int, ...], tuple[int, int]]:
...


@overload
def batched_shape(
arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: int
) -> tuple[tuple[int, ...], tuple[int, ...]]:
...


def batched_shape(
arr: ArrayAnyShape | FloatLike, /, *, expect_ndim: int
) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""Return the (batch_shape, arr_shape) an array.
Parameters
----------
arr : array-like
The array to get the shape of.
expect_ndim : int
The expected dimensionality of the array.
Returns
-------
batch_shape : tuple[int, ...]
The shape of the batch.
arr_shape : tuple[int, ...]
The shape of the array.
Examples
--------
>>> import jax.numpy as jnp
>>> from galdynamix.utils._shape import batched_shape
Expecting a scalar:
>>> batched_shape(0, expect_ndim=0)
((), ())
>>> batched_shape(jnp.array([1]), expect_ndim=0)
((1,), ())
>>> batched_shape(jnp.array([1, 2, 3]), expect_ndim=0)
((3,), ())
Expecting a 1D vector:
>>> batched_shape(jnp.array(0), expect_ndim=1)
((), (1,))
>>> batched_shape(jnp.array([1]), expect_ndim=1)
((), (1,))
>>> batched_shape(jnp.array([1, 2, 3]), expect_ndim=1)
((), (3,))
>>> batched_shape(jnp.array([[1, 2, 3]]), expect_ndim=1)
((1,), (3,))
Expecting a 2D matrix:
>>> batched_shape(jnp.array([[1]]), expect_ndim=2)
((), (1, 1))
>>> batched_shape(jnp.array([[[1]]]), expect_ndim=2)
((1,), (1, 1))
>>> batched_shape(jnp.array([[[1]], [[1]]]), expect_ndim=2)
((2,), (1, 1))
"""
shape: tuple[int, ...] = xp.shape(arr)
ndim = len(shape)
return shape[: ndim - expect_ndim], shape[ndim - expect_ndim :]


def expand_batch_dims(arr: ArrayAnyShape, /, ndim: int) -> ArrayAnyShape:
"""Expand the batch dimensions of an array.
Parameters
----------
arr : array-like
The array to expand the batch dimensions of.
ndim : int
The number of batch dimensions to expand.
Returns
-------
arr : array-like
The array with expanded batch dimensions.
Examples
--------
>>> import jax.numpy as jnp
>>> from galdynamix.utils._shape import expand_batch_dims
>>> expand_batch_dims(jnp.array(0), ndim=0).shape
()
>>> expand_batch_dims(jnp.array([0]), ndim=0).shape
(1,)
>>> expand_batch_dims(jnp.array(0), ndim=1).shape
(1,)
>>> expand_batch_dims(jnp.array([0, 1]), ndim=1).shape
(1, 2)
"""
return xp.expand_dims(arr, axis=tuple(0 for _ in range(ndim)))


def expand_arr_dims(arr: ArrayAnyShape, /, ndim: int) -> ArrayAnyShape:
"""Expand the array dimensions of an array.
Parameters
----------
arr : array-like
The array to expand the array dimensions of.
ndim : int
The number of array dimensions to expand.
Returns
-------
arr : array-like
The array with expanded array dimensions.
Examples
--------
>>> import jax.numpy as jnp
>>> from galdynamix.utils._shape import expand_arr_dims
>>> expand_arr_dims(jnp.array(0), ndim=0).shape
()
>>> expand_arr_dims(jnp.array([0]), ndim=0).shape
(1,)
>>> expand_arr_dims(jnp.array(0), ndim=1).shape
(1,)
>>> expand_arr_dims(jnp.array([0, 0]), ndim=1).shape
(2, 1)
"""
nbatch = len(arr.shape)
return xp.expand_dims(arr, axis=tuple(nbatch + i for i in range(ndim)))

0 comments on commit 548a99b

Please sign in to comment.