diff --git a/.copier-answers.yml b/.copier-answers.yml index 4256caa8..9461176d 100644 --- a/.copier-answers.yml +++ b/.copier-answers.yml @@ -3,10 +3,10 @@ _commit: 2023.10.27 _src_path: gh:scientific-python/cookie backend: hatch email: nstarman@users.noreply.github.com -full_name: Nathaniel Starkman -license: BSD -org: nstarman +full_name: galdynamix maintainers +license: MIT +org: galdynamix project_name: galdynamix -project_short_description: Galactic Dynamix in Jax -url: https://github.com/nstarman/galdynamix +project_short_description: Galactic Dynamix in Jax. +url: https://github.com/galdynamix/galdynamix vcs: true diff --git a/LICENSE b/LICENSE index 48f9efbc..68e13639 100644 --- a/LICENSE +++ b/LICENSE @@ -1,29 +1,19 @@ -BSD 3-Clause License +Copyright 2023 galdynamix maintainers -Copyright (c) 2023, Nathaniel Starkman. -All rights reserved. +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -* Neither the name of the vector package developers nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index faad22d3..474d342e 100644 --- a/README.md +++ b/README.md @@ -12,12 +12,12 @@ -[actions-badge]: https://github.com/nstarman/galdynamix/workflows/CI/badge.svg -[actions-link]: https://github.com/nstarman/galdynamix/actions +[actions-badge]: https://github.com/galdynamix/galdynamix/workflows/CI/badge.svg +[actions-link]: https://github.com/galdynamix/galdynamix/actions [conda-badge]: https://img.shields.io/conda/vn/conda-forge/galdynamix [conda-link]: https://github.com/conda-forge/galdynamix-feedstock [github-discussions-badge]: https://img.shields.io/static/v1?label=Discussions&message=Ask&color=blue&logo=github -[github-discussions-link]: https://github.com/nstarman/galdynamix/discussions +[github-discussions-link]: https://github.com/galdynamix/galdynamix/discussions [pypi-link]: https://pypi.org/project/galdynamix/ [pypi-platforms]: https://img.shields.io/pypi/pyversions/galdynamix [pypi-version]: https://img.shields.io/pypi/v/galdynamix diff --git a/docs/conf.py b/docs/conf.py index d6fc8bc8..884e55f1 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,10 +1,12 @@ +"""Sphinx configuration.""" + from __future__ import annotations import importlib.metadata project = "galdynamix" -copyright = "2023, Nathaniel Starkman" -author = "Nathaniel Starkman" +copyright = "2023, galdynamix maintainers" +author = "galdynamix maintainers" version = release = importlib.metadata.version("galdynamix") extensions = [ diff --git a/noxfile.py b/noxfile.py index 24c8ea76..d8121b2a 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,3 +1,5 @@ +"""Configuration for Nox.""" + from __future__ import annotations import argparse @@ -13,9 +15,7 @@ @nox.session def lint(session: nox.Session) -> None: - """ - Run the linter. - """ + """Run the linter.""" session.install("pre-commit") session.run( "pre-commit", "run", "--all-files", "--show-diff-on-failure", *session.posargs @@ -35,19 +35,14 @@ def lint(session: nox.Session) -> None: @nox.session def tests(session: nox.Session) -> None: - """ - Run the unit and regular tests. - """ + """Run the unit and regular tests.""" session.install(".[test]") session.run("pytest", *session.posargs) @nox.session(reuse_venv=True) def docs(session: nox.Session) -> None: - """ - Build the docs. Pass "--serve" to serve. Pass "-b linkcheck" to check links. - """ - + """Build the docs. Pass "--serve" to serve. Pass "-b linkcheck" to check links.""" parser = argparse.ArgumentParser() parser.add_argument("--serve", action="store_true", help="Serve after building") parser.add_argument( @@ -86,10 +81,7 @@ def docs(session: nox.Session) -> None: @nox.session def build_api_docs(session: nox.Session) -> None: - """ - Build (regenerate) API docs. - """ - + """Build (regenerate) API docs.""" session.install("sphinx") session.chdir("docs") session.run( @@ -105,10 +97,7 @@ def build_api_docs(session: nox.Session) -> None: @nox.session def build(session: nox.Session) -> None: - """ - Build an SDist and wheel. - """ - + """Build an SDist and wheel.""" build_path = DIR.joinpath("build") if build_path.exists(): shutil.rmtree(build_path) diff --git a/pyproject.toml b/pyproject.toml index a2ed4a91..835ad09f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,16 +6,18 @@ build-backend = "hatchling.build" [project] name = "galdynamix" authors = [ + { name = "galdynamix maintainers", email = "nstarman@users.noreply.github.com" }, + { name = "Jake Nibauer", email = "jnibauer@princeton.edu" }, { name = "Nathaniel Starkman", email = "nstarman@users.noreply.github.com" }, ] -description = "Galactic Dynamix in Jax" +description = "Galactic Dynamix in Jax." readme = "README.md" requires-python = ">=3.11" classifiers = [ "Development Status :: 1 - Planning", "Intended Audience :: Science/Research", "Intended Audience :: Developers", - "License :: OSI Approved :: BSD License", + "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", @@ -53,10 +55,10 @@ docs = [ ] [project.urls] -Homepage = "https://github.com/nstarman/galdynamix" -"Bug Tracker" = "https://github.com/nstarman/galdynamix/issues" -Discussions = "https://github.com/nstarman/galdynamix/discussions" -Changelog = "https://github.com/nstarman/galdynamix/releases" +Homepage = "https://github.com/galdynamix/galdynamix" +"Bug Tracker" = "https://github.com/galdynamix/galdynamix/issues" +Discussions = "https://github.com/galdynamix/galdynamix/discussions" +Changelog = "https://github.com/galdynamix/galdynamix/releases" [tool.hatch] @@ -114,32 +116,22 @@ ignore_missing_imports = true src = ["src"] [tool.ruff.lint] -extend-select = [ - "B", # flake8-bugbear - "I", # isort - "ARG", # flake8-unused-arguments - "C4", # flake8-comprehensions - "EM", # flake8-errmsg - "ICN", # flake8-import-conventions - "G", # flake8-logging-format - "PGH", # pygrep-hooks - "PIE", # flake8-pie - "PL", # pylint - "PT", # flake8-pytest-style - "PTH", # flake8-use-pathlib - "RET", # flake8-return - "RUF", # Ruff-specific - "SIM", # flake8-simplify - "T20", # flake8-print - "UP", # pyupgrade - "YTT", # flake8-2020 - "EXE", # flake8-executable - "NPY", # NumPy specific rules -] +extend-select = ["ALL"] ignore = [ - "PD", # pandas-vet - "PLR", # Design related pylint codes - # TODO! fix these + "ANN101", # Missing type annotation for self in method + "COM812", # Missing trailing comma in Python 3.6+ + "D203", # 1 blank line required before class docstring + "D105", # Missing docstring in magic method + "D107", # Missing docstring in `__init__` + "D213", # Multi-line docstring summary should start at the second line + "FIX002", # Line contains TODO, consider resolving the issue + "N80", # Naming conventions. + "PD", # pandas-vet + "PLR", # Design related pylint codes + "TCH00", # Move into a type-checking block + "TD002", # Missing author in TODO + "TD003", # Missing issue link on the line following this TODO + # TODO: fix these "ARG001", "ARG002", "F841", @@ -149,8 +141,12 @@ isort.required-imports = ["from __future__ import annotations"] # typing-modules = ["galdynamix._compat.typing"] [tool.ruff.lint.per-file-ignores] -"tests/**" = ["T20"] -"noxfile.py" = ["T20"] +"tests/**" = ["ANN", "D10", "INP001", "S101", "T20"] +"noxfile.py" = ["ERA001", "T20"] +"docs/conf.py" = [ + "A001", # Variable `copyright` is shadowing a Python builtin + "INP001", # implicit namespace package +] [tool.pylint] diff --git a/src/galdynamix/__init__.py b/src/galdynamix/__init__.py index 33b4515e..662f5629 100644 --- a/src/galdynamix/__init__.py +++ b/src/galdynamix/__init__.py @@ -1,4 +1,4 @@ -"""galdynamix: Galactic Dynamix in Jax""" +"""Copyright (c) 2023 galdynamix maintainers. All rights reserved.""" from __future__ import annotations __all__ = ["__version__"] @@ -7,4 +7,4 @@ from ._version import version as __version__ -config.update("jax_enable_x64", True) +config.update("jax_enable_x64", True) # noqa: FBT003 diff --git a/src/galdynamix/_version.pyi b/src/galdynamix/_version.pyi index 91744f98..5bb2b22f 100644 --- a/src/galdynamix/_version.pyi +++ b/src/galdynamix/_version.pyi @@ -1,4 +1,2 @@ -from __future__ import annotations - version: str version_tuple: tuple[int, int, int] | tuple[int, int, int, str, str] diff --git a/src/galdynamix/dynamics/__init__.py b/src/galdynamix/dynamics/__init__.py index aaa4c850..febab8f4 100644 --- a/src/galdynamix/dynamics/__init__.py +++ b/src/galdynamix/dynamics/__init__.py @@ -1,4 +1,4 @@ -"""galdynamix: Galactic Dynamix in Jax""" +"""galdynamix: Galactic Dynamix in Jax.""" # ruff: noqa: F403 from __future__ import annotations diff --git a/src/galdynamix/dynamics/_orbit.py b/src/galdynamix/dynamics/_orbit.py index 31a4fe15..9d63fe4d 100644 --- a/src/galdynamix/dynamics/_orbit.py +++ b/src/galdynamix/dynamics/_orbit.py @@ -1,4 +1,4 @@ -"""galdynamix: Galactic Dynamix in Jax""" +"""galdynamix: Galactic Dynamix in Jax.""" from __future__ import annotations @@ -10,12 +10,14 @@ import jax.typing as jt from galdynamix.potential._potential.base import AbstractPotentialBase +from galdynamix.utils._jax import partial_jit class Orbit(eqx.Module): # type: ignore[misc] """Orbit. - TODO: + Todo: + ---- - Units stuff - GR stuff """ @@ -27,20 +29,36 @@ class Orbit(eqx.Module): # type: ignore[misc] """Position of the stream particles (x, y, z) [kpc/Myr].""" t: jt.Array - """Release time of the stream particles [Myr].""" + """Array of times [Myr].""" potential: AbstractPotentialBase """Potential in which the orbit was integrated.""" - def to_w(self) -> jt.Array: + @property + @partial_jit() + def qp(self) -> jt.Array: + """Return as a single Array[(N, Q + P),].""" + # Determine output shape + qd = self.q.shape[1] # dimensionality of q + shape = (self.q.shape[0], qd + self.p.shape[1]) + # Create output array (jax will fuse these ops) + out = xp.empty(shape) + out = out.at[:, :qd].set(self.q) + out = out.at[:, qd:].set(self.p) + return out # noqa: RET504 + + @property + @partial_jit() + def w(self) -> jt.Array: + """Return as a single Array[(N, Q + P + T),].""" + qp = self.qp + qpd = qp.shape[1] # dimensionality of qp + # Reshape t to (N, 1) if necessary t = self.t[:, None] if self.t.ndim == 1 else self.t - out = xp.empty( - ( - self.q.shape[0], - self.q.shape[1] + self.p.shape[1] + t.shape[1], - ) - ) - out = out.at[:, : self.q.shape[1]].set(self.q) - out = out.at[:, self.q.shape[1] : -1].set(self.p) - out = out.at[:, -1:].set(t) + # Determine output shape + shape = (qp.shape[0], qpd + t.shape[1]) + # Create output array (jax will fuse these ops) + out = xp.empty(shape) + out = out.at[:, :qpd].set(qp) + out = out.at[:, qpd:].set(t) return out # noqa: RET504 diff --git a/src/galdynamix/dynamics/mockstream/__init__.py b/src/galdynamix/dynamics/mockstream/__init__.py index cecd6698..4a6bee40 100644 --- a/src/galdynamix/dynamics/mockstream/__init__.py +++ b/src/galdynamix/dynamics/mockstream/__init__.py @@ -1,4 +1,4 @@ -"""galdynamix: Galactic Dynamix in Jax""" +"""galdynamix: Galactic Dynamix in Jax.""" # ruff: noqa: F403 from __future__ import annotations diff --git a/src/galdynamix/dynamics/mockstream/_core.py b/src/galdynamix/dynamics/mockstream/_core.py index fc9ff02c..b4f3f874 100644 --- a/src/galdynamix/dynamics/mockstream/_core.py +++ b/src/galdynamix/dynamics/mockstream/_core.py @@ -1,4 +1,4 @@ -"""galdynamix: Galactic Dynamix in Jax""" +"""galdynamix: Galactic Dynamix in Jax.""" from __future__ import annotations @@ -6,13 +6,17 @@ import equinox as eqx +import jax.numpy as xp import jax.typing as jt +from galdynamix.utils._jax import partial_jit + class MockStream(eqx.Module): # type: ignore[misc] """Mock stream object. - TODO: + Todo: + ---- - units stuff - change this to be a collection of sub-objects: progenitor, leading arm, trailing arm, 3-body ejecta, etc. @@ -27,3 +31,36 @@ class MockStream(eqx.Module): # type: ignore[misc] release_time: jt.Array """Release time of the stream particles [Myr].""" + + @property + @partial_jit() + def qp(self) -> jt.Array: + """Return as a single Array[(N, Q + P),].""" + # Determine output shape + qd = self.q.shape[1] # dimensionality of q + shape = (self.q.shape[0], qd + self.p.shape[1]) + # Create output array (jax will fuse these ops) + out = xp.empty(shape) + out = out.at[:, :qd].set(self.q) + out = out.at[:, qd:].set(self.p) + return out # noqa: RET504 + + @property + @partial_jit() + def w(self) -> jt.Array: + """Return as a single Array[(N, Q + P + T),].""" + qp = self.qp + qpd = qp.shape[1] # dimensionality of qp + # Reshape t to (N, 1) if necessary + t = ( + self.release_time[:, None] + if self.release_time.ndim == 1 + else self.release_time + ) + # Determine output shape + shape = (qp.shape[0], qpd + t.shape[1]) + # Create output array (jax will fuse these ops) + out = xp.empty(shape) + out = out.at[:, :qpd].set(qp) + out = out.at[:, qpd:].set(t) + return out # noqa: RET504 diff --git a/src/galdynamix/dynamics/mockstream/_df/__init__.py b/src/galdynamix/dynamics/mockstream/_df/__init__.py index 82f7c369..48d3f41d 100644 --- a/src/galdynamix/dynamics/mockstream/_df/__init__.py +++ b/src/galdynamix/dynamics/mockstream/_df/__init__.py @@ -1,4 +1,4 @@ -"""galdynamix: Galactic Dynamix in Jax""" +"""galdynamix: Galactic Dynamix in Jax.""" # ruff: noqa: F403 from __future__ import annotations diff --git a/src/galdynamix/dynamics/mockstream/_df/base.py b/src/galdynamix/dynamics/mockstream/_df/base.py index 8b044522..8a3317dd 100644 --- a/src/galdynamix/dynamics/mockstream/_df/base.py +++ b/src/galdynamix/dynamics/mockstream/_df/base.py @@ -1,11 +1,12 @@ -"""galdynamix: Galactic Dynamix in Jax""" +"""galdynamix: Galactic Dynamix in Jax.""" +# ruff: noqa: F403 from __future__ import annotations __all__ = ["AbstractStreamDF"] import abc -from typing import TYPE_CHECKING, Any, TypeAlias +from typing import TYPE_CHECKING, TypeAlias import equinox as eqx import jax @@ -18,8 +19,8 @@ from galdynamix.utils import partial_jit if TYPE_CHECKING: - _wifT: TypeAlias = tuple[jt.Array, jt.Array, jt.Array, jt.Array] - _carryT: TypeAlias = tuple[int, jt.Array, jt.Array, jt.Array, jt.Array] + Wif: TypeAlias = tuple[jt.Array, jt.Array, jt.Array, jt.Array] + Carry: TypeAlias = tuple[int, jt.Array, jt.Array, jt.Array, jt.Array] class AbstractStreamDF(eqx.Module): # type: ignore[misc] @@ -62,14 +63,14 @@ def sample( mock_lead, mock_trail : MockStream Positions and velocities of the leading and trailing tails. """ - prog_ws = prog_orbit.to_w()[:, :-1] # -1 is time + prog_qps = prog_orbit.qp ts = prog_orbit.t - def scan_fn(carry: _carryT, t: Any) -> tuple[_carryT, _wifT]: + def scan_fn(carry: Carry, t: jt.Numeric) -> tuple[Carry, Wif]: i = carry[0] output = self._sample( potential, - prog_ws[i], + prog_qps[i], prog_mass, t, i=i, diff --git a/src/galdynamix/dynamics/mockstream/_df/fardal.py b/src/galdynamix/dynamics/mockstream/_df/fardal.py index e0b06bdc..e8b1a4c5 100644 --- a/src/galdynamix/dynamics/mockstream/_df/fardal.py +++ b/src/galdynamix/dynamics/mockstream/_df/fardal.py @@ -1,19 +1,22 @@ -"""galdynamix: Galactic Dynamix in Jax""" +"""galdynamix: Galactic Dynamix in Jax.""" from __future__ import annotations __all__ = ["FardalStreamDF"] +from typing import TYPE_CHECKING import jax import jax.numpy as xp import jax.typing as jt -from galdynamix.potential._potential.base import AbstractPotentialBase from galdynamix.utils import partial_jit from .base import AbstractStreamDF +if TYPE_CHECKING: + from galdynamix.potential._potential.base import AbstractPotentialBase + class FardalStreamDF(AbstractStreamDF): @partial_jit(static_argnums=(0,), static_argnames=("seed_num",)) @@ -51,7 +54,7 @@ def _sample( rel_v = omega_val * r_tidal # relative velocity # circlar_velocity - v_circ = rel_v ##xp.sqrt( r*dphi_dr ) + v_circ = rel_v L_vec = xp.cross(x, v) z_hat = L_vec / xp.linalg.norm(L_vec) @@ -61,7 +64,6 @@ def _sample( kr_bar = 2.0 kvphi_bar = 0.3 - ####################kvt_bar = 0.3 ## FROM GALA kz_bar = 0.0 kvz_bar = 0.0 @@ -70,7 +72,6 @@ def _sample( sigma_kvphi = 0.5 sigma_kz = 0.5 sigma_kvz = 0.5 - ##############sigma_kvt = 0.5 ##FROM GALA kr_samp = kr_bar + jax.random.normal(keya, shape=(1,)) * sigma_kr kvphi_samp = kr_samp * ( @@ -78,31 +79,26 @@ def _sample( ) kz_samp = kz_bar + jax.random.normal(keyc, shape=(1,)) * sigma_kz kvz_samp = kvz_bar + jax.random.normal(keyd, shape=(1,)) * sigma_kvz - ########kvt_samp = kvt_bar + jax.random.normal(keye,shape=(1,))*sigma_kvt - ## Trailing arm - x_trail = x + kr_samp * r_hat * (r_tidal) # nudge out - x_trail = x_trail + z_hat * kz_samp * ( - r_tidal / 1.0 - ) # r #nudge above/below orbital plane - v_trail = ( - v + (0.0 + kvphi_samp * v_circ * (1.0)) * phi_hat - ) # v + (0.0 + kvphi_samp*v_circ*(-r_tidal/r))*phi_hat #nudge velocity along tangential direction + # Trailing arm + x_trail = ( + x + (kr_samp * r_hat * (r_tidal)) + (z_hat * kz_samp * (r_tidal / 1.0)) + ) v_trail = ( - v_trail + (kvz_samp * v_circ * (1.0)) * z_hat - ) # v_trail + (kvz_samp*v_circ*(-r_tidal/r))*z_hat #nudge velocity along vertical direction - - ## Leading arm - x_lead = x + kr_samp * r_hat * (-r_tidal) # nudge in - x_lead = x_lead + z_hat * kz_samp * ( - -r_tidal / 1.0 - ) # r #nudge above/below orbital plane - v_lead = ( - v + (0.0 + kvphi_samp * v_circ * (-1.0)) * phi_hat - ) # v + (0.0 + kvphi_samp*v_circ*(r_tidal/r))*phi_hat #nudge velocity along tangential direction + v + + (0.0 + kvphi_samp * v_circ * (1.0)) * phi_hat + + (kvz_samp * v_circ * (1.0)) * z_hat + ) + + # Leading arm + x_lead = ( + x + (kr_samp * r_hat * (-r_tidal)) + (z_hat * kz_samp * (-r_tidal / 1.0)) + ) v_lead = ( - v_lead + (kvz_samp * v_circ * (-1.0)) * z_hat - ) # v_lead + (kvz_samp*v_circ*(r_tidal/r))*z_hat #nudge velocity against vertical direction + v + + (0.0 + kvphi_samp * v_circ * (-1.0)) * phi_hat + + (kvz_samp * v_circ * (-1.0)) * z_hat + ) return x_lead, x_trail, v_lead, v_trail @@ -113,7 +109,7 @@ def _sample( @partial_jit() def dphidr(potential: AbstractPotentialBase, x: jt.Array, t: jt.Numeric) -> jt.Array: - """Computes the derivative of the potential at a position x. + """Compute the derivative of the potential at a position x. Parameters ---------- @@ -135,7 +131,9 @@ def dphidr(potential: AbstractPotentialBase, x: jt.Array, t: jt.Numeric) -> jt.A @partial_jit() def d2phidr2(potential: AbstractPotentialBase, x: jt.Array, t: jt.Numeric) -> jt.Array: - """Computes the second derivative of the potential at a position x (in the simulation frame). + """Compute the second derivative of the potential. + + At a position x (in the simulation frame). Parameters ---------- @@ -162,23 +160,25 @@ def d2phidr2(potential: AbstractPotentialBase, x: jt.Array, t: jt.Numeric) -> jt @partial_jit() def orbital_angular_velocity(x: jt.Array, v: jt.Array, /) -> jt.Array: - """Computes the orbital angular velocity about the origin. + """Compute the orbital angular velocity about the origin. - Arguments + Arguments: --------- x: Array[(3,), Any] 3d position (x, y, z) in [length] v: Array[(3,), Any] 3d velocity (v_x, v_y, v_z) in [length/time] - Returns + Returns: ------- Array Angular velocity in [rad/time] - Examples + Examples: -------- - >>> orbital_angular_velocity(x=xp.array([8.0, 0.0, 0.0]), v=xp.array([8.0, 0.0, 0.0])) + >>> x = xp.array([8.0, 0.0, 0.0]) + >>> v = xp.array([8.0, 0.0, 0.0]) + >>> orbital_angular_velocity(x=x, v=v) """ r = xp.linalg.norm(x) return xp.cross(x, v) / r**2 @@ -186,23 +186,25 @@ def orbital_angular_velocity(x: jt.Array, v: jt.Array, /) -> jt.Array: @partial_jit() def orbital_angular_velocity_mag(x: jt.Array, v: jt.Array, /) -> jt.Array: - """Computes the magnitude of the angular momentum in the simulation frame. + """Compute the magnitude of the angular momentum in the simulation frame. - Arguments + Arguments: --------- x: Array[(3,), Any] 3d position (x, y, z) in [kpc] v: Array[(3,), Any] 3d velocity (v_x, v_y, v_z) in [kpc/Myr] - Returns + Returns: ------- Array Magnitude of angular momentum in [rad/Myr] - Examples + Examples: -------- - >>> orbital_angular_velocity_mag(x=xp.array([8.0, 0.0, 0.0]), v=xp.array([8.0, 0.0, 0.0])) + >>> x = xp.array([8.0, 0.0, 0.0]) + >>> v = xp.array([8.0, 0.0, 0.0]) + >>> orbital_angular_velocity_mag(x=x, v=v) """ return xp.linalg.norm(orbital_angular_velocity(x, v)) @@ -216,10 +218,12 @@ def tidal_radius( prog_mass: jt.Array, t: jt.Array, ) -> jt.Array: - """Computes the tidal radius of a cluster in the potential. + """Compute the tidal radius of a cluster in the potential. Parameters ---------- + potential: AbstractPotentialBase + The gravitational potential of the host. x: Array 3d position (x, y, z) in [kpc] v: Array @@ -236,10 +240,12 @@ def tidal_radius( Examples -------- - >>> tidal_radius(x=xp.array([8.0, 0.0, 0.0]), v=xp.array([8.0, 0.0, 0.0]), prog_mass=1e4) + >>> x=xp.array([8.0, 0.0, 0.0]) + >>> v=xp.array([8.0, 0.0, 0.0] + >>> tidal_radius(x=x, v=v, prog_mass=1e4) """ return ( - potential._G + potential._G # noqa: SLF001 * prog_mass / (orbital_angular_velocity_mag(x, v) ** 2 - d2phidr2(potential, x, t)) ) ** (1.0 / 3.0) @@ -253,7 +259,7 @@ def lagrange_points( prog_mass: jt.Array, t: jt.Array, ) -> tuple[jt.Array, jt.Array]: - """Computes the lagrange points of a cluster in a host potential. + """Compute the lagrange points of a cluster in a host potential. Parameters ---------- diff --git a/src/galdynamix/dynamics/mockstream/_mockstream_generator.py b/src/galdynamix/dynamics/mockstream/_mockstream_generator.py index a6307baa..73e3a992 100644 --- a/src/galdynamix/dynamics/mockstream/_mockstream_generator.py +++ b/src/galdynamix/dynamics/mockstream/_mockstream_generator.py @@ -1,4 +1,4 @@ -"""galdynamix: Galactic Dynamix in Jax""" +"""galdynamix: Galactic Dynamix in Jax.""" from __future__ import annotations @@ -17,8 +17,7 @@ from ._df import AbstractStreamDF if TYPE_CHECKING: - _wifT: TypeAlias = tuple[jt.Array, jt.Array, jt.Array, jt.Array] - _carryT: TypeAlias = tuple[int, jt.Array, jt.Array, jt.Array, jt.Array] + Carry: TypeAlias = tuple[int, jt.Array, jt.Array] from galdynamix.dynamics._orbit import Orbit @@ -26,11 +25,6 @@ class MockStreamGenerator(eqx.Module): # type: ignore[misc] df: AbstractStreamDF potential: AbstractPotentialBase - # progenitor_potential: AbstractPotentialBase | None = None - - # @property - # def self_gravity(self) -> bool: - # return self.progenitor_potential is not None # ========================================================================== @@ -46,37 +40,27 @@ def _run_scan( prog_o = self.potential.integrate_orbit(prog_w0, xp.min(ts), xp.max(ts), ts) # Generate stream initial conditions along the integrated progenitor orbit - mock_lead, mock_trail = self.df.sample( + mock0_lead, mock0_trail = self.df.sample( self.potential, prog_o, prog_mass, seed_num=seed_num ) - x_lead, v_lead = mock_lead.q, mock_lead.p - x_trail, v_trail = mock_trail.q, mock_trail.p - - def scan_fn( - carry: _carryT, particle_idx: int - ) -> tuple[_carryT, tuple[jt.Array, jt.Array]]: - i, x_lead_i, x_trail_i, v_lead_i, v_trail_i = carry - w0_lead_i = xp.hstack([x_lead_i, v_lead_i]) - w0_trail_i = xp.hstack([x_trail_i, v_trail_i]) - w0_lead_trail = xp.vstack([w0_lead_i, w0_trail_i]) - - minval, maxval = ts[i], ts[-1] - integ_ics = lambda ics: self.potential.integrate_orbit( # noqa: E731 - ics, minval, maxval, None - ).to_w()[0, :-1] + qp0_lead = mock0_lead.qp + qp0_trail = mock0_trail.qp + + def scan_fn(carry: Carry, idx: int) -> tuple[Carry, tuple[jt.Array, jt.Array]]: + i, qp0_lead_i, qp0_trail_i = carry + qp0_lead_trail = xp.vstack([qp0_lead_i, qp0_trail_i]) + t_i, t_f = ts[i], ts[-1] + + def integ_ics(ics: jt.Array) -> jt.Array: + return self.potential.integrate_orbit(ics, t_i, t_f, None).qp[0] + # vmap over leading and trailing arm - w_lead, w_trail = jax.vmap(integ_ics, in_axes=(0,))(w0_lead_trail) - carry_out = ( - i + 1, - x_lead[i + 1, :], - x_trail[i + 1, :], - v_lead[i + 1, :], - v_trail[i + 1, :], - ) - return carry_out, (w_lead, w_trail) - - carry_init = (0, x_lead[0, :], x_trail[0, :], v_lead[0, :], v_trail[0, :]) - particle_ids = xp.arange(len(x_lead)) + qp_lead, qp_trail = jax.vmap(integ_ics, in_axes=(0,))(qp0_lead_trail) + carry_out = (i + 1, qp0_lead[i + 1, :], qp0_trail[i + 1, :]) + return carry_out, (qp_lead, qp_trail) + + carry_init = (0, qp0_lead[0, :], qp0_trail[0, :]) + particle_ids = xp.arange(len(qp0_lead)) lead_arm, trail_arm = jax.lax.scan(scan_fn, carry_init, particle_ids)[1] return (lead_arm, trail_arm), prog_o @@ -84,8 +68,9 @@ def scan_fn( def _run_vmap( self, ts: jt.Array, prog_w0: jt.Array, prog_mass: jt.Array, *, seed_num: int ) -> tuple[tuple[jt.Array, jt.Array], Orbit]: - """ - Generate stellar stream by vmapping over the release model/integration. Better for GPU usage. + """Generate stellar stream by vmapping over the release model/integration. + + Better for GPU usage. """ # Integrate the progenitor orbit prog_o = self.potential.integrate_orbit(prog_w0, xp.min(ts), xp.max(ts), ts) @@ -94,33 +79,24 @@ def _run_vmap( mock_lead, mock_trail = self.df.sample( self.potential, prog_o, prog_mass, seed_num=seed_num ) - x_lead, v_lead = mock_lead.q, mock_lead.p - x_trail, v_trail = mock_trail.q, mock_trail.p + qp0_lead = mock_lead.qp + qp0_trail = mock_trail.qp + t_f = ts[-1] + 0.01 # TODO: make this a separated method @jax.jit # type: ignore[misc] def single_particle_integrate( - i: int, - x_lead_i: jt.Array, - x_trail_i: jt.Array, - v_lead_i: jt.Array, - v_trail_i: jt.Array, + i: int, qp0_lead_i: jt.Array, qp0_trail_i: jt.Array ) -> tuple[jt.Array, jt.Array]: - w0_lead_i = xp.hstack([x_lead_i, v_lead_i]) - w0_trail_i = xp.hstack([x_trail_i, v_trail_i]) t_i = ts[i] - t_f = ts[-1] + 0.01 - - w_lead = self.integrate_orbit(w0_lead_i, t_i, t_f, None)[0] - w_trail = self.integrate_orbit(w0_trail_i, t_i, t_f, None)[0] - - return w_lead, w_trail - - particle_ids = xp.arange(len(x_lead)) + qp_lead = self.integrate_orbit(qp0_lead_i, t_i, t_f, None).qp[0] + qp_trail = self.integrate_orbit(qp0_trail_i, t_i, t_f, None).qp[0] + return qp_lead, qp_trail - integrator = jax.vmap(single_particle_integrate, in_axes=(0, 0, 0, 0, 0)) - w_lead, w_trail = integrator(particle_ids, x_lead, x_trail, v_lead, v_trail) - return (w_lead, w_trail), prog_o + particle_ids = xp.arange(len(qp0_lead)) + integrator = jax.vmap(single_particle_integrate, in_axes=(0, 0, 0)) + qp_lead, qp_trail = integrator(particle_ids, qp0_lead, qp0_trail) + return (qp_lead, qp_trail), prog_o @partial_jit(static_argnames=("seed_num", "vmapped")) def run( @@ -132,6 +108,7 @@ def run( seed_num: int, vmapped: bool = False, ) -> tuple[jt.Array, jt.Array]: + # TODO: figure out better return type: MockStream? if vmapped: return self._run_vmap(ts, prog_w0, prog_mass, seed_num=seed_num) return self._run_scan(ts, prog_w0, prog_mass, seed_num=seed_num) diff --git a/src/galdynamix/integrate/__init__.py b/src/galdynamix/integrate/__init__.py index 6878ebbd..bddb7263 100644 --- a/src/galdynamix/integrate/__init__.py +++ b/src/galdynamix/integrate/__init__.py @@ -1,4 +1,4 @@ -"""galdynamix: Galactic Dynamix in Jax""" +"""galdynamix: Galactic Dynamix in Jax.""" # ruff: noqa: F403 from __future__ import annotations diff --git a/src/galdynamix/integrate/_base.py b/src/galdynamix/integrate/_base.py index c6959b05..ddd2dbd5 100644 --- a/src/galdynamix/integrate/_base.py +++ b/src/galdynamix/integrate/_base.py @@ -10,7 +10,7 @@ class FCallable(Protocol): - def __call__(self, t: jt.Array, xv: jt.Array, args: Any) -> jt.Array: + def __call__(self, t: jt.Array, xv: jt.Array, args: tuple[Any, ...]) -> jt.Array: ... diff --git a/src/galdynamix/potential/__init__.py b/src/galdynamix/potential/__init__.py index 2ecdc19e..0fc3f675 100644 --- a/src/galdynamix/potential/__init__.py +++ b/src/galdynamix/potential/__init__.py @@ -1,4 +1,4 @@ -"""galdynamix: Galactic Dynamix in Jax""" +"""galdynamix: Galactic Dynamix in Jax.""" # ruff: noqa: F403 from __future__ import annotations diff --git a/src/galdynamix/potential/_potential/__init__.py b/src/galdynamix/potential/_potential/__init__.py index 97813388..0425bfec 100644 --- a/src/galdynamix/potential/_potential/__init__.py +++ b/src/galdynamix/potential/_potential/__init__.py @@ -1,4 +1,4 @@ -"""galdynamix: Galactic Dynamix in Jax""" +"""galdynamix: Galactic Dynamix in Jax.""" # ruff: noqa: F403 from __future__ import annotations diff --git a/src/galdynamix/potential/_potential/base.py b/src/galdynamix/potential/_potential/base.py index 880cf40d..33c31832 100644 --- a/src/galdynamix/potential/_potential/base.py +++ b/src/galdynamix/potential/_potential/base.py @@ -4,21 +4,23 @@ import abc from dataclasses import KW_ONLY, fields -from typing import Any +from typing import TYPE_CHECKING, Any import astropy.units as u import equinox as eqx import jax import jax.numpy as xp import jax.typing as jt -from astropy.constants import G as apy_G +from astropy.constants import G as _G -from galdynamix.integrate._base import AbstractIntegrator from galdynamix.integrate._builtin import DiffraxIntegrator from galdynamix.potential._potential.param.field import ParameterField from galdynamix.units import UnitSystem, dimensionless from galdynamix.utils import partial_jit +if TYPE_CHECKING: + from galdynamix.integrate._base import AbstractIntegrator + class AbstractPotentialBase(eqx.Module): # type: ignore[misc] """Potential Class.""" @@ -37,7 +39,7 @@ def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array: # Parsing def _init_units(self) -> None: - G = 1 if self.units == dimensionless else apy_G.decompose(self.units).value + G = 1 if self.units == dimensionless else _G.decompose(self.units).value object.__setattr__(self, "_G", G) # Handle unit conversion for all ParameterField diff --git a/src/galdynamix/potential/_potential/builtin.py b/src/galdynamix/potential/_potential/builtin.py index 034e8f9c..7a0d8383 100644 --- a/src/galdynamix/potential/_potential/builtin.py +++ b/src/galdynamix/potential/_potential/builtin.py @@ -1,9 +1,7 @@ -"""galdynamix: Galactic Dynamix in Jax""" +"""galdynamix: Galactic Dynamix in Jax.""" from __future__ import annotations -from typing import Any - __all__ = [ "MiyamotoNagaiDisk", "BarPotential", @@ -13,6 +11,7 @@ ] from dataclasses import KW_ONLY +from typing import Any import equinox as eqx import jax @@ -48,8 +47,8 @@ def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array: class BarPotential(AbstractPotential): - """ - Rotating bar potentil, with hard-coded rotation. + """Rotating bar potentil, with hard-coded rotation. + Eq 8a in https://articles.adsabs.harvard.edu/pdf/1992ApJ...397...44L Rz according to https://en.wikipedia.org/wiki/Rotation_matrix """ @@ -64,15 +63,14 @@ class BarPotential(AbstractPotential): def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array: ## First take the simulation frame coordinates and rotate them by Omega*t ang = -self.Omega(t) * t - Rot_mat = xp.array( + rotation_matrix = xp.array( [ [xp.cos(ang), -xp.sin(ang), 0], [xp.sin(ang), xp.cos(ang), 0.0], [0.0, 0.0, 1.0], - ] + ], ) - # Rot_inv = xp.linalg.inv(Rot_mat) - q_corot = xp.matmul(Rot_mat, q) + q_corot = xp.matmul(rotation_matrix, q) a = self.a(t) b = self.b(t) @@ -90,7 +88,7 @@ def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array: # potential in a corotating frame return (self._G * self.m(t) / (2.0 * a)) * xp.log( - (q_corot[0] - a + T_minus) / (q_corot[0] + a + T_plus) + (q_corot[0] - a + T_minus) / (q_corot[0] + a + T_plus), ) @@ -125,7 +123,7 @@ class NFWPotential(AbstractPotential): def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array: v_h2 = -self._G * self.m(t) / self.r_s(t) m = xp.sqrt( - q[0] ** 2 + q[1] ** 2 + q[2] ** 2 + self.softening_length + q[0] ** 2 + q[1] ** 2 + q[2] ** 2 + self.softening_length, ) / self.r_s(t) return v_h2 * xp.log(1.0 + m) / m @@ -134,7 +132,7 @@ def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array: @jax.jit # type: ignore[misc] -def get_splines(x_eval: jt.Array, x: jt.Array, y: jt.Array) -> Any: +def get_splines(x_eval: jt.Array, x: jt.Array, y: jt.Array) -> Any: # noqa: ANN401 return InterpolatedUnivariateSpline(x, y, k=3)(x_eval) @@ -142,8 +140,8 @@ def get_splines(x_eval: jt.Array, x: jt.Array, y: jt.Array) -> Any: def single_subhalo_potential( params: dict[str, jt.Array], q: jt.Array, /, t: jt.Array ) -> jt.Array: - """ - Potential for a single subhalo + """Potential for a single subhalo. + TODO: custom unit specification/subhalo potential specficiation. Currently supports units kpc, Myr, Msun, rad. """ @@ -152,8 +150,8 @@ def single_subhalo_potential( class SubHaloPopulation(AbstractPotential): - """ - m has length n_subhalo + """m has length n_subhalo. + a has length n_subhalo tq_subhalo_arr has shape t_orbit x n_subhalo x 3 t_orbit is the array of times the subhalos are integrated over @@ -166,26 +164,22 @@ class SubHaloPopulation(AbstractPotential): @partial_jit() def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array: - x_at_t_eval = get_splines( - t, self.t_orbit, self.tq_subhalo_arr[:, :, 0] - ) # expect n_subhalo x-positions - y_at_t_eval = get_splines( - t, self.t_orbit, self.tq_subhalo_arr[:, :, 1] - ) # expect n_subhalo y-positions - z_at_t_eval = get_splines( - t, self.t_orbit, self.tq_subhalo_arr[:, :, 2] - ) # expect n_subhalo z-positions - - subhalo_locations = xp.vstack( - [x_at_t_eval, y_at_t_eval, z_at_t_eval] - ).T # n_subhalo x 3: the position of all subhalos at time t + # expect n_subhalo x-positions + x_at_t_eval = get_splines(t, self.t_orbit, self.tq_subhalo_arr[:, :, 0]) + # expect n_subhalo y-positions + y_at_t_eval = get_splines(t, self.t_orbit, self.tq_subhalo_arr[:, :, 1]) + # expect n_subhalo z-positions + z_at_t_eval = get_splines(t, self.t_orbit, self.tq_subhalo_arr[:, :, 2]) + + # n_subhalo x 3: the position of all subhalos at time t + subhalo_locations = xp.vstack([x_at_t_eval, y_at_t_eval, z_at_t_eval]).T delta_position = q - subhalo_locations # n_subhalo x 3 - # sum over potential due to all subhalos in the field by vmapping over m, a, and delta_position - ##dct = {'m': self.m, 'a': self.a,} + # sum over potential due to all subhalos in the field by vmapping over + # m, a, and delta_position return xp.sum( jax.vmap( single_subhalo_potential, in_axes=(({"m": 0, "a": 0}, 0, None)), - )({"m": self.m(t), "a": self.a(t)}, delta_position, t) + )({"m": self.m(t), "a": self.a(t)}, delta_position, t), ) diff --git a/src/galdynamix/potential/_potential/param/core.py b/src/galdynamix/potential/_potential/param/core.py index e85ab6e1..7e2ecfbf 100644 --- a/src/galdynamix/potential/_potential/param/core.py +++ b/src/galdynamix/potential/_potential/param/core.py @@ -49,7 +49,7 @@ class AbstractParameter(eqx.Module): # type: ignore[misc] """ _: KW_ONLY - unit: u.Unit = eqx.field(static=True) # TODO? move this to an annotation? + unit: u.Unit = eqx.field(static=True) # TODO: move this to an annotation? @abc.abstractmethod def __call__(self, t: jt.Array) -> jt.Array: diff --git a/src/galdynamix/potential/_potential/param/field.py b/src/galdynamix/potential/_potential/param/field.py index 0a533142..43ed1d0b 100644 --- a/src/galdynamix/potential/_potential/param/field.py +++ b/src/galdynamix/potential/_potential/param/field.py @@ -30,7 +30,7 @@ class ParameterField: name: str = field(init=False) _: KW_ONLY - physical_type: u.PhysicalType # TODO add a converter_argument + physical_type: u.PhysicalType # TODO: add a converter_argument equivalencies: u.Equivalency | tuple[u.Equivalency, ...] | None = None def __post_init__(self) -> None: @@ -41,7 +41,10 @@ def __post_init__(self) -> None: self, "physical_type", u.get_physical_type(self.physical_type) ) elif not isinstance(self.physical_type, u.PhysicalType): - msg = f"Expected physical_type to be a PhysicalType, got {self.physical_type!r}" + msg = ( + "Expected physical_type to be a PhysicalType, " + f"got {self.physical_type!r}" + ) raise TypeError(msg) # =========================================== @@ -81,16 +84,16 @@ def __get__( def __set__( self, potential: AbstractPotential, - value: AbstractParameter | ParameterCallable | Any, + value: AbstractParameter | ParameterCallable | Any, # noqa: ANN401 ) -> None: # Convert if isinstance(value, AbstractParameter): - # TODO! use the physical_type information to check the parameters. - # TODO! use the units on the `potential` to convert the parameter value. + # TODO: use the physical_type information to check the parameters. + # TODO: use the units on the `potential` to convert the parameter value. pass elif callable(value): - # TODO! use the physical_type information to check the parameters. - # TODO! use the units on the `potential` to convert the parameter value. + # TODO: use the physical_type information to check the parameters. + # TODO: use the units on the `potential` to convert the parameter value. value = UserParameter(func=value) else: unit = potential.units[self.physical_type] diff --git a/src/galdynamix/units.py b/src/galdynamix/units.py index 07798dcc..dedcb3e1 100644 --- a/src/galdynamix/units.py +++ b/src/galdynamix/units.py @@ -29,7 +29,7 @@ from __future__ import annotations -from typing import Any, ClassVar +from typing import ClassVar __all__ = [ "UnitSystem", @@ -56,14 +56,14 @@ class UnitSystem: u.get_physical_type("angle"), ] - def __init__(self, units: UnitSystem | u.UnitBase, *args: u.UnitBase): + def __init__(self, units: UnitSystem | u.UnitBase, *args: u.UnitBase) -> None: if isinstance(units, UnitSystem): if len(args) > 0: - msg = "If passing in a UnitSystem instance, you cannot pass in additional units." + msg = "If passing in a UnitSystem, cannot pass in additional units." raise ValueError(msg) - self._registry = units._registry.copy() - self._core_units = units._core_units + self._registry = units._registry.copy() # noqa: SLF001 + self._core_units = units._core_units # noqa: SLF001 return units = (units, *args) @@ -98,14 +98,18 @@ def __iter__(self) -> u.UnitBase: def __repr__(self) -> str: return f"UnitSystem({', '.join(str(uu) for uu in self._core_units)})" - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: + if not isinstance(other, UnitSystem): + return NotImplemented return bool(self._registry == other._registry) - def __ne__(self, other: Any) -> bool: + def __ne__(self, other: object) -> bool: return not self.__eq__(other) class DimensionlessUnitSystem(UnitSystem): + """A unit system with only dimensionless units.""" + _required_physical_types: ClassVar[list[u.PhysicalType]] = [] def __init__(self) -> None: diff --git a/src/galdynamix/utils/__init__.py b/src/galdynamix/utils/__init__.py index 267d6d23..e76124d4 100644 --- a/src/galdynamix/utils/__init__.py +++ b/src/galdynamix/utils/__init__.py @@ -1,4 +1,4 @@ -"""galdynamix: Galactic Dynamix in Jax""" +"""galdynamix: Galactic Dynamix in Jax.""" from __future__ import annotations diff --git a/src/galdynamix/utils/_collections.py b/src/galdynamix/utils/_collections.py index 05bbdb1e..a6ffc190 100644 --- a/src/galdynamix/utils/_collections.py +++ b/src/galdynamix/utils/_collections.py @@ -1,4 +1,4 @@ -"""galdynamix: Galactic Dynamix in Jax""" +"""galdynamix: Galactic Dynamix in Jax.""" from __future__ import annotations @@ -47,7 +47,7 @@ def __repr__(self) -> str: # === PyTree === def tree_flatten(self) -> tuple[tuple[V, ...], tuple[str, ...]]: - """Specifies a flattening recipe. + """Flatten to a dict. Returns ------- @@ -61,16 +61,19 @@ def tree_flatten(self) -> tuple[tuple[V, ...], tuple[str, ...]]: @classmethod def tree_unflatten( - cls: type[Self], aux_data: tuple[str, ...], children: tuple[V, ...] + cls: type[Self], + aux_data: tuple[str, ...], + children: tuple[V, ...], ) -> Self[str, V]: # type: ignore[misc] - """Specifies an unflattening recipe. + """Unflatten. Params: aux_data: the opaque data that was specified during flattening of the current treedef. children: the unflattened children - Returns: + Returns + ------- a re-constructed object of the registered type, using the specified children and auxiliary data. """ diff --git a/src/galdynamix/utils/_jax.py b/src/galdynamix/utils/_jax.py index 91546ba1..14b4376a 100644 --- a/src/galdynamix/utils/_jax.py +++ b/src/galdynamix/utils/_jax.py @@ -1,20 +1,17 @@ -"""galdynamix: Galactic Dynamix in Jax""" +"""galdynamix: Galactic Dynamix in Jax.""" from __future__ import annotations __all__ = ["partial_jit"] -from collections.abc import Callable, Iterable, Sequence from functools import partial -from typing import ( - Any, - NotRequired, - TypedDict, - TypeVar, -) +from typing import TYPE_CHECKING, Any, NotRequired, TypedDict, TypeVar import jax -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, Unpack + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Sequence P = ParamSpec("P") R = TypeVar("R") @@ -34,6 +31,6 @@ class JITKwargs(TypedDict): def partial_jit( - **kwargs: Any, + **kwargs: Unpack[JITKwargs], ) -> Callable[[Callable[P, R]], Callable[P, R]]: return partial(jax.jit, **kwargs) diff --git a/src/galdynamix/utils/dataclasses.py b/src/galdynamix/utils/dataclasses.py index 84a6c32c..acff9933 100644 --- a/src/galdynamix/utils/dataclasses.py +++ b/src/galdynamix/utils/dataclasses.py @@ -1,16 +1,18 @@ -"""galdynamix: Galactic Dynamix in Jax""" +"""galdynamix: Galactic Dynamix in Jax.""" from __future__ import annotations __all__ = ["field"] import dataclasses -from collections.abc import Callable, Mapping -from typing import Any, Generic, NotRequired, TypedDict, TypeVar +from typing import TYPE_CHECKING, Any, Generic, NotRequired, TypedDict, TypeVar import astropy.units as u from typing_extensions import ParamSpec, Unpack +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + P = ParamSpec("P") R = TypeVar("R")