diff --git a/src/galdynamix/potential/_potential/param/core.py b/src/galdynamix/potential/_potential/param/core.py index ba5f3397..194cf500 100644 --- a/src/galdynamix/potential/_potential/param/core.py +++ b/src/galdynamix/potential/_potential/param/core.py @@ -63,6 +63,7 @@ class ConstantParameter(AbstractParameter): # TODO: link this shape to the return shape from __call__ value: FloatArrayAnyShape = eqx.field(converter=converter_float_array) + # This is a workaround since vectorized methods don't support kwargs. @partial_jit() @vectorize_method(signature="()->()") def _call_helper(self, _: FloatOrIntScalar) -> ArrayAnyShape: @@ -93,6 +94,7 @@ def __call__( ##################################################################### # User-defined Parameter +# For passing a function as a parameter. @runtime_checkable @@ -118,11 +120,19 @@ def __call__(self, t: FloatScalar, **kwargs: Any) -> ArrayAnyShape: class UserParameter(AbstractParameter): - """User-defined Parameter.""" + """User-defined Parameter. + + Parameters + ---------- + func : Callable[[Array[float, ()] | float | int], Array[float, (*shape,)]] + The function to use to compute the parameter value. + unit : Unit, keyword-only + The output unit of the parameter. + """ # TODO: unit handling - func: ParameterCallable + func: ParameterCallable = eqx.field(static=True) @partial_jit() - def __call__(self, t: FloatScalar, **kwargs: Any) -> ArrayAnyShape: + def __call__(self, t: FloatOrIntScalar, **kwargs: Any) -> FloatArrayAnyShape: return self.func(t, **kwargs) diff --git a/src/galdynamix/potential/_potential/param/field.py b/src/galdynamix/potential/_potential/param/field.py index 14348d97..cb390854 100644 --- a/src/galdynamix/potential/_potential/param/field.py +++ b/src/galdynamix/potential/_potential/param/field.py @@ -1,12 +1,13 @@ __all__ = ["ParameterField"] from dataclasses import KW_ONLY, dataclass, field, is_dataclass -from typing import Any, cast, overload +from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints, overload import astropy.units as u import jax.numpy as xp from galdynamix.potential._potential.core import AbstractPotential +from galdynamix.typing import Unit from .core import AbstractParameter, ConstantParameter, ParameterCallable, UserParameter @@ -27,7 +28,7 @@ class ParameterField: name: str = field(init=False) _: KW_ONLY - dimensions: u.PhysicalType # TODO: add a converter_argument + dimensions: u.PhysicalType equivalencies: u.Equivalency | tuple[u.Equivalency, ...] | None = None def __post_init__(self) -> None: @@ -76,6 +77,18 @@ def __get__( # TODO: use `Self` when beartype is happy # ----------------------------- + def _check_unit(self, potential: AbstractPotential, unit: Unit) -> None: + """Check that the given unit is compatible with the parameter's.""" + if not unit.is_equivalent( + potential.units[self.dimensions], + equivalencies=self.equivalencies, + ): + msg = ( + "Parameter function must return a value " + f"with units equivalent to {self.dimensions}" + ) + raise ValueError(msg) + def __set__( self, potential: AbstractPotential, @@ -83,13 +96,15 @@ def __set__( ) -> None: # Convert if isinstance(value, AbstractParameter): - # TODO: use the dimensions & equivalencies info to check the parameters. - # TODO: use the units on the `potential` to convert the parameter value. - pass + # TODO: this doesn't handle the correct output unit, a. la. + # potential.units[self.dimensions] + self._check_unit(potential, value.unit) # Check the unit is compatible elif callable(value): - # TODO: use the dimensions & equivalencies info to check the parameters. - # TODO: use the units on the `potential` to convert the parameter value. - value = UserParameter(func=value) + # TODO: this only gets the existing unit, it doesn't handle the + # correct output unit, a. la. potential.units[self.dimensions] + unit = _get_unit_from_return_annotation(value) + self._check_unit(potential, unit) # Check the unit is compatible + value = UserParameter(func=value, unit=unit) else: # TODO: the issue here is that ``units`` hasn't necessarily been set # on the potential yet. What is needed is to possibly bail out @@ -100,8 +115,55 @@ def __set__( unit = potential.units[self.dimensions] if isinstance(value, u.Quantity): value = value.to_value(unit, equivalencies=self.equivalencies) - value = ConstantParameter(xp.asarray(value), unit=unit) # Set potential.__dict__[self.name] = value + + +def _get_unit_from_return_annotation(func: ParameterCallable) -> Unit: + """Get the unit from the return annotation of a Parameter function. + + Parameters + ---------- + func : Callable[[Array[float, ()] | float | int], Array[float, (*shape,)]] + The function to use to compute the parameter value. + + Returns + ------- + Unit + The unit from the return annotation of the function. + """ + # Get the return annotation + type_hints = get_type_hints(func, include_extras=True) + if "return" not in type_hints: + msg = "Parameter function must have a return annotation" + raise TypeError(msg) + + # Check that the return annotation might contain a unit + return_annotation = type_hints["return"] + return_origin = get_origin(return_annotation) + if return_origin is not Annotated: + msg = "Parameter function return annotation must be annotated" + raise TypeError(msg) + + # Get the unit from the return annotation + return_args = get_args(return_annotation) + has_unit = False + for arg in return_args[1:]: + # Try to convert the argument to a unit + try: + unit = u.Unit(arg) + except ValueError: + continue + # Only one unit annotation is allowed + if has_unit: + msg = "function has more than one unit annotation" + raise ValueError(msg) + has_unit = True + + if not has_unit: + msg = "function did not have a valid unit annotation" + raise ValueError(msg) + + return unit