diff --git a/.cspell.json b/.cspell.json
index 37f2a6cc..48422d67 100644
--- a/.cspell.json
+++ b/.cspell.json
@@ -125,6 +125,7 @@
"darkred",
"doctest",
"doctests",
+ "dotprint",
"dtype",
"eval",
"evalf",
@@ -176,6 +177,7 @@
"precommit",
"prefactor",
"pwa's",
+ "py's",
"pygments",
"pypi",
"pyplot",
@@ -199,8 +201,10 @@
"unnormalized",
"unsubscriptable",
"vstack",
+ "waves's",
"xlabel",
"xlim",
+ "xreplace",
"ylabel",
"ylim",
"yticks"
diff --git a/docs/.gitignore b/docs/.gitignore
index 10288328..627ea8d0 100644
--- a/docs/.gitignore
+++ b/docs/.gitignore
@@ -3,3 +3,4 @@
*.inv
*build/
api/
+sub_expr_*
diff --git a/docs/usage.ipynb b/docs/usage.ipynb
index 4dd6718d..e0e09d50 100644
--- a/docs/usage.ipynb
+++ b/docs/usage.ipynb
@@ -439,6 +439,7 @@
"usage/step2\n",
"usage/step3\n",
"usage/basics\n",
+ "usage/faster-lambdify\n",
"```"
]
}
@@ -459,7 +460,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.8"
+ "version": "3.8.10"
}
},
"nbformat": 4,
diff --git a/docs/usage/basics.ipynb b/docs/usage/basics.ipynb
index c58ed581..b407ac8b 100644
--- a/docs/usage/basics.ipynb
+++ b/docs/usage/basics.ipynb
@@ -1076,7 +1076,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.8"
+ "version": "3.8.10"
}
},
"nbformat": 4,
diff --git a/docs/usage/faster-lambdify.ipynb b/docs/usage/faster-lambdify.ipynb
new file mode 100644
index 00000000..744349dc
--- /dev/null
+++ b/docs/usage/faster-lambdify.ipynb
@@ -0,0 +1,370 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "jupyter": {
+ "source_hidden": true
+ },
+ "slideshow": {
+ "slide_type": "skip"
+ },
+ "tags": [
+ "remove-cell"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "%config Completer.use_jedi = False\n",
+ "%config InlineBackend.figure_formats = ['svg']\n",
+ "import os\n",
+ "\n",
+ "STATIC_WEB_PAGE = {\"EXECUTE_NB\", \"READTHEDOCS\"}.intersection(os.environ)\n",
+ "\n",
+ "# Install on Google Colab\n",
+ "import subprocess\n",
+ "import sys\n",
+ "\n",
+ "from IPython import get_ipython\n",
+ "\n",
+ "install_packages = \"google.colab\" in str(get_ipython())\n",
+ "if install_packages:\n",
+ " for package in [\"tensorwaves[doc]\", \"graphviz\"]:\n",
+ " subprocess.check_call(\n",
+ " [sys.executable, \"-m\", \"pip\", \"install\", package]\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Speed up lambdifying"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "jupyter": {
+ "source_hidden": true
+ },
+ "tags": [
+ "hide-cell"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "import logging\n",
+ "\n",
+ "import ampform\n",
+ "import graphviz\n",
+ "import qrules\n",
+ "import sympy as sp\n",
+ "from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff\n",
+ "from IPython.display import HTML, SVG\n",
+ "\n",
+ "from tensorwaves.model import (\n",
+ " LambdifiedFunction,\n",
+ " SympyModel,\n",
+ " optimized_lambdify,\n",
+ " split_expression,\n",
+ ")\n",
+ "\n",
+ "logger = logging.getLogger()\n",
+ "logger.setLevel(logging.ERROR)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Split expression"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Lambdifying a SymPy expression can take rather long when an expression is complicated. Fortunately, TensorWaves offers a way to speed up the lambdify process. The idea is to split up an an expression into sub-expressions, separate those separately, and then recombining them. Let's illustrate that idea with the following simplified example:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x, y, z = sp.symbols(\"x:z\")\n",
+ "expr = x ** z + 2 * y + sp.log(y * z)\n",
+ "expr"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This expression can be represented in a tree of mathematical operations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "jupyter": {
+ "source_hidden": true
+ },
+ "tags": [
+ "hide-input"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "dot = sp.dotprint(expr)\n",
+ "graphviz.Source(dot)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The function {func}`.split_expression` can now be used to split up this expression tree into a 'top expression' plus definitions for each of the sub-expressions into which it was split:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "top_expr, sub_expressions = split_expression(expr, max_complexity=3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "top_expr"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sub_expressions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The original expression can easily be reconstructed with {meth}`~sympy.core.basic.Basic.doit` or {meth}`~sympy.core.basic.Basic.xreplace`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "top_expr.xreplace(sub_expressions)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Each of the expression trees are now smaller than the original:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "jupyter": {
+ "source_hidden": true
+ },
+ "tags": [
+ "hide-input"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "dot = sp.dotprint(top_expr)\n",
+ "graphviz.Source(dot)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "jupyter": {
+ "source_hidden": true
+ },
+ "tags": [
+ "hide-input"
+ ]
+ },
+ "outputs": [],
+ "source": [
+ "for symbol, definition in sub_expressions.items():\n",
+ " dot = sp.dotprint(definition)\n",
+ " graph = graphviz.Source(dot)\n",
+ " graph.render(filename=f\"sub_expr_{symbol.name}\", format=\"svg\")\n",
+ "\n",
+ "html = \"
\\n\"\n",
+ "html += \" \\n\"\n",
+ "html += \"\".join(\n",
+ " f' {symbol.name} | \\n'\n",
+ " for symbol in sub_expressions\n",
+ ")\n",
+ "html += \"
\\n\"\n",
+ "html += \" \\n\"\n",
+ "for symbol in sub_expressions:\n",
+ " svg = SVG(f\"sub_expr_{symbol.name}.svg\").data\n",
+ " html += f' {svg} | \\n'\n",
+ "html += \"
\\n\"\n",
+ "html += \"
\"\n",
+ "HTML(html)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Optimized lambdify"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generally, the lambdify time scales exponentially with the size of an expression tree. With larger expression trees, it's therefore much faster to lambdify these sub-expressions separately and to recombine them. TensorWaves offers a function that does this for you: {func}`.optimized_lambdify`. We'll use an {class}`~ampform.helicity.HelicityModel` to illustrate this:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "result = qrules.generate_transitions(\n",
+ " initial_state=(\"J/psi(1S)\", [+1]),\n",
+ " final_state=[\"gamma\", \"pi0\", \"pi0\"],\n",
+ " allowed_intermediate_particles=[\"f(0)\"],\n",
+ ")\n",
+ "model_builder = ampform.get_builder(result)\n",
+ "for name in result.get_intermediate_particles().names:\n",
+ " model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)\n",
+ "model = model_builder.generate()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "expression = model.expression.doit()\n",
+ "sorted_symbols = sorted(expression.free_symbols, key=lambda s: s.name)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%time\n",
+ "lambdified_optimized = optimized_lambdify(\n",
+ " sorted_symbols,\n",
+ " expression,\n",
+ " max_complexity=100,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "%%time\n",
+ "sp.lambdify(sorted_symbols, expression)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Specifying complexity"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In the usually workflow (see {doc}`/usage`), TensorWaves uses SymPy's own {func}`~sympy.utilities.lambdify.lambdify` by default. You can change this behavior with the `max_complexity` argument of {class}`.SympyModel`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sympy_model = SympyModel(\n",
+ " expression=model.expression,\n",
+ " parameters=model.parameter_defaults,\n",
+ " max_complexity=100,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "If `max_complexity` is specified (i.e., is not {obj}`None`), {class}`.LambdifiedFunction` uses TensorWaves's {func}`.optimized_lambdify`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%time\n",
+ "intensity = LambdifiedFunction(sympy_model, backend=\"jax\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/docs/usage/step1.ipynb b/docs/usage/step1.ipynb
index 5c596ca5..4a56e7b5 100644
--- a/docs/usage/step1.ipynb
+++ b/docs/usage/step1.ipynb
@@ -216,7 +216,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.8"
+ "version": "3.8.10"
}
},
"nbformat": 4,
diff --git a/docs/usage/step2.ipynb b/docs/usage/step2.ipynb
index 4ef1278a..f89b93d0 100644
--- a/docs/usage/step2.ipynb
+++ b/docs/usage/step2.ipynb
@@ -144,19 +144,31 @@
"The {obj}`~ampform.helicity.HelicityModel` was expressed in terms of {mod}`sympy`, so we express the model as a {class}`.SympyModel` and lambdify it to a {class}`.LambdifiedFunction`:"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ ":::{margin}\n",
+ "\n",
+ "Here, we make use of {func}`.optimized_lambdify` by specifying `max_complexity`. See {ref}`usage/faster-lambdify:Specifying complexity`.\n",
+ "\n",
+ ":::"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "from tensorwaves.model import LambdifiedFunction, SympyModel\n",
+ "from tensorwaves.model import LambdifiedFunction, SympyModel # noqa: F401\n",
"\n",
"sympy_model = SympyModel(\n",
" expression=model.expression,\n",
" parameters=model.parameter_defaults,\n",
+ " max_complexity=200,\n",
")\n",
- "intensity = LambdifiedFunction(sympy_model, backend=\"numpy\")"
+ "%time intensity = LambdifiedFunction(sympy_model, backend=\"numpy\")"
]
},
{
@@ -393,7 +405,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.8"
+ "version": "3.8.10"
}
},
"nbformat": 4,
diff --git a/docs/usage/step3.ipynb b/docs/usage/step3.ipynb
index 203e0c4b..2ca8adef 100644
--- a/docs/usage/step3.ipynb
+++ b/docs/usage/step3.ipynb
@@ -926,7 +926,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.8"
+ "version": "3.8.10"
}
},
"nbformat": 4,
diff --git a/pytest.ini b/pytest.ini
index eceb4f46..d6ca2f1d 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -16,6 +16,7 @@ filterwarnings =
nb_diff_ignore =
/cells/*/execution_count
/cells/*/outputs
+ /metadata/language_info/version
/metadata/widgets
norecursedirs =
_build
diff --git a/src/tensorwaves/model.py b/src/tensorwaves/model.py
index 4b47a54b..26ccd2d8 100644
--- a/src/tensorwaves/model.py
+++ b/src/tensorwaves/model.py
@@ -4,13 +4,24 @@
computational backends. Currently, mathematical expressions are implemented
as `sympy` expressions only.
"""
-# cspell: ignore xreplace
+
import copy
import logging
-from typing import Any, Callable, Dict, FrozenSet, Mapping, Tuple, Union
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ FrozenSet,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+)
import numpy as np
import sympy as sp
+from tqdm.auto import tqdm
from tensorwaves.interfaces import DataSample, Function, Model
@@ -45,20 +56,128 @@ def get_backend_modules(
return backend
+def split_expression(
+ expression: sp.Expr,
+ max_complexity: int,
+ min_complexity: int = 0,
+) -> Tuple[sp.Expr, Dict[sp.Symbol, sp.Expr]]:
+ """Split an expression into a 'top expression' and several sub-expressions.
+
+ Replace nodes in the expression tree of a `sympy.Expr
+ ` that lie within a certain complexity range (see
+ :meth:`~sympy.core.basic.Basic.count_ops`) with symbols and keep a mapping
+ of each to these symbols to the sub-expressions that they replaced.
+
+ .. seealso:: :doc:`/usage/faster-lambdify`
+ """
+ i = 0
+ symbol_mapping: Dict[sp.Symbol, sp.Expr] = {}
+ n_operations = sp.count_ops(expression)
+ if n_operations < max_complexity:
+ return expression, symbol_mapping
+ progress_bar = tqdm(
+ total=n_operations,
+ desc="Splitting expression",
+ unit="node",
+ disable=logging.getLogger().level > logging.WARNING,
+ )
+
+ def recursive_split(sub_expression: sp.Expr) -> sp.Expr:
+ nonlocal i
+ for arg in sub_expression.args:
+ complexity = sp.count_ops(arg)
+ if min_complexity < complexity < max_complexity:
+ progress_bar.update(n=complexity)
+ symbol = sp.Symbol(f"f{i}")
+ i += 1
+ symbol_mapping[symbol] = arg
+ sub_expression = sub_expression.xreplace({arg: symbol})
+ else:
+ new_arg = recursive_split(arg)
+ sub_expression = sub_expression.xreplace({arg: new_arg})
+ return sub_expression
+
+ top_expression = recursive_split(expression)
+ remainder = progress_bar.total - progress_bar.n
+ progress_bar.update(n=remainder) # pylint crashes if total is set directly
+ progress_bar.close()
+ return top_expression, symbol_mapping
+
+
+def optimized_lambdify(
+ args: Sequence[sp.Symbol],
+ expression: sp.Expr,
+ modules: Optional[Union[str, tuple, dict]] = None,
+ *,
+ min_complexity: int = 0,
+ max_complexity: int,
+) -> Callable:
+ """Speed up `~sympy.utilities.lambdify.lambdify` with `.split_expression`.
+
+ .. seealso:: :doc:`/usage/faster-lambdify`
+ """
+ top_expression, definitions = split_expression(
+ expression,
+ min_complexity=min_complexity,
+ max_complexity=max_complexity,
+ )
+ top_symbols = sorted(definitions, key=lambda s: s.name)
+ top_lambdified = sp.lambdify(top_symbols, top_expression, modules)
+ sub_lambdified = [ # same order as positional arguments in top_lambdified
+ sp.lambdify(args, definitions[symbol], modules)
+ for symbol in tqdm(
+ iterable=top_symbols,
+ desc="Lambdifying sub-expressions",
+ unit="expr",
+ disable=logging.getLogger().level > logging.WARNING,
+ )
+ ]
+
+ def recombined_function(*args): # type: ignore
+ new_args = [sub_expr(*args) for sub_expr in sub_lambdified]
+ return top_lambdified(*new_args)
+
+ return recombined_function
+
+
def _sympy_lambdify(
+ ordered_symbols: Sequence[sp.Symbol],
+ expression: sp.Expr,
+ modules: Union[str, tuple, dict],
+ *,
+ max_complexity: Optional[int] = None,
+) -> Callable:
+ if max_complexity is None:
+ return sp.lambdify(
+ ordered_symbols,
+ expression,
+ modules=modules,
+ )
+ return optimized_lambdify(
+ ordered_symbols,
+ expression,
+ modules=modules,
+ max_complexity=max_complexity,
+ )
+
+
+def _backend_lambdify(
ordered_symbols: Tuple[sp.Symbol, ...],
expression: sp.Expr,
backend: Union[str, tuple, dict],
+ *,
+ max_complexity: Optional[int] = None,
) -> Callable:
# pylint: disable=import-outside-toplevel,too-many-return-statements
def jax_lambdify() -> Callable:
import jax
return jax.jit(
- sp.lambdify(
+ _sympy_lambdify(
ordered_symbols,
expression,
modules=backend_modules,
+ max_complexity=max_complexity,
)
)
@@ -67,10 +186,11 @@ def numba_lambdify() -> Callable:
import numba
return numba.jit(
- sp.lambdify(
+ _sympy_lambdify(
ordered_symbols,
expression,
modules="numpy",
+ max_complexity=max_complexity,
),
forceobj=True,
parallel=True,
@@ -80,10 +200,11 @@ def tensorflow_lambdify() -> Callable:
# pylint: disable=import-error
import tensorflow.experimental.numpy as tnp # pyright: reportMissingImports=false
- return sp.lambdify(
+ return _sympy_lambdify(
ordered_symbols,
expression,
modules=tnp,
+ max_complexity=max_complexity,
)
backend_modules = get_backend_modules(backend)
@@ -103,16 +224,19 @@ def tensorflow_lambdify() -> Callable:
"tf" in x.__name__ for x in backend
):
return tensorflow_lambdify()
- return sp.lambdify(
+ return _sympy_lambdify(
ordered_symbols,
expression,
modules=backend_modules,
+ max_complexity=max_complexity,
)
class LambdifiedFunction(Function):
def __init__(
- self, model: Model, backend: Union[str, tuple, dict] = "numpy"
+ self,
+ model: Model,
+ backend: Union[str, tuple, dict] = "numpy",
) -> None:
"""Implements `.Function` based on a `.Model` using `~Model.lambdify`."""
self.__lambdified_model = model.lambdify(backend=backend)
@@ -212,15 +336,17 @@ def __replace_constant_sub_expressions(
def lambdify(self, backend: Union[str, tuple, dict]) -> Callable:
input_symbols = tuple(self.__expression.free_symbols)
- lambdified_model = _sympy_lambdify(
+ lambdified_model = _backend_lambdify(
input_symbols,
self.__expression,
backend=backend,
)
constant_input_storage = {}
for placeholder, sub_expr in self.__constant_sub_expressions.items():
- temp_lambdify = _sympy_lambdify(
- tuple(sub_expr.free_symbols), sub_expr, backend
+ temp_lambdify = _backend_lambdify(
+ tuple(sub_expr.free_symbols),
+ sub_expr,
+ backend=backend,
)
free_symbol_names = {x.name for x in sub_expr.free_symbols}
constant_input_storage[placeholder.name] = temp_lambdify(
@@ -297,6 +423,7 @@ def __init__(
self,
expression: sp.Expr,
parameters: Dict[sp.Symbol, Union[float, complex]],
+ max_complexity: Optional[int] = None,
) -> None:
if not all(map(lambda p: isinstance(p, sp.Symbol), parameters)):
raise TypeError(f"Not all parameters are of type {sp.Symbol}")
@@ -308,7 +435,9 @@ def __init__(
" in the model!"
)
+ logging.info("Performing .doit() on input expression...")
self.__expression: sp.Expr = expression.doit()
+ logging.info("done")
# after .doit() certain symbols like the meson radius can disappear
# hence the parameters have to be shrunk to this space
self.__parameters = {
@@ -324,11 +453,15 @@ def __init__(
self.__argument_order = tuple(self.__variables) + tuple(
self.__parameters
)
+ self.max_complexity = max_complexity
def lambdify(self, backend: Union[str, tuple, dict]) -> Callable:
"""Lambdify the model using `~sympy.utilities.lambdify.lambdify`."""
- return _sympy_lambdify(
- self.__argument_order, self.__expression, backend
+ return _backend_lambdify(
+ self.__argument_order,
+ self.__expression,
+ backend=backend,
+ max_complexity=self.max_complexity,
)
def performance_optimize(self, fix_inputs: DataSample) -> "Model":