Skip to content

Commit

Permalink
better units argument converter (#17)
Browse files Browse the repository at this point in the history
* better units converter

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Dec 8, 2023
1 parent ef8ce99 commit 283e530
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 12 deletions.
11 changes: 4 additions & 7 deletions src/galdynamix/potential/_potential/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import jax.numpy as xp

from galdynamix.typing import BatchableFloatLike, BatchFloatScalar, BatchVec3
from galdynamix.units import UnitSystem, dimensionless
from galdynamix.units import UnitSystem
from galdynamix.utils import ImmutableDict, partial_jit
from galdynamix.utils._misc import first

from .base import AbstractPotentialBase
from .utils import converter_to_usys

K = TypeVar("K")
V = TypeVar("V")
Expand All @@ -25,12 +26,8 @@ class CompositePotential(ImmutableDict[AbstractPotentialBase], AbstractPotential

_data: dict[str, AbstractPotentialBase]
_: KW_ONLY
units: UnitSystem = eqx.field(
init=False,
static=True,
converter=lambda x: dimensionless if x is None else UnitSystem(x),
)
_G: float = eqx.field(init=False, static=True, repr=False)
units: UnitSystem = eqx.field(init=False, static=True, converter=converter_to_usys)
_G: float = eqx.field(init=False, static=True, repr=False, converter=float)

def __init__(
self,
Expand Down
9 changes: 4 additions & 5 deletions src/galdynamix/potential/_potential/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,19 @@

import equinox as eqx

from galdynamix.units import UnitSystem, dimensionless
from galdynamix.units import UnitSystem

from .base import AbstractPotentialBase
from .composite import CompositePotential
from .utils import converter_to_usys


class AbstractPotential(AbstractPotentialBase):
_: KW_ONLY
units: UnitSystem = eqx.field(
default=None,
converter=lambda x: dimensionless if x is None else UnitSystem(x),
static=True,
default=None, converter=converter_to_usys, static=True
)
_G: float = eqx.field(init=False, static=True, repr=False)
_G: float = eqx.field(init=False, static=True, repr=False, converter=float)

def __post_init__(self) -> None:
self._init_units()
Expand Down
42 changes: 42 additions & 0 deletions src/galdynamix/potential/_potential/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""galdynamix: Galactic Dynamix in Jax."""


from functools import singledispatch
from typing import Any

from galdynamix.units import UnitSystem, dimensionless, galactic, solarsystem


@singledispatch
def converter_to_usys(value: Any, /) -> UnitSystem:
"""Argument to ``eqx.field(converter=...)``."""
msg = f"cannot convert {value} to a UnitSystem"
raise NotImplementedError(msg)


@converter_to_usys.register
def _from_usys(value: UnitSystem, /) -> UnitSystem:
return value


@converter_to_usys.register
def _from_none(value: None, /) -> UnitSystem:
return dimensionless


@converter_to_usys.register(tuple)
def _from_args(value: tuple[Any, ...], /) -> UnitSystem:
return UnitSystem(*value)


@converter_to_usys.register
def _from_named(value: str, /) -> UnitSystem:
if value == "dimensionless":
return dimensionless
if value == "solarsystem":
return solarsystem
if value == "galactic":
return galactic

msg = f"cannot convert {value} to a UnitSystem"
raise NotImplementedError(msg)

0 comments on commit 283e530

Please sign in to comment.