Skip to content

Commit

Permalink
feat: change shape property of PhaseSpacePosition
Browse files Browse the repository at this point in the history
  • Loading branch information
nstarman committed Jan 29, 2024
1 parent 79637ae commit 62a5494
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
11 changes: 9 additions & 2 deletions src/galax/dynamics/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,20 @@ def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
@property
def shape(self) -> tuple[int, ...]:
"""Shape of the position and velocity arrays."""
batch_shape, component_shapes = self._shape_tuple
return (*batch_shape, sum(component_shapes))
return self._shape_tuple[0]

def __len__(self) -> int:
"""Return the number of particles."""
return self.shape[0]

# ==========================================================================

@property
def full_shape(self) -> tuple[int, ...]:
"""Shape of the position and velocity arrays."""
batch_shape, component_shapes = self._shape_tuple
return (*batch_shape, sum(component_shapes))

# ==========================================================================
# Convenience properties

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/potential/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_integrate_orbit(self, pot: AbstractPotentialBase, xv: Vec6) -> None:

orbit = pot.integrate_orbit(xv, ts)
assert isinstance(orbit, gd.Orbit)
assert orbit.shape == (len(ts), 7)
assert orbit.shape == (len(ts),)
assert array_equal(orbit.t, ts)

def test_integrate_orbit_batch(self, pot: AbstractPotentialBase, xv: Vec6) -> None:
Expand All @@ -152,13 +152,13 @@ def test_integrate_orbit_batch(self, pot: AbstractPotentialBase, xv: Vec6) -> No
# Simple batch
orbits = pot.integrate_orbit(xv[None, :], ts)
assert isinstance(orbits, gd.Orbit)
assert orbits.shape == (1, len(ts), 7)
assert orbits.shape == (1, len(ts))
assert array_equal(orbits.t, ts[None, :])

# More complicated batch
xv2 = xp.stack([xv, xv], axis=0)
orbits = pot.integrate_orbit(xv2, ts)
assert isinstance(orbits, gd.Orbit)
assert orbits.shape == (2, len(ts), 7)
assert orbits.shape == (2, len(ts))
assert array_equal(orbits.t[0], ts)
assert array_equal(orbits.t[1], ts)

0 comments on commit 62a5494

Please sign in to comment.