Skip to content

Commit

Permalink
refactor!: let Model.lambdify return Function
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Nov 25, 2021
1 parent d6f70cf commit 195c8e1
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 52 deletions.
2 changes: 1 addition & 1 deletion docs/usage/basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@
"print(\n",
" format_str(\n",
" inspect.getsource(\n",
" function_1d._ParametrizedBackendFunction__lambdified_model\n",
" function_1d._ParametrizedBackendFunction__function.function\n",
" ),\n",
" mode=FileMode(),\n",
" )\n",
Expand Down
13 changes: 2 additions & 11 deletions src/tensorwaves/interface.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
"""Defines top-level interface of tensorwaves."""

from abc import ABC, abstractmethod
from typing import (
Any,
Callable,
Dict,
FrozenSet,
Mapping,
Optional,
Tuple,
Union,
)
from typing import Any, Dict, FrozenSet, Mapping, Optional, Tuple, Union

import attr
import numpy as np
Expand Down Expand Up @@ -79,7 +70,7 @@ class Model(ABC):
"""Interface of a model which can be lambdified into a callable."""

@abstractmethod
def lambdify(self, backend: Union[str, tuple, dict]) -> Callable:
def lambdify(self, backend: Union[str, tuple, dict]) -> Function:
"""Lambdify the model into a Callable.
Args:
Expand Down
13 changes: 3 additions & 10 deletions src/tensorwaves/model/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,12 @@ def __init__(
model: Model,
backend: Union[str, tuple, dict] = "numpy",
) -> None:
self.__lambdified_model = model.lambdify(backend=backend)
self.__function = model.lambdify(backend=backend)
self.__parameters = model.parameters
self.__ordered_args = model.argument_order

def __call__(self, dataset: DataSample) -> np.ndarray:
return self.__lambdified_model(
*[
dataset[var_name]
if var_name in dataset
else self.__parameters[var_name]
for var_name in self.__ordered_args
],
)
extended_data = {**dataset, **self.__parameters} # type: ignore[arg-type]
return self.__function(extended_data)

@property
def parameters(self) -> Dict[str, ParameterValue]:
Expand Down
60 changes: 38 additions & 22 deletions src/tensorwaves/model/sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from tqdm.auto import tqdm

from tensorwaves._backend import get_backend_modules
from tensorwaves.interface import DataSample, Model, ParameterValue
from tensorwaves.interface import DataSample, Function, Model, ParameterValue

from .function import PositionalArgumentFunction

_jax_known_functions = {
k: v.replace("numpy.", "jnp.") for k, v in _numpy_known_functions.items()
Expand Down Expand Up @@ -106,7 +108,7 @@ def optimized_lambdify(
min_complexity: int = 0,
max_complexity: int,
**kwargs: Any,
) -> Callable:
) -> PositionalArgumentFunction:
"""Speed up `~sympy.utilities.lambdify.lambdify` with `.split_expression`.
.. seealso:: :doc:`/usage/faster-lambdify`
Expand All @@ -117,33 +119,45 @@ def optimized_lambdify(
max_complexity=max_complexity,
)
if not sub_expressions:
return _backend_lambdify(top_expression, symbols, backend, **kwargs)
return lambdify(top_expression, symbols, backend, **kwargs)

