# Component-wise lambdify

This notebook investigates how to speed up {func}`sympy.lambdify <sympy.utilities.lambdify.lambdify>` by splitting up the expression tree of a complicated expression into components, lambdifying those, and then combining them back again.

In [None]:
import inspect
import logging
import warnings

import ampform
import graphviz
import matplotlib.pyplot as plt
import numpy as np
import qrules as q
import sympy as sp
from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff
from tensorwaves.data import generate_phsp
from tensorwaves.data.transform import HelicityTransformer
from tensorwaves.model import LambdifiedFunction, SympyModel

logger = logging.getLogger()

## Create amplitude model

First, let's create an amplitude model with {mod}`ampform`. We'll use this model as complicated {class}`sympy.Expr <sympy.core.expr.Expr>` in the rest of this notebooks.

In [None]:
logger.setLevel(logging.ERROR)

In [None]:
result = q.generate_transitions(
    initial_state=("J/psi(1S)", [-1, +1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)(980)"],
    allowed_interaction_types=["strong", "EM"],
    formalism_type="canonical-helicity",
)
dot = q.io.asdot(result, collapse_graphs=True)
graphviz.Source(dot)

In [None]:
model_builder = ampform.get_builder(result)
for name in result.get_intermediate_particles().names:
    model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = model_builder.generate()

In [None]:
free_symbols = sorted(model.expression.free_symbols, key=lambda s: s.name)
free_symbols

## Component-wise lambdifying

A {class}`~ampform.helicity.HelicityModel` has the benefit that it comes with {attr}`~ampform.helicity.HelicityModel.components` (intensities and amplitudes) that together form its {attr}`~ampform.helicity.HelicityModel.expression`. Let's separate these components into _amplitude_ and _intensity_.

In [None]:
amplitudes = {
    name: expr
    for name, expr in model.components.items()
    if name.startswith("A")
}
list(amplitudes)

In [None]:
intensities = {
    name: expr
    for name, expr in model.components.items()
    if name.startswith("I")
}

In [None]:
assert len(amplitudes) + len(intensities) == len(model.components)

### Structure of helicity model components

Note that each intensity consists of a subset of these amplitudes. This means that _intensities have a larger expression tree than amplitudes_.

In [None]:
amplitude_to_symbol = {
    expr: sp.Symbol(f"A{i}") for i, expr in enumerate(amplitudes.values(), 1)
}

In [None]:
intensity_to_symbol = {
    expr: sp.Symbol(f"I{i}") for i, expr in enumerate(intensities.values(), 1)
}

In [None]:
intensity_expr = model.expression.subs(intensity_to_symbol, simultaneous=True)
intensity_expr

In [None]:
dot = sp.dotprint(intensity_expr)
graphviz.Source(dot)

In [None]:
amplitude_expr = model.expression.subs(amplitude_to_symbol, simultaneous=True)
amplitude_expr

In [None]:
dot = sp.dotprint(amplitude_expr)
graphviz.Source(dot)

### Performance check

Lambdifying the whole {attr}`HelicityModel.expression <ampform.helicity.HelicityModel.expression>` is slowest. The {func}`~sympy.utilities.lambdify.lambdify` function first prints the expression as a {obj}`str` (!) with (in this case) {mod}`numpy` syntax and then uses {func}`eval` to convert that back to actual {mod}`numpy` objects:

In [None]:
%%time
np_complete_model = sp.lambdify(free_symbols, model.expression.doit(), "numpy")

Printing to {obj}`str` and converting back with {func}`eval` becomes exponentially slow the larger the expression tree. This means that it's more efficient to lambdify sub-trees of the expression tree separately. Lambdifying the four _intensities_ of this model separately, the effect is not noticeable:

In [None]:
logger.setLevel(logging.INFO)

In [None]:
%%time
for name, expr in intensities.items():
    logging.info(f"Lambdifying {name}")
    sp.lambdify(free_symbols, expr.doit(), "numpy")

...but each of the eight _amplitudes_ separately does result in a significant speed-up:

In [None]:
%%time
np_amplitudes = {}
for expr, symbol in amplitude_to_symbol.items():
    logging.info(f"Lambdifying {symbol.name}")
    np_expr = sp.lambdify(free_symbols, expr.doit(), "numpy")
    np_amplitudes[symbol] = np_expr

### Recombining lambdified components

{ref}`Recall <reports/sympy/lambdify-speedup:Structure of helicity model components>` what amplitude module expressed in its amplitude components looks like:

In [None]:
amplitude_expr

We have to lambdify that top expression as well:

In [None]:
sorted_amplitude_symbols = sorted(np_amplitudes, key=lambda s: s.name)
np_amplitude_expr = sp.lambdify(
    sorted_amplitude_symbols, amplitude_expr, "numpy"
)

In [None]:
source = inspect.getsource(np_amplitude_expr)
print(source)

We now have a lambdified expression for the complete amplitude model, as well as lambdified expressions that are to be plugged in to its arguments.

In [None]:
def componentwise_lambdified(*args):
    """Lambdified amplitude model, recombined from its amplitude components.

    .. warning:: Order of the ``args`` has to be the same as that
        of the ``args`` of the lambdified amplitude components.
    """
    amplitude_values = []
    for amp_symbol in sorted_amplitude_symbols:
        np_amplitude = np_amplitudes[amp_symbol]
        values = np_amplitude(*args)
        amplitude_values.append(values)
    return np_amplitude_expr(*amplitude_values)

### Test with data

Okay, so does all this work? Let's first generate a phase space sample with good-old {mod}`tensorwaves`. We can then use this sample as input to the component-wise lambdified function.

In [None]:
logger.setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

In [None]:
sympy_model = SympyModel(
    expression=model.expression,
    parameters=model.parameter_defaults,
)
intensity = LambdifiedFunction(sympy_model, backend="jax")
data_converter = HelicityTransformer(model.adapter)
phsp_sample = generate_phsp(10_000, model.adapter.reaction_info)
phsp_set = data_converter.transform(phsp_sample)

In [None]:
plt.hist(phsp_set["m_12"], bins=50, alpha=0.5, density=True)
plt.hist(
    phsp_set["m_12"],
    bins=50,
    alpha=0.5,
    density=True,
    weights=intensity(phsp_set),
)
plt.show()

The arguments of the component-wise lambdified amplitude model should be covered by the entries in the phase space set and the provided parameter defaults:

In [None]:
kinematic_variable_names = {key for key in phsp_set}
parameter_names = {symbol.name for symbol in model.parameter_defaults}
free_symbol_names = {symbol.name for symbol in free_symbols}

In [None]:
assert free_symbol_names <= kinematic_variable_names ^ parameter_names

That allows us to sort the input arrays and parameter defaults so that they can be used as positional argument input to the component-wise lambdified amplitude model:

In [None]:
merged_par_var_values = {
    symbol.name: value for symbol, value in model.parameter_defaults.items()
}
merged_par_var_values.update(phsp_set)
args_values = [merged_par_var_values[symbol.name] for symbol in free_symbols]

Finally, here's the result of plugging that back into the component-wise lambdified expression:

In [None]:
componentwise_result = componentwise_lambdified(*args_values)
componentwise_result

And it's indeed the same as that the intensity computed by {mod}`tensorwaves` (direct lambdify):

In [None]:
tensorwaves_result = np.array(intensity(phsp_set))
mean_difference = (componentwise_result - tensorwaves_result).mean()
mean_difference

In [None]:
assert mean_difference < 1e-9