Skip to content

Commit

Permalink
auto-detect backend to choose stream gen method
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 19, 2024
1 parent 805666f commit 0cd0ee9
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/galax/dynamics/mockstream/_mockstream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import equinox as eqx
import jax
import jax.numpy as xp
from jax.lib.xla_bridge import get_backend

from galax.dynamics._orbit import Orbit
from galax.integrate._base import AbstractIntegrator
Expand Down Expand Up @@ -130,7 +131,7 @@ def run(
prog_mass: FloatScalar,
*,
seed_num: int,
vmapped: bool = False,
vmapped: bool | None = None,
) -> tuple[tuple[MockStream, MockStream], Orbit]:
"""Generate mock stellar stream.
Expand All @@ -148,10 +149,12 @@ def run(
:todo: a better way to handle PRNG
vmapped : bool, optional keyword-only
vmapped : bool | None, optional keyword-only
Whether to use `jax.vmap` (`True`) or `jax.lax.scan` (`False`) to
parallelize the integration. ``vmapped=True`` is recommended for GPU
usage, while ``vmapped=False`` is recommended for CPU usage.
usage, while ``vmapped=False`` is recommended for CPU usage. If
`None` (default), then `jax.vmap` is used on GPU and `jax.lax.scan`
otherwise.
Returns
-------
Expand All @@ -160,7 +163,9 @@ def run(
prog_o : Orbit
Orbit of the progenitor.
"""
# TODO: a discussion about the stripping times
# TODO: ꜛ a discussion about the stripping times
# Parse vmapped
use_vmap = get_backend().platform == "gpu" if vmapped is None else vmapped

# Integrate the progenitor orbit to the stripping times
prog_o = self.potential.integrate_orbit(
Expand All @@ -172,7 +177,7 @@ def run(
self.potential, prog_o, prog_mass, seed_num=seed_num
)

if vmapped:
if use_vmap:
lead_arm_qp, trail_arm_qp = self._run_vmap(ts, mock0_lead, mock0_trail)
else:
lead_arm_qp, trail_arm_qp = self._run_scan(ts, mock0_lead, mock0_trail)
Expand Down

0 comments on commit 0cd0ee9

Please sign in to comment.