Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve passing callable function as a parameter #33

Merged
merged 2 commits into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/galdynamix/potential/_potential/param/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class ConstantParameter(AbstractParameter):
# TODO: link this shape to the return shape from __call__
value: FloatArrayAnyShape = eqx.field(converter=converter_float_array)

# This is a workaround since vectorized methods don't support kwargs.
@partial_jit()
@vectorize_method(signature="()->()")
def _call_helper(self, _: FloatOrIntScalar) -> ArrayAnyShape:
Expand Down Expand Up @@ -93,6 +94,7 @@ def __call__(

#####################################################################
# User-defined Parameter
# For passing a function as a parameter.


@runtime_checkable
Expand All @@ -118,11 +120,19 @@ def __call__(self, t: FloatScalar, **kwargs: Any) -> ArrayAnyShape:


class UserParameter(AbstractParameter):
"""User-defined Parameter."""
"""User-defined Parameter.

Parameters
----------
func : Callable[[Array[float, ()] | float | int], Array[float, (*shape,)]]
The function to use to compute the parameter value.
unit : Unit, keyword-only
The output unit of the parameter.
"""

# TODO: unit handling
func: ParameterCallable
func: ParameterCallable = eqx.field(static=True)

@partial_jit()
def __call__(self, t: FloatScalar, **kwargs: Any) -> ArrayAnyShape:
def __call__(self, t: FloatOrIntScalar, **kwargs: Any) -> FloatArrayAnyShape:
return self.func(t, **kwargs)
80 changes: 71 additions & 9 deletions src/galdynamix/potential/_potential/param/field.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
__all__ = ["ParameterField"]

from dataclasses import KW_ONLY, dataclass, field, is_dataclass
from typing import Any, cast, overload
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints, overload

import astropy.units as u
import jax.numpy as xp

from galdynamix.potential._potential.core import AbstractPotential
from galdynamix.typing import Unit

from .core import AbstractParameter, ConstantParameter, ParameterCallable, UserParameter

Expand All @@ -27,7 +28,7 @@ class ParameterField:

name: str = field(init=False)
_: KW_ONLY
dimensions: u.PhysicalType # TODO: add a converter_argument
dimensions: u.PhysicalType
equivalencies: u.Equivalency | tuple[u.Equivalency, ...] | None = None

def __post_init__(self) -> None:
Expand Down Expand Up @@ -76,20 +77,34 @@ def __get__( # TODO: use `Self` when beartype is happy

# -----------------------------

def _check_unit(self, potential: AbstractPotential, unit: Unit) -> None:
"""Check that the given unit is compatible with the parameter's."""
if not unit.is_equivalent(
potential.units[self.dimensions],
equivalencies=self.equivalencies,
):
msg = (
"Parameter function must return a value "
f"with units equivalent to {self.dimensions}"
)
raise ValueError(msg)

def __set__(
self,
potential: AbstractPotential,
value: AbstractParameter | ParameterCallable | Any,
) -> None:
# Convert
if isinstance(value, AbstractParameter):
# TODO: use the dimensions & equivalencies info to check the parameters.
# TODO: use the units on the `potential` to convert the parameter value.
pass
# TODO: this doesn't handle the correct output unit, a. la.
# potential.units[self.dimensions]
self._check_unit(potential, value.unit) # Check the unit is compatible
elif callable(value):
# TODO: use the dimensions & equivalencies info to check the parameters.
# TODO: use the units on the `potential` to convert the parameter value.
value = UserParameter(func=value)
# TODO: this only gets the existing unit, it doesn't handle the
# correct output unit, a. la. potential.units[self.dimensions]
unit = _get_unit_from_return_annotation(value)
self._check_unit(potential, unit) # Check the unit is compatible
value = UserParameter(func=value, unit=unit)
else:
# TODO: the issue here is that ``units`` hasn't necessarily been set
# on the potential yet. What is needed is to possibly bail out
Expand All @@ -100,8 +115,55 @@ def __set__(
unit = potential.units[self.dimensions]
if isinstance(value, u.Quantity):
value = value.to_value(unit, equivalencies=self.equivalencies)

value = ConstantParameter(xp.asarray(value), unit=unit)

# Set
potential.__dict__[self.name] = value


def _get_unit_from_return_annotation(func: ParameterCallable) -> Unit:
"""Get the unit from the return annotation of a Parameter function.

Parameters
----------
func : Callable[[Array[float, ()] | float | int], Array[float, (*shape,)]]
The function to use to compute the parameter value.

Returns
-------
Unit
The unit from the return annotation of the function.
"""
# Get the return annotation
type_hints = get_type_hints(func, include_extras=True)
if "return" not in type_hints:
msg = "Parameter function must have a return annotation"
raise TypeError(msg)

# Check that the return annotation might contain a unit
return_annotation = type_hints["return"]
return_origin = get_origin(return_annotation)
if return_origin is not Annotated:
msg = "Parameter function return annotation must be annotated"
raise TypeError(msg)

# Get the unit from the return annotation
return_args = get_args(return_annotation)
has_unit = False
for arg in return_args[1:]:
# Try to convert the argument to a unit
try:
unit = u.Unit(arg)
except ValueError:
continue
# Only one unit annotation is allowed
if has_unit:
msg = "function has more than one unit annotation"
raise ValueError(msg)
has_unit = True

if not has_unit:
msg = "function did not have a valid unit annotation"
raise ValueError(msg)

return unit