From 4c037e4de9fe65b5baa435bf905b6d92a448110b Mon Sep 17 00:00:00 2001 From: Adrian Price-Whelan Date: Mon, 1 Jan 2024 12:31:54 -0500 Subject: [PATCH 1/2] remove niche subhalo module and jax_cosmo dep --- pyproject.toml | 1 - src/galax/potential/_potential/subhalo.py | 76 ----------------------- 2 files changed, 77 deletions(-) delete mode 100644 src/galax/potential/_potential/subhalo.py diff --git a/pyproject.toml b/pyproject.toml index 71653688..0ac2a0c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,6 @@ dependencies = [ "diffrax", "equinox", "jax", - "jax_cosmo", "typing_extensions", ] diff --git a/src/galax/potential/_potential/subhalo.py b/src/galax/potential/_potential/subhalo.py deleted file mode 100644 index 9361ec01..00000000 --- a/src/galax/potential/_potential/subhalo.py +++ /dev/null @@ -1,76 +0,0 @@ -"""galax: Galactic Dynamix in Jax.""" - - -__all__ = [ - "SubHaloPopulation", -] - -from typing import Any - -import equinox as eqx -import jax -import jax.numpy as xp -import jax.typing as jt -from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline - -from galax.potential._potential.builtin import IsochronePotential -from galax.potential._potential.core import AbstractPotential -from galax.potential._potential.param import AbstractParameter, ParameterField -from galax.units import galactic -from galax.utils import partial_jit - -# ------------------------------------------------------------------- - - -@jax.jit # type: ignore[misc] -def get_splines(x_eval: jt.Array, x: jt.Array, y: jt.Array) -> Any: - return InterpolatedUnivariateSpline(x, y, k=3)(x_eval) - - -@jax.jit # type: ignore[misc] -def single_subhalo_potential( - params: dict[str, jt.Array], q: jt.Array, /, t: jt.Array -) -> jt.Array: - """Potential for a single subhalo. - - TODO: custom unit specification/subhalo potential specficiation. - Currently supports units kpc, Myr, Msun, rad. - """ - pot_single = IsochronePotential(m=params["m"], a=params["a"], units=galactic) - return pot_single.potential_energy(q, t) - - -class SubHaloPopulation(AbstractPotential): - """m has length n_subhalo. - - a has length n_subhalo - tq_subhalo_arr has shape t_orbit x n_subhalo x 3 - t_orbit is the array of times the subhalos are integrated over - """ - - m: AbstractParameter = ParameterField(dimensions="mass") # type: ignore[assignment] - a: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment] - tq_subhalo_arr: jt.Array = eqx.field(converter=xp.asarray) - t_orbit: jt.Array = eqx.field(converter=xp.asarray) - - @partial_jit() - def _potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array: - # expect n_subhalo x-positions - x_at_t_eval = get_splines(t, self.t_orbit, self.tq_subhalo_arr[:, :, 0]) - # expect n_subhalo y-positions - y_at_t_eval = get_splines(t, self.t_orbit, self.tq_subhalo_arr[:, :, 1]) - # expect n_subhalo z-positions - z_at_t_eval = get_splines(t, self.t_orbit, self.tq_subhalo_arr[:, :, 2]) - - # n_subhalo x 3: the position of all subhalos at time t - subhalo_locations = xp.vstack([x_at_t_eval, y_at_t_eval, z_at_t_eval]).T - - delta_position = q - subhalo_locations # n_subhalo x 3 - # sum over potential due to all subhalos in the field by vmapping over - # m, a, and delta_position - return xp.sum( - jax.vmap( - single_subhalo_potential, - in_axes=(({"m": 0, "a": 0}, 0, None)), - )({"m": self.m(t), "a": self.a(t)}, delta_position, t), - ) From a7341c42431f7f1022a598abed30b083636acb6e Mon Sep 17 00:00:00 2001 From: Adrian Price-Whelan Date: Mon, 1 Jan 2024 12:37:53 -0500 Subject: [PATCH 2/2] add explicit jaxlib dep --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 0ac2a0c7..09f77296 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "diffrax", "equinox", "jax", + "jaxlib", "typing_extensions", ]