Skip to content

Commit

Permalink
refactor: dynamics library (#116)
Browse files Browse the repository at this point in the history
* refactor: dynamics library 
* refactor: cleanup integrate API in integrate

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman authored Jan 31, 2024
1 parent 872dd03 commit 936b71c
Show file tree
Hide file tree
Showing 19 changed files with 66 additions and 45 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@
]
"noxfile.py" = ["ERA001", "T20"]
"__init__.py" = ["F403"]
"__init__.pyi" = ["F401"]
"__init__.pyi" = ["F401", "F403"]


[tool.ruff.lint.isort]
Expand Down
15 changes: 4 additions & 11 deletions src/galax/dynamics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
"""galax: Galactic Dynamix in Jax."""
""":mod:`galax.dynamics`."""

from . import _base, _core, _orbit, mockstream
from ._base import *
from ._core import *
from ._orbit import *
from .mockstream import *
from . import _dynamics
from ._dynamics import *

__all__: list[str] = []
__all__ += _base.__all__
__all__ += _core.__all__
__all__ += _orbit.__all__
__all__ += mockstream.__all__
__all__ = _dynamics.__all__
17 changes: 17 additions & 0 deletions src/galax/dynamics/_dynamics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""galax: Galactic Dynamix in Jax."""

from . import base, core, mockstream, orbit
from .base import *
from .core import *
from .mockstream import *
from .orbit import *

__all__: list[str] = ["mockstream"]
__all__ += base.__all__
__all__ += core.__all__
__all__ += orbit.__all__
__all__ += mockstream.__all__


# Cleanup
del base, core, orbit
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from galax.units import UnitSystem
from galax.utils._shape import atleast_batched

from ._utils import getitem_time_index
from .utils import getitem_time_index

if TYPE_CHECKING:
from typing import Self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from galax.utils._shape import batched_shape, expand_batch_dims
from galax.utils.dataclasses import converter_float_array

from ._base import AbstractPhaseSpacePosition
from .base import AbstractPhaseSpacePosition

if TYPE_CHECKING:
from galax.potential._potential.base import AbstractPotentialBase
Expand Down
14 changes: 14 additions & 0 deletions src/galax/dynamics/_dynamics/mockstream/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""galax: Galactic Dynamix in Jax."""

from . import core, df, mockstream_generator
from .core import *
from .df import *
from .mockstream_generator import *

__all__: list[str] = []
__all__ += df.__all__
__all__ += core.__all__
__all__ += mockstream_generator.__all__

# Cleanup
del core, df, mockstream_generator
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import jax
import jax.numpy as jnp

from galax.dynamics._base import AbstractPhaseSpacePosition
from galax.dynamics._utils import getitem_time_index
from galax.dynamics._dynamics.base import AbstractPhaseSpacePosition
from galax.dynamics._dynamics.utils import getitem_time_index
from galax.typing import BatchFloatScalar, BroadBatchVec3, VecTime
from galax.utils._shape import batched_shape
from galax.utils.dataclasses import converter_float_array
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from .base import *
from .fardal import *

__all__ = []
__all__: list[str] = []
__all__ += base.__all__
__all__ += fardal.__all__

# Cleanup
del base, fardal
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""galax: Galactic Dynamix in Jax."""


__all__ = ["AbstractStreamDF"]

import abc
Expand All @@ -12,8 +11,8 @@
import jax.experimental.array_api as xp
from jax.numpy import copy

from galax.dynamics._orbit import Orbit
from galax.dynamics.mockstream._core import MockStream
from galax.dynamics._dynamics.mockstream.core import MockStream
from galax.dynamics._dynamics.orbit import Orbit
from galax.potential._potential.base import AbstractPotentialBase
from galax.typing import BatchVec3, FloatScalar, IntLike, Vec3, Vec6

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
from jax.lib.xla_bridge import get_backend
from jaxtyping import Array, Shaped

from galax.dynamics._orbit import Orbit
from galax.dynamics._dynamics.orbit import Orbit
from galax.integrate._base import Integrator
from galax.integrate._builtin import DiffraxIntegrator
from galax.potential._potential.base import AbstractPotentialBase
from galax.typing import BatchVec6, FloatScalar, IntScalar, Vec6, VecN, VecTime

from ._core import MockStream
from ._df import AbstractStreamDF
from .core import MockStream
from .df import AbstractStreamDF

Carry: TypeAlias = tuple[IntScalar, VecN, VecN]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from galax.utils._shape import batched_shape
from galax.utils.dataclasses import converter_float_array

from ._base import AbstractPhaseSpacePosition
from .base import AbstractPhaseSpacePosition


@final
Expand Down
File renamed without changes.
11 changes: 0 additions & 11 deletions src/galax/dynamics/mockstream/__init__.py

This file was deleted.

3 changes: 3 additions & 0 deletions src/galax/integrate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@
__all__: list[str] = []
__all__ += _base.__all__
__all__ += _builtin.__all__

# Cleanup
del _base, _builtin
2 changes: 1 addition & 1 deletion src/galax/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def integrate_orbit(
)
"""
# TODO: ꜛ get NORMALIZE_WHITESPACE to work correctly so Orbit is 1 line
from galax.dynamics._orbit import Orbit
from galax.dynamics._dynamics.orbit import Orbit

integrator_ = default_integrator if integrator is None else replace(integrator)

Expand Down
5 changes: 2 additions & 3 deletions tests/smoke/dynamics/mockstream/test_package.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Testing :mod:`galax.dynamics.mockstream` module."""

from galax.dynamics import mockstream
from galax.dynamics._dynamics.mockstream import core, df, mockstream_generator


def test_all() -> None:
"""Test the `galax.dynamics.mockstream` API."""
assert set(mockstream.__all__) == set(
mockstream._df.__all__
+ mockstream._core.__all__
+ mockstream._mockstream_generator.__all__
df.__all__ + core.__all__ + mockstream_generator.__all__
)
11 changes: 8 additions & 3 deletions tests/smoke/dynamics/test_package.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
"""Testing :mod:`galax.dynamics` module."""

import galax.dynamics as gd
from galax.dynamics._dynamics import base, core, mockstream, orbit


def test_all() -> None:
"""Test the `galax.potential` API."""
assert set(gd.__all__) == set(
gd._base.__all__ + gd._core.__all__ + gd._orbit.__all__ + gd.mockstream.__all__
)
assert set(gd.__all__) == {
"mockstream",
*base.__all__,
*core.__all__,
*orbit.__all__,
*mockstream.__all__,
}
5 changes: 2 additions & 3 deletions tests/smoke/integrate/test_package.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""Test the :mod:`galax.integrate` module."""

from galax import integrate
from galax.integrate import _base, _builtin


def test_all():
"""Test the API."""
assert set(integrate.__all__) == set(
integrate._base.__all__ + integrate._builtin.__all__
)
assert set(integrate.__all__) == set(_base.__all__ + _builtin.__all__)

0 comments on commit 936b71c

Please sign in to comment.