Skip to content

Commit

Permalink
feat: add option to lambdify with SymPy's cse (#374)
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Dec 7, 2021
1 parent eb4f80f commit 3147abd
Showing 1 changed file with 67 additions and 18 deletions.
85 changes: 67 additions & 18 deletions src/tensorwaves/model/sympy/__init__.py
Expand Up @@ -23,58 +23,59 @@

if TYPE_CHECKING:
import sympy as sp
from sympy.printing.printer import Printer


def _sympy_lambdify(
expression: "sp.Expr",
symbols: Sequence["sp.Symbol"],
backend: Union[str, tuple, dict],
*,
use_cse: bool = True,
max_complexity: Optional[int] = None,
**kwargs: Any,
) -> Callable:
if max_complexity is None:
return _backend_lambdify(
expression=expression,
symbols=symbols,
backend=backend,
**kwargs,
use_cse=use_cse,
)
return optimized_lambdify(
expression=expression,
symbols=symbols,
backend=backend,
max_complexity=max_complexity,
**kwargs,
use_cse=use_cse,
)


def _backend_lambdify(
expression: "sp.Expr",
symbols: Sequence["sp.Symbol"],
backend: Union[str, tuple, dict],
**kwargs: Any,
use_cse: bool,
) -> Callable:
"""A wrapper around :func:`~sympy.utilities.lambdify.lambdify`."""
# pylint: disable=too-many-return-statements
import sympy as sp

def jax_lambdify() -> Callable:
from ._printer import JaxPrinter

return jit_compile(backend="jax")(
sp.lambdify(
symbols,
_wrapped_lambdify(
expression,
symbols,
modules=modules,
printer=JaxPrinter(),
**kwargs,
use_cse=use_cse,
)
)

def numba_lambdify() -> Callable:
return jit_compile(backend="numba")(
sp.lambdify(symbols, expression, modules="numpy", **kwargs)
_wrapped_lambdify(
expression, symbols, modules="numpy", use_cse=use_cse
)
)

def tensorflow_lambdify() -> Callable:
Expand All @@ -83,12 +84,12 @@ def tensorflow_lambdify() -> Callable:

from ._printer import TensorflowPrinter

return sp.lambdify(
symbols,
return _wrapped_lambdify(
expression,
symbols,
modules=tnp,
printer=TensorflowPrinter(),
**kwargs,
use_cse=use_cse,
)

modules = get_backend_modules(backend)
Expand All @@ -110,17 +111,47 @@ def tensorflow_lambdify() -> Callable:
):
return tensorflow_lambdify()

return sp.lambdify(symbols, expression, modules=modules, **kwargs)
return _wrapped_lambdify(
expression,
symbols,
modules=modules,
use_cse=use_cse,
)


def _wrapped_lambdify(
expression: "sp.Expr",
symbols: Sequence["sp.Symbol"],
modules: Union[str, tuple, dict],
use_cse: bool,
printer: Optional["Printer"] = None,
) -> Callable:
import sympy as sp

if use_cse:
dummy_replacements = {
symbol: sp.Symbol(f"z{i}", **symbol.assumptions0)
for i, symbol in enumerate(symbols)
}
expression = expression.xreplace(dummy_replacements)
symbols = [dummy_replacements[s] for s in symbols]
return sp.lambdify(
symbols,
expression,
modules=modules,
printer=printer,
cse=use_cse,
)


def optimized_lambdify(
expression: "sp.Expr",
symbols: Sequence["sp.Symbol"],
backend: Union[str, tuple, dict],
use_cse: bool = True,
*,
min_complexity: int = 0,
max_complexity: int,
**kwargs: Any,
) -> Callable:
"""Speed up `~sympy.utilities.lambdify.lambdify` with `.split_expression`.
Expand All @@ -132,11 +163,19 @@ def optimized_lambdify(
max_complexity=max_complexity,
)
if not sub_expressions:
return _backend_lambdify(top_expression, symbols, backend, **kwargs)
return _backend_lambdify(
top_expression,
symbols,
backend,
use_cse=use_cse,
)

sorted_top_symbols = sorted(sub_expressions, key=lambda s: s.name)
top_function = _backend_lambdify(
top_expression, sorted_top_symbols, backend, **kwargs
top_expression,
sorted_top_symbols,
backend,
use_cse=use_cse,
)
sub_functions: List[Callable] = []
for symbol in tqdm(
Expand All @@ -147,7 +186,10 @@ def optimized_lambdify(
):
sub_expression = sub_expressions[symbol]
sub_function = _backend_lambdify(
sub_expression, symbols, backend, **kwargs
sub_expression,
symbols,
backend,
use_cse=use_cse,
)
sub_functions.append(sub_function)

Expand Down Expand Up @@ -227,6 +269,7 @@ def __init__(
expression: "sp.Expr",
parameters: Dict["sp.Symbol", ParameterValue],
fix_inputs: DataSample,
use_cse: bool = True,
) -> None:
self.__fix_inputs = fix_inputs
self.__constant_symbols = set(self.__fix_inputs)
Expand All @@ -248,6 +291,7 @@ def __init__(
self.__argument_order = tuple(self.__not_fixed_variables) + tuple(
self.__not_fixed_parameters
)
self.__use_cse = use_cse

def __find_constant_subexpressions(self, expr: "sp.Expr") -> bool:
import sympy as sp
Expand Down Expand Up @@ -289,13 +333,15 @@ def lambdify(self, backend: Union[str, tuple, dict]) -> Callable:
expression=self.__expression,
symbols=input_symbols,
backend=backend,
use_cse=self.__use_cse,
)
constant_input_storage = {}
for placeholder, sub_expr in self.__constant_sub_expressions.items():
temp_lambdify = _backend_lambdify(
expression=sub_expr,
symbols=tuple(sub_expr.free_symbols),
backend=backend,
use_cse=self.__use_cse,
)
free_symbol_names = {x.name for x in sub_expr.free_symbols}
constant_input_storage[placeholder.name] = temp_lambdify(
Expand Down Expand Up @@ -375,6 +421,7 @@ def __init__(
self,
expression: "sp.Expr",
parameters: Dict["sp.Symbol", ParameterValue],
use_cse: bool = True,
max_complexity: Optional[int] = None,
) -> None:
import sympy as sp
Expand Down Expand Up @@ -406,6 +453,7 @@ def __init__(
self.__parameters
)
self.max_complexity = max_complexity
self.__use_cse = use_cse

def lambdify(self, backend: Union[str, tuple, dict]) -> Callable:
"""Lambdify the model using `~sympy.utilities.lambdify.lambdify`."""
Expand All @@ -414,6 +462,7 @@ def lambdify(self, backend: Union[str, tuple, dict]) -> Callable:
symbols=self.__argument_order,
backend=backend,
max_complexity=self.max_complexity,
use_cse=self.__use_cse,
)

def performance_optimize(self, fix_inputs: DataSample) -> "Model":
Expand Down

0 comments on commit 3147abd

Please sign in to comment.