From 805666f89f161ebbd61c542425c2f7b20d6b21d7 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Thu, 18 Jan 2024 21:43:30 -0500 Subject: [PATCH] fix selection of integrated state (#69) Signed-off-by: nstarman --- src/galax/dynamics/mockstream/_mockstream_generator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/galax/dynamics/mockstream/_mockstream_generator.py b/src/galax/dynamics/mockstream/_mockstream_generator.py index f8f269ab..b7a7dae0 100644 --- a/src/galax/dynamics/mockstream/_mockstream_generator.py +++ b/src/galax/dynamics/mockstream/_mockstream_generator.py @@ -74,9 +74,10 @@ def scan_fn(carry: Carry, idx: IntScalar) -> tuple[Carry, tuple[VecN, VecN]]: tstep = xp.asarray([ts[i], ts[-1]]) def integ_ics(ics: Vec6) -> VecN: + # TODO: only return the final state return self.potential.integrate_orbit( ics, tstep, integrator=self.stream_integrator - ).qp[0] + ).qp[-1] # vmap over leading and trailing arm qp_lead, qp_trail = jax.vmap(integ_ics, in_axes=(0,))(qp0_lead_trail) @@ -107,12 +108,13 @@ def single_particle_integrate( i: IntScalar, qp0_lead_i: Vec6, qp0_trail_i: Vec6 ) -> tuple[Vec6, Vec6]: tstep = xp.asarray([ts[i], t_f]) + # TODO: only return the final state qp_lead = self.potential.integrate_orbit( qp0_lead_i, tstep, integrator=self.stream_integrator - ).qp[0] + ).qp[-1] qp_trail = self.potential.integrate_orbit( qp0_trail_i, tstep, integrator=self.stream_integrator - ).qp[0] + ).qp[-1] return qp_lead, qp_trail particle_ids = xp.arange(len(qp0_lead))