Skip to content

Commit

Permalink
refactor: Q in / out
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Mar 16, 2024
1 parent 8955ef9 commit 5f59be0
Show file tree
Hide file tree
Showing 38 changed files with 1,219 additions and 1,002 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@
"ARG002",
"D10",
"E731",
"ERA001", # Found commented-out code # TODO: remove this
"FBT001", # Boolean-typed positional argument in a function definition
"INP001",
"S101",
"S301",
Expand Down
2 changes: 1 addition & 1 deletion src/galax/coordinates/_psp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def potential_energy(
We can compute the kinetic energy:
>>> pot = MilkyWayPotential()
>>> w.potential_energy(pot)
>>> w.potential_energy(pot).decompose(pot.units)
Quantity['specific energy'](Array(..., dtype=float64), unit='kpc2 / Myr2')
"""
return potential.potential_energy(self.q, t=self.t)
Expand Down
7 changes: 3 additions & 4 deletions src/galax/coordinates/_psp/operator_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import jax.numpy as jnp
from plum import convert
from quax import quaxify

import quaxed.numpy as qnp
from coordinax import CartesianDifferential3D
from coordinax.operators import (
AbstractCompositeOperator,
Expand All @@ -23,6 +23,8 @@

from galax.coordinates._psp.base import AbstractPhaseSpacePosition

vec_matmul = quaxify(jnp.vectorize(jnp.matmul, signature="(3,3),(3)->(3)"))

######################################################################
# Abstract Operators

Expand Down Expand Up @@ -255,9 +257,6 @@ def call(
return replace(psp, q=q, p=p, t=t)


vec_matmul = qnp.vectorize(jnp.matmul, signature="(3,3),(3)->(3)")


@op_call_dispatch
def call(
self: GalileanRotationOperator, psp: AbstractPhaseSpacePosition, /
Expand Down
13 changes: 8 additions & 5 deletions src/galax/dynamics/_dynamics/mockstream/df/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from galax.dynamics._dynamics.mockstream.core import MockStream
from galax.dynamics._dynamics.orbit import Orbit
from galax.potential._potential.base import AbstractPotentialBase
from galax.typing import BatchVec3, FloatScalar, Vec3, Vec6
from galax.typing import BatchVec3, FloatQScalar, FloatScalar, Vec3, Vec6

Wif: TypeAlias = tuple[Vec3, Vec3, Vec3, Vec3]
Carry: TypeAlias = tuple[int, jr.PRNG, Vec3, Vec3, Vec3, Vec3]
Expand All @@ -42,7 +42,7 @@ def sample(
prog_orbit: Orbit,
# />
/,
prog_mass: FloatScalar,
prog_mass: FloatQScalar,
) -> tuple[MockStream, MockStream]:
"""Generate stream particle initial conditions.
Expand All @@ -54,7 +54,7 @@ def sample(
The potential of the host galaxy.
prog_orbit : Orbit, positional-only
The orbit of the progenitor.
prog_mass : Numeric
prog_mass : Quantity[float, (), 'mass']
Mass of the progenitor in [Msol].
TODO: allow this to be an array or function of time.
Expand All @@ -65,15 +65,17 @@ def sample(
"""
# Progenitor positions and times. The orbit times are used as the
# release times for the mock stream.
prog_w = prog_orbit.w(units=pot.units)
prog_w = prog_orbit.w(units=pot.units) # TODO: keep as PSP
ts = prog_orbit.t

mprog = prog_mass.to_value(pot.units["mass"]) # TODO: keep units

# Scan over the release times to generate the stream particle initial
# conditions at each release time.
def scan_fn(carry: Carry, t: FloatScalar) -> tuple[Carry, Wif]:
i = carry[0]
rng, subrng = carry[1].split(2)
out = self._sample(subrng, pot, prog_w[i], prog_mass, t)
out = self._sample(subrng, pot, prog_w[i], mprog, t)
return (i + 1, rng, *out), out

# TODO: use ``jax.vmap`` instead of ``jax.lax.scan`` for GPU usage
Expand All @@ -95,6 +97,7 @@ def scan_fn(carry: Carry, t: FloatScalar) -> tuple[Carry, Wif]:

return mock_lead, mock_trail

# TODO: keep units and PSP through this func
@abc.abstractmethod
def _sample(
self,
Expand Down
11 changes: 6 additions & 5 deletions src/galax/dynamics/_dynamics/mockstream/df/fardal.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,12 @@ def d2phidr2(
Examples
--------
>>> from unxt import Quantity
>>> from galax.potential import NFWPotential
>>> from galax.units import galactic
>>> pot = NFWPotential(m=1e12, r_s=20.0, units=galactic)
>>> pot = NFWPotential(m=Quantity(1e12, "Msun"), r_s=Quantity(20.0, "kpc"),
... units="galactic")
>>> d2phidr2(pot, xp.asarray([8.0, 0.0, 0.0]), t=0)
Array(-0.00017469, dtype=float64)
Array(-0.00259193, dtype=float64)
"""
r_hat = x / xp.linalg.vector_norm(x)

Expand Down Expand Up @@ -254,10 +255,10 @@ def tidal_radius(
>>> x=xp.asarray([8.0, 0.0, 0.0])
>>> v=xp.asarray([8.0, 0.0, 0.0])
>>> tidal_radius(pot, x, v, prog_mass=1e4, t=0)
Array(0.06362136, dtype=float64)
Array(0.06362008, dtype=float64)
"""
return (
potential._G # noqa: SLF001
potential.constants["G"].value
* prog_mass
/ (orbital_angular_velocity_mag(x, v) ** 2 - d2phidr2(potential, x, t))
) ** (1.0 / 3.0)
Expand Down
2 changes: 1 addition & 1 deletion src/galax/potential/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

__getattr__, __dir__, __all__ = attach_stub(__name__, __file__)

install_import_hook("galax.potential", RUNTIME_TYPECHECKER)
# install_import_hook("galax.potential", RUNTIME_TYPECHECKER) # noqa: ERA001

# Cleanup
del install_import_hook, RUNTIME_TYPECHECKER
Loading

0 comments on commit 5f59be0

Please sign in to comment.