# Speed up lambdifying

This notebook investigates how to speed up {func}`sympy.lambdify <sympy.utilities.lambdify.lambdify>` by splitting up the expression tree of a complicated expression into components, lambdifying those, and then combining them back again.

In [None]:
import inspect
import logging
import timeit
import warnings
from typing import Dict, Generator, Tuple

import ampform
import graphviz
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import qrules as q
import sympy as sp
from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff
from tensorwaves.data import generate_phsp
from tensorwaves.data.transform import HelicityTransformer
from tensorwaves.model import LambdifiedFunction, SympyModel

logger = logging.getLogger()

## Create dummy expression

First, let's create an amplitude model with {mod}`ampform`. We'll use this model as complicated {class}`sympy.Expr <sympy.core.expr.Expr>` in the rest of this notebooks.

In [None]:
logger.setLevel(logging.ERROR)

In [None]:
result = q.generate_transitions(
    initial_state=("J/psi(1S)", [-1, +1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)(980)"],
    allowed_interaction_types=["strong", "EM"],
    formalism_type="canonical-helicity",
)
dot = q.io.asdot(result, collapse_graphs=True)
graphviz.Source(dot)

In [None]:
model_builder = ampform.get_builder(result)
for name in result.get_intermediate_particles().names:
    model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = model_builder.generate()

In [None]:
free_symbols = sorted(model.expression.free_symbols, key=lambda s: s.name)
free_symbols

## Helicity model components

A {class}`~ampform.helicity.HelicityModel` has the benefit that it comes with {attr}`~ampform.helicity.HelicityModel.components` (intensities and amplitudes) that together form its {attr}`~ampform.helicity.HelicityModel.expression`. Let's separate these components into _amplitude_ and _intensity_.

In [None]:
amplitudes = {
    name: expr
    for name, expr in model.components.items()
    if name.startswith("A")
}
list(amplitudes)

In [None]:
intensities = {
    name: expr
    for name, expr in model.components.items()
    if name.startswith("I")
}

In [None]:
assert len(amplitudes) + len(intensities) == len(model.components)

### Component structure

Note that each intensity consists of a subset of these amplitudes. This means that _intensities have a larger expression tree than amplitudes_.

In [None]:
amplitude_to_symbol = {
    expr: sp.Symbol(f"A{i}") for i, expr in enumerate(amplitudes.values(), 1)
}

In [None]:
intensity_to_symbol = {
    expr: sp.Symbol(f"I{i}") for i, expr in enumerate(intensities.values(), 1)
}

In [None]:
intensity_expr = model.expression.subs(intensity_to_symbol, simultaneous=True)
intensity_expr

In [None]:
dot = sp.dotprint(intensity_expr)
graphviz.Source(dot)

In [None]:
amplitude_expr = model.expression.subs(amplitude_to_symbol, simultaneous=True)
amplitude_expr

In [None]:
dot = sp.dotprint(amplitude_expr)
graphviz.Source(dot)

### Performance check

Lambdifying the whole {attr}`HelicityModel.expression <ampform.helicity.HelicityModel.expression>` is slowest. The {func}`~sympy.utilities.lambdify.lambdify` function first prints the expression as a {obj}`str` (!) with (in this case) {mod}`numpy` syntax and then uses {func}`eval` to convert that back to actual {mod}`numpy` objects:

:::{margin}

We store the time with {mod}`timeit` for section {ref}`reports/sympy/lambdify-speedup:Arbitrary expressions`.

:::

In [None]:
runtime = {}
start = timeit.default_timer()

In [None]:
%%time
np_complete_model = sp.lambdify(free_symbols, model.expression.doit(), "numpy")

In [None]:
stop = timeit.default_timer()
runtime["complete model"] = stop - start

Printing to {obj}`str` and converting back with {func}`eval` becomes exponentially slow the larger the expression tree. This means that it's more efficient to lambdify sub-trees of the expression tree separately. Lambdifying the four _intensities_ of this model separately, the effect is not noticeable:

In [None]:
logger.setLevel(logging.INFO)

In [None]:
%%time
for expr, symbol in intensity_to_symbol.items():
    logging.info(f"Lambdifying {symbol.name}")
    start = timeit.default_timer()
    sp.lambdify(free_symbols, expr.doit(), "numpy")
    stop = timeit.default_timer()
    runtime[symbol.name] = stop - start

...but each of the eight _amplitudes_ separately does result in a significant speed-up:

In [None]:
%%time
np_amplitudes = {}
for expr, symbol in amplitude_to_symbol.items():
    logging.info(f"Lambdifying {symbol.name}")
    start = timeit.default_timer()
    np_expr = sp.lambdify(free_symbols, expr.doit(), "numpy")
    stop = timeit.default_timer()
    runtime[symbol.name] = stop - start
    np_amplitudes[symbol] = np_expr

### Recombining components

{ref}`Recall <reports/sympy/lambdify-speedup:Component structure>` what amplitude module expressed in its amplitude components looks like:

In [None]:
amplitude_expr

We have to lambdify that top expression as well:

In [None]:
sorted_amplitude_symbols = sorted(np_amplitudes, key=lambda s: s.name)
np_amplitude_expr = sp.lambdify(
    sorted_amplitude_symbols, amplitude_expr, "numpy"
)

In [None]:
source = inspect.getsource(np_amplitude_expr)
print(source)

We now have a lambdified expression for the complete amplitude model, as well as lambdified expressions that are to be plugged in to its arguments.

In [None]:
def componentwise_lambdified(*args):
    """Lambdified amplitude model, recombined from its amplitude components.

    .. warning:: Order of the ``args`` has to be the same as that
        of the ``args`` of the lambdified amplitude components.
    """
    amplitude_values = []
    for amp_symbol in sorted_amplitude_symbols:
        np_amplitude = np_amplitudes[amp_symbol]
        values = np_amplitude(*args)
        amplitude_values.append(values)
    return np_amplitude_expr(*amplitude_values)

### Test with data

Okay, so does all this work? Let's first generate a phase space sample with good-old {mod}`tensorwaves`. We can then use this sample as input to the component-wise lambdified function.

In [None]:
logger.setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

In [None]:
sympy_model = SympyModel(
    expression=model.expression,
    parameters=model.parameter_defaults,
)
intensity = LambdifiedFunction(sympy_model, backend="jax")
data_converter = HelicityTransformer(model.adapter)
phsp_sample = generate_phsp(10_000, model.adapter.reaction_info)
phsp_set = data_converter.transform(phsp_sample)

In [None]:
plt.hist(phsp_set["m_12"], bins=50, alpha=0.5, density=True)
plt.hist(
    phsp_set["m_12"],
    bins=50,
    alpha=0.5,
    density=True,
    weights=intensity(phsp_set),
)
plt.show()

The arguments of the component-wise lambdified amplitude model should be covered by the entries in the phase space set and the provided parameter defaults:

In [None]:
kinematic_variable_names = {key for key in phsp_set}
parameter_names = {symbol.name for symbol in model.parameter_defaults}
free_symbol_names = {symbol.name for symbol in free_symbols}

In [None]:
assert free_symbol_names <= kinematic_variable_names ^ parameter_names

That allows us to sort the input arrays and parameter defaults so that they can be used as positional argument input to the component-wise lambdified amplitude model:

In [None]:
merged_par_var_values = {
    symbol.name: value for symbol, value in model.parameter_defaults.items()
}
merged_par_var_values.update(phsp_set)
args_values = [merged_par_var_values[symbol.name] for symbol in free_symbols]

Finally, here's the result of plugging that back into the component-wise lambdified expression:

In [None]:
componentwise_result = componentwise_lambdified(*args_values)
componentwise_result

And it's indeed the same as that the intensity computed by {mod}`tensorwaves` (direct lambdify):

In [None]:
tensorwaves_result = np.array(intensity(phsp_set))
mean_difference = (componentwise_result - tensorwaves_result).mean()
mean_difference

In [None]:
assert mean_difference < 1e-9

## Arbitrary expressions

The problem with {ref}`reports/sympy/lambdify-speedup:Test with data` is that it requires a {obj}`~ampform.helicity.HelicityModel`. In {mod}`tensorwaves`, we want to work with general {class}`sympy.Expr <sympy.core.expr.Expr>`s though (see {class}`~tensorwaves.model.SympyModel`), where we don't have sub-{attr}`ampform.helicity.HelicityModel.components` available.

Instead, we have to split up the lambdifying in a more general way that can handle arbitrary {class}`sympy.core.expr.Expr`s. For that we need:

1. A general method of traversing through a SymPy expression tree. This can be done with {doc}`sympy:tutorial/manipulation`.
2. A **fast** method to estimate the complexity of a model, so that we can decide whether a node in the expression tree is small enough to be lambdified without much runtime. The best measure for complexity is {func}`~sympy.core.function.count_ops` ("count operations"), see notes under {doc}`sympy:modules/simplify/simplify`.

### Expression complexity

Let's tackle 2. first and use the {attr}`HelicityModel.expression <ampform.helicity.HelicityModel.expression>` and its {attr}`~ampform.helicity.HelicityModel.components` that we lambdified earlier on. Here's an overview of the number of operations versus the time it took to lambdify each component:

In [None]:
df = pd.DataFrame(runtime.values(), index=runtime, columns=["runtime (s)"])
operations = [sp.count_ops(model.expression)]
operations.extend(sp.count_ops(expr) for expr in intensity_to_symbol)
operations.extend(sp.count_ops(expr) for expr in amplitude_to_symbol)
df.insert(0, "operations", operations)
df

From this we can already roughly see that the lambdify runtime scales roughly with the number of SymPy operations.

To better visualize this, we can lambdify the expressions in {class}`~ampform.dynamics.BlattWeisskopfSquared` for each angular momentums and compute their runtime a number of times with {mod}`timeit`. Note that the {class}`~ampform.dynamics.BlattWeisskopfSquared` becomes increasingly complex the higher the angular momentum.

In [None]:
from ampform.dynamics import BlattWeisskopfSquared

angular_momentum, z = sp.symbols("L z")
BlattWeisskopfSquared(angular_momentum, z).doit()

In [None]:
operations = []
runtime = []
for angular_momentum in range(9):
    ff2 = BlattWeisskopfSquared(angular_momentum, z)
    operations.append(sp.count_ops(ff2.doit()))
    n_iterations = 10
    t = timeit.timeit(
        setup=f"""
import sympy as sp
from ampform.dynamics import BlattWeisskopfSquared
z = sp.Symbol("z")
ff2 = BlattWeisskopfSquared({angular_momentum}, z)
    """,
        stmt='sp.lambdify(z, ff2.doit(), "numpy")',
        number=n_iterations,
    )
    runtime.append(t / n_iterations * 1_000)

In [None]:
df = pd.DataFrame(
    {
        "operations": operations,
        "runtime (ms)": runtime,
    },
)
df

In [None]:
plt.scatter(x=df["operations"], y=df["runtime (ms)"])
ax = plt.gca()
ax.set_ylim(bottom=0)
ax.set_xlabel("operations")
ax.set_ylabel("runtime (ms)");

### Identifying nodes

Now imagine that we don't know anything about the {attr}`~ampform.helicity.HelicityModel.expression` that we created before other than that it is a {class}`sympy.Expr <sympy.core.expr.Expr>`.

#### Approach 1: Generator

A first attempt is to use a generator to recursively identify components in the expression that lie within a certain 'complexity' (as computed by {func}`~sympy.core.function.count_ops`).

In [None]:
def recurse_tree(
    expression: sp.Expr, *, min_complexity: int = 0, max_complexity: int
) -> Generator[sp.Expr, None, None]:
    for arg in expression.args:
        complexity = sp.count_ops(arg)
        if complexity < max_complexity and complexity > min_complexity:
            yield arg
        else:
            yield from recurse_tree(
                arg,
                min_complexity=min_complexity,
                max_complexity=max_complexity,
            )

We can then use this generator function to create a mapping of these sub-expressions within the expression tree to {class}`~sympy.core.symbol.Symbol`s. That mapping can then be used in {meth}`~sympy.core.basic.Basic.xreplace` to replace the sub-expressions with those symbols.

:::{margin}

The {meth}`~sympy.core.basic.Basic.xreplace` method is **much faster** than {meth}`~sympy.core.basic.Basic.subs`, because it doesn't do any evaluation.

:::

In [None]:
%%time
expression = model.expression.doit()
sub_expressions = {}
for i, expr in enumerate(recurse_tree(expression, max_complexity=100)):
    symbol = sp.Symbol(f"f{i}")
    complexity = sp.count_ops(expr)
    sub_expressions[expr] = symbol
expression.xreplace(sub_expressions)

#### Approach 2: Direct substitution

There is one problem though: {meth}`~sympy.core.basic.Basic.xreplace` is not accurate for larger expressions. It would therefore be better to directly substitute the sub-expression with a symbol while we loop over the nodes in the expression tree. The following function can do that:

In [None]:
def simplify_expression(
    expression: sp.Expr,
    max_complexity: int,
    min_complexity: int = 0,
) -> Tuple[sp.Expr, Dict[sp.Symbol, sp.Expr]]:
    i = 0
    symbol_mapping = {}

    def recursive_simplify(sub_expression: sp.Expr) -> sp.Expr:
        nonlocal i
        for arg in sub_expression.args:
            complexity = sp.count_ops(arg)
            if complexity < max_complexity and complexity > min_complexity:
                symbol = sp.Symbol(f"f{i}")
                i += 1
                symbol_mapping[symbol] = arg
                sub_expression = sub_expression.xreplace({arg: symbol})
            else:
                new_arg = recursive_simplify(arg)
                sub_expression = sub_expression.xreplace({arg: new_arg})
        return sub_expression

    new_expression = recursive_simplify(expression)
    return new_expression, symbol_mapping

And indeed, this is **much faster** than {ref}`reports/sympy/lambdify-speedup:Approach 1: Generator` (it's even possible to parallelize this for loop):

In [None]:
%time
simplified_expr, definitions = simplify_expression(
    expression, max_complexity=100
)

In [None]:
simplified_expr

In [None]:
definitions[sp.Symbol("f0")]