In [None]:
import os

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

```{autolink-concat}
```

::::{margin}
:::{card} Amplitude analysis with zfit
TR-020
^^^
This reports builds a [simple symbolic amplitude model](https://tensorwaves.readthedocs.io/en/0.4.5/amplitude-analysis.html) with {mod}`qrules` and {mod}`ampform` and feeds it to [zfit](https://zfit.rtfd.io) instead of {mod}`tensorwaves`.
+++
✅&nbsp;[compwa.github.io#151](https://github.com/ComPWA/compwa.github.io/issues/151)<br>
WIP&nbsp;[ComPWA/.github#14](https://github.com/ComPWA/.github/issues/14)
:::
::::

# Amplitude analysis with zfit

In [None]:
%pip install -q ampform[viz]~=0.14.1 hepstats~=0.6.0 pandas~=1.4.2 sympy~=1.10.1 tensorwaves[jax,pwa]~=0.4.8 zfit~=0.10.0

In [None]:
%config InlineBackend.figure_formats = ['svg']

import logging
import os
import warnings

JAX_LOGGER = logging.getLogger("absl")
JAX_LOGGER.setLevel(logging.ERROR)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore")

## Formulating the model

In [None]:
import qrules

reaction = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [-1, +1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)"],
    allowed_interaction_types=["strong", "EM"],
    formalism="helicity",
)

In [None]:
import graphviz

dot = qrules.io.asdot(reaction, collapse_graphs=True)
graphviz.Source(dot)

![image](https://user-images.githubusercontent.com/29308176/194346498-ff9b9379-90f1-4348-81ba-d5a57ddccc83.png)

In [None]:
import ampform
from ampform.dynamics.builder import (
    create_non_dynamic_with_ff,
    create_relativistic_breit_wigner_with_ff,
)

model_builder = ampform.get_builder(reaction)
model_builder.scalar_initial_state_mass = True
model_builder.stable_final_state_ids = [0, 1, 2]
model_builder.set_dynamics("J/psi(1S)", create_non_dynamic_with_ff)
for name in reaction.get_intermediate_particles().names:
    model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = model_builder.formulate()

## Generate data

### Phase space sample

In [None]:
from tensorwaves.data import TFPhaseSpaceGenerator, TFUniformRealNumberGenerator

%pip install tensorwaves[tf]
rng = TFUniformRealNumberGenerator(seed=0)
phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
phsp_momenta = phsp_generator.generate(100_000, rng)

### Intensity-based sample

In [None]:
from tensorwaves.function.sympy import create_function

unfolded_expression = model.expression.doit()
fixed_intensity_func = create_function(
    unfolded_expression.xreplace(model.parameter_defaults),
    backend="jax",
)

In [None]:
from tensorwaves.data import SympyDataTransformer

transform_momenta = SympyDataTransformer.from_sympy(
    model.kinematic_variables, backend="jax"
)

In [None]:
from tensorwaves.data import (
    IntensityDistributionGenerator,
    TFWeightedPhaseSpaceGenerator,
)

weighted_phsp_generator = TFWeightedPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
data_generator = IntensityDistributionGenerator(
    domain_generator=weighted_phsp_generator,
    function=fixed_intensity_func,
    domain_transformer=transform_momenta,
)
data_momenta = data_generator.generate(10_000, rng)

In [None]:
import pandas as pd

phsp = transform_momenta(phsp_momenta)
data = transform_momenta(data_momenta)
pd.DataFrame(data)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm

resonances = sorted(
    reaction.get_intermediate_particles(),
    key=lambda p: p.mass,
)
evenly_spaced_interval = np.linspace(0, 1, len(resonances))
colors = [cm.rainbow(x) for x in evenly_spaced_interval]
fig, ax = plt.subplots(figsize=(9, 4))
ax.hist(
    np.real(data["m_12"]),
    bins=100,
    alpha=0.5,
    density=True,
)
ax.set_xlabel("$m$ [GeV]")
ax.set_ylabel("$I$ a.u.")
for p, color in zip(resonances, colors):
    ax.axvline(x=p.mass, linestyle="dotted", label=p.name, color=color)
ax.legend()
plt.show()

![](https://user-images.githubusercontent.com/29308176/194346791-c6e9f66f-ccf8-48b0-9b4b-5ea964db88d4.svg)

## Fit

### Determine free parameters

In [None]:
initial_parameters = {
    R"C_{J/\psi(1S) \to {f_{0}(1500)}_{0} \gamma_{+1}; f_{0}(1500) \to \pi^{0}_{0} \pi^{0}_{0}}": (
        1.0 + 0.0j
    ),
    "m_{f_{0}(500)}": 0.4,
    "m_{f_{0}(980)}": 0.88,
    "m_{f_{0}(1370)}": 1.22,
    "m_{f_{0}(1500)}": 1.45,
    "m_{f_{0}(1710)}": 1.83,
    R"\Gamma_{f_{0}(500)}": 0.3,
    R"\Gamma_{f_{0}(980)}": 0.1,
    R"\Gamma_{f_{0}(1710)}": 0.3,
}

### Parametrized function and caching

In [None]:
from tensorwaves.function.sympy import create_parametrized_function

intensity_func = create_parametrized_function(
    expression=unfolded_expression,
    parameters=model.parameter_defaults,
    backend="jax",
)

In [None]:
from tensorwaves.estimator import create_cached_function

free_parameter_symbols = [
    symbol
    for symbol in model.parameter_defaults
    if symbol.name in set(initial_parameters)
]
cached_intensity_func, transform_to_cache = create_cached_function(
    unfolded_expression,
    parameters=model.parameter_defaults,
    free_parameters=free_parameter_symbols,
    backend="jax",
)
cached_data = transform_to_cache(data)
cached_phsp = transform_to_cache(phsp)

### Estimator

In [None]:
from tensorwaves.estimator import UnbinnedNLL

estimator = UnbinnedNLL(
    intensity_func,
    data=data,
    phsp=phsp,
    backend="jax",
)
estimator_with_caching = UnbinnedNLL(
    cached_intensity_func,
    data=cached_data,
    phsp=cached_phsp,
    backend="jax",
)

### Optimize fit parameters

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm

reaction_info = model.reaction_info
resonances = sorted(
    reaction_info.get_intermediate_particles(),
    key=lambda p: p.mass,
)

evenly_spaced_interval = np.linspace(0, 1, len(resonances))
colors = [cm.rainbow(x) for x in evenly_spaced_interval]


def indicate_masses(ax):
    ax.set_xlabel("$m$ [GeV]")
    for color, resonance in zip(colors, resonances):
        ax.axvline(
            x=resonance.mass,
            linestyle="dotted",
            label=resonance.name,
            color=color,
        )


def compare_model(
    variable_name,
    data,
    phsp,
    function,
    bins=100,
):
    intensities = function(phsp)
    _, ax = plt.subplots(figsize=(9, 4))
    data_projection = np.real(data[variable_name])
    ax = plt.gca()
    ax.hist(
        data_projection,
        bins=bins,
        alpha=0.5,
        label="data",
        density=True,
    )
    phsp_projection = np.real(phsp[variable_name])
    ax.hist(
        phsp_projection,
        weights=np.array(intensities),
        bins=bins,
        histtype="step",
        color="red",
        label="fit model",
        density=True,
    )
    indicate_masses(ax)
    ax.legend()

In [None]:
original_parameters = intensity_func.parameters
intensity_func.update_parameters(initial_parameters)
compare_model("m_12", data, phsp, intensity_func)

![](https://user-images.githubusercontent.com/29308176/194347028-0fe1354e-0f2d-4684-a6d7-172ab53db8d5.svg)

In [None]:
from tensorwaves.optimizer import Minuit2
from tensorwaves.optimizer.callbacks import CSVSummary

minuit2 = Minuit2(
    callback=CSVSummary("fit_traceback.csv"),
    use_analytic_gradient=False,
)
fit_result = minuit2.optimize(estimator, initial_parameters)
fit_result

In [None]:
minuit2 = Minuit2()
fit_result_with_caching = minuit2.optimize(estimator_with_caching, initial_parameters)
fit_result_with_caching

### Fit result analysis

In [None]:
intensity_func.update_parameters(fit_result.parameter_values)
compare_model("m_12", data, phsp, intensity_func)

![](https://user-images.githubusercontent.com/29308176/194347843-0ce5e251-a78a-4123-a9b3-09fdd632a524.svg)

In [None]:
fit_traceback = pd.read_csv("fit_traceback.csv")
fig, (ax1, ax2) = plt.subplots(
    2, figsize=(7, 9), sharex=True, gridspec_kw={"height_ratios": [1, 2]}
)
fit_traceback.plot("function_call", "estimator_value", ax=ax1)
fit_traceback.plot("function_call", sorted(initial_parameters), ax=ax2)
fig.tight_layout()
ax2.set_xlabel("function call")
plt.show()

![](https://user-images.githubusercontent.com/29308176/194347685-7bd12eee-ea65-44e1-b6b5-3a615234040a.svg)

## Zfit

### PDF definition

In [None]:
import jax.numpy as jnp
import zfit  # suppress tf warnings
import zfit.z.numpy as znp
from zfit import supports, z

zfit.run.set_graph_mode(False)  # We cannot (yet) compile through the function
zfit.run.set_autograd_mode(False)


class TensorWavesPDF(zfit.pdf.BasePDF):
    def __init__(self, intensity, norm, obs, params=None, name="tensorwaves"):
        """tensorwaves intensity normalized over the *norm* dataset."""
        super().__init__(obs, params, name)
        self.intensity = intensity
        norm = {ob: jnp.asarray(ar) for ob, ar in zip(self.obs, z.unstack_x(norm))}
        self.norm_sample = norm

    @supports(norm=True)
    def _pdf(self, x, norm):
        # we can also use better mechanics, where it automatically normalizes or not
        # this here is rather to take full control, it is always possible

        # updating the parameters of the model. This seems not very TF compatible?
        self.intensity.update_parameters({
            p.name: float(p) for p in self.params.values()
        })

        # converting the data to a dict for tensorwaves
        data = {ob: jnp.asarray(ar) for ob, ar in zip(self.obs, z.unstack_x(x))}

        non_normalized_pdf = self.intensity(data)
        # this is not really needed, but can be useful for e.g. sampling with `pdf(..., norm_range=False)`
        if norm is False:
            out = non_normalized_pdf
        else:
            out = non_normalized_pdf / jnp.mean(self.intensity(self.norm_sample))
        return znp.asarray(out)

In [None]:
params = [
    zfit.param.convert_to_parameter(val, name, prefer_constant=False)
    for name, val in model.parameter_defaults.items()
]

In [None]:
def reset_parameters():
    for p in params_fit:
        if p.name in initial_parameters:
            p.set_value(initial_parameters[p.name])

In [None]:
obs = [
    zfit.Space(ob, limits=(np.min(data[ob]) - 1, np.max(data[ob]) + 1))
    for ob in pd.DataFrame(phsp)
]
obs_all = zfit.dimension.combine_spaces(*obs)

### Data conversion

In [None]:
phsp_zfit = zfit.Data.from_pandas(pd.DataFrame(phsp), obs=obs_all)
data_zfit = zfit.Data.from_pandas(pd.DataFrame(data), obs=obs_all)

### Perform fit

{obj}`complex` parameters need to be removed first:

In [None]:
params_fit = [p for p in params if p.name in initial_parameters if p.independent]

In [None]:
jax_intensity_func = create_parametrized_function(
    expression=unfolded_expression,
    parameters=model.parameter_defaults,
    backend="jax",
)

In [None]:
reset_parameters()

In [None]:
pdf = TensorWavesPDF(
    obs=obs_all,
    intensity=jax_intensity_func,
    norm=phsp_zfit,
    params={f"{p.name}": p for i, p in enumerate(params_fit)},
)
loss = zfit.loss.UnbinnedNLL(pdf, data_zfit)

In [None]:
minimizer = zfit.minimize.Minuit(gradient=True, mode=0)

:::{note}
You can also try different minimizers, like {class}`~zfit.minimizers.minimizers_scipy.ScipyTrustConstrV1`, but {class}`~zfit.minimizers.minimizer_minuit.Minuit` seems to perform best.
:::

In [None]:
%%time

result = minimizer.minimize(loss)
result

In [None]:
%%time

result.hesse(name="hesse")
result

In [None]:
%%time

result.errors(name="errors")
result

### Statistical inference using the hepstats library

{mod}`hepstats` is built on top of [`zfit-interface`](https://zfit-interface.readthedocs.io):

In [None]:
from hepstats.hypotests import ConfidenceInterval
from hepstats.hypotests.calculators import AsymptoticCalculator
from hepstats.hypotests.parameters import POIarray

In [None]:
calculator = AsymptoticCalculator(result, minimizer)

We take one of the parameters as POI:

In [None]:
poi = pdf.params[r"\Gamma_{f_{0}(500)}"]
poi

In [None]:
poi_null = POIarray(poi, np.linspace(poi - 0.1, poi + 0.1, 50))
ci = ConfidenceInterval(calculator, poi_null)
alpha = 0.328
ci.interval(alpha=alpha);

A helper function to plot the result:

In [None]:
def one_minus_cl_plot(x, pvalues, alpha=None, ax=None):
    if alpha is None:
        alpha = [0.32]
    if isinstance(alpha, (float, int)):
        alpha = [alpha]
    if ax is None:
        ax = plt.gca()

    ax.plot(x, pvalues, ".--")
    for a in alpha:
        ax.axhline(a, color="red", label="$\\alpha = " + str(a) + "$")
    ax.set_ylabel("1-CL")

    return ax

In [None]:
plt.figure(figsize=(9, 8))
one_minus_cl_plot(poi_null.values, ci.pvalues(), alpha=alpha)
plt.xlabel(f"${poi.name}$")
plt.show()

![](https://user-images.githubusercontent.com/29308176/194348046-c0b9026e-d4e9-434a-830d-351e6ba6e635.svg)