# TensorWaves demo PANDA Seminar December 2021

This notebook accompanies [these slides](https://docs.google.com/presentation/d/e/2PACX-1vSymz5AjdhPw4Kz1pKhdFMnFGYuQvVaC8WbV_HTg770x6RDYoP-Anv9tn88DSuzvSiiQ9F4pcDGVExv/pub). They were presented during a PANDA Seminar on 13 December 2021.

Related notebooks for this presentation:
- [QRules demo](./qrules.ipynb)
- [AmpForm demo](./ampform.ipynb)

For more extensive examples, see **[tensorwaves.rtfd.io](https://tensorwaves.readthedocs.io)**.

## Install dependencies

In [None]:
%config InlineBackend.figure_formats = ['svg']
import logging
import re
import warnings

import ampform
import graphviz
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import qrules
import sympy as sp
import tensorflow as tf
from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff
from jax.lib import xla_bridge
from matplotlib import cm
from tensorflow.python.ops.numpy_ops import np_config
from tensorwaves.data import (
    IntensityDistributionGenerator,
    TFPhaseSpaceGenerator,
    TFUniformRealNumberGenerator,
    TFWeightedPhaseSpaceGenerator,
)
from tensorwaves.data.transform import SympyDataTransformer
from tensorwaves.estimator import UnbinnedNLL
from tensorwaves.function.sympy import create_parametrized_function
from tensorwaves.optimizer.callbacks import CSVSummary
from tensorwaves.optimizer.minuit import Minuit2

LOGGER = logging.getLogger("absl")
LOGGER.setLevel(logging.ERROR)
tf.get_logger().setLevel("WARNING")
warnings.filterwarnings("ignore")

mpl.rcParams.update({"font.size": 14})
np_config.enable_numpy_behavior()

has_tf_gpu = bool(tf.config.list_physical_devices("GPU"))
jax_backend = xla_bridge.get_backend().platform.upper()
print("TF backend: ", "GPU" if has_tf_gpu else "CPU")
print("JAX backend:", jax_backend)

Some helper functions for visualizing the distributions and fit result:

In [None]:
def indicate_masses():
    reaction_info = model.reaction_info
    resonances = sorted(
        reaction_info.get_intermediate_particles(),
        key=lambda p: p.mass,
    )
    evenly_spaced_interval = np.linspace(0, 1, len(resonances))
    colors = [cm.rainbow(x) for x in evenly_spaced_interval]
    plt.xlabel("$m$ [GeV]")
    for i, p in enumerate(resonances):
        plt.gca().axvline(x=p.mass, linestyle="dotted", label=p.name, color=colors[i])


def compare_model(
    variable_name,
    data_set,
    phsp_set,
    intensity_model,
    bins=100,
):
    data = np.array(data_set[variable_name])
    phsp = np.array(phsp_set[variable_name])
    intensities = np.array(intensity_model(phsp_set))
    _, ax = plt.subplots(figsize=(10, 3.5))
    ax.hist(
        data,
        bins=bins,
        alpha=0.5,
        label="data",
        density=True,
    )
    ax.hist(
        phsp,
        weights=intensities,
        bins=bins,
        histtype="step",
        color="red",
        label="model",
        density=True,
    )
    ax.set_yticks([])
    indicate_masses()
    ax.legend()


def natural_sorting(text):
    # https://stackoverflow.com/a/5967539/13219025
    return [
        __attempt_number_cast(c)
        for c in re.split(r"[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)", text)
    ]


def __attempt_number_cast(text):
    try:
        return float(text)
    except ValueError:
        return text

## Formulate amplitude model

Generate allowed transitions for $J/\psi \to \gamma f_0, f_0 \to \pi^0 \pi^0$ with QRules (see [QRules demo](./qrules.ipynb)):

In [None]:
reaction = qrules.generate_transitions(
    initial_state="J/psi(1S)",
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)"],
    allowed_interaction_types=["strong", "EM"],
)

In [None]:
dot = qrules.io.asdot(reaction, collapse_graphs=True)
graphviz.Source(dot)

Express the transitions as an amplitude model with the resonances parametrized as a relativistic Breit-Wigner with form factor (see [AmpForm demo](./ampform.ipynb)):

In [None]:
builder = ampform.get_builder(reaction)
resonances = reaction.get_intermediate_particles()
for p in resonances:
    builder.set_dynamics(p.name, create_relativistic_breit_wigner_with_ff)
model = builder.formulate()

## Generate data

Generate a **deterministic** phase-space sample for this decay and an intensity-based hit-and-miss sample for this amplitude model (`intensity`):

In [None]:
expression = model.expression.doit()
parameter_defaults = model.parameter_defaults
intensity = create_parametrized_function(expression, parameter_defaults, backend="jax")
helicity_transformer = SympyDataTransformer.from_sympy(
    model.kinematic_variables, backend="jax"
)
reaction_info = model.reaction_info
initial_state_mass = reaction_info.initial_state[-1].mass
final_state_masses = {i: p.mass for i, p in reaction_info.final_state.items()}
rng = TFUniformRealNumberGenerator(seed=0)

phsp_generator = TFPhaseSpaceGenerator(initial_state_mass, final_state_masses)
phsp_momenta = phsp_generator.generate(1_000_000, rng)
phsp = helicity_transformer(phsp_momenta)

data_generator = IntensityDistributionGenerator(
    function=intensity,
    domain_generator=TFWeightedPhaseSpaceGenerator(
        initial_state_mass, final_state_masses
    ),
    domain_transformer=helicity_transformer,
)
data_momenta = data_generator.generate(100_000, rng)
data = helicity_transformer(data_momenta)

In [None]:
fig, ax = plt.subplots()
ax.hist(np.array(data["m_12"]), density=True, bins=100, alpha=0.5, label="data")
ax.hist(np.array(phsp["m_12"]), density=True, bins=100, alpha=0.5, label="phsp")
ax.set_xlabel(R"$m_{\pi^0\pi^0}$")
ax.set_yticks([])
plt.legend()
plt.show()

## Perform fit

In [None]:
initial_parameters = {
    "m_f(0)(500)": 0.35,
    "m_f(0)(980)": 0.88,
    "m_f(0)(1370)": 1.22,
    "m_f(0)(1500)": 1.45,
    "m_f(0)(1710)": 1.83,
    "Gamma_f(0)(500)": 0.3,
    "Gamma_f(0)(980)": 0.1,
    "Gamma_f(0)(1710)": 0.3,
}
original_parameters = intensity.parameters
intensity.update_parameters(initial_parameters)
compare_model("m_12", data, phsp, intensity, bins=200)

Note that we first run the function on a single event. This is to JIT-compile the JAX functions, so that we get the raw performance numbers on the actual optimization process.

_For an optimized version of the fit, see [`tensorwaves-optimized-expression.ipynb`](./tensorwaves-optimized-expression.ipynb)._

In [None]:
function = create_parametrized_function(expression, parameter_defaults, backend="jax")
function({k: v[0] for k, v in data.items()})  # JIT-compile function
estimator = UnbinnedNLL(function, data, phsp, backend="jax")
optimizer = Minuit2(callback=CSVSummary("fit_traceback.csv"))
fit_result = optimizer.optimize(estimator, initial_parameters)
fit_result.execution_time

_Time estimate:_<br>
For this (deterministic) data sample and these initial parameter values, the fit requires **515 iterations** (GPU; 516 on CPU). On Google Colab, this should take **around 20 seconds** with GPU.

In [None]:
function.update_parameters(fit_result.parameter_values)
compare_model("m_12", data, phsp, function, bins=200)

## Visualize fit traceback

In [None]:
traceback = np.genfromtxt("fit_traceback.csv", delimiter=",", names=True)
fig, (ax1, ax2) = plt.subplots(
    nrows=2,
    figsize=(7, 8),
    gridspec_kw={"height_ratios": [1, 2.5]},
    sharex=True,
)
ax1.set_title("Negative log likelihood")
ax2.set_title("Parameter values")
ax2.set_xlabel("function call")
fig.tight_layout()

x = traceback["function_call"]
ax1.plot(x, traceback["estimator_value"])
for par in initial_parameters:
    label = f"${sp.latex(sp.Symbol(par))}$"
    key = re.sub(r"[\(\)]", "", par)
    ax2.plot(x, traceback[key], label=label)

legend = ax2.legend(loc=(0.77, 0.43))
legend.get_frame().set_alpha(None)
for par, line in zip(initial_parameters, ax2.get_lines(), strict=True):
    label = line.get_label()
    color = line.get_color()
    ax2.axhline(
        y=original_parameters[par],
        color=color,
        alpha=0.5,
        linestyle="dotted",
    )