-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
better units argument converter (#17)
* better units converter Signed-off-by: nstarman <nstarman@users.noreply.github.com>
- Loading branch information
Showing
3 changed files
with
50 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
"""galdynamix: Galactic Dynamix in Jax.""" | ||
|
||
|
||
from functools import singledispatch | ||
from typing import Any | ||
|
||
from galdynamix.units import UnitSystem, dimensionless, galactic, solarsystem | ||
|
||
|
||
@singledispatch | ||
def converter_to_usys(value: Any, /) -> UnitSystem: | ||
"""Argument to ``eqx.field(converter=...)``.""" | ||
msg = f"cannot convert {value} to a UnitSystem" | ||
raise NotImplementedError(msg) | ||
|
||
|
||
@converter_to_usys.register | ||
def _from_usys(value: UnitSystem, /) -> UnitSystem: | ||
return value | ||
|
||
|
||
@converter_to_usys.register | ||
def _from_none(value: None, /) -> UnitSystem: | ||
return dimensionless | ||
|
||
|
||
@converter_to_usys.register(tuple) | ||
def _from_args(value: tuple[Any, ...], /) -> UnitSystem: | ||
return UnitSystem(*value) | ||
|
||
|
||
@converter_to_usys.register | ||
def _from_named(value: str, /) -> UnitSystem: | ||
if value == "dimensionless": | ||
return dimensionless | ||
if value == "solarsystem": | ||
return solarsystem | ||
if value == "galactic": | ||
return galactic | ||
|
||
msg = f"cannot convert {value} to a UnitSystem" | ||
raise NotImplementedError(msg) |