# PyHEP 2022 Notebook Talk ― amplitude analysis

In [None]:
import logging
import os

import matplotlib.pyplot as plt
import numpy as np
from ampform.io import aslatex
from IPython.display import Math

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
JAX_LOGGER = logging.getLogger("absl")
JAX_LOGGER.setLevel(logging.ERROR)

## Model formulation

In [None]:
import qrules

reaction = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [-1, +1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)"],
    allowed_interaction_types=["strong", "EM"],
    formalism="helicity",
)

In [None]:
import ampform
from ampform.dynamics.builder import (
    create_non_dynamic_with_ff,
    create_relativistic_breit_wigner_with_ff,
)

model_builder = ampform.get_builder(reaction)
model_builder.set_dynamics("J/psi(1S)", create_non_dynamic_with_ff)
for name in reaction.get_intermediate_particles().names:
    model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = model_builder.formulate()

In [None]:
model.intensity

In [None]:
Math(aslatex(model.amplitudes))

### Expression simplification

In [None]:
import sympy as sp

unfolded_expression = model.expression.doit()
sp.count_ops(unfolded_expression)

### Conversion to numerical function

In [None]:
from tensorwaves.function.sympy import create_parametrized_function

intensity = create_parametrized_function(
    expression=unfolded_expression,
    parameters=model.parameter_defaults,
    backend="jax",
)

In [None]:
from tensorwaves.data import TFPhaseSpaceGenerator, TFUniformRealNumberGenerator

rng = TFUniformRealNumberGenerator(seed=0)
phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
phsp_momenta = phsp_generator.generate(100_000, rng)

In [None]:
from tensorwaves.data import SympyDataTransformer

helicity_transformer = SympyDataTransformer.from_sympy(
    model.kinematic_variables, backend="jax"
)

In [None]:
from tensorwaves.data import (
    IntensityDistributionGenerator,
    TFWeightedPhaseSpaceGenerator,
)

weighted_phsp_generator = TFWeightedPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
data_generator = IntensityDistributionGenerator(
    domain_generator=weighted_phsp_generator,
    function=intensity,
    domain_transformer=helicity_transformer,
)
data_momenta = data_generator.generate(10_000, rng)

In [None]:
phsp = helicity_transformer(phsp_momenta)
data = helicity_transformer(data_momenta)
list(data)

In [None]:
resonances = sorted(
    reaction.get_intermediate_particles(),
    key=lambda p: p.mass,
)
evenly_spaced_interval = np.linspace(0, 1, len(resonances))
colors = [plt.cm.rainbow(x) for x in evenly_spaced_interval]
fig, ax = plt.subplots(figsize=(9, 4))
ax.hist(
    np.real(data["m_12"]),
    bins=100,
    alpha=0.5,
    density=True,
)
ax.set_xlabel("$m$ [GeV]")
for p, color in zip(resonances, colors):
    ax.axvline(x=p.mass, linestyle="dotted", label=p.name, color=color)
ax.legend()
plt.show()

In [None]:
initial_parameters = {
    R"C_{J/\psi(1S) \to {f_{0}(1500)}_{0} \gamma_{+1}; f_{0}(1500) \to \pi^{0}_{0} \pi^{0}_{0}}": 1.0
    + 0.0j,
    "m_{f_{0}(500)}": 0.4,
    "m_{f_{0}(980)}": 0.88,
    "m_{f_{0}(1370)}": 1.22,
    "m_{f_{0}(1500)}": 1.45,
    "m_{f_{0}(1710)}": 1.83,
    R"\Gamma_{f_{0}(500)}": 0.3,
    R"\Gamma_{f_{0}(980)}": 0.1,
    R"\Gamma_{f_{0}(1710)}": 0.3,
}

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm


def indicate_masses(ax):
    ax.set_xlabel("$m$ [GeV]")
    for color, resonance in zip(colors, resonances):
        ax.axvline(
            x=resonance.mass,
            linestyle="dotted",
            label=resonance.name,
            color=color,
        )


def compare_model(
    variable_name,
    data,
    phsp,
    function,
    bins=100,
):
    intensities = function(phsp)
    _, ax = plt.subplots(figsize=(9, 4))
    data_projection = np.real(data[variable_name])
    ax = plt.gca()
    ax.hist(
        data_projection,
        bins=bins,
        alpha=0.5,
        label="data",
        density=True,
    )
    phsp_projection = np.real(phsp[variable_name])
    ax.hist(
        phsp_projection,
        weights=np.array(intensities),
        bins=bins,
        histtype="step",
        color="red",
        label="fit model",
        density=True,
    )
    indicate_masses(ax)
    ax.legend()

In [None]:
free_parameters = {p for p in model.parameter_defaults if p.name in initial_parameters}
fixed_parameters = {
    p: v for p, v in model.parameter_defaults.items() if p not in free_parameters
}
optimized_expression = unfolded_expression.subs(fixed_parameters)
sp.count_ops(optimized_expression)

In [None]:
optimized_expression.args[0].args[0].args[0].args[0]

In [None]:
for node in sp.preorder_traversal(optimized_expression):
    if node.free_symbols:
        continue
    if isinstance(node, sp.Pow):
        optimized_expression = optimized_expression.xreplace({node: node.n()})
sp.count_ops(optimized_expression)

In [None]:
optimized_expression.args[0].args[0].args[0].args[0]

In [None]:
import graphviz

dot = sp.dotprint(unfolded_expression.args[0].args[0].args[0].args[0], size=12)
graphviz.Source(dot)

In [None]:
dot = sp.dotprint(optimized_expression.args[0].args[0].args[0].args[0], size=12)
graphviz.Source(dot)

In [None]:
optimized_function = create_parametrized_function(
    optimized_expression, model.parameter_defaults, backend="jax"
)

In [None]:
from tensorwaves.estimator import UnbinnedNLL

estimator = UnbinnedNLL(optimized_function, data, phsp, backend="jax")

In [None]:
original_parameters = optimized_function.parameters
optimized_function.update_parameters(initial_parameters)
compare_model("m_12", data, phsp, optimized_function)

In [None]:
from tensorwaves.optimizer import Minuit2
from tensorwaves.optimizer.callbacks import CSVSummary

minuit2 = Minuit2(
    callback=CSVSummary("fit_traceback.csv"),
    use_analytic_gradient=False,
)
fit_result = minuit2.optimize(estimator, initial_parameters)
fit_result

In [None]:
optimized_function.update_parameters(fit_result.parameter_values)
compare_model("m_12", data, phsp, optimized_function)

In [None]:
covariance_matrix = fit_result.specifics.covariance
covariance_matrix.correlation()