Skip to content

Commit

Permalink
FIX: set i/o types for function implementations (#522)
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Mar 7, 2024
1 parent bb8b60b commit 4c10317
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 18 deletions.
8 changes: 5 additions & 3 deletions benchmarks/ampform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from ampform.helicity import HelicityModel
from qrules.combinatorics import StateDefinition

from tensorwaves.function import ParametrizedBackendFunction
from tensorwaves.interface import (
DataSample,
FitResult,
Function,
ParameterValue,
ParametrizedFunction,
)
Expand Down Expand Up @@ -55,7 +57,7 @@ def formulate_amplitude_model(

def create_function(
model: HelicityModel, backend: str, max_complexity: int | None = None
) -> ParametrizedFunction:
) -> ParametrizedBackendFunction:
return create_parametrized_function(
expression=model.expression.doit(),
parameters=model.parameter_defaults,
Expand All @@ -66,7 +68,7 @@ def create_function(

def generate_data(
model: HelicityModel,
function: ParametrizedFunction,
function: Function[DataSample, np.ndarray],
data_sample_size: int,
phsp_sample_size: int,
backend: str,
Expand Down Expand Up @@ -103,7 +105,7 @@ def generate_data(
def fit(
data: DataSample,
phsp: DataSample,
function: ParametrizedFunction,
function: ParametrizedFunction[DataSample, np.ndarray],
initial_parameters: Mapping[str, ParameterValue],
backend: str,
) -> FitResult:
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _generate_domain(

def _generate_data(
size: int,
function: Function,
function: Function[DataSample, np.ndarray],
rng: np.random.Generator,
bunch_size: int = 10_000,
) -> DataSample:
Expand Down
2 changes: 1 addition & 1 deletion src/tensorwaves/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class IntensityDistributionGenerator(DataGenerator):
def __init__(
self,
domain_generator: DataGenerator,
function: Function,
function: Function[DataSample, np.ndarray],
domain_transformer: DataTransformer | None = None,
bunch_size: int = 50_000,
) -> None:
Expand Down
7 changes: 5 additions & 2 deletions src/tensorwaves/data/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ._attrs import to_tuple

if TYPE_CHECKING: # pragma: no cover
import numpy as np
import sympy as sp


Expand Down Expand Up @@ -55,7 +56,9 @@ def __call__(self, data: DataSample) -> DataSample:
class SympyDataTransformer(DataTransformer):
"""Implementation of a `.DataTransformer`."""

def __init__(self, functions: Mapping[str, Function]) -> None:
def __init__(
self, functions: Mapping[str, Function[DataSample, np.ndarray]]
) -> None:
if any(not isinstance(f, Function) for f in functions.values()):
msg = (
f"Not all values in the mapping are an instance of {Function.__name__}"
Expand All @@ -64,7 +67,7 @@ def __init__(self, functions: Mapping[str, Function]) -> None:
self.__functions = dict(functions)

@property
def functions(self) -> dict[str, Function]:
def functions(self) -> dict[str, Function[DataSample, np.ndarray]]:
"""Read-only access to the internal mapping of functions."""
return dict(self.__functions)

Expand Down
6 changes: 3 additions & 3 deletions src/tensorwaves/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def create_cached_function(
backend: str,
free_parameters: Iterable[sp.Symbol],
use_cse: bool = True,
) -> tuple[ParametrizedFunction, DataTransformer]:
) -> tuple[ParametrizedFunction[DataSample, np.ndarray], DataTransformer]:
"""Create a function and data transformer for cached computations.
Once it is known which parameters in an expression are to be optimized, this
Expand Down Expand Up @@ -118,7 +118,7 @@ class ChiSquared(Estimator):

def __init__( # noqa: PLR0913
self,
function: ParametrizedFunction,
function: ParametrizedFunction[DataSample, np.ndarray],
domain: DataSample,
observed_values: np.ndarray,
weights: np.ndarray | None = None,
Expand Down Expand Up @@ -185,7 +185,7 @@ class UnbinnedNLL(Estimator):

def __init__( # noqa: PLR0913
self,
function: ParametrizedFunction,
function: ParametrizedFunction[DataSample, np.ndarray],
data: DataSample,
phsp: DataSample,
phsp_volume: float = 1.0,
Expand Down
10 changes: 4 additions & 6 deletions src/tensorwaves/function/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Callable, Iterable, Mapping
from typing import Callable, Iterable, Mapping

import attrs
import numpy as np
from attrs import field, frozen

from tensorwaves.interface import (
Expand All @@ -15,9 +16,6 @@
ParametrizedFunction,
)

if TYPE_CHECKING:
import numpy as np


def _all_str(
_: PositionalArgumentFunction, __: attrs.Attribute, value: Iterable[str]
Expand Down Expand Up @@ -66,7 +64,7 @@ def _to_tuple(argument_order: Iterable[str]) -> tuple[str, ...]:


@frozen
class PositionalArgumentFunction(Function):
class PositionalArgumentFunction(Function[DataSample, np.ndarray]):
"""Wrapper around a function with positional arguments.
This class provides a :meth:`~.Function.__call__` that can take a `.DataSample` for
Expand All @@ -90,7 +88,7 @@ def __call__(self, data: DataSample) -> np.ndarray:
return self.function(*args)


class ParametrizedBackendFunction(ParametrizedFunction):
class ParametrizedBackendFunction(ParametrizedFunction[DataSample, np.ndarray]):
"""Implements `.ParametrizedFunction` for a specific computational back-end.
.. seealso:: :func:`.create_parametrized_function`
Expand Down
2 changes: 1 addition & 1 deletion src/tensorwaves/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __call__(self, data: InputType) -> OutputType: ...
"""Allowed types for parameter values."""


class ParametrizedFunction(Function[DataSample, np.ndarray]):
class ParametrizedFunction(Function[InputType, OutputType]):
"""Interface of a callable function.
A `ParametrizedFunction` identifies certain variables in a mathematical expression
Expand Down
2 changes: 1 addition & 1 deletion tests/optimizer/test_fit_simple_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def generate_domain(
def generate_data(
size: int,
boundaries: dict[str, tuple[float, float]],
function: Function,
function: Function[DataSample, np.ndarray],
rng: np.random.Generator,
bunch_size: int = 10_000,
) -> DataSample:
Expand Down

0 comments on commit 4c10317

Please sign in to comment.