```{autolink-concat}
```

::::{margin}
:::{card} Faster lambdification by splitting expressions
TR-002
^^^
This notebook investigates how to speed up {func}`sympy.lambdify <sympy.utilities.lambdify.lambdify>` by splitting up the expression tree of a complicated expression into components, lambdifying those, and then combining them back again.
+++
✅&nbsp;[tensorwaves#281](https://github.com/ComPWA/tensorwaves/issues/281)
:::
::::

# Speed up lambdifying

<!-- cspell:disable -->

In [None]:
%pip install -q ampform==0.9.1 numpy==1.19.5 pandas==1.2.4 qrules==0.8.2 sympy==1.8 'tensorwaves[jax]==0.2.9' git+https://github.com/zfit/phasespace@7131fbd

In [None]:
%config InlineBackend.figure_formats = ['svg']

from __future__ import annotations

import inspect
import logging
import timeit
import warnings
from collections.abc import Generator, Sequence
from pathlib import Path
from typing import Callable

import ampform
import graphviz
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import qrules
import sympy as sp
from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff
from tensorwaves.data import generate_phsp
from tensorwaves.data.phasespace import TFUniformRealNumberGenerator
from tensorwaves.data.transform import HelicityTransformer
from tensorwaves.model import LambdifiedFunction, SympyModel

LOGGER = logging.getLogger()

## Create dummy expression

First, let's create an amplitude model with {mod}`ampform`. We'll use this model as complicated {class}`sympy.Expr <sympy.core.expr.Expr>` in the rest of this notebooks.

In [None]:
LOGGER.setLevel(logging.ERROR)

In [None]:
result = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [-1, +1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)(980)"],
    allowed_interaction_types=["strong", "EM"],
    formalism_type="canonical-helicity",
)
dot = qrules.io.asdot(result, collapse_graphs=True, render_final_state_id=False)
graphviz.Source(dot)

In [None]:
output_file = Path("002/f0(980)-graph")
output_file.parent.mkdir(exist_ok=True)
graphviz.Source(dot).render(output_file, format="svg");

