Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: move PSP to coordinates module #125

Merged
merged 2 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions docs/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ Now that we have a potential model, if we want to compute an orbit, we need to
specify a set of initial conditions to initialize the numerical orbit
integration. In :mod:`galax`, initial conditions and other positions in
phase-space (locations in position and velocity space) are defined using the
:class:`~galax.dynamics.PhaseSpacePosition` class. This class allows a number of
:class:`~galax.coordinates.PhaseSpacePosition` class. This class allows a number of
possible inputs, but one of the most common inputs are Cartesian position and
velocity vectors. As an example orbit, we will use a position and velocity that
is close to the Sun's Galactocentric position and velocity::

>>> import galax.dynamics as gd
>>> psp = gd.PhaseSpacePosition(q=[-8.1, 0, 0.02] * u.kpc,
>>> import galax.coordinates as gc
>>> psp = gc.PhaseSpacePosition(q=[-8.1, 0, 0.02] * u.kpc,
... p=[13, 245, 8.] * u.km/u.s)

By convention, I typically use the variable ``w`` to represent phase-space
Expand Down Expand Up @@ -124,7 +124,7 @@ on any Potential object through the
By default, this method uses Leapfrog integration , which is a fast, symplectic
integration scheme. The returned object is an instance of the
:class:`~galax.dynamics.Orbit` class, which is similar to the
:class:`~galax.dynamics.PhaseSpacePosition` but represents a collection of
:class:`~galax.coordinates.PhaseSpacePosition` but represents a collection of
phase-space positions at times::

>>> orbit
Expand All @@ -144,11 +144,12 @@ performing common tasks, like plotting an orbit::
import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
import galax.coordinates as gc
import galax.dynamics as gd
import galax.potential as gp

mw = gp.MilkyWayPotential()
psp = gd.PhaseSpacePosition(pos=[-8.1, 0, 0.02] * u.kpc,
psp = gc.PhaseSpacePosition(pos=[-8.1, 0, 0.02] * u.kpc,
vel=[13, 245, 8.] * u.km/u.s)
orbit = mw.integrate_orbit(psp.w(), dt=1*u.Myr, t1=0, t2=2*u.Gyr)

Expand Down
11 changes: 11 additions & 0 deletions src/galax/coordinates/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
""":mod:`galax.coordinates`."""

from . import _base, _core, _utils
from ._base import *
from ._core import *
from ._utils import *

__all__: list[str] = []
__all__ += _base.__all__
__all__ += _core.__all__
__all__ += _utils.__all__
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from galax.units import UnitSystem
from galax.utils._shape import atleast_batched

from .utils import getitem_time_index
from ._utils import getitem_time_index

if TYPE_CHECKING:
from typing import Self
Expand Down Expand Up @@ -215,7 +215,7 @@ def angular_momentum(self) -> BatchVec3:

>>> import numpy as np
>>> import astropy.units as u
>>> from galax.dynamics import PhaseSpacePosition
>>> from galax.coordinates import PhaseSpacePosition

We can compute the angular momentum of a single object

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from galax.utils._shape import batched_shape, expand_batch_dims
from galax.utils.dataclasses import converter_float_array

from .base import AbstractPhaseSpacePosition
from ._base import AbstractPhaseSpacePosition

if TYPE_CHECKING:
from galax.potential._potential.base import AbstractPotentialBase
Expand Down
File renamed without changes.
8 changes: 2 additions & 6 deletions src/galax/dynamics/_dynamics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
"""galax: Galactic Dynamix in Jax."""

from . import base, core, mockstream, orbit
from .base import *
from .core import *
from . import mockstream, orbit
from .mockstream import *
from .orbit import *

__all__: list[str] = ["mockstream"]
__all__ += base.__all__
__all__ += core.__all__
__all__ += orbit.__all__
__all__ += mockstream.__all__


# Cleanup
del base, core, orbit
del orbit
4 changes: 2 additions & 2 deletions src/galax/dynamics/_dynamics/mockstream/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import jax
import jax.numpy as jnp

from galax.dynamics._dynamics.base import AbstractPhaseSpacePosition
from galax.dynamics._dynamics.utils import getitem_vectime_index
from galax.coordinates import AbstractPhaseSpacePosition
from galax.coordinates._utils import getitem_vectime_index
from galax.typing import BatchFloatScalar, BroadBatchVec3, VecTime
from galax.utils._shape import batched_shape
from galax.utils.dataclasses import converter_float_array
Expand Down
6 changes: 3 additions & 3 deletions src/galax/dynamics/_dynamics/orbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
import jax.numpy as jnp
from astropy.units import Quantity

from galax.coordinates import AbstractPhaseSpacePosition
from galax.coordinates._utils import getitem_vectime_index
from galax.integrate import DiffraxIntegrator, Integrator
from galax.potential._potential.base import AbstractPotentialBase
from galax.typing import BatchFloatScalar, BatchVec6, BroadBatchVec3, VecTime
from galax.utils._shape import batched_shape
from galax.utils.dataclasses import converter_float_array

from .base import AbstractPhaseSpacePosition
from .utils import getitem_vectime_index

if TYPE_CHECKING:
from typing import Self

Expand Down Expand Up @@ -118,6 +117,7 @@ def energy(
##############################################################################


# TODO: enable setting the default integrator
_default_integrator: Integrator = DiffraxIntegrator()


Expand Down
Empty file.
9 changes: 9 additions & 0 deletions tests/smoke/coordinates/test_package.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Testing :mod:`galax.dynamics` module."""

import galax.coordinates as gc
from galax.coordinates import _base, _core, _utils


def test_all() -> None:
"""Test the `galax.potential` API."""
assert set(gc.__all__) == {*_base.__all__, *_core.__all__, *_utils.__all__}
10 changes: 2 additions & 8 deletions tests/smoke/dynamics/test_package.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
"""Testing :mod:`galax.dynamics` module."""

import galax.dynamics as gd
from galax.dynamics._dynamics import base, core, mockstream, orbit
from galax.dynamics._dynamics import mockstream, orbit


def test_all() -> None:
"""Test the `galax.potential` API."""
assert set(gd.__all__) == {
"mockstream",
*base.__all__,
*core.__all__,
*orbit.__all__,
*mockstream.__all__,
}
assert set(gd.__all__) == {"mockstream", *orbit.__all__, *mockstream.__all__}
2 changes: 1 addition & 1 deletion tests/unit/dynamics/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from jax import random

from galax.dynamics import PhaseSpacePosition
from galax.coordinates import PhaseSpacePosition
from galax.units import galactic

Shape: TypeAlias = tuple[int, ...]
Expand Down
Loading