Skip to content

Commit

Permalink
Orbits (#4)
Browse files Browse the repository at this point in the history
* Add MockStream class

Signed-off-by: nstarman <nstarman@users.noreply.github.com>

* include time as integrator output

Signed-off-by: nstarman <nstarman@users.noreply.github.com>

* move i in df sampler

Signed-off-by: nstarman <nstarman@users.noreply.github.com>

* add Orbit class

Signed-off-by: nstarman <nstarman@users.noreply.github.com>

* change private method name

Signed-off-by: nstarman <nstarman@users.noreply.github.com>

* Add example notebook

Signed-off-by: nstarman <nstarman@users.noreply.github.com>

---------

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Nov 9, 2023
1 parent b79cef4 commit 847927f
Show file tree
Hide file tree
Showing 13 changed files with 924 additions and 57 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ repos:
rev: "v2.2.6"
hooks:
- id: codespell
exclude: notebooks

- repo: https://github.com/shellcheck-py/shellcheck-py
rev: "v0.9.0.6"
Expand Down
783 changes: 783 additions & 0 deletions notebooks/example.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ extend-select = [
"YTT", # flake8-2020
"EXE", # flake8-executable
"NPY", # NumPy specific rules
"PD", # pandas-vet
]
ignore = [
"PD", # pandas-vet
"PLR", # Design related pylint codes
# TODO! fix these
"ARG001",
Expand Down
4 changes: 3 additions & 1 deletion src/galdynamix/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

from __future__ import annotations

from . import mockstream
from . import _orbit, mockstream
from ._orbit import *
from .mockstream import *

__all__: list[str] = []
__all__ += _orbit.__all__
__all__ += mockstream.__all__
46 changes: 46 additions & 0 deletions src/galdynamix/dynamics/_orbit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""galdynamix: Galactic Dynamix in Jax"""

from __future__ import annotations

__all__ = ["Orbit"]


import equinox as eqx
import jax.numpy as xp
import jax.typing as jt

from galdynamix.potential._potential.base import AbstractPotentialBase


class Orbit(eqx.Module): # type: ignore[misc]
"""Orbit.
TODO:
- Units stuff
- GR stuff
"""

q: jt.Array
"""Position of the stream particles (x, y, z) [kpc]."""

p: jt.Array
"""Position of the stream particles (x, y, z) [kpc/Myr]."""

t: jt.Array
"""Release time of the stream particles [Myr]."""

potential: AbstractPotentialBase
"""Potential in which the orbit was integrated."""

def to_w(self) -> jt.Array:
t = self.t[:, None] if self.t.ndim == 1 else self.t
out = xp.empty(
(
self.q.shape[0],
self.q.shape[1] + self.p.shape[1] + t.shape[1],
)
)
out = out.at[:, : self.q.shape[1]].set(self.q)
out = out.at[:, self.q.shape[1] : -1].set(self.p)
out = out.at[:, -1:].set(t)
return out # noqa: RET504
6 changes: 3 additions & 3 deletions src/galdynamix/dynamics/mockstream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from __future__ import annotations

from . import _df, _mockstream
from . import _df, _mockstream_generator
from ._df import *
from ._mockstream import *
from ._mockstream_generator import *

__all__: list[str] = []
__all__ += _df.__all__
__all__ += _mockstream.__all__
__all__ += _mockstream_generator.__all__
29 changes: 29 additions & 0 deletions src/galdynamix/dynamics/mockstream/_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""galdynamix: Galactic Dynamix in Jax"""

from __future__ import annotations

__all__ = ["MockStream"]


import equinox as eqx
import jax.typing as jt


class MockStream(eqx.Module): # type: ignore[misc]
"""Mock stream object.
TODO:
- units stuff
- change this to be a collection of sub-objects: progenitor, leading arm,
trailing arm, 3-body ejecta, etc.
- GR 4-vector stuff
"""

q: jt.Array
"""Position of the stream particles (x, y, z) [kpc]."""

p: jt.Array
"""Position of the stream particles (x, y, z) [kpc/Myr]."""

release_time: jt.Array
"""Release time of the stream particles [Myr]."""
41 changes: 20 additions & 21 deletions src/galdynamix/dynamics/mockstream/_df/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""galdynamix: Galactic Dynamix in Jax"""
# ruff: noqa: F403

from __future__ import annotations

Expand All @@ -13,6 +12,8 @@
import jax.numpy as xp
import jax.typing as jt

from galdynamix.dynamics._orbit import Orbit
from galdynamix.dynamics.mockstream._core import MockStream
from galdynamix.potential._potential.base import AbstractPotentialBase
from galdynamix.utils import partial_jit

Expand All @@ -35,46 +36,43 @@ def sample(
self,
# <\ parts of gala's ``prog_orbit``
potential: AbstractPotentialBase,
prog_ws: jt.Array,
ts: jt.Numeric,
prog_orbit: Orbit,
# />
prog_mass: jt.Numeric,
*,
seed_num: int,
) -> tuple[jt.Array, jt.Array, jt.Array, jt.Array]:
) -> tuple[MockStream, MockStream]:
"""Generate stream particle initial conditions.
Parameters
----------
potential : AbstractPotentialBase
The potential of the host galaxy.
prog_ws : Array[(N, 6), float]
Columns are (x, y, z) [kpc], (v_x, v_y, v_z) [kpc/Myr]
Rows are at times `ts`.
prog_orbit : Orbit
The orbit of the progenitor.
prog_mass : Numeric
Mass of the progenitor in [Msol].
TODO: allow this to be an array or function of time.
ts : Numeric
Times in [Myr]
seed_num : int, keyword-only
PRNG seed
Returns
-------
x_lead, x_trail, v_lead, v_trail : Array
mock_lead, mock_trail : MockStream
Positions and velocities of the leading and trailing tails.
"""
prog_ws = prog_orbit.to_w()[:, :-1] # -1 is time
ts = prog_orbit.t

def scan_fn(carry: _carryT, t: Any) -> tuple[_carryT, _wifT]:
i = carry[0]
output = self._sample(
potential,
prog_ws[i, :3],
prog_ws[i, 3:],
prog_ws[i],
prog_mass,
i,
t,
i=i,
seed_num=seed_num,
)
return (i + 1, *output), tuple(output) # type: ignore[return-value]
Expand All @@ -87,18 +85,21 @@ def scan_fn(carry: _carryT, t: Any) -> tuple[_carryT, _wifT]:
xp.array([0.0, 0.0, 0.0]),
)
x_lead, x_trail, v_lead, v_trail = jax.lax.scan(scan_fn, init_carry, ts[1:])[1]
return x_lead, x_trail, v_lead, v_trail

mock_lead = MockStream(x_lead, v_lead, ts[1:])
mock_trail = MockStream(x_trail, v_trail, ts[1:])

return mock_lead, mock_trail

@abc.abstractmethod
def _sample(
self,
potential: AbstractPotentialBase,
x: jt.Array,
v: jt.Array,
w: jt.Array,
prog_mass: jt.Numeric,
i: int,
t: jt.Numeric,
*,
i: int,
seed_num: int,
) -> tuple[jt.Array, jt.Array, jt.Array, jt.Array]:
"""Generate stream particle initial conditions.
Expand All @@ -107,10 +108,8 @@ def _sample(
----------
potential : AbstractPotentialBase
The potential of the host galaxy.
x : Array
3d position (x, y, z) in [kpc]
v : Array
3d velocity (v_x, v_y, v_z) in [kpc/Myr]
w : Array
6d position (x, y, z) [kpc], (v_x, v_y, v_z) [kpc/Myr]
prog_mass : Numeric
Mass of the progenitor in [Msol]
t : Numeric
Expand Down
24 changes: 10 additions & 14 deletions src/galdynamix/dynamics/mockstream/_df/fardal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""galdynamix: Galactic Dynamix in Jax"""
# ruff: noqa: F403

from __future__ import annotations

Expand All @@ -20,20 +19,15 @@ class FardalStreamDF(AbstractStreamDF):
@partial_jit(static_argnums=(0,), static_argnames=("seed_num",))
def _sample(
self,
# parts of gala's ``prog_orbit``
potential: AbstractPotentialBase,
x: jt.Array,
v: jt.Array,
w: jt.Array,
prog_mass: jt.Array,
i: int,
t: jt.Array,
*,
i: int,
seed_num: int,
) -> tuple[jt.Array, jt.Array, jt.Array, jt.Array]:
"""
Simplification of particle spray: just release particles in gaussian blob at each lagrange point.
User sets the spatial and velocity dispersion for the "leaking" of particles
"""
"""Generate stream particle initial conditions."""
# Random number generation
# TODO: change random key handling... need to do all of the sampling up front...
key_master = jax.random.PRNGKey(seed_num)
Expand All @@ -47,6 +41,8 @@ def _sample(

# ---------------------------

x, v = w[:3], w[3:]

omega_val = orbital_angular_velocity_mag(x, v)

r = xp.linalg.norm(x)
Expand Down Expand Up @@ -85,8 +81,8 @@ def _sample(
########kvt_samp = kvt_bar + jax.random.normal(keye,shape=(1,))*sigma_kvt

## Trailing arm
pos_trail = x + kr_samp * r_hat * (r_tidal) # nudge out
pos_trail = pos_trail + z_hat * kz_samp * (
x_trail = x + kr_samp * r_hat * (r_tidal) # nudge out
x_trail = x_trail + z_hat * kz_samp * (
r_tidal / 1.0
) # r #nudge above/below orbital plane
v_trail = (
Expand All @@ -97,8 +93,8 @@ def _sample(
) # v_trail + (kvz_samp*v_circ*(-r_tidal/r))*z_hat #nudge velocity along vertical direction

## Leading arm
pos_lead = x + kr_samp * r_hat * (-r_tidal) # nudge in
pos_lead = pos_lead + z_hat * kz_samp * (
x_lead = x + kr_samp * r_hat * (-r_tidal) # nudge in
x_lead = x_lead + z_hat * kz_samp * (
-r_tidal / 1.0
) # r #nudge above/below orbital plane
v_lead = (
Expand All @@ -108,7 +104,7 @@ def _sample(
v_lead + (kvz_samp * v_circ * (-1.0)) * z_hat
) # v_lead + (kvz_samp*v_circ*(r_tidal/r))*z_hat #nudge velocity against vertical direction

return pos_lead, pos_trail, v_lead, v_trail
return x_lead, x_trail, v_lead, v_trail


#####################################################################
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""galdynamix: Galactic Dynamix in Jax"""
# ruff: noqa: F403

from __future__ import annotations

Expand All @@ -21,6 +20,8 @@
_wifT: TypeAlias = tuple[jt.Array, jt.Array, jt.Array, jt.Array]
_carryT: TypeAlias = tuple[int, jt.Array, jt.Array, jt.Array, jt.Array]

from galdynamix.dynamics._orbit import Orbit


class MockStreamGenerator(eqx.Module): # type: ignore[misc]
df: AbstractStreamDF
Expand All @@ -36,18 +37,20 @@ class MockStreamGenerator(eqx.Module): # type: ignore[misc]
@partial_jit(static_argnames=("seed_num",))
def _run_scan(
self, ts: jt.Array, prog_w0: jt.Array, prog_mass: jt.Array, *, seed_num: int
) -> tuple[tuple[jt.Array, jt.Array], jt.Array]:
) -> tuple[tuple[jt.Array, jt.Array], Orbit]:
"""Generate stellar stream by scanning over the release model/integration.
Better for CPU usage.
"""
# Integrate the progenitor orbit
prog_ws = self.potential.integrate_orbit(prog_w0, xp.min(ts), xp.max(ts), ts)
prog_o = self.potential.integrate_orbit(prog_w0, xp.min(ts), xp.max(ts), ts)

# Generate stream initial conditions along the integrated progenitor orbit
x_lead, x_trail, v_lead, v_trail = self.df.sample(
self.potential, prog_ws, ts, prog_mass, seed_num=seed_num
mock_lead, mock_trail = self.df.sample(
self.potential, prog_o, prog_mass, seed_num=seed_num
)
x_lead, v_lead = mock_lead.q, mock_lead.p
x_trail, v_trail = mock_trail.q, mock_trail.p

def scan_fn(
carry: _carryT, particle_idx: int
Expand All @@ -60,7 +63,7 @@ def scan_fn(
minval, maxval = ts[i], ts[-1]
integ_ics = lambda ics: self.potential.integrate_orbit( # noqa: E731
ics, minval, maxval, None
)[0]
).to_w()[0, :-1]
# vmap over leading and trailing arm
w_lead, w_trail = jax.vmap(integ_ics, in_axes=(0,))(w0_lead_trail)
carry_out = (
Expand All @@ -75,22 +78,24 @@ def scan_fn(
carry_init = (0, x_lead[0, :], x_trail[0, :], v_lead[0, :], v_trail[0, :])
particle_ids = xp.arange(len(x_lead))
lead_arm, trail_arm = jax.lax.scan(scan_fn, carry_init, particle_ids)[1]
return (lead_arm, trail_arm), prog_ws
return (lead_arm, trail_arm), prog_o

@partial_jit(static_argnames=("seed_num",))
def _run_vmap(
self, ts: jt.Array, prog_w0: jt.Array, prog_mass: jt.Array, *, seed_num: int
) -> tuple[tuple[jt.Array, jt.Array], jt.Array]:
) -> tuple[tuple[jt.Array, jt.Array], Orbit]:
"""
Generate stellar stream by vmapping over the release model/integration. Better for GPU usage.
"""
# Integrate the progenitor orbit
prog_ws = self.potential.integrate_orbit(prog_w0, xp.min(ts), xp.max(ts), ts)
prog_o = self.potential.integrate_orbit(prog_w0, xp.min(ts), xp.max(ts), ts)

# Generate stream initial conditions along the integrated progenitor orbit
x_lead, x_trail, v_lead, v_trail = self.df.sample(
self.potential, prog_ws, ts, prog_mass, seed_num=seed_num
mock_lead, mock_trail = self.df.sample(
self.potential, prog_o, prog_mass, seed_num=seed_num
)
x_lead, v_lead = mock_lead.q, mock_lead.p
x_trail, v_trail = mock_trail.q, mock_trail.p

# TODO: make this a separated method
@jax.jit # type: ignore[misc]
Expand All @@ -115,7 +120,7 @@ def single_particle_integrate(

integrator = jax.vmap(single_particle_integrate, in_axes=(0, 0, 0, 0, 0))
w_lead, w_trail = integrator(particle_ids, x_lead, x_trail, v_lead, v_trail)
return (w_lead, w_trail), prog_ws
return (w_lead, w_trail), prog_o

@partial_jit(static_argnames=("seed_num", "vmapped"))
def run(
Expand Down
Loading

0 comments on commit 847927f

Please sign in to comment.