![](https://user-images.githubusercontent.com/29308176/164983331-6eb948fe-d360-40bd-a4f7-fa1aad3e296a.svg)

In [None]:
model_builder = ampform.get_builder(result)
for name in result.get_intermediate_particles().names:
    model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = model_builder.generate()

In [None]:
free_symbols = sorted(model.expression.free_symbols, key=lambda s: s.name)
free_symbols

[C[J/\psi(1S) \to f_{0}(980)_{0} \gamma_{+1}; f_{0}(980) \to \pi^{0}_{0} \pi^{0}_{0}],
 Gamma_f(0)(980),
 d_f(0)(980),
 m_1,
 m_12,
 m_2,
 m_f(0)(980),
 phi_1+2,
 phi_1,1+2,
 theta_1+2,
 theta_1,1+2]

## Helicity model components

A {class}`~ampform.helicity.HelicityModel` has the benefit that it comes with {attr}`~ampform.helicity.HelicityModel.components` (intensities and amplitudes) that together form its {attr}`~ampform.helicity.HelicityModel.expression`. Let's separate these components into _amplitude_ and _intensity_.

In [None]:
amplitudes = {
    name: expr for name, expr in model.components.items() if name.startswith("A")
}
sorted(amplitudes)

['A[J/\\psi(1S)_{+1} \\to f_{0}(980)_{0} \\gamma_{+1,L=0,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{+1} \\to f_{0}(980)_{0} \\gamma_{+1,L=2,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{+1} \\to f_{0}(980)_{0} \\gamma_{-1,L=0,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{+1} \\to f_{0}(980)_{0} \\gamma_{-1,L=2,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{-1} \\to f_{0}(980)_{0} \\gamma_{+1,L=0,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{-1} \\to f_{0}(980)_{0} \\gamma_{+1,L=2,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{-1} \\to f_{0}(980)_{0} \\gamma_{-1,L=0,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{-1} \\to f_{0}(980)_{0} \\gamma_{-1,L=2,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]']

In [None]:
intensities = {
    name: expr for name, expr in model.components.items() if name.startswith("I")
}

In [None]:
assert len(amplitudes) + len(intensities) == len(model.components)

### Component structure

Note that each intensity consists of a subset of these amplitudes. This means that _intensities have a larger expression tree than amplitudes_.

In [None]:
amplitude_to_symbol = {
    expr: sp.Symbol(f"A{i}") for i, expr in enumerate(amplitudes.values(), 1)
}

In [None]:
intensity_to_symbol = {
    expr: sp.Symbol(f"I{i}") for i, expr in enumerate(intensities.values(), 1)
}

In [None]:
intensity_expr = model.expression.subs(intensity_to_symbol, simultaneous=True)
intensity_expr

I1 + I2 + I3 + I4

In [None]:
dot = sp.dotprint(intensity_expr)
graphviz.Source(dot)

In [None]:
output_file = Path("002/collapsed-expression-tree")
output_file.parent.mkdir(exist_ok=True)
graphviz.Source(dot).render(output_file, format="svg");

![](https://user-images.githubusercontent.com/29308176/164983184-fde89791-2e75-4bd1-9c03-9a45edf24216.svg)

In [None]:
amplitude_expr = model.expression.subs(amplitude_to_symbol, simultaneous=True)
amplitude_expr

Abs(A1 + A2)**2 + Abs(A3 + A4)**2 + Abs(A5 + A6)**2 + Abs(A7 + A8)**2

In [None]:
dot = sp.dotprint(amplitude_expr)
graphviz.Source(dot)

In [None]:
output_file = Path("002/simple-expression-tree")
output_file.parent.mkdir(exist_ok=True)
graphviz.Source(dot).render(output_file, format="svg");

![](https://user-images.githubusercontent.com/29308176/164983978-73d1b6a4-0f09-4a10-88d8-de9d6a055bf3.svg)

### Performance check

Lambdifying the whole {attr}`HelicityModel.expression <ampform.helicity.HelicityModel.expression>` is slowest. The {func}`~sympy.utilities.lambdify.lambdify` function first prints the expression as a {obj}`str` (!) with (in this case) {mod}`numpy` syntax and then uses {func}`eval` to convert that back to actual {mod}`numpy` objects:

:::{margin}

We store the time with {mod}`timeit` for section {ref}`002:Arbitrary expressions`.

:::

In [None]:
runtime = {}
start = timeit.default_timer()

```{autolink-skip}
```

In [None]:
%%time

np_complete_model = sp.lambdify(free_symbols, model.expression.doit(), "numpy")

CPU times: user 1.46 s, sys: 703 µs, total: 1.46 s
Wall time: 1.46 s


In [None]:
stop = timeit.default_timer()
runtime["complete model"] = stop - start

Printing to {obj}`str` and converting back with {func}`eval` becomes exponentially slow the larger the expression tree. This means that it's more efficient to lambdify sub-trees of the expression tree separately. Lambdifying the four _intensities_ of this model separately, the effect is not noticeable:

In [None]:
LOGGER.setLevel(logging.INFO)

```{autolink-skip}
```

In [None]:
%%time

for expr, symbol in intensity_to_symbol.items():
    logging.info(f"Lambdifying {symbol.name}")
    start = timeit.default_timer()
    sp.lambdify(free_symbols, expr.doit(), "numpy")
    stop = timeit.default_timer()
    runtime[symbol.name] = stop - start

INFO:root:Lambdifying I1
INFO:root:Lambdifying I2
INFO:root:Lambdifying I3
INFO:root:Lambdifying I4


CPU times: user 1.56 s, sys: 4.94 ms, total: 1.56 s
Wall time: 1.56 s


...but each of the eight _amplitudes_ separately does result in a significant speed-up:

```{autolink-skip}
```

In [None]:
%%time

np_amplitudes = {}
for expr, symbol in amplitude_to_symbol.items():
    logging.info(f"Lambdifying {symbol.name}")
    start = timeit.default_timer()
    np_expr = sp.lambdify(free_symbols, expr.doit(), "numpy")
    stop = timeit.default_timer()
    runtime[symbol.name] = stop - start
    np_amplitudes[symbol] = np_expr

INFO:root:Lambdifying A1
INFO:root:Lambdifying A2
INFO:root:Lambdifying A3
INFO:root:Lambdifying A4
INFO:root:Lambdifying A5
INFO:root:Lambdifying A6
INFO:root:Lambdifying A7
INFO:root:Lambdifying A8


CPU times: user 547 ms, sys: 3.85 ms, total: 550 ms
Wall time: 547 ms


### Recombining components

{ref}`Recall <002:Component structure>` what amplitude module expressed in its amplitude components looks like:

In [None]:
amplitude_expr

Abs(A1 + A2)**2 + Abs(A3 + A4)**2 + Abs(A5 + A6)**2 + Abs(A7 + A8)**2

We have to lambdify that top expression as well:

In [None]:
sorted_amplitude_symbols = sorted(np_amplitudes, key=lambda s: s.name)
np_amplitude_expr = sp.lambdify(sorted_amplitude_symbols, amplitude_expr, "numpy")

In [None]:
source = inspect.getsource(np_amplitude_expr)
print(source)

def _lambdifygenerated(A1, A2, A3, A4, A5, A6, A7, A8):
    return (abs(A1 + A2)**2 + abs(A3 + A4)**2 + abs(A5 + A6)**2 + abs(A7 + A8)**2)



We now have a lambdified expression for the complete amplitude model, as well as lambdified expressions that are to be plugged in to its arguments.

In [None]:
def componentwise_lambdified(*args):
    """Lambdified amplitude model, recombined from its amplitude components.

    .. warning:: Order of the ``args`` has to be the same as that
        of the ``args`` of the lambdified amplitude components.
    """
    amplitude_values = []
    for amp_symbol in sorted_amplitude_symbols:
        np_amplitude = np_amplitudes[amp_symbol]
        values = np_amplitude(*args)
        amplitude_values.append(values)
    return np_amplitude_expr(*amplitude_values)

### Test with data

Okay, so does all this work? Let's first generate a phase space sample with good-old {mod}`tensorwaves`. We can then use this sample as input to the component-wise lambdified function.

In [None]:
LOGGER.setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

In [None]:
sympy_model = SympyModel(
    expression=model.expression,
    parameters=model.parameter_defaults,
)
intensity = LambdifiedFunction(sympy_model, backend="jax")
data_converter = HelicityTransformer(model.adapter)
rng = TFUniformRealNumberGenerator(seed=0)
phsp_sample = generate_phsp(10_000, model.adapter.reaction_info, random_generator=rng)
phsp_set = data_converter.transform(phsp_sample)

In [None]:
fig, ax = plt.subplots(figsize=(8, 4))
ax.hist(phsp_set["m_12"], bins=50, alpha=0.5, density=True)
ax.hist(
    phsp_set["m_12"],
    bins=50,
    alpha=0.5,
    density=True,
    weights=np.array(intensity(phsp_set)),
)
plt.show()

In [None]:
output_file = Path("002/histogram-m12.svg")
output_file.parent.mkdir(exist_ok=True)
fig.savefig(output_file);

![](https://user-images.githubusercontent.com/29308176/164983924-9ecf9149-af1d-437b-b5f2-4a73a4d1d81b.svg)

The arguments of the component-wise lambdified amplitude model should be covered by the entries in the phase space set and the provided parameter defaults:

In [None]:
kinematic_variable_names = set(phsp_set)
parameter_names = {symbol.name for symbol in model.parameter_defaults}
free_symbol_names = {symbol.name for symbol in free_symbols}

In [None]:
assert free_symbol_names <= kinematic_variable_names ^ parameter_names

That allows us to sort the input arrays and parameter defaults so that they can be used as positional argument input to the component-wise lambdified amplitude model:

In [None]:
merged_par_var_values = {
    symbol.name: value for symbol, value in model.parameter_defaults.items()
}
merged_par_var_values.update(phsp_set)
args_values = [merged_par_var_values[symbol.name] for symbol in free_symbols]

Finally, here's the result of plugging that back into the component-wise lambdified expression:

In [None]:
componentwise_result = componentwise_lambdified(*args_values)
componentwise_result

array([0.00048765, 0.00033425, 0.00524706, ..., 0.00140122, 0.00714365,
       0.00030117])

And it's indeed the same as that the intensity computed by {mod}`tensorwaves` (direct lambdify):

In [None]:
tensorwaves_result = np.array(intensity(phsp_set))
mean_difference = (componentwise_result - tensorwaves_result).mean()
mean_difference

-7.307471250984975e-11

In [None]:
assert mean_difference < 1e-9

## Arbitrary expressions

The problem with {ref}`002:Test with data` is that it requires a {obj}`~ampform.helicity.HelicityModel`. In {mod}`tensorwaves`, we want to work with general {class}`sympy.Expr <sympy.core.expr.Expr>`s though (see [`SympyModel`](https://tensorwaves.readthedocs.io/en/0.2.7/api/tensorwaves.model.html#tensorwaves.model.SympyModel)), where we don't have sub-{attr}`ampform.helicity.HelicityModel.components` available.

Instead, we have to split up the lambdifying in a more general way that can handle arbitrary {class}`sympy.core.expr.Expr`s. For that we need:

1. A general method of traversing through a SymPy expression tree. This can be done with {doc}`sympy:tutorials/intro-tutorial/manipulation`.
2. A **fast** method to estimate the complexity of a model, so that we can decide whether a node in the expression tree is small enough to be lambdified without much runtime. The best measure for complexity is {func}`~sympy.core.function.count_ops` ("count operations"), see notes under {doc}`sympy:modules/simplify/simplify`.

### Expression complexity

Let's tackle 2. first and use the {attr}`HelicityModel.expression <ampform.helicity.HelicityModel.expression>` and its {attr}`~ampform.helicity.HelicityModel.components` that we lambdified earlier on. Here's an overview of the number of operations versus the time it took to lambdify each component:

In [None]:
df = pd.DataFrame(runtime.values(), index=runtime, columns=["runtime (s)"])
operations = [sp.count_ops(model.expression)]
operations.extend(sp.count_ops(expr) for expr in intensity_to_symbol)
operations.extend(sp.count_ops(expr) for expr in amplitude_to_symbol)
df.insert(0, "operations", operations)
df

Unnamed: 0,operations,runtime (s)
complete model,823,0.980456
I1,209,0.279897
I2,203,0.235227
I3,207,0.215937
I4,201,0.233635
A1,103,0.0453
A2,103,0.04071
A3,100,0.039767
A4,100,0.035684
A5,102,0.036551


From this we can already roughly see that the lambdify runtime scales roughly with the number of SymPy operations.

To better visualize this, we can lambdify the expressions in {class}`~ampform.dynamics.BlattWeisskopfSquared` for each angular momentums and compute their runtime a number of times with {mod}`timeit`. Note that the {class}`~ampform.dynamics.BlattWeisskopfSquared` becomes increasingly complex the higher the angular momentum.

In [None]:
from ampform.dynamics import BlattWeisskopfSquared

angular_momentum, z = sp.symbols("L z")
BlattWeisskopfSquared(angular_momentum, z).doit()

Piecewise((1, Eq(L, 0)), (2*z/(z + 1), Eq(L, 1)), (13*z**2/(9*z + (z - 3)**2), Eq(L, 2)), (277*z**3/(z*(z - 15)**2 + (2*z - 5)*(18*z - 45)), Eq(L, 3)), (12746*z**4/(25*z*(2*z - 21)**2 + (z**2 - 45*z + 105)**2), Eq(L, 4)), (998881*z**5/(z**5 + 15*z**4 + 315*z**3 + 6300*z**2 + 99225*z + 893025), Eq(L, 5)), (118394977*z**6/(z**6 + 21*z**5 + 630*z**4 + 18900*z**3 + 496125*z**2 + 9823275*z + 108056025), Eq(L, 6)), (19727003738*z**7/(z**7 + 28*z**6 + 1134*z**5 + 47250*z**4 + 1819125*z**3 + 58939650*z**2 + 1404728325*z + 18261468225), Eq(L, 7)), (4392846440677*z**8/(z**8 + 36*z**7 + 1890*z**6 + 103950*z**5 + 5457375*z**4 + 255405150*z**3 + 9833098275*z**2 + 273922023375*z + 4108830350625), Eq(L, 8)))

In [None]:
operations = []
runtime = []
for angular_momentum in range(9):
    ff2 = BlattWeisskopfSquared(angular_momentum, z)
    operations.append(sp.count_ops(ff2.doit()))
    n_iterations = 10
    t = timeit.timeit(
        setup=f"""
import sympy as sp
from ampform.dynamics import BlattWeisskopfSquared
z = sp.Symbol("z")
ff2 = BlattWeisskopfSquared({angular_momentum}, z)
    """,
        stmt='sp.lambdify(z, ff2.doit(), "numpy")',
        number=n_iterations,
    )
    runtime.append(t / n_iterations * 1_000)

In [None]:
df = pd.DataFrame(
    {
        "operations": operations,
        "runtime (ms)": runtime,
    },
)
df

Unnamed: 0,operations,runtime (ms)
0,0,0.81877
1,3,1.24712
2,7,1.64094
3,12,2.52622
4,14,2.29422
5,16,1.889
6,19,2.24741
7,22,2.72068
8,25,3.01171


In [None]:
fig, ax = plt.subplots(figsize=(8, 4))
plt.scatter(x=df["operations"], y=df["runtime (ms)"])
ax.set_ylim(bottom=0)
ax.set_xlabel("operations")
ax.set_ylabel("runtime (ms)")
plt.show()

In [None]:
output_file = Path("002/runtime-vs-operations.svg")
output_file.parent.mkdir(exist_ok=True)
fig.savefig(output_file);

![](https://user-images.githubusercontent.com/29308176/164983940-3da1d1df-d740-42e4-8a6e-e899c5148034.svg)

### Identifying nodes

Now imagine that we don't know anything about the {attr}`~ampform.helicity.HelicityModel.expression` that we created before other than that it is a {class}`sympy.Expr <sympy.core.expr.Expr>`.

#### Approach 1: Generator

A first attempt is to use a generator to recursively identify components in the expression that lie within a certain 'complexity' (as computed by {func}`~sympy.core.function.count_ops`).

In [None]:
def recurse_tree(
    expression: sp.Expr, *, min_complexity: int = 0, max_complexity: int
) -> Generator[sp.Expr, None, None]:
    for arg in expression.args:
        complexity = sp.count_ops(arg)
        if complexity < max_complexity and complexity > min_complexity:
            yield arg
        else:
            yield from recurse_tree(
                arg,
                min_complexity=min_complexity,
                max_complexity=max_complexity,
            )

We can then use this generator function to create a mapping of these sub-expressions within the expression tree to {class}`~sympy.core.symbol.Symbol`s. That mapping can then be used in {meth}`~sympy.core.basic.Basic.xreplace` to replace the sub-expressions with those symbols.

:::{margin}

The {meth}`~sympy.core.basic.Basic.xreplace` method is **much faster** than {meth}`~sympy.core.basic.Basic.subs`, because it doesn't do any evaluation.

:::

```{autolink-skip}
```

In [None]:
%%time

expression = model.expression.doit()
sub_expressions = {}
for i, expr in enumerate(recurse_tree(expression, max_complexity=100)):
    symbol = sp.Symbol(f"f{i}")
    complexity = sp.count_ops(expr)
    sub_expressions[expr] = symbol
expression.xreplace(sub_expressions)

CPU times: user 314 ms, sys: 135 µs, total: 314 ms
Wall time: 313 ms


Abs(f0 + f1)**2 + Abs(f2 + f3)**2 + Abs(f4 + f5)**2 + Abs(f6 + f7)**2

#### Approach 2: Direct substitution

There is one problem though: {meth}`~sympy.core.basic.Basic.xreplace` is not accurate for larger expressions. It would therefore be better to directly substitute the sub-expression with a symbol while we loop over the nodes in the expression tree. The following function can do that:

In [None]:
def split_expression(
    expression: sp.Expr,
    max_complexity: int,
    min_complexity: int = 0,
) -> tuple[sp.Expr, dict[sp.Symbol, sp.Expr]]:
    i = 0
    symbol_mapping = {}

    def recursive_split(sub_expression: sp.Expr) -> sp.Expr:
        nonlocal i
        for arg in sub_expression.args:
            complexity = sp.count_ops(arg)
            if complexity < max_complexity and complexity > min_complexity:
                symbol = sp.Symbol(f"f{i}")
                i += 1
                symbol_mapping[symbol] = arg
                sub_expression = sub_expression.xreplace({arg: symbol})
            else:
                new_arg = recursive_split(arg)
                sub_expression = sub_expression.xreplace({arg: new_arg})
        return sub_expression

    top_expression = recursive_split(expression)
    return top_expression, symbol_mapping

And indeed, this is **much faster** than {ref}`002:Approach 1: Generator` (it's even possible to parallelize this for loop):

```{autolink-skip}
```

In [None]:
%time

top_expression, sub_expressions = split_expression(expression, max_complexity=100)

CPU times: user 7 µs, sys: 1 µs, total: 8 µs
Wall time: 15.5 µs


In [None]:
top_expression

Abs(f0 + f1)**2 + Abs(f2 + f3)**2 + Abs(f4 + f5)**2 + Abs(f6 + f7)**2

In [None]:
sub_expressions[sp.Symbol("f0")]

C[J/\psi(1S) \to f_{0}(980)_{0} \gamma_{+1}; f_{0}(980) \to \pi^{0}_{0} \pi^{0}_{0}]*Gamma_f(0)(980)*m_f(0)(980)*(cos(theta_1+2)/2 + 1/2)*exp(I*phi_1+2)/(-I*Gamma_f(0)(980)*m_f(0)(980)*sqrt((m_12**2 - (m_1 - m_2)**2)*(m_12**2 - (m_1 + m_2)**2)/m_12**2)*sqrt(m_f(0)(980)**2)/(sqrt((m_f(0)(980)**2 - (m_1 - m_2)**2)*(m_f(0)(980)**2 - (m_1 + m_2)**2)/m_f(0)(980)**2)*Abs(m_12)) - m_12**2 + m_f(0)(980)**2)

### Lambdify and combine

Now that we have the machinery to split up arbitrary expressions by complexity, we need to lambdify the top expression as well as each of the sub-expressions and recombine them. The following function can do that and return a recombined {class}`~typing.Callable`.

In [None]:
def optimized_lambdify(
    args: Sequence[sp.Symbol],
    expr: sp.Expr,
    modules: str | None = None,
    min_complexity: int = 0,
    max_complexity: int = 100,
) -> Callable:
    top_expression, definitions = split_expression(
        expression,
        min_complexity=min_complexity,
        max_complexity=max_complexity,
    )
    top_symbols = sorted(definitions, key=lambda s: s.name)
    top_lambdified = sp.lambdify(top_symbols, top_expression, modules)
    sub_lambdified = [
        sp.lambdify(args, definitions[symbol], modules) for symbol in top_symbols
    ]

    def recombined_function(*args):
        new_args = [sub_expr(*args) for sub_expr in sub_lambdified]
        return top_lambdified(*new_args)

    return recombined_function

We can use the same input values as in {ref}`002:Test with data` to check that the resulting lambdified expression results in the same output.

```{autolink-skip}
```

In [None]:
%time

treewise_lambdified = optimized_lambdify(free_symbols, expression, "numpy")

CPU times: user 8 µs, sys: 1 µs, total: 9 µs
Wall time: 17.4 µs


In [None]:
treewise_result = treewise_lambdified(*args_values)
treewise_result

array([0.00048765, 0.00033425, 0.00524706, ..., 0.00140122, 0.00714365,
       0.00030117])

And it's indeed the same as that the intensity computed by {mod}`tensorwaves` (direct lambdify):

In [None]:
mean_difference = (treewise_result - tensorwaves_result).mean()
mean_difference

-7.307471274905997e-11

In [None]:
assert mean_difference < 1e-9

## Comparison

Now have a look at a slightly more complicated model:

In [None]:
result = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [+1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)"],
    allowed_interaction_types=["strong", "EM"],
    formalism_type="canonical-helicity",
)
model_builder = ampform.get_builder(result)
for name in result.get_intermediate_particles().names:
    model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
complex_model = model_builder.generate()
dot = qrules.io.asdot(result, collapse_graphs=True, render_final_state_id=False)
graphviz.Source(dot)

In [None]:
output_file = Path("002/f0-graph")
output_file.parent.mkdir(exist_ok=True)
graphviz.Source(dot).render(output_file, format="svg");

![](https://user-images.githubusercontent.com/29308176/164983377-f7c3c5a3-edfd-49aa-b449-08ee77cda67f.svg)

This makes it clear that the functions defined in {ref}`002:Arbitrary expressions` results in a huge speed-up!

In [None]:
new_expression = complex_model.expression.doit()
new_free_symbols = sorted(new_expression.free_symbols, key=lambda s: s.name)

```{autolink-skip} section
```

In [None]:
%%time

np_expr = sp.lambdify(new_free_symbols, new_expression)

CPU times: user 4.57 s, sys: 3.16 ms, total: 4.57 s
Wall time: 4.57 s


In [None]:
%%time

np_expr = optimized_lambdify(new_free_symbols, new_expression)

CPU times: user 261 ms, sys: 87 µs, total: 262 ms
Wall time: 260 ms
