Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Orbits #4

Merged
merged 6 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ repos:
rev: "v2.2.6"
hooks:
- id: codespell
exclude: notebooks

- repo: https://github.com/shellcheck-py/shellcheck-py
rev: "v0.9.0.6"
Expand Down
783 changes: 783 additions & 0 deletions notebooks/example.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ extend-select = [
"YTT", # flake8-2020
"EXE", # flake8-executable
"NPY", # NumPy specific rules
"PD", # pandas-vet
]
ignore = [
"PD", # pandas-vet
"PLR", # Design related pylint codes
# TODO! fix these
"ARG001",
Expand Down
4 changes: 3 additions & 1 deletion src/galdynamix/dynamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

from __future__ import annotations

from . import mockstream
from . import _orbit, mockstream
from ._orbit import *
from .mockstream import *

__all__: list[str] = []
__all__ += _orbit.__all__
__all__ += mockstream.__all__
46 changes: 46 additions & 0 deletions src/galdynamix/dynamics/_orbit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""galdynamix: Galactic Dynamix in Jax"""

from __future__ import annotations

__all__ = ["Orbit"]


import equinox as eqx
import jax.numpy as xp
import jax.typing as jt

from galdynamix.potential._potential.base import AbstractPotentialBase


class Orbit(eqx.Module): # type: ignore[misc]
"""Orbit.

TODO:
- Units stuff
- GR stuff
"""

q: jt.Array
"""Position of the stream particles (x, y, z) [kpc]."""

p: jt.Array
"""Position of the stream particles (x, y, z) [kpc/Myr]."""

t: jt.Array
"""Release time of the stream particles [Myr]."""

potential: AbstractPotentialBase
"""Potential in which the orbit was integrated."""

def to_w(self) -> jt.Array:
t = self.t[:, None] if self.t.ndim == 1 else self.t
out = xp.empty(
(
self.q.shape[0],
self.q.shape[1] + self.p.shape[1] + t.shape[1],
)
)
out = out.at[:, : self.q.shape[1]].set(self.q)
out = out.at[:, self.q.shape[1] : -1].set(self.p)
out = out.at[:, -1:].set(t)
return out # noqa: RET504
6 changes: 3 additions & 3 deletions src/galdynamix/dynamics/mockstream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from __future__ import annotations

from . import _df, _mockstream
from . import _df, _mockstream_generator
from ._df import *
from ._mockstream import *
from ._mockstream_generator import *

__all__: list[str] = []
__all__ += _df.__all__
__all__ += _mockstream.__all__
__all__ += _mockstream_generator.__all__
29 changes: 29 additions & 0 deletions src/galdynamix/dynamics/mockstream/_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""galdynamix: Galactic Dynamix in Jax"""

from __future__ import annotations

__all__ = ["MockStream"]


import equinox as eqx
import jax.typing as jt


class MockStream(eqx.Module): # type: ignore[misc]
"""Mock stream object.

TODO:
- units stuff
- change this to be a collection of sub-objects: progenitor, leading arm,
trailing arm, 3-body ejecta, etc.
- GR 4-vector stuff
"""

q: jt.Array
"""Position of the stream particles (x, y, z) [kpc]."""

p: jt.Array
"""Position of the stream particles (x, y, z) [kpc/Myr]."""

release_time: jt.Array
"""Release time of the stream particles [Myr]."""
41 changes: 20 additions & 21 deletions src/galdynamix/dynamics/mockstream/_df/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""galdynamix: Galactic Dynamix in Jax"""
# ruff: noqa: F403

from __future__ import annotations

Expand All @@ -13,6 +12,8 @@
import jax.numpy as xp
import jax.typing as jt

from galdynamix.dynamics._orbit import Orbit
from galdynamix.dynamics.mockstream._core import MockStream
from galdynamix.potential._potential.base import AbstractPotentialBase
from galdynamix.utils import partial_jit

Expand All @@ -35,46 +36,43 @@ def sample(
self,
# <\ parts of gala's ``prog_orbit``
potential: AbstractPotentialBase,
prog_ws: jt.Array,
ts: jt.Numeric,
prog_orbit: Orbit,
# />
prog_mass: jt.Numeric,
*,
seed_num: int,
) -> tuple[jt.Array, jt.Array, jt.Array, jt.Array]:
) -> tuple[MockStream, MockStream]:
"""Generate stream particle initial conditions.

Parameters
----------
potential : AbstractPotentialBase
The potential of the host galaxy.
prog_ws : Array[(N, 6), float]
Columns are (x, y, z) [kpc], (v_x, v_y, v_z) [kpc/Myr]
Rows are at times `ts`.
prog_orbit : Orbit
The orbit of the progenitor.
prog_mass : Numeric
Mass of the progenitor in [Msol].
TODO: allow this to be an array or function of time.
ts : Numeric
Times in [Myr]

seed_num : int, keyword-only
PRNG seed

Returns
-------
x_lead, x_trail, v_lead, v_trail : Array
mock_lead, mock_trail : MockStream
Positions and velocities of the leading and trailing tails.
"""
prog_ws = prog_orbit.to_w()[:, :-1] # -1 is time
ts = prog_orbit.t

def scan_fn(carry: _carryT, t: Any) -> tuple[_carryT, _wifT]:
i = carry[0]
output = self._sample(
potential,
prog_ws[i, :3],
prog_ws[i, 3:],
prog_ws[i],
prog_mass,
i,
t,
i=i,
seed_num=seed_num,
)
return (i + 1, *output), tuple(output) # type: ignore[return-value]
Expand All @@ -87,18 +85,21 @@ def scan_fn(carry: _carryT, t: Any) -> tuple[_carryT, _wifT]:
xp.array([0.0, 0.0, 0.0]),
)
x_lead, x_trail, v_lead, v_trail = jax.lax.scan(scan_fn, init_carry, ts[1:])[1]
return x_lead, x_trail, v_lead, v_trail

mock_lead = MockStream(x_lead, v_lead, ts[1:])
mock_trail = MockStream(x_trail, v_trail, ts[1:])

return mock_lead, mock_trail

@abc.abstractmethod
def _sample(
self,
potential: AbstractPotentialBase,
x: jt.Array,
v: jt.Array,
w: jt.Array,
prog_mass: jt.Numeric,
i: int,
t: jt.Numeric,
*,
i: int,
seed_num: int,
) -> tuple[jt.Array, jt.Array, jt.Array, jt.Array]:
"""Generate stream particle initial conditions.
Expand All @@ -107,10 +108,8 @@ def _sample(
----------
potential : AbstractPotentialBase
The potential of the host galaxy.
x : Array
3d position (x, y, z) in [kpc]
v : Array
3d velocity (v_x, v_y, v_z) in [kpc/Myr]
w : Array
6d position (x, y, z) [kpc], (v_x, v_y, v_z) [kpc/Myr]
prog_mass : Numeric
Mass of the progenitor in [Msol]
t : Numeric
Expand Down
24 changes: 10 additions & 14 deletions src/galdynamix/dynamics/mockstream/_df/fardal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""galdynamix: Galactic Dynamix in Jax"""
# ruff: noqa: F403

from __future__ import annotations

Expand All @@ -20,20 +19,15 @@ class FardalStreamDF(AbstractStreamDF):
@partial_jit(static_argnums=(0,), static_argnames=("seed_num",))
def _sample(
self,
# parts of gala's ``prog_orbit``
potential: AbstractPotentialBase,
x: jt.Array,
v: jt.Array,
w: jt.Array,
prog_mass: jt.Array,
i: int,
t: jt.Array,
*,
i: int,
seed_num: int,
) -> tuple[jt.Array, jt.Array, jt.Array, jt.Array]:
"""
Simplification of particle spray: just release particles in gaussian blob at each lagrange point.
User sets the spatial and velocity dispersion for the "leaking" of particles
"""
"""Generate stream particle initial conditions."""
# Random number generation
# TODO: change random key handling... need to do all of the sampling up front...
key_master = jax.random.PRNGKey(seed_num)
Expand All @@ -47,6 +41,8 @@ def _sample(

# ---------------------------

x, v = w[:3], w[3:]

omega_val = orbital_angular_velocity_mag(x, v)

r = xp.linalg.norm(x)
Expand Down Expand Up @@ -85,8 +81,8 @@ def _sample(
########kvt_samp = kvt_bar + jax.random.normal(keye,shape=(1,))*sigma_kvt

## Trailing arm
pos_trail = x + kr_samp * r_hat * (r_tidal) # nudge out
pos_trail = pos_trail + z_hat * kz_samp * (
x_trail = x + kr_samp * r_hat * (r_tidal) # nudge out
x_trail = x_trail + z_hat * kz_samp * (
r_tidal / 1.0
) # r #nudge above/below orbital plane
v_trail = (
Expand All @@ -97,8 +93,8 @@ def _sample(
) # v_trail + (kvz_samp*v_circ*(-r_tidal/r))*z_hat #nudge velocity along vertical direction

## Leading arm
pos_lead = x + kr_samp * r_hat * (-r_tidal) # nudge in
pos_lead = pos_lead + z_hat * kz_samp * (
x_lead = x + kr_samp * r_hat * (-r_tidal) # nudge in
x_lead = x_lead + z_hat * kz_samp * (
-r_tidal / 1.0
) # r #nudge above/below orbital plane
v_lead = (
Expand All @@ -108,7 +104,7 @@ def _sample(
v_lead + (kvz_samp * v_circ * (-1.0)) * z_hat
) # v_lead + (kvz_samp*v_circ*(r_tidal/r))*z_hat #nudge velocity against vertical direction

return pos_lead, pos_trail, v_lead, v_trail
return x_lead, x_trail, v_lead, v_trail


#####################################################################
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""galdynamix: Galactic Dynamix in Jax"""
# ruff: noqa: F403

from __future__ import annotations

Expand All @@ -21,6 +20,8 @@
_wifT: TypeAlias = tuple[jt.Array, jt.Array, jt.Array, jt.Array]
_carryT: TypeAlias = tuple[int, jt.Array, jt.Array, jt.Array, jt.Array]

from galdynamix.dynamics._orbit import Orbit


class MockStreamGenerator(eqx.Module): # type: ignore[misc]
df: AbstractStreamDF
Expand All @@ -36,18 +37,20 @@ class MockStreamGenerator(eqx.Module): # type: ignore[misc]
@partial_jit(static_argnames=("seed_num",))
def _run_scan(
self, ts: jt.Array, prog_w0: jt.Array, prog_mass: jt.Array, *, seed_num: int
) -> tuple[tuple[jt.Array, jt.Array], jt.Array]:
) -> tuple[tuple[jt.Array, jt.Array], Orbit]:
"""Generate stellar stream by scanning over the release model/integration.

Better for CPU usage.
"""
# Integrate the progenitor orbit
prog_ws = self.potential.integrate_orbit(prog_w0, xp.min(ts), xp.max(ts), ts)
prog_o = self.potential.integrate_orbit(prog_w0, xp.min(ts), xp.max(ts), ts)

# Generate stream initial conditions along the integrated progenitor orbit
x_lead, x_trail, v_lead, v_trail = self.df.sample(
self.potential, prog_ws, ts, prog_mass, seed_num=seed_num
mock_lead, mock_trail = self.df.sample(
self.potential, prog_o, prog_mass, seed_num=seed_num
)
x_lead, v_lead = mock_lead.q, mock_lead.p
x_trail, v_trail = mock_trail.q, mock_trail.p

def scan_fn(
carry: _carryT, particle_idx: int
Expand All @@ -60,7 +63,7 @@ def scan_fn(
minval, maxval = ts[i], ts[-1]
integ_ics = lambda ics: self.potential.integrate_orbit( # noqa: E731
ics, minval, maxval, None
)[0]
).to_w()[0, :-1]
# vmap over leading and trailing arm
w_lead, w_trail = jax.vmap(integ_ics, in_axes=(0,))(w0_lead_trail)
carry_out = (
Expand All @@ -75,22 +78,24 @@ def scan_fn(
carry_init = (0, x_lead[0, :], x_trail[0, :], v_lead[0, :], v_trail[0, :])
particle_ids = xp.arange(len(x_lead))
lead_arm, trail_arm = jax.lax.scan(scan_fn, carry_init, particle_ids)[1]
return (lead_arm, trail_arm), prog_ws
return (lead_arm, trail_arm), prog_o

@partial_jit(static_argnames=("seed_num",))
def _run_vmap(
self, ts: jt.Array, prog_w0: jt.Array, prog_mass: jt.Array, *, seed_num: int
) -> tuple[tuple[jt.Array, jt.Array], jt.Array]:
) -> tuple[tuple[jt.Array, jt.Array], Orbit]:
"""
Generate stellar stream by vmapping over the release model/integration. Better for GPU usage.
"""
# Integrate the progenitor orbit
prog_ws = self.potential.integrate_orbit(prog_w0, xp.min(ts), xp.max(ts), ts)
prog_o = self.potential.integrate_orbit(prog_w0, xp.min(ts), xp.max(ts), ts)

# Generate stream initial conditions along the integrated progenitor orbit
x_lead, x_trail, v_lead, v_trail = self.df.sample(
self.potential, prog_ws, ts, prog_mass, seed_num=seed_num
mock_lead, mock_trail = self.df.sample(
self.potential, prog_o, prog_mass, seed_num=seed_num
)
x_lead, v_lead = mock_lead.q, mock_lead.p
x_trail, v_trail = mock_trail.q, mock_trail.p

# TODO: make this a separated method
@jax.jit # type: ignore[misc]
Expand All @@ -115,7 +120,7 @@ def single_particle_integrate(

integrator = jax.vmap(single_particle_integrate, in_axes=(0, 0, 0, 0, 0))
w_lead, w_trail = integrator(particle_ids, x_lead, x_trail, v_lead, v_trail)
return (w_lead, w_trail), prog_ws
return (w_lead, w_trail), prog_o

@partial_jit(static_argnames=("seed_num", "vmapped"))
def run(
Expand Down
Loading