diff --git a/.cspell.json b/.cspell.json index 37f2a6cc..48422d67 100644 --- a/.cspell.json +++ b/.cspell.json @@ -125,6 +125,7 @@ "darkred", "doctest", "doctests", + "dotprint", "dtype", "eval", "evalf", @@ -176,6 +177,7 @@ "precommit", "prefactor", "pwa's", + "py's", "pygments", "pypi", "pyplot", @@ -199,8 +201,10 @@ "unnormalized", "unsubscriptable", "vstack", + "waves's", "xlabel", "xlim", + "xreplace", "ylabel", "ylim", "yticks" diff --git a/docs/.gitignore b/docs/.gitignore index 10288328..627ea8d0 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -3,3 +3,4 @@ *.inv *build/ api/ +sub_expr_* diff --git a/docs/usage.ipynb b/docs/usage.ipynb index 4dd6718d..e0e09d50 100644 --- a/docs/usage.ipynb +++ b/docs/usage.ipynb @@ -439,6 +439,7 @@ "usage/step2\n", "usage/step3\n", "usage/basics\n", + "usage/faster-lambdify\n", "```" ] } @@ -459,7 +460,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.8" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/docs/usage/basics.ipynb b/docs/usage/basics.ipynb index c58ed581..b407ac8b 100644 --- a/docs/usage/basics.ipynb +++ b/docs/usage/basics.ipynb @@ -1076,7 +1076,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.8" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/docs/usage/faster-lambdify.ipynb b/docs/usage/faster-lambdify.ipynb new file mode 100644 index 00000000..744349dc --- /dev/null +++ b/docs/usage/faster-lambdify.ipynb @@ -0,0 +1,370 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "slideshow": { + "slide_type": "skip" + }, + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "%%capture\n", + "%config Completer.use_jedi = False\n", + "%config InlineBackend.figure_formats = ['svg']\n", + "import os\n", + "\n", + "STATIC_WEB_PAGE = {\"EXECUTE_NB\", \"READTHEDOCS\"}.intersection(os.environ)\n", + "\n", + "# Install on Google Colab\n", + "import subprocess\n", + "import sys\n", + "\n", + "from IPython import get_ipython\n", + "\n", + "install_packages = \"google.colab\" in str(get_ipython())\n", + "if install_packages:\n", + " for package in [\"tensorwaves[doc]\", \"graphviz\"]:\n", + " subprocess.check_call(\n", + " [sys.executable, \"-m\", \"pip\", \"install\", package]\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Speed up lambdifying" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "import logging\n", + "\n", + "import ampform\n", + "import graphviz\n", + "import qrules\n", + "import sympy as sp\n", + "from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff\n", + "from IPython.display import HTML, SVG\n", + "\n", + "from tensorwaves.model import (\n", + " LambdifiedFunction,\n", + " SympyModel,\n", + " optimized_lambdify,\n", + " split_expression,\n", + ")\n", + "\n", + "logger = logging.getLogger()\n", + "logger.setLevel(logging.ERROR)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Split expression" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lambdifying a SymPy expression can take rather long when an expression is complicated. Fortunately, TensorWaves offers a way to speed up the lambdify process. The idea is to split up an an expression into sub-expressions, separate those separately, and then recombining them. Let's illustrate that idea with the following simplified example:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x, y, z = sp.symbols(\"x:z\")\n", + "expr = x ** z + 2 * y + sp.log(y * z)\n", + "expr" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This expression can be represented in a tree of mathematical operations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "dot = sp.dotprint(expr)\n", + "graphviz.Source(dot)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The function {func}`.split_expression` can now be used to split up this expression tree into a 'top expression' plus definitions for each of the sub-expressions into which it was split:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "top_expr, sub_expressions = split_expression(expr, max_complexity=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "top_expr" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sub_expressions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The original expression can easily be reconstructed with {meth}`~sympy.core.basic.Basic.doit` or {meth}`~sympy.core.basic.Basic.xreplace`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "top_expr.xreplace(sub_expressions)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Each of the expression trees are now smaller than the original:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "dot = sp.dotprint(top_expr)\n", + "graphviz.Source(dot)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "for symbol, definition in sub_expressions.items():\n", + " dot = sp.dotprint(definition)\n", + " graph = graphviz.Source(dot)\n", + " graph.render(filename=f\"sub_expr_{symbol.name}\", format=\"svg\")\n", + "\n", + "html = \"\\n\"\n", + "html += \" \\n\"\n", + "html += \"\".join(\n", + " f' \\n'\n", + " for symbol in sub_expressions\n", + ")\n", + "html += \" \\n\"\n", + "html += \" \\n\"\n", + "for symbol in sub_expressions:\n", + " svg = SVG(f\"sub_expr_{symbol.name}.svg\").data\n", + " html += f' \\n'\n", + "html += \" \\n\"\n", + "html += \"
{symbol.name}
{svg}
\"\n", + "HTML(html)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optimized lambdify" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Generally, the lambdify time scales exponentially with the size of an expression tree. With larger expression trees, it's therefore much faster to lambdify these sub-expressions separately and to recombine them. TensorWaves offers a function that does this for you: {func}`.optimized_lambdify`. We'll use an {class}`~ampform.helicity.HelicityModel` to illustrate this:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = qrules.generate_transitions(\n", + " initial_state=(\"J/psi(1S)\", [+1]),\n", + " final_state=[\"gamma\", \"pi0\", \"pi0\"],\n", + " allowed_intermediate_particles=[\"f(0)\"],\n", + ")\n", + "model_builder = ampform.get_builder(result)\n", + "for name in result.get_intermediate_particles().names:\n", + " model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)\n", + "model = model_builder.generate()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "expression = model.expression.doit()\n", + "sorted_symbols = sorted(expression.free_symbols, key=lambda s: s.name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "lambdified_optimized = optimized_lambdify(\n", + " sorted_symbols,\n", + " expression,\n", + " max_complexity=100,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%%time\n", + "sp.lambdify(sorted_symbols, expression)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Specifying complexity" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the usually workflow (see {doc}`/usage`), TensorWaves uses SymPy's own {func}`~sympy.utilities.lambdify.lambdify` by default. You can change this behavior with the `max_complexity` argument of {class}`.SympyModel`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sympy_model = SympyModel(\n", + " expression=model.expression,\n", + " parameters=model.parameter_defaults,\n", + " max_complexity=100,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If `max_complexity` is specified (i.e., is not {obj}`None`), {class}`.LambdifiedFunction` uses TensorWaves's {func}`.optimized_lambdify`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "intensity = LambdifiedFunction(sympy_model, backend=\"jax\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/usage/step1.ipynb b/docs/usage/step1.ipynb index 5c596ca5..4a56e7b5 100644 --- a/docs/usage/step1.ipynb +++ b/docs/usage/step1.ipynb @@ -216,7 +216,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.8" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/docs/usage/step2.ipynb b/docs/usage/step2.ipynb index 4ef1278a..f89b93d0 100644 --- a/docs/usage/step2.ipynb +++ b/docs/usage/step2.ipynb @@ -144,19 +144,31 @@ "The {obj}`~ampform.helicity.HelicityModel` was expressed in terms of {mod}`sympy`, so we express the model as a {class}`.SympyModel` and lambdify it to a {class}`.LambdifiedFunction`:" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ":::{margin}\n", + "\n", + "Here, we make use of {func}`.optimized_lambdify` by specifying `max_complexity`. See {ref}`usage/faster-lambdify:Specifying complexity`.\n", + "\n", + ":::" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "from tensorwaves.model import LambdifiedFunction, SympyModel\n", + "from tensorwaves.model import LambdifiedFunction, SympyModel # noqa: F401\n", "\n", "sympy_model = SympyModel(\n", " expression=model.expression,\n", " parameters=model.parameter_defaults,\n", + " max_complexity=200,\n", ")\n", - "intensity = LambdifiedFunction(sympy_model, backend=\"numpy\")" + "%time intensity = LambdifiedFunction(sympy_model, backend=\"numpy\")" ] }, { @@ -393,7 +405,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.8" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/docs/usage/step3.ipynb b/docs/usage/step3.ipynb index 203e0c4b..2ca8adef 100644 --- a/docs/usage/step3.ipynb +++ b/docs/usage/step3.ipynb @@ -926,7 +926,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.8" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/pytest.ini b/pytest.ini index eceb4f46..d6ca2f1d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -16,6 +16,7 @@ filterwarnings = nb_diff_ignore = /cells/*/execution_count /cells/*/outputs + /metadata/language_info/version /metadata/widgets norecursedirs = _build diff --git a/src/tensorwaves/model.py b/src/tensorwaves/model.py index 4b47a54b..26ccd2d8 100644 --- a/src/tensorwaves/model.py +++ b/src/tensorwaves/model.py @@ -4,13 +4,24 @@ computational backends. Currently, mathematical expressions are implemented as `sympy` expressions only. """ -# cspell: ignore xreplace + import copy import logging -from typing import Any, Callable, Dict, FrozenSet, Mapping, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + FrozenSet, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) import numpy as np import sympy as sp +from tqdm.auto import tqdm from tensorwaves.interfaces import DataSample, Function, Model @@ -45,20 +56,128 @@ def get_backend_modules( return backend +def split_expression( + expression: sp.Expr, + max_complexity: int, + min_complexity: int = 0, +) -> Tuple[sp.Expr, Dict[sp.Symbol, sp.Expr]]: + """Split an expression into a 'top expression' and several sub-expressions. + + Replace nodes in the expression tree of a `sympy.Expr + ` that lie within a certain complexity range (see + :meth:`~sympy.core.basic.Basic.count_ops`) with symbols and keep a mapping + of each to these symbols to the sub-expressions that they replaced. + + .. seealso:: :doc:`/usage/faster-lambdify` + """ + i = 0 + symbol_mapping: Dict[sp.Symbol, sp.Expr] = {} + n_operations = sp.count_ops(expression) + if n_operations < max_complexity: + return expression, symbol_mapping + progress_bar = tqdm( + total=n_operations, + desc="Splitting expression", + unit="node", + disable=logging.getLogger().level > logging.WARNING, + ) + + def recursive_split(sub_expression: sp.Expr) -> sp.Expr: + nonlocal i + for arg in sub_expression.args: + complexity = sp.count_ops(arg) + if min_complexity < complexity < max_complexity: + progress_bar.update(n=complexity) + symbol = sp.Symbol(f"f{i}") + i += 1 + symbol_mapping[symbol] = arg + sub_expression = sub_expression.xreplace({arg: symbol}) + else: + new_arg = recursive_split(arg) + sub_expression = sub_expression.xreplace({arg: new_arg}) + return sub_expression + + top_expression = recursive_split(expression) + remainder = progress_bar.total - progress_bar.n + progress_bar.update(n=remainder) # pylint crashes if total is set directly + progress_bar.close() + return top_expression, symbol_mapping + + +def optimized_lambdify( + args: Sequence[sp.Symbol], + expression: sp.Expr, + modules: Optional[Union[str, tuple, dict]] = None, + *, + min_complexity: int = 0, + max_complexity: int, +) -> Callable: + """Speed up `~sympy.utilities.lambdify.lambdify` with `.split_expression`. + + .. seealso:: :doc:`/usage/faster-lambdify` + """ + top_expression, definitions = split_expression( + expression, + min_complexity=min_complexity, + max_complexity=max_complexity, + ) + top_symbols = sorted(definitions, key=lambda s: s.name) + top_lambdified = sp.lambdify(top_symbols, top_expression, modules) + sub_lambdified = [ # same order as positional arguments in top_lambdified + sp.lambdify(args, definitions[symbol], modules) + for symbol in tqdm( + iterable=top_symbols, + desc="Lambdifying sub-expressions", + unit="expr", + disable=logging.getLogger().level > logging.WARNING, + ) + ] + + def recombined_function(*args): # type: ignore + new_args = [sub_expr(*args) for sub_expr in sub_lambdified] + return top_lambdified(*new_args) + + return recombined_function + + def _sympy_lambdify( + ordered_symbols: Sequence[sp.Symbol], + expression: sp.Expr, + modules: Union[str, tuple, dict], + *, + max_complexity: Optional[int] = None, +) -> Callable: + if max_complexity is None: + return sp.lambdify( + ordered_symbols, + expression, + modules=modules, + ) + return optimized_lambdify( + ordered_symbols, + expression, + modules=modules, + max_complexity=max_complexity, + ) + + +def _backend_lambdify( ordered_symbols: Tuple[sp.Symbol, ...], expression: sp.Expr, backend: Union[str, tuple, dict], + *, + max_complexity: Optional[int] = None, ) -> Callable: # pylint: disable=import-outside-toplevel,too-many-return-statements def jax_lambdify() -> Callable: import jax return jax.jit( - sp.lambdify( + _sympy_lambdify( ordered_symbols, expression, modules=backend_modules, + max_complexity=max_complexity, ) ) @@ -67,10 +186,11 @@ def numba_lambdify() -> Callable: import numba return numba.jit( - sp.lambdify( + _sympy_lambdify( ordered_symbols, expression, modules="numpy", + max_complexity=max_complexity, ), forceobj=True, parallel=True, @@ -80,10 +200,11 @@ def tensorflow_lambdify() -> Callable: # pylint: disable=import-error import tensorflow.experimental.numpy as tnp # pyright: reportMissingImports=false - return sp.lambdify( + return _sympy_lambdify( ordered_symbols, expression, modules=tnp, + max_complexity=max_complexity, ) backend_modules = get_backend_modules(backend) @@ -103,16 +224,19 @@ def tensorflow_lambdify() -> Callable: "tf" in x.__name__ for x in backend ): return tensorflow_lambdify() - return sp.lambdify( + return _sympy_lambdify( ordered_symbols, expression, modules=backend_modules, + max_complexity=max_complexity, ) class LambdifiedFunction(Function): def __init__( - self, model: Model, backend: Union[str, tuple, dict] = "numpy" + self, + model: Model, + backend: Union[str, tuple, dict] = "numpy", ) -> None: """Implements `.Function` based on a `.Model` using `~Model.lambdify`.""" self.__lambdified_model = model.lambdify(backend=backend) @@ -212,15 +336,17 @@ def __replace_constant_sub_expressions( def lambdify(self, backend: Union[str, tuple, dict]) -> Callable: input_symbols = tuple(self.__expression.free_symbols) - lambdified_model = _sympy_lambdify( + lambdified_model = _backend_lambdify( input_symbols, self.__expression, backend=backend, ) constant_input_storage = {} for placeholder, sub_expr in self.__constant_sub_expressions.items(): - temp_lambdify = _sympy_lambdify( - tuple(sub_expr.free_symbols), sub_expr, backend + temp_lambdify = _backend_lambdify( + tuple(sub_expr.free_symbols), + sub_expr, + backend=backend, ) free_symbol_names = {x.name for x in sub_expr.free_symbols} constant_input_storage[placeholder.name] = temp_lambdify( @@ -297,6 +423,7 @@ def __init__( self, expression: sp.Expr, parameters: Dict[sp.Symbol, Union[float, complex]], + max_complexity: Optional[int] = None, ) -> None: if not all(map(lambda p: isinstance(p, sp.Symbol), parameters)): raise TypeError(f"Not all parameters are of type {sp.Symbol}") @@ -308,7 +435,9 @@ def __init__( " in the model!" ) + logging.info("Performing .doit() on input expression...") self.__expression: sp.Expr = expression.doit() + logging.info("done") # after .doit() certain symbols like the meson radius can disappear # hence the parameters have to be shrunk to this space self.__parameters = { @@ -324,11 +453,15 @@ def __init__( self.__argument_order = tuple(self.__variables) + tuple( self.__parameters ) + self.max_complexity = max_complexity def lambdify(self, backend: Union[str, tuple, dict]) -> Callable: """Lambdify the model using `~sympy.utilities.lambdify.lambdify`.""" - return _sympy_lambdify( - self.__argument_order, self.__expression, backend + return _backend_lambdify( + self.__argument_order, + self.__expression, + backend=backend, + max_complexity=self.max_complexity, ) def performance_optimize(self, fix_inputs: DataSample) -> "Model":