In [None]:
%config InlineBackend.figure_formats = ['svg']
import os

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

```{autolink-concat}
```

````{margin}
```{spec} Amplitude analysis with zfit
:id: TR-018
:status: WIP
:tags: physics;sympy;tensorwaves

This reports builds a [simple symbolic amplitude model](https://tensorwaves.readthedocs.io/en/0.4.5/amplitude-analysis.html) with {mod}`qrules` and {mod}`ampform` and feeds it to [zfit](https://zfit.rtfd.io) instead of {mod}`tensorwaves`. See {issue}`ComPWA/compwa-org#156`.
```
````

# Amplitude analysis with zfit

In [None]:
%pip install -q ampform==0.14.1 pandas==1.4.2 sympy==1.10.1 tensorflow==2.6.5 tensorwaves[jax,pwa]==0.4.5 git+https://github.com/zfit/zfit@fbeb661

In [None]:
import logging
import os
import warnings

JAX_LOGGER = logging.getLogger("absl")
JAX_LOGGER.setLevel(logging.ERROR)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
warnings.filterwarnings("ignore")

## Formulating the model

In [None]:
import qrules

reaction = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [-1, +1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)"],
    allowed_interaction_types=["strong", "EM"],
    formalism="helicity",
)

In [None]:
import graphviz

dot = qrules.io.asdot(reaction, collapse_graphs=True)
graphviz.Source(dot)

In [None]:
import ampform
from ampform.dynamics.builder import (
    create_non_dynamic_with_ff,
    create_relativistic_breit_wigner_with_ff,
)

model_builder = ampform.get_builder(reaction)
model_builder.scalar_initial_state_mass = True
model_builder.stable_final_state_ids = [0, 1, 2]
model_builder.set_dynamics("J/psi(1S)", create_non_dynamic_with_ff)
for name in reaction.get_intermediate_particles().names:
    model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = model_builder.formulate()

## Generate data

### Phase space sample

In [None]:
from tensorwaves.data import (
    TFPhaseSpaceGenerator,
    TFUniformRealNumberGenerator,
)

rng = TFUniformRealNumberGenerator(seed=0)
phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
phsp_momenta = phsp_generator.generate(100_000, rng)

In [None]:
import numpy as np
import pandas as pd

pd.DataFrame(
    {
        (k, label): np.transpose(v)[i]
        for k, v in phsp_momenta.items()
        for i, label in enumerate(["E", "px", "py", "pz"])
    }
)

### Intensity-based sample

In [None]:
from tensorwaves.function.sympy import create_function

unfolded_expression = model.expression.doit()
fixed_intensity_func = create_function(
    unfolded_expression.xreplace(model.parameter_defaults),
    backend="jax",
)

In [None]:
from tensorwaves.data import SympyDataTransformer

transform_momenta = SympyDataTransformer.from_sympy(
    model.kinematic_variables, backend="jax"
)

In [None]:
from tensorwaves.data import (
    IntensityDistributionGenerator,
    TFWeightedPhaseSpaceGenerator,
)

weighted_phsp_generator = TFWeightedPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
data_generator = IntensityDistributionGenerator(
    domain_generator=weighted_phsp_generator,
    function=fixed_intensity_func,
    domain_transformer=transform_momenta,
)
data_momenta = data_generator.generate(10_000, rng)

In [None]:
phsp = transform_momenta(phsp_momenta)
data = transform_momenta(data_momenta)
pd.DataFrame(data)

In [None]:
import matplotlib.pyplot as plt
from matplotlib import cm

resonances = sorted(
    reaction.get_intermediate_particles(),
    key=lambda p: p.mass,
)
evenly_spaced_interval = np.linspace(0, 1, len(resonances))
colors = [cm.rainbow(x) for x in evenly_spaced_interval]
fig, ax = plt.subplots(figsize=(9, 4))
ax.hist(
    np.real(data["m_12"]),
    bins=100,
    alpha=0.5,
    density=True,
)
ax.set_xlabel("$m$ [GeV]")
for p, color in zip(resonances, colors):
    ax.axvline(x=p.mass, linestyle="dotted", label=p.name, color=color)
ax.legend()
plt.show()

## Fit

### Determine free parameters

In [None]:
initial_parameters = {
    R"C_{J/\psi(1S) \to {f_{0}(1500)}_{0} \gamma_{+1}; f_{0}(1500) \to \pi^{0}_{0} \pi^{0}_{0}}": 1.0
    + 0.0j,
    "m_{f_{0}(500)}": 0.4,
    "m_{f_{0}(980)}": 0.88,
    "m_{f_{0}(1370)}": 1.22,
    "m_{f_{0}(1500)}": 1.45,
    "m_{f_{0}(1710)}": 1.83,
    R"\Gamma_{f_{0}(500)}": 0.3,
    R"\Gamma_{f_{0}(980)}": 0.1,
    R"\Gamma_{f_{0}(1710)}": 0.3,
}

### Parametrized function and caching

In [None]:
from tensorwaves.function.sympy import create_parametrized_function

intensity_func = create_parametrized_function(
    expression=unfolded_expression,
    parameters=model.parameter_defaults,
    backend="jax",
)

In [None]:
from tensorwaves.estimator import create_cached_function

free_parameter_symbols = [
    symbol
    for symbol in model.parameter_defaults
    if symbol.name in set(initial_parameters)
]
cached_intensity_func, transform_to_cache = create_cached_function(
    unfolded_expression,
    parameters=model.parameter_defaults,
    free_parameters=free_parameter_symbols,
    backend="jax",
)
cached_data = transform_to_cache(data)
cached_phsp = transform_to_cache(phsp)

### Estimator

In [None]:
from tensorwaves.estimator import UnbinnedNLL

estimator = UnbinnedNLL(
    intensity_func,
    data=data,
    phsp=phsp,
    backend="jax",
)
estimator_with_caching = UnbinnedNLL(
    cached_intensity_func,
    data=cached_data,
    phsp=cached_phsp,
    backend="jax",
)

### Optimize fit parameters

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm

reaction_info = model.reaction_info
resonances = sorted(
    reaction_info.get_intermediate_particles(),
    key=lambda p: p.mass,
)

evenly_spaced_interval = np.linspace(0, 1, len(resonances))
colors = [cm.rainbow(x) for x in evenly_spaced_interval]


def indicate_masses(ax):
    ax.set_xlabel("$m$ [GeV]")
    for color, resonance in zip(colors, resonances):
        ax.axvline(
            x=resonance.mass,
            linestyle="dotted",
            label=resonance.name,
            color=color,
        )


def compare_model(
    variable_name,
    data,
    phsp,
    function,
    bins=100,
):
    intensities = function(phsp)
    _, ax = plt.subplots(figsize=(9, 4))
    data_projection = np.real(data[variable_name])
    ax = plt.gca()
    ax.hist(
        data_projection,
        bins=bins,
        alpha=0.5,
        label="data",
        density=True,
    )
    phsp_projection = np.real(phsp[variable_name])
    ax.hist(
        phsp_projection,
        weights=np.array(intensities),
        bins=bins,
        histtype="step",
        color="red",
        label="fit model",
        density=True,
    )
    indicate_masses(ax)
    ax.legend()

In [None]:
original_parameters = intensity_func.parameters
intensity_func.update_parameters(initial_parameters)
compare_model("m_12", data, phsp, intensity_func)

In [None]:
from tensorwaves.optimizer import Minuit2
from tensorwaves.optimizer.callbacks import CSVSummary

minuit2 = Minuit2(
    callback=CSVSummary("fit_traceback.csv"),
    use_analytic_gradient=False,
)
fit_result = minuit2.optimize(estimator, initial_parameters)
fit_result

In [None]:
minuit2 = Minuit2()
fit_result_with_caching = minuit2.optimize(
    estimator_with_caching, initial_parameters
)
fit_result_with_caching

### Fit result analysis

In [None]:
intensity_func.update_parameters(fit_result.parameter_values)
compare_model("m_12", data, phsp, intensity_func)

In [None]:
fit_traceback = pd.read_csv("fit_traceback.csv")
fig, (ax1, ax2) = plt.subplots(
    2, figsize=(7, 9), sharex=True, gridspec_kw={"height_ratios": [1, 2]}
)
fit_traceback.plot("function_call", "estimator_value", ax=ax1)
fit_traceback.plot("function_call", sorted(initial_parameters), ax=ax2)
fig.tight_layout()
ax2.set_xlabel("function call");

## Zfit

### PDF definition

In [None]:
import tensorflow.experimental.numpy as tnp
import zfit  # suppress tf warnings
from zfit import z
from zfit.core.space import supports

zfit.run.set_graph_mode(False)
zfit.run.set_autograd_mode(False)


class TensorWavesPDF(zfit.pdf.BasePDF):
    def __init__(self, intensity, norm, obs, params=None, name="tensorwaves"):
        """tensorwaves intensity normalized over the *norm* dataset."""
        super().__init__(obs, params, name)
        self.intensity = intensity
        norm = {
            ob: tnp.array(ar) for ob, ar in zip(self.obs, z.unstack_x(norm))
        }
        self.norm_sample = norm

    @supports(norm=True)
    def _pdf(
        self, x, norm_range
    ):  # we can also use better mechanics, where it automatically normalizes or not
        # this here is rather to take full control, it is always possible

        # updating the parameters of the model. This seems not very TF compatible?
        self.intensity.update_parameters(
            {p.name: float(p) for p in self.params.values()}
        )

        # converting the data to a dict for tensorwaves
        data = {ob: tnp.array(ar) for ob, ar in zip(self.obs, z.unstack_x(x))}
        non_normalized_pdf = self.intensity(data)
        # this is not really needed, but can be useful for e.g. sampling with `pdf(..., norm_range=False)`
        if norm_range is False:
            return non_normalized_pdf
        else:
            return non_normalized_pdf / tnp.mean(
                self.intensity(self.norm_sample)
            )

In [None]:
params = [
    zfit.param.convert_to_parameter(val, name, prefer_constant=False)
    for name, val in model.parameter_defaults.items()
]

In [None]:
obs = [
    zfit.Space(ob, limits=(np.min(data[ob]) - 1, np.max(data[ob]) + 1))
    for ob in pd.DataFrame(phsp)
]
obs_all = zfit.dimension.combine_spaces(*obs)

### Data conversion

In [None]:
phsp_zfit = zfit.Data.from_pandas(pd.DataFrame(phsp), obs=obs_all)
data_zfit = zfit.Data.from_pandas(pd.DataFrame(data), obs=obs_all)

### Perform fit

In [None]:
params_fit = [
    p for p in params if p.name in initial_parameters if p.independent
]  # remove the complex parameters

In [None]:
tf_intensity_func = create_parametrized_function(
    expression=unfolded_expression,
    parameters=model.parameter_defaults,
    backend="tf",
)

In [None]:
for p in params_fit:
    if p.name in initial_parameters:
        p.set_value(initial_parameters[p.name])

In [None]:
pdf = TensorWavesPDF(
    obs=obs_all,
    intensity=tf_intensity_func,
    norm=phsp_zfit,
    params={f"param{i}": p for i, p in enumerate(params_fit)},
)
loss = zfit.loss.UnbinnedNLL(pdf, data_zfit)

In [None]:
%%time
minimizer = zfit.minimize.Minuit(gradient=True)
result = minimizer.minimize(loss)
result

In [None]:
%%time
minimizer = zfit.minimize.Minuit(gradient=False)
result = minimizer.minimize(loss)
result

In [None]:
%%time
result.hesse()