# ComPWA at PyHEP 2022

This notebook accompanies [these slides](https://docs.google.com/presentation/d/e/2PACX-1vRF-EG2B6u8a6Wb3--TY37bBEgM0bIxgNkCesokrTEwdQZbMwONMXOKqn5GZSirAIH9NXVv6v0ym_es/pub), which were [presented at PyHEP 2022](https://indico.cern.ch/event/1150631/contributions/5002013).

Depending on tim, we'll cover the following parts:
1. [Formulating expressions with SymPy](#symbolic-expressions)
2. [Intro to TensorWaves](#intro-to-tensorwaves)
3. [Amplitude analysis example](#larger-expressions-―-amplitude-analysis)

Please **execute this cell** before starting the notebook talk:

In [None]:
import logging
import os

import ipywidgets
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import psutil
from ampform.io import aslatex
from IPython.display import HTML, Math, display
from matplotlib import cm

# Hide device warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
JAX_LOGGER = logging.getLogger("absl")
JAX_LOGGER.setLevel(logging.ERROR)

# Increase font size
PARENT_PROCESS = psutil.Process().parent().cmdline()[-1]
ON_JUPYTER_LAB = "jupyter-lab" in PARENT_PROCESS
if not ON_JUPYTER_LAB:
    plt.rc("font", size=20)
    src = """
    <style>
      /* Classical Jupyter notebook (RISE) */
      div.output_subarea.output_latex.output_result { font-size: 36px; }

      /* Jupyter Lab */
      div.jp-OutputArea-output pre { font-size: 36px; }
    </style>
    """
    display(HTML(src))


def remove_ipywidget_toolbars(fig) -> None:
    fig.canvas.toolbar_visible = False
    fig.canvas.header_visible = False
    fig.canvas.footer_visible = False


def plot_distributions():
    fig, (ax1, ax2) = plt.subplots(figsize=(12, 7), ncols=2)
    remove_ipywidget_toolbars(fig)
    ax1.hist2d(*cartesian_data.values(), bins=100, cmap=plt.cm.coolwarm)
    ax2.hist2d(polar_data["phi"], polar_data["r"], bins=100, cmap=plt.cm.coolwarm)
    fig.suptitle("Hit-and-miss intensity distribution")
    ax1.set_title("cartesian")
    ax2.set_title("polar")
    ax1.set_xlabel("$x$")
    ax1.set_ylabel("$y$")
    ax2.set_xlabel(R"$\phi$")
    ax2.set_ylabel("$r$")
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax2.set_xticks([-np.pi / 2, 0, np.pi / 2])
    ax2.set_yticks([])
    ax2.set_xticklabels([r"$-\frac{\pi}{2}$", "0", r"$+\frac{\pi}{2}$"])
    fig.tight_layout()
    plt.show()


def plot_interactive():
    size = 200
    X, Y = np.meshgrid(
        np.linspace(-5, +5, size),
        np.linspace(-5, +5, size),
    )
    cartesian_domain = {"x": X, "y": Y}
    polar_domain = converter(cartesian_domain)

    fig, ax_interactive = plt.subplots(figsize=(5, 5), tight_layout=True)
    remove_ipywidget_toolbars(fig)
    ax_interactive.set_xticks([])
    ax_interactive.set_yticks([])
    ax_interactive.set_xlabel("$x$")
    ax_interactive.set_ylabel("$y$")
    color_mesh = None

    @ipywidgets.interact(
        dphi=ipywidgets.FloatSlider(value=0, min=0, max=np.pi, step=np.pi / 100),
        k_r=(0, 3.0, np.pi / 100),
        k_phi=(0, 6),
        sigma=(0.1, 5),
    )
    def plot(dphi, k_r, k_phi, sigma):
        nonlocal color_mesh, X, Y
        polar_function.update_parameters(
            {R"\Delta\phi": dphi, "k_r": k_r, "k_phi": k_phi, "sigma": sigma}
        )
        Z = polar_function(polar_domain)
        if color_mesh is not None:
            color_mesh.remove()
        color_mesh = ax_interactive.pcolormesh(X, Y, Z, cmap=plt.cm.coolwarm)


def indicate_masses(ax):
    ax.set_xlabel(R"$M\left(\pi^0\pi^0\right)$ [GeV]")
    resonances = sorted(
        reaction.get_intermediate_particles(),
        key=lambda p: p.mass,
    )
    evenly_spaced_interval = np.linspace(0, 1, len(resonances))
    colors = [plt.cm.rainbow(x) for x in evenly_spaced_interval]
    for color, resonance in zip(colors, resonances):
        ax.axvline(
            x=resonance.mass,
            linestyle="dotted",
            label=resonance.name,
            color=color,
        )


def compare_model(n_bins=100) -> None:
    fig, ax = plt.subplots(figsize=(14, 5))
    remove_ipywidget_toolbars(fig)
    variable_name = "m_12"
    min_ = data[variable_name].real.min()
    max_ = data[variable_name].real.max()
    bin_edges = np.linspace(min_, max_, num=n_bins + 1)
    bin_values_data, bin_edges_data = jnp.histogram(
        data[variable_name].real,
        bins=bin_edges,
        density=True,
    )
    bin_values_func, bin_edges_func = jnp.histogram(
        phsp[variable_name].real,
        bins=bin_edges,
        density=True,
        weights=intensity_func(phsp),
    )
    ax.fill_between(
        bin_edges[:-1],
        bin_values_data,
        alpha=0.5,
        label="data",
        step="pre",
    )
    ax.step(
        bin_edges[:-1],
        bin_values_func,
        alpha=0.5,
        color="red",
        label="fit model",
    )
    indicate_masses(ax)
    ax.legend()
    plt.show()


def plot_traceback(traceback_file: str) -> None:
    with open(traceback_file) as f:
        headers = eval(f"[{f.readline()}]")
    fit_traceback = np.genfromtxt(
        "fit_traceback.csv", delimiter=",", skip_header=1, dtype=complex
    )
    fit_traceback = fit_traceback.T
    function_call = fit_traceback[1].real.astype(int)
    estimator_values = fit_traceback[5].real
    parameter_values = dict(zip(headers[6:], fit_traceback[6:].real))

    fig, (ax1, ax2) = plt.subplots(
        2, figsize=(15, 9), sharex=True, gridspec_kw={"height_ratios": [1, 2]}
    )
    remove_ipywidget_toolbars(fig)
    ax1.plot(function_call, estimator_values)
    for label, values in parameter_values.items():
        ax2.plot(function_call, values, label=f"${label}$")
    ax2.set_xlabel("function call")
    ax2.legend(bbox_to_anchor=(0.95, 1), loc="upper left")
    ax2.set_ylim(-0.1, +2.0)
    fig.tight_layout()
    plt.show()

## Symbolic expressions

[SymPy](https://www.sympy.org) expressions are build up by applying operations to [`Symbol`](https://docs.sympy.org/latest/modules/core.html#sympy.core.symbol.Symbol)s:

In [None]:
import sympy as sp

x, n, mu, sigma = sp.symbols("x n mu sigma")
expression = n * sp.exp(-((x - mu) ** 2) / (2 * sigma**2))
expression

Computations with numbers are **exact**, but **not suitable for large numerical computations**:

In [None]:
expression.subs({n: 1, mu: 1 / 3, sigma: sp.sqrt(2)})

### From expression to numerical function

SymPy can convert symbolic expressions to a numerical function with `sympy.lambdify()`:

In [None]:
func = sp.lambdify(
    args=(x, n, mu, sigma),
    expr=expression,
    modules="numpy",
)

In this case, the resulting function is Python code for a NumPy function:

In [None]:
import inspect

src = inspect.getsource(func)
print(src)

This **numerical function** can be used for faster computations:

In [None]:
import matplotlib.pyplot as plt
import numpy as np

x_array = np.linspace(-4.0, +4.0, num=1_000)
y_array = func(x_array, n=1, mu=0.0, sigma=1.0)
plt.plot(x_array, y_array);

## Intro to TensorWaves

[TensorWaves](https://tensorwaves.rtfd.io) is a [ComPWA](https://compwa-org.readthedocs.io) package that **streamlines this conversion from expressions to numerical functions**.

Some responsibilities:
- Facilitate JIT-compilation etc. for backends like JAX
- **Improve argument handling** for complicated expressions and symbol names
- **Parametrize functions** for fitting
- Facilitate **data generation**
- Facilitate **data transformation**

### Lambdification to other backends

[TensorWaves](https://tensorwaves.rtfd.io) is a [ComPWA](https://compwa-org.readthedocs.io) package that **streamlines this conversion from expressions to numerical functions**

In [None]:
x, y = sp.symbols("x y")
expr_2d = x**3 + sp.sin(sp.pi * y) ** 2 / 5
expr_2d

In [None]:
from tensorwaves.function.sympy import create_function

numpy_function = create_function(expr_2d, backend="numpy")
tf_function = create_function(expr_2d, backend="tensorflow")
jax_function = create_function(expr_2d, backend="jax")

To allow for multidimensional expressions and more complicated `Symbol` names, functions take a `dict` of data as input:

In [None]:
sample_size = 1_000_000
data = {
    "x": np.random.uniform(-50, +50, sample_size),
    "y": np.random.uniform(0.1, 2.0, sample_size),
}

In [None]:
numpy_function(data)

In [None]:
tf_function(data)

In [None]:
jax_function(data)

### Function parametrization

TensorWaves is intended to **optimize parameters** in a model. We therefore need to distinguish which symbols are data input **variables** and **parameters**

In [None]:
r, phi, dphi, k_phi, k_r, sigma = sp.symbols(R"r phi \Delta\phi k_phi k_r sigma")
expression = (
    sp.exp(-r / sigma) * sp.sin(k_r * r) ** 2 * sp.cos(k_phi * (phi + dphi)) ** 2
)
expression

In [None]:
from tensorwaves.function.sympy import create_parametrized_function

polar_function = create_parametrized_function(
    expression,
    parameters={dphi: 0, k_r: 0.6, k_phi: 2, sigma: 2.5},
    backend="jax",
)
polar_function.parameters

### Data transformation

TensorWaves also makes it easier to transform data sets to a different representation

In [None]:
cartesian_to_polar = {
    r: sp.sqrt(x**2 + y**2),
    phi: sp.Piecewise((0, sp.Eq(x, 0)), (sp.atan(y / x), True)),
}
Math(aslatex(cartesian_to_polar))

In [None]:
from tensorwaves.data import SympyDataTransformer

converter = SympyDataTransformer.from_sympy(cartesian_to_polar, backend="jax")
converter.functions

### Data generation

We now have all the tools to generate data based on our function with a **hit & miss strategy**:

In [None]:
from tensorwaves.data import (
    IntensityDistributionGenerator,
    NumpyDomainGenerator,
    NumpyUniformRNG,
)

rng = NumpyUniformRNG()
domain_generator = NumpyDomainGenerator(boundaries={"x": (-5, 5), "y": (-5, +5)})
data_generator = IntensityDistributionGenerator(
    domain_generator,
    function=polar_function,
    domain_transformer=converter,
)
cartesian_data = data_generator.generate(1_000_000, rng)
polar_data = converter(cartesian_data)
polar_data

What does it look like...?

In [None]:
plot_distributions()  # <-- function with matplotlib code

Finally, using the fact that the function is **parametrized** 🙌

In [None]:
%matplotlib widget
plot_interactive()  # <-- function with ipywidgets code

&nbsp;

## Larger expressions ― amplitude analysis

### Model formulation

Amplitude model formulated is automated and standardized with two ComPWA libraries
- [QRules](https://qrules.rtfd.io): generate allowed particle transitions
- [AmpForm](https://ampform.rtfd.io): symbolic expressions for dynamics and spin formalisms

QRules automatically finds all allowed transitions between some initial and final state.

Simple example:

In [None]:
import qrules

reaction = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [-1, +1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)"],  # optional
    allowed_interaction_types=["strong", "EM"],  # optional
    formalism="helicity",
)

In [None]:
import graphviz

dot = qrules.io.asdot(reaction, collapse_graphs=True)
graphviz.Source(dot)

We can now use AmpForm to formulate an amplitude model for these transitions

In [None]:
import ampform
from ampform.dynamics.builder import (
    create_non_dynamic_with_ff,
    create_relativistic_breit_wigner_with_ff,
)

model_builder = ampform.get_builder(reaction)
model_builder.set_dynamics("J/psi(1S)", create_non_dynamic_with_ff)
for name in reaction.get_intermediate_particles().names:
    model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = model_builder.formulate()
model.intensity

In [None]:
Math(aslatex(model.amplitudes))

For an amplitude analysis, this example is rather simple, but the full expression is large enough that we can illustrate **CAS simplification**

In [None]:
full_intensity_expr = model.expression.doit()
sp.count_ops(full_intensity_expr)

Can we simplify the full expression? Here's what one of the nodes looks like:

In [None]:
full_intensity_expr.args[0].args[0].args[0].args[0]

Imagine we want to fit this model to a data distribution and **only** want to optimize these parameters:

In [None]:
initial_parameters = {
    R"C_{J/\psi(1S) \to {f_{0}(1500)}_{0} \gamma_{+1}; f_{0}(1500) \to \pi^{0}_{0} \pi^{0}_{0}}": 1.0
    + 0.0j,
    "m_{f_{0}(500)}": 0.4,
    "m_{f_{0}(980)}": 0.88,
    "m_{f_{0}(1370)}": 1.22,
    "m_{f_{0}(1500)}": 1.45,
    "m_{f_{0}(1710)}": 1.83,
    R"\Gamma_{f_{0}(500)}": 0.3,
    R"\Gamma_{f_{0}(980)}": 0.1,
    R"\Gamma_{f_{0}(1710)}": 0.3,
}

Remaining symbols in the model can now be **analytically substituted** with their suggested parameter value:

In [None]:
free_parameters = {p for p in model.parameter_defaults if p.name in initial_parameters}
fixed_parameters = {
    par: value
    for par, value in model.parameter_defaults.items()
    if par not in free_parameters
}
substituted_expression = full_intensity_expr.subs(fixed_parameters)

The effect of substitution:

In [None]:
full_intensity_expr.args[0].args[0].args[0].args[0]

In [None]:
substituted_expression.args[0].args[0].args[0].args[0]

### Numerical computations

In [None]:
from tensorwaves.function.sympy import create_parametrized_function

intensity_func = create_parametrized_function(
    substituted_expression,
    parameters=model.parameter_defaults,
    backend="jax",
)

### Data generation

TensorWaves interfaces to the [`phasespace`](https://phasespace.readthedocs.io) package to generate a phase space sample

In [None]:
from tensorwaves.data import TFPhaseSpaceGenerator, TFUniformRealNumberGenerator

rng = TFUniformRealNumberGenerator(seed=0)
phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
phsp_momenta = phsp_generator.generate(100_000, rng)  # small size for binder
list(phsp_momenta)

Our intensity function takes helicity angles and masses as input, so we need to **transform our generated four-momenta to these kinematic variables**. Just like cartesian to polar coordinates!

In [None]:
from tensorwaves.data import SympyDataTransformer

helicity_transformer = SympyDataTransformer.from_sympy(
    model.kinematic_variables, backend="jax"
)

We now have all the components to generate a **hit & miss** distribution based on the intensity function:

In [None]:
from tensorwaves.data import TFWeightedPhaseSpaceGenerator

data_generator = IntensityDistributionGenerator(
    domain_generator=phsp_generator,
    function=intensity_func,
    domain_transformer=helicity_transformer,
)
data_momenta = data_generator.generate(10_000, rng)  # small size for binder
list(data_momenta)

In [None]:
phsp = helicity_transformer(phsp_momenta)
data = helicity_transformer(data_momenta)
list(data)

### Fitting the model

We now imagine that we don't know the parameters with which we generated our distribution. Instead, we guess the following parameters:

In [None]:
initial_parameters = {
    R"C_{J/\psi(1S) \to {f_{0}(1500)}_{0} \gamma_{+1}; f_{0}(1500) \to \pi^{0}_{0} \pi^{0}_{0}}": 1.0
    + 0.0j,
    "m_{f_{0}(500)}": 0.4,
    "m_{f_{0}(980)}": 0.88,
    "m_{f_{0}(1370)}": 1.22,
    "m_{f_{0}(1500)}": 1.45,
    "m_{f_{0}(1710)}": 1.83,
    R"\Gamma_{f_{0}(500)}": 0.3,
    R"\Gamma_{f_{0}(980)}": 0.1,
    R"\Gamma_{f_{0}(1710)}": 0.3,
}

Our model now looks as follows:

In [None]:
%matplotlib inline
original_parameters = intensity_func.parameters
intensity_func.update_parameters(initial_parameters)
compare_model()

Let's use [`iminuit`](https://iminuit.rtfd.io) to optimize these parameters so that the model fits the distribution:

In [None]:
from tensorwaves.estimator import UnbinnedNLL
from tensorwaves.optimizer import Minuit2
from tensorwaves.optimizer.callbacks import CSVSummary

estimator = UnbinnedNLL(intensity_func, data, phsp, backend="jax")
minuit2 = Minuit2(callback=CSVSummary("fit_traceback.csv"))
fit_result = minuit2.optimize(estimator, initial_parameters)
fit_result

With the optimized parameters, the distribution looks like:

In [None]:
intensity_func.update_parameters(original_parameters)
compare_model()

Callbacks allow us to insert behavior in each fit iteration. In this case, we have recorded the parameter values during each optimization step:

In [None]:
plot_traceback("fit_traceback.csv")

Other optimizer functionality is available through the optimizer that has been used.

In [None]:
covariance_matrix = fit_result.specifics.covariance
covariance_matrix.correlation()