Skip to content

Commit

Permalink
feat: change broadcasting of t in PhaseSpacePosition
Browse files Browse the repository at this point in the history
  • Loading branch information
nstarman committed Jan 29, 2024
1 parent 62a5494 commit e86e23e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
11 changes: 5 additions & 6 deletions src/galax/dynamics/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jax
import jax.numpy as jnp

from galax.typing import BatchFloatScalar, BroadBatchVec1, BroadBatchVec3
from galax.typing import BatchFloatScalar, BroadBatchVec1, BroadBatchVec3, Vec1
from galax.utils._shape import batched_shape, expand_batch_dims
from galax.utils.dataclasses import converter_float_array

Expand All @@ -35,11 +35,13 @@ class PhaseSpacePosition(AbstractPhaseSpacePosition):
This is a 3-vector with a batch shape allowing for vector inputs.
"""

t: BroadBatchVec1 = eqx.field(default=(0.0,), converter=converter_float_array)
t: BroadBatchVec1 | Vec1 = eqx.field(
default=(0.0,), converter=converter_float_array
)
"""The time corresponding to the positions.
This is a scalar with the same batch shape as the positions and velocities.
The default value is a scalar zero. `t` will be broadcast to the same batch
The default value is a scalar zero. If `t` will be broadcast to the same batch
shape as `q` and `p`.
"""

Expand All @@ -49,9 +51,6 @@ def __post_init__(self) -> None:
if self.t.ndim == 0:
t = expand_batch_dims(self.t, ndim=self.q.ndim)
object.__setattr__(self, "t", t)
elif self.t.ndim == 1:
t = expand_batch_dims(self.t, ndim=self.q.ndim - 1)
object.__setattr__(self, "t", t)

# ==========================================================================
# Array properties
Expand Down
5 changes: 2 additions & 3 deletions tests/unit/potential/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,11 @@ def test_integrate_orbit_batch(self, pot: AbstractPotentialBase, xv: Vec6) -> No
orbits = pot.integrate_orbit(xv[None, :], ts)
assert isinstance(orbits, gd.Orbit)
assert orbits.shape == (1, len(ts))
assert array_equal(orbits.t, ts[None, :])
assert array_equal(orbits.t, ts)

# More complicated batch
xv2 = xp.stack([xv, xv], axis=0)
orbits = pot.integrate_orbit(xv2, ts)
assert isinstance(orbits, gd.Orbit)
assert orbits.shape == (2, len(ts))
assert array_equal(orbits.t[0], ts)
assert array_equal(orbits.t[1], ts)
assert array_equal(orbits.t, ts)

0 comments on commit e86e23e

Please sign in to comment.