diff --git a/src/galax/dynamics/mockstream/_df/base.py b/src/galax/dynamics/mockstream/_df/base.py index 7dc90aa9..949b2255 100644 --- a/src/galax/dynamics/mockstream/_df/base.py +++ b/src/galax/dynamics/mockstream/_df/base.py @@ -83,10 +83,10 @@ def scan_fn(carry: Carry, t: FloatScalar) -> tuple[Carry, Wif]: xp.asarray([0.0, 0.0, 0.0]), xp.asarray([0.0, 0.0, 0.0]), ) - x_lead, x_trail, v_lead, v_trail = jax.lax.scan(scan_fn, init_carry, ts[1:])[1] + x_lead, x_trail, v_lead, v_trail = jax.lax.scan(scan_fn, init_carry, ts)[1] - mock_lead = MockStream(x_lead, v_lead, ts[1:]) - mock_trail = MockStream(x_trail, v_trail, ts[1:]) + mock_lead = MockStream(x_lead, v_lead, ts) + mock_trail = MockStream(x_trail, v_trail, ts) return mock_lead, mock_trail