sorted_top_symbols = sorted(sub_expressions, key=lambda s: s.name)
top_function = _backend_lambdify(
top_function = lambdify(
top_expression, sorted_top_symbols, backend, **kwargs
)
sub_functions: List[Callable] = []
sub_functions: List[PositionalArgumentFunction] = []
for symbol in tqdm(
iterable=sorted_top_symbols,
desc="Lambdifying sub-expressions",
unit="expr",
disable=not _use_progress_bar(),
):
sub_expression = sub_expressions[symbol]
sub_function = _backend_lambdify(
sub_expression, symbols, backend, **kwargs
)
sub_function = lambdify(sub_expression, symbols, backend, **kwargs)
sub_functions.append(sub_function)

def recombined_function(*args: Any) -> Any:
new_args = [sub_function(*args) for sub_function in sub_functions]
return top_function(*new_args)
new_args = [func.function(*args) for func in sub_functions]
return top_function.function(*new_args)

return PositionalArgumentFunction(
recombined_function, argument_order=map(str, symbols)
)


return recombined_function
def lambdify(
expression: sp.Expr,
symbols: Sequence[sp.Symbol],
backend: Union[str, tuple, dict],
**kwargs: Any,
) -> PositionalArgumentFunction:
return PositionalArgumentFunction(
function=_wrapped_lambdify(expression, symbols, backend, **kwargs),
argument_order=map(str, symbols),
)


def _backend_lambdify(
def _wrapped_lambdify(
expression: sp.Expr,
symbols: Sequence[sp.Symbol],
backend: Union[str, tuple, dict],
Expand Down Expand Up @@ -268,22 +282,22 @@ def __replace_constant_sub_expressions(
{v: k for k, v in self.__constant_sub_expressions.items()}
)

def lambdify(self, backend: Union[str, tuple, dict]) -> Callable:
def lambdify(self, backend: Union[str, tuple, dict]) -> Function:
input_symbols = tuple(self.__expression.free_symbols)
lambdified_model = _backend_lambdify(
lambdified_model = lambdify(
expression=self.__expression,
symbols=input_symbols,
backend=backend,
)
constant_input_storage = {}
for placeholder, sub_expr in self.__constant_sub_expressions.items():
temp_lambdify = _backend_lambdify(
temp_lambdify = lambdify(
expression=sub_expr,
symbols=tuple(sub_expr.free_symbols),
backend=backend,
)
free_symbol_names = {x.name for x in sub_expr.free_symbols}
constant_input_storage[placeholder.name] = temp_lambdify(
constant_input_storage[placeholder.name] = temp_lambdify.function(
*(self.__fix_inputs[k] for k in free_symbol_names)
)

Expand All @@ -307,11 +321,13 @@ def update_args(*args: Tuple[Any, ...]) -> None:

def wrapper(*args: Tuple[Any, ...]) -> Any:
update_args(*args)
return lambdified_model(*input_args)
return lambdified_model.function(*input_args)

return wrapper
return PositionalArgumentFunction(
wrapper, argument_order=self.argument_order
)

def performance_optimize(self, fix_inputs: DataSample) -> "Model":
def performance_optimize(self, fix_inputs: DataSample) -> Model:
return NotImplemented

@property
Expand Down Expand Up @@ -390,10 +406,10 @@ def __init__(
)
self.max_complexity = max_complexity

def lambdify(self, backend: Union[str, tuple, dict]) -> Callable:
def lambdify(self, backend: Union[str, tuple, dict]) -> Function:
"""Lambdify the model using `~sympy.utilities.lambdify.lambdify`."""
if self.max_complexity is None:
return _backend_lambdify(
return lambdify(
expression=self.__expression,
symbols=self.__argument_order,
backend=backend,
Expand All @@ -405,7 +421,7 @@ def lambdify(self, backend: Union[str, tuple, dict]) -> Callable:
max_complexity=self.max_complexity,
)

def performance_optimize(self, fix_inputs: DataSample) -> "Model":
def performance_optimize(self, fix_inputs: DataSample) -> Model:
return _ConstantSubExpressionSympyModel(
self.__expression, self.__parameters, fix_inputs
)
Expand Down
16 changes: 8 additions & 8 deletions tests/unit/model/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_optimized_lambdify(backend: str, max_complexity: int):
backend=backend,
)

func_repr = str(function)
func_repr = str(function.function)
if max_complexity <= 3:
assert func_repr.startswith("<function optimized_lambdify.<locals>")
else:
Expand All @@ -110,13 +110,13 @@ def test_optimized_lambdify(backend: str, max_complexity: int):
repr_start = "<CompiledFunction object at 0x"
assert func_repr.startswith(repr_start)

data = (
np.array([1, 2]),
np.array([1, np.e]),
np.array([1, 2]),
)
output = function(*data)
expected = create_expression(*data)
data: DataSample = {
"x": np.array([1, 2]),
"y": np.array([1, np.e]),
"z": np.array([1, 2]),
}
output = function(data)
expected = create_expression(*data.values())
assert pytest.approx(output) == expected


Expand Down

0 comments on commit 195c8e1

Please sign in to comment.