Skip to content

Commit

Permalink
Dimensions (#8)
Browse files Browse the repository at this point in the history
* add ruff ignore
* move subhalo potential
* partial_vmap
* change physical_type to dimensions
* improve density calculation

---------

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Dec 2, 2023
1 parent 5a71769 commit e3eaf24
Show file tree
Hide file tree
Showing 9 changed files with 201 additions and 151 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ src = ["src"]
extend-select = ["ALL"]
ignore = [
"ANN101", # Missing type annotation for self in method
"ANN102", # Missing type annotation for cls in classmethod
"ANN401", # Dynamically typed expressions (typing.Any) are disallowed in `*args`
"COM812", # Missing trailing comma in Python 3.6+
"D203", # 1 blank line required before class docstring
Expand Down
43 changes: 18 additions & 25 deletions src/galdynamix/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,31 +42,23 @@ def _init_units(self) -> None:
G = 1 if self.units == dimensionless else _G.decompose(self.units).value
object.__setattr__(self, "_G", G)

# Handle unit conversion for all ParameterField
# Handle unit conversion for all fields, e.g. the parameters.
for f in fields(self):
# Process ParameterFields
param = getattr(self.__class__, f.name, None)
if not isinstance(param, ParameterField):
continue

value = getattr(self, f.name)
if isinstance(value, u.Quantity):
value = value.to_value(
self.units[param.physical_type], equivalencies=param.equivalencies
)
object.__setattr__(self, f.name, value)

# other parameters, check their metadata
for f in fields(self):
if "physical_type" not in f.metadata:
continue

value = getattr(self, f.name)
if isinstance(value, u.Quantity):
value = value.to_value(
self.units[f.metadata["physical_type"]],
equivalencies=f.metadata.get("equivalencies", None),
)
object.__setattr__(self, f.name, value)
if isinstance(param, ParameterField):
# Set, since the ``.units`` are now known
param.__set__(self, getattr(self, f.name))

# Other fields, check their metadata
elif "dimensions" in f.metadata:
value = getattr(self, f.name)
if isinstance(value, u.Quantity):
value = value.to_value(
self.units[f.metadata["dimensions"]],
equivalencies=f.metadata.get("equivalencies", None),
)
object.__setattr__(self, f.name, value)

###########################################################################
# Core methods that use the above implemented functions
Expand All @@ -79,11 +71,12 @@ def __call__(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
@partial_jit()
def gradient(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
"""Compute the gradient."""
return jax.grad(self.potential_energy)(q, t)
return jax.grad(self.potential_energy, argnums=0)(q, t)

@partial_jit()
def density(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
lap = xp.trace(jax.hessian(self.potential_energy)(q, t))
# Note: trace(jacobian(gradient)) is faster than trace(hessian(energy))
lap = xp.trace(jax.jacfwd(self.gradient)(q, t))
return lap / (4 * xp.pi * self._G)

@partial_jit()
Expand Down
91 changes: 13 additions & 78 deletions src/galdynamix/potential/_potential/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,25 @@
"BarPotential",
"Isochrone",
"NFWPotential",
"SubHaloPopulation",
]

from dataclasses import KW_ONLY
from typing import Any

import equinox as eqx
import jax
import jax.numpy as xp
import jax.typing as jt
from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline

from galdynamix.potential._potential.base import AbstractPotential
from galdynamix.potential._potential.param import AbstractParameter, ParameterField
from galdynamix.units import galactic
from galdynamix.utils import partial_jit
from galdynamix.utils.dataclasses import field

# -------------------------------------------------------------------


class MiyamotoNagaiDisk(AbstractPotential):
m: AbstractParameter = ParameterField(physical_type="mass") # type: ignore[assignment]
a: AbstractParameter = ParameterField(physical_type="length") # type: ignore[assignment]
b: AbstractParameter = ParameterField(physical_type="length") # type: ignore[assignment]
m: AbstractParameter = ParameterField(dimensions="mass") # type: ignore[assignment]
a: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment]
b: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment]

@partial_jit()
def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
Expand All @@ -53,11 +47,11 @@ class BarPotential(AbstractPotential):
Rz according to https://en.wikipedia.org/wiki/Rotation_matrix
"""

m: AbstractParameter = ParameterField(physical_type="mass") # type: ignore[assignment]
a: AbstractParameter = ParameterField(physical_type="length") # type: ignore[assignment]
b: AbstractParameter = ParameterField(physical_type="length") # type: ignore[assignment]
c: AbstractParameter = ParameterField(physical_type="length") # type: ignore[assignment]
Omega: AbstractParameter = ParameterField(physical_type="frequency") # type: ignore[assignment]
m: AbstractParameter = ParameterField(dimensions="mass") # type: ignore[assignment]
a: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment]
b: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment]
c: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment]
Omega: AbstractParameter = ParameterField(dimensions="frequency") # type: ignore[assignment]

@partial_jit()
def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
Expand Down Expand Up @@ -96,8 +90,8 @@ def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array:


class Isochrone(AbstractPotential):
m: AbstractParameter = ParameterField(physical_type="mass") # type: ignore[assignment]
a: AbstractParameter = ParameterField(physical_type="length") # type: ignore[assignment]
m: AbstractParameter = ParameterField(dimensions="mass") # type: ignore[assignment]
a: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment]

@partial_jit()
def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
Expand All @@ -112,12 +106,10 @@ def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
class NFWPotential(AbstractPotential):
"""NFW Potential."""

m: AbstractParameter = ParameterField(physical_type="mass") # type: ignore[assignment]
r_s: AbstractParameter = ParameterField(physical_type="length") # type: ignore[assignment]
m: AbstractParameter = ParameterField(dimensions="mass") # type: ignore[assignment]
r_s: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment]
_: KW_ONLY
softening_length: jt.Array = field(
default=0.001, static=True, physical_type="length"
)
softening_length: jt.Array = field(default=0.001, static=True, dimensions="length")

@partial_jit()
def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
Expand All @@ -126,60 +118,3 @@ def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
q[0] ** 2 + q[1] ** 2 + q[2] ** 2 + self.softening_length,
) / self.r_s(t)
return v_h2 * xp.log(1.0 + m) / m


# -------------------------------------------------------------------


@jax.jit # type: ignore[misc]
def get_splines(x_eval: jt.Array, x: jt.Array, y: jt.Array) -> Any:
return InterpolatedUnivariateSpline(x, y, k=3)(x_eval)


@jax.jit # type: ignore[misc]
def single_subhalo_potential(
params: dict[str, jt.Array], q: jt.Array, /, t: jt.Array
) -> jt.Array:
"""Potential for a single subhalo.
TODO: custom unit specification/subhalo potential specficiation.
Currently supports units kpc, Myr, Msun, rad.
"""
pot_single = Isochrone(m=params["m"], a=params["a"], units=galactic)
return pot_single.potential_energy(q, t)


class SubHaloPopulation(AbstractPotential):
"""m has length n_subhalo.
a has length n_subhalo
tq_subhalo_arr has shape t_orbit x n_subhalo x 3
t_orbit is the array of times the subhalos are integrated over
"""

m: AbstractParameter = ParameterField(physical_type="mass") # type: ignore[assignment]
a: AbstractParameter = ParameterField(physical_type="length") # type: ignore[assignment]
tq_subhalo_arr: jt.Array = eqx.field(converter=xp.asarray)
t_orbit: jt.Array = eqx.field(converter=xp.asarray)

@partial_jit()
def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
# expect n_subhalo x-positions
x_at_t_eval = get_splines(t, self.t_orbit, self.tq_subhalo_arr[:, :, 0])
# expect n_subhalo y-positions
y_at_t_eval = get_splines(t, self.t_orbit, self.tq_subhalo_arr[:, :, 1])
# expect n_subhalo z-positions
z_at_t_eval = get_splines(t, self.t_orbit, self.tq_subhalo_arr[:, :, 2])

# n_subhalo x 3: the position of all subhalos at time t
subhalo_locations = xp.vstack([x_at_t_eval, y_at_t_eval, z_at_t_eval]).T

delta_position = q - subhalo_locations # n_subhalo x 3
# sum over potential due to all subhalos in the field by vmapping over
# m, a, and delta_position
return xp.sum(
jax.vmap(
single_subhalo_potential,
in_axes=(({"m": 0, "a": 0}, 0, None)),
)({"m": self.m(t), "a": self.a(t)}, delta_position, t),
)
73 changes: 50 additions & 23 deletions src/galdynamix/potential/_potential/param/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,6 @@
from galdynamix.utils import partial_jit


class ParameterCallable(Protocol):
"""Protocol for a Parameter callable."""

def __call__(self, t: jt.Array) -> jt.Array:
"""Compute the parameter value at the given time(s).
Parameters
----------
t : Array
Time(s) at which to compute the parameter value.
Returns
-------
Array
Parameter value(s) at the given time(s).
"""
...


# ====================================================================


class AbstractParameter(eqx.Module): # type: ignore[misc]
"""Abstract Base Class for Parameters on a Potential.
Expand All @@ -53,6 +31,18 @@ class AbstractParameter(eqx.Module): # type: ignore[misc]

@abc.abstractmethod
def __call__(self, t: jt.Array) -> jt.Array:
"""Compute the parameter value at the given time(s).
Parameters
----------
t : Array
The time(s) at which to compute the parameter value.
Returns
-------
Array
The parameter value at times ``t``.
"""
...


Expand All @@ -63,10 +53,47 @@ class ConstantParameter(AbstractParameter):
value: jt.Array

@partial_jit()
def __call__(self, t: jt.Array) -> jt.Array:
def __call__(self, t: jt.Array = 0) -> jt.Array:
"""Return the constant parameter value.
Parameters
----------
t : Array, optional
This is ignored and is thus optional.
Note that for most :class:`~galdynamix.potential.AbstractParameter`
the time is required.
Returns
-------
Array
The constant parameter value.
"""
return self.value


#####################################################################
# User-defined Parameter


class ParameterCallable(Protocol):
"""Protocol for a Parameter callable."""

def __call__(self, t: jt.Array) -> jt.Array:
"""Compute the parameter value at the given time(s).
Parameters
----------
t : Array
Time(s) at which to compute the parameter value.
Returns
-------
Array
Parameter value(s) at the given time(s).
"""
...


class UserParameter(AbstractParameter):
"""User-defined Parameter."""

Expand Down
33 changes: 18 additions & 15 deletions src/galdynamix/potential/_potential/param/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ class ParameterField:
Parameters
----------
physical_type : PhysicalType
Physical type of the parameter.
dimensions : PhysicalType
Dimensions (unit-wise) of the parameter.
equivalencies : Equivalency or tuple[Equivalency, ...], optional
Equivalencies to use when converting the parameter value to the
physical type. If not specified, the default equivalencies for the
Expand All @@ -30,21 +30,17 @@ class ParameterField:

name: str = field(init=False)
_: KW_ONLY
physical_type: u.PhysicalType # TODO: add a converter_argument
dimensions: u.PhysicalType # TODO: add a converter_argument
equivalencies: u.Equivalency | tuple[u.Equivalency, ...] | None = None

def __post_init__(self) -> None:
# Process the physical type
# TODO: move this to a ``converter`` argument for a custom
# ``dataclass_transform``'s ``__init__`` method.
if isinstance(self.physical_type, str):
object.__setattr__(
self, "physical_type", u.get_physical_type(self.physical_type)
)
elif not isinstance(self.physical_type, u.PhysicalType):
msg = (
"Expected physical_type to be a PhysicalType, "
f"got {self.physical_type!r}"
)
if isinstance(self.dimensions, str):
object.__setattr__(self, "dimensions", u.get_physical_type(self.dimensions))
elif not isinstance(self.dimensions, u.PhysicalType):
msg = f"Expected dimensions to be a PhysicalType, got {self.dimensions!r}"
raise TypeError(msg)

# ===========================================
Expand Down Expand Up @@ -88,18 +84,25 @@ def __set__(
) -> None:
# Convert
if isinstance(value, AbstractParameter):
# TODO: use the physical_type information to check the parameters.
# TODO: use the dimensions & equivalencies info to check the parameters.
# TODO: use the units on the `potential` to convert the parameter value.
pass
elif callable(value):
# TODO: use the physical_type information to check the parameters.
# TODO: use the dimensions & equivalencies info to check the parameters.
# TODO: use the units on the `potential` to convert the parameter value.
value = UserParameter(func=value)
else:
unit = potential.units[self.physical_type]
# TODO: the issue here is that ``units`` hasn't necessarily been set
# on the potential yet. What is needed is to possibly bail out
# here and defer the conversion until the units are set.
# AbstractPotentialBase has the ``_init_units`` method that
# can then call this method, hitting ``AbstractParameter``
# this time.
unit = potential.units[self.dimensions]
if isinstance(value, u.Quantity):
value = value.to_value(unit, equivalencies=self.equivalencies)

value = ConstantParameter(xp.asarray(value), unit=unit)

# Set
potential.__dict__[self.name] = value
Loading

0 comments on commit e3eaf24

Please sign in to comment.