diff --git a/src/galax/dynamics/_core.py b/src/galax/dynamics/_core.py index 51c98716..c29c642c 100644 --- a/src/galax/dynamics/_core.py +++ b/src/galax/dynamics/_core.py @@ -9,7 +9,7 @@ import jax import jax.numpy as jnp -from galax.typing import BatchFloatScalar, BroadBatchVec1, BroadBatchVec3 +from galax.typing import BatchFloatScalar, BroadBatchVec1, BroadBatchVec3, Vec1 from galax.utils._shape import batched_shape, expand_batch_dims from galax.utils.dataclasses import converter_float_array @@ -35,11 +35,13 @@ class PhaseSpacePosition(AbstractPhaseSpacePosition): This is a 3-vector with a batch shape allowing for vector inputs. """ - t: BroadBatchVec1 = eqx.field(default=(0.0,), converter=converter_float_array) + t: BroadBatchVec1 | Vec1 = eqx.field( + default=(0.0,), converter=converter_float_array + ) """The time corresponding to the positions. This is a scalar with the same batch shape as the positions and velocities. - The default value is a scalar zero. `t` will be broadcast to the same batch + The default value is a scalar zero. If `t` will be broadcast to the same batch shape as `q` and `p`. """ @@ -49,9 +51,6 @@ def __post_init__(self) -> None: if self.t.ndim == 0: t = expand_batch_dims(self.t, ndim=self.q.ndim) object.__setattr__(self, "t", t) - elif self.t.ndim == 1: - t = expand_batch_dims(self.t, ndim=self.q.ndim - 1) - object.__setattr__(self, "t", t) # ========================================================================== # Array properties diff --git a/tests/unit/potential/test_base.py b/tests/unit/potential/test_base.py index 7052853a..a5bbf497 100644 --- a/tests/unit/potential/test_base.py +++ b/tests/unit/potential/test_base.py @@ -153,12 +153,11 @@ def test_integrate_orbit_batch(self, pot: AbstractPotentialBase, xv: Vec6) -> No orbits = pot.integrate_orbit(xv[None, :], ts) assert isinstance(orbits, gd.Orbit) assert orbits.shape == (1, len(ts)) - assert array_equal(orbits.t, ts[None, :]) + assert array_equal(orbits.t, ts) # More complicated batch xv2 = xp.stack([xv, xv], axis=0) orbits = pot.integrate_orbit(xv2, ts) assert isinstance(orbits, gd.Orbit) assert orbits.shape == (2, len(ts)) - assert array_equal(orbits.t[0], ts) - assert array_equal(orbits.t[1], ts) + assert array_equal(orbits.t, ts)