Skip to content

Commit

Permalink
typing hints (#21)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Dec 9, 2023
1 parent b514329 commit 50d3eac
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 15 deletions.
6 changes: 3 additions & 3 deletions src/galdynamix/dynamics/mockstream/_df/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from galdynamix.dynamics._orbit import Orbit
from galdynamix.dynamics.mockstream._core import MockStream
from galdynamix.potential._potential.base import AbstractPotentialBase
from galdynamix.typing import BatchVec3, FloatScalar, IntegerLike, Vec3, Vec6
from galdynamix.typing import BatchVec3, FloatScalar, IntLike, Vec3, Vec6
from galdynamix.utils import partial_jit

Wif: TypeAlias = tuple[Vec3, Vec3, Vec3, Vec3]
Carry: TypeAlias = tuple[IntegerLike, Vec3, Vec3, Vec3, Vec3]
Carry: TypeAlias = tuple[IntLike, Vec3, Vec3, Vec3, Vec3]


class AbstractStreamDF(eqx.Module): # type: ignore[misc]
Expand Down Expand Up @@ -97,7 +97,7 @@ def _sample(
prog_mass: FloatScalar,
t: FloatScalar,
*,
i: IntegerLike,
i: IntLike,
seed_num: int,
) -> tuple[BatchVec3, BatchVec3, BatchVec3, BatchVec3]:
"""Generate stream particle initial conditions.
Expand Down
4 changes: 2 additions & 2 deletions src/galdynamix/dynamics/mockstream/_df/fardal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import jax.numpy as xp

from galdynamix.potential._potential.base import AbstractPotentialBase
from galdynamix.typing import FloatScalar, IntegerLike, Vec3, Vec6
from galdynamix.typing import FloatScalar, IntLike, Vec3, Vec6
from galdynamix.utils import partial_jit

from .base import AbstractStreamDF
Expand All @@ -23,7 +23,7 @@ def _sample(
prog_mass: FloatScalar,
t: FloatScalar,
*,
i: IntegerLike,
i: IntLike,
seed_num: int,
) -> tuple[Vec3, Vec3, Vec3, Vec3]:
"""Generate stream particle initial conditions."""
Expand Down
8 changes: 3 additions & 5 deletions src/galdynamix/dynamics/mockstream/_mockstream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from galdynamix.potential._potential.base import AbstractPotentialBase
from galdynamix.typing import (
FloatScalar,
IntegerScalar,
IntScalar,
TimeVector,
Vec6,
VecN,
Expand All @@ -27,7 +27,7 @@

from ._df import AbstractStreamDF

Carry: TypeAlias = tuple[IntegerScalar, VecN, VecN]
Carry: TypeAlias = tuple[IntScalar, VecN, VecN]


def _converter_immutabledict_or_none(x: Any) -> ImmutableDict[Any] | None:
Expand Down Expand Up @@ -92,9 +92,7 @@ def _run_scan(
qp0_lead = mock0_lead.qp
qp0_trail = mock0_trail.qp

def scan_fn(
carry: Carry, idx: IntegerScalar
) -> tuple[Carry, tuple[VecN, VecN]]:
def scan_fn(carry: Carry, idx: IntScalar) -> tuple[Carry, tuple[VecN, VecN]]:
i, qp0_lead_i, qp0_trail_i = carry
qp0_lead_trail = xp.vstack([qp0_lead_i, qp0_trail_i])
t_i, t_f = ts[i], ts[-1]
Expand Down
27 changes: 22 additions & 5 deletions src/galdynamix/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,37 @@
# =============================================================================
# Scalars

IntegerScalar = Integer[Scalar, ""]
IntScalar = Integer[Scalar, ""]
"""An integer scalar."""

IntegerLike = IntegerScalar | int
IntLike = IntScalar | int
"""An integer or integer scalar."""

FloatScalar = Float[Scalar, ""]
"""A float scalar."""

FloatLike = FloatScalar | float
FloatLike = FloatScalar | float | int
"""A float(/int) or float scalar."""

FloatOrIntScalar = FloatScalar | IntScalar
"""A float or integer scalar."""

FloatOrIntScalarLike = FloatLike | IntLike
"""A float or integer or float(/int) scalar."""


# =============================================================================
# Vectors

# -----------------------------------------------------------------------------
# Shaped

Vec3 = Float[Array, "3"]
"""A 3-vector, e.g. q=(x, y, z) or p=(vx, vy, vz)."""

Matrix33 = Float[Array, "3 3"]
"""A 3x3 matrix."""

Vec6 = Float[Array, "6"]
"""A 6-vector e.g. qp=(x, y, z, vx, vy, vz)."""

Expand All @@ -43,6 +56,7 @@
BatchableFloatLike = BatchFloatScalar | FloatLike

# -----------------
# Shaped

BatchVec3 = Shaped[Vec3, "*batch"]
"""Zero or more batches of 3-vectors."""
Expand All @@ -53,11 +67,14 @@
BatchVec7 = Shaped[Vec7, "*batch"]
"""Zero or more batches of 7-vectors."""

VecN = Float[Array, "N"]

ArrayAnyShape = Float[Array, "..."]
"""An array with any shape."""

# =============================================================================
# Specific Vectors

VecN = Float[Array, "N"]
"""An (N,)-vector."""

TimeVector = Float[Array, "time"]
"""A vector of times."""

0 comments on commit 50d3eac

Please sign in to comment.