```{autolink-concat}
```

::::{margin}
:::{card} Investigation of acceptance matrix with Dalitz-plot decomposition
TR-035
^^^
We investigate how to compute the acceptance matrix for a three-body decay with a spinful final state ($J/\psi \to \bar{p} K^0_S \Sigma^+$). The acceptance matrix then becomes a complex-valued matrix.
:::
::::

# DPD acceptance matrix

In [None]:
from __future__ import annotations

import logging
import re
import warnings
from collections.abc import Iterable

import attrs
import graphviz
import jax.numpy as jnp
import qrules
import sympy as sp
from ampform.dynamics.form_factor import FormFactor
from ampform.kinematics.lorentz import (
    Energy,
    EuclideanNorm,
    FourMomentumSymbol,
    InvariantMass,
    ThreeMomentum,
)
from ampform.sympy import PoolSum
from ampform.sympy._decorator import unevaluated
from ampform_dpd import AmplitudeModel, DalitzPlotDecompositionBuilder
from ampform_dpd.adapter.qrules import normalize_state_ids, to_three_body_decay
from ampform_dpd.decay import IsobarNode, Particle, State, ThreeBodyDecayChain
from ampform_dpd.dynamics import RelativisticBreitWigner
from ampform_dpd.io import simplify_latex_rendering
from IPython.display import HTML, display
from tensorwaves.data.phasespace import TFPhaseSpaceGenerator
from tensorwaves.data.rng import TFUniformRealNumberGenerator
from tensorwaves.data.transform import SympyDataTransformer

simplify_latex_rendering()

logging.getLogger().setLevel(logging.ERROR)
warnings.simplefilter("ignore", category=RuntimeWarning)

In [None]:
display(HTML("<style>.container { width:100% !important; }</style>"))


To speed up the fit the part of the negative log likelihood (NLL) function representing the normalization computed over the  Monte-Carlo phsp can be pre-computed. Note that this method can only be applied in the _specific case that masses and widths are fixed_ as the normalization depends on these parameters.

Statistically this approach can be derived as follows:
Given total sample $\vec{x}$ consisting of _N_ independent observations of a set the $n$ quantities $x$ depending the likelihood function can be written as the product of the (Probability Density Functions) PDFs $f$ of each single observation:

$$L(\vec{x} ; \vec{\theta})=\prod_{\mathfrak{i}=1}^N f\left(x_i ; \vec{\theta}\right)$$

$\vec{\theta}$ represents the set of $m$ unknown parameters the chosen PDF depends on.
In the case of the this amplitude analysis the PDFs are the normalized intensity functions for each partial wave:

$$f(x_i ; \vec{\theta})= \frac{I(x_i ; \vec{\theta})}{\int I(\vec{x} ; \vec{\theta}) d\vec{x}}$$ 

 The $x_i$ represent the selected events containing the information about the Mandelstam variables and the decay and production angles. With the expression $f$ the NLL function is given as:

$$
NLL(x_i ; \vec{\theta})=-\sum^n_{i=1} \log \left(\frac{I(x_i ; \vec{\theta})}{\int I(\vec{x} ; \vec{\theta}) d\vec{x}}\right) = \underbrace{-\sum^n_{i=1} \log(I(x_i ; \vec{\theta}))}_{\mathrm{Data}} + n\cdot \log\left(\int I(\vec{x} ; \vec{\theta}) d\vec{x} \right)
$$

Since the second term integrates over all phase-space points, it is independent of \( x_i \) and can be approximated using Monte Carlo integration:

$$
\int I(\vec{x} \,; \vec{\theta}) \, d\vec{x} = 
c^{\dagger} \cdot \frac{1}{N_{\mathrm{MC}}} \sum_{j=1}^{N_{\mathrm{MC}}} I(\vec{x}_j^{\mathrm{MC}} \,; \vec{\theta}) \cdot c 
\cdot \left( \int d\vec{x} \right)
$$

The term $\frac{1}{N_{\mathrm{MC}}} \sum_{j=1}^{N_{\mathrm{MC}}} I(\vec{x}_j^{\mathrm{MC}} ; \vec{\theta})$ represents the so-called acceptance matrix, a hermitian matrix which we call $X$ from now on. The acceptance matrix can be used obtain integrated intensity via the **bilinear relation**:

$$
\int I\\ d\vec{x}=c^{\dagger}\cdot X \cdot c
$$
## Construction of the acceptance matrix
The acceptance matrix can be constructed by probing with four different value combinations for the $c_ij$ leading to the four equations:
$$
A_{ij}^{+}=\int I(c_i=1, c_j=1) \\ d\vec{x} =  \left[ {\begin{array}{cc}
   1 & 1 \\
  \end{array} } \right]\cdot   \left[ {\begin{array}{cc}
   X_{ii} & X_{ij} \\
   X_{ji}& X_{jj} \\
  \end{array} } \right] \cdot  \left[ {\begin{array}{cc}
   1 \\
1
  \end{array} } \right]
$$

$$
A_{ij}^{-}=\int I(c_i=1, c_j=-1) \\ d\vec{x} =  \left[ {\begin{array}{cc}
   1 & -1 \\
  \end{array} } \right]\cdot   \left[ {\begin{array}{cc}
   X_{ii} & X_{ij} \\
   X_{ji}& X_{jj} \\
  \end{array} } \right] \cdot  \left[ {\begin{array}{cc}
   1 \\
-1
  \end{array} } \right]
$$

$$
B_{ij}^{+}=\int I(c_i=1, c_j=i)\\ d\vec{x} =  \left[ {\begin{array}{cc}
   1 & i \\
  \end{array} } \right]\cdot   \left[ {\begin{array}{cc}
   X_{ii} & X_{ij} \\
   X_{ji}& X_{jj} \\
  \end{array} } \right] \cdot  \left[ {\begin{array}{cc}
   1 \\
i
  \end{array} } \right]
$$

$$
B_{ij}^{-}=\int I(c_i=1, c_j=-i) \\ d\vec{x} =  \left[ {\begin{array}{cc}
   1 & -i \\
  \end{array} } \right]\cdot   \left[ {\begin{array}{cc}
   X_{ii} & X_{ij} \\
   X_{ji}& X_{jj} \\
  \end{array} } \right] \cdot  \left[ {\begin{array}{cc}
   1 \\
-i
  \end{array} } \right]
$$
The $A_{ij}$ and $B_{ij}$ are the elements of the sub-intensity matrix, the results of the integration over the phasespace when setting $c_i$ and $c_j$ to the respective values and the rest of to couplings to zero.\\
The elements of $X$ can finally be expressed as:
$$
X_{ij} = \frac{(A_{ij}^{+}-A_{ij}^{-})}{4}+i\cdot\frac{(B_{ij}^{+}-B_{ij}^{-})}{4}
$$
Two equations are needed for the real and two for the imaginary part of $X_{ij}$, as you can see in the expression above.
## Speed up the construction of $X$
Using four equations as mentioned above is not computational efficient as sub-intensity matrix for all four combinations of $c_i$ and $c_j$ has to be computed. However the hermitian property of the acceptance matrix can be used to reduce the number of equations necessary to construct the matrix to two. A hermitian matrix is mirror symmetric with respect to its main diagonal, up to the complex conjugation of all entries. Therefore one can only calculate the upper triangle of $X$ and then fill the lower triangle with the complex conjugate. As probes the vector ($c_i=1$,$c_j=1$) is used to obtain the real part  and ($c_i=i$,$c_j=1$) to obtain the imaginary part of the elements of $X_ {ij}$ . The diagonal of $X$ is real valued and is equal the the diagonal of the sub-intensity matrix can assuming $c_i=c_j=1$.
The speed-up construction results in the following expression:
$$
X_{ij} = \frac{A_{c_i=1,c_j=1}-A_{c_i=i,c_j=1}-(1 + i) (\text{diagonal}[i] + \text{diagonal}[j])}{2}
$$

## Prepare model and data

### Define reaction

In [None]:
PDG = qrules.load_default_particles()
PDG.add(
    qrules.particle.create_particle(
        template_particle=PDG.find("N(1720)+"),
        name="N(2060)+",
        mass=2.1,
        width=0.4,
        pid=200004,
        parity=-1,
        spin=5 / 2,
        latex="N(2060)^+",
    ),
)

In [None]:
reaction = qrules.generate_transitions(
    initial_state=[("J/psi(1S)", [-1, +1])],
    final_state=["K0", "Sigma+", "p~"],
    allowed_interaction_types="strong",
    allowed_intermediate_particles=[
        "N(2060)",
        "Sigma(1750)",
        "Sigma(1775)",
    ],
    mass_conservation_factor=0,
    particle_db=PDG,
)
reaction = normalize_state_ids(reaction)
dot = qrules.io.asdot(reaction, collapse_graphs=True)
graphviz.Source(dot)

### Define amplitude model

In [None]:
def formulate_breit_wigner_with_ff(
    chain: ThreeBodyDecayChain,
) -> tuple[sp.Expr, dict[sp.Symbol, float]]:
    s = _get_mandelstam_s(chain)
    parameter_defaults = {}
    production_ff, new_pars = _create_form_factor(s, chain.production_node)
    parameter_defaults.update(new_pars)
    decay_ff, new_pars = _create_form_factor(s, chain.decay_node)
    parameter_defaults.update(new_pars)
    breit_wigner, new_pars = _create_breit_wigner(s, chain.decay_node)
    parameter_defaults.update(new_pars)
    return (
        production_ff * decay_ff * breit_wigner,
        parameter_defaults,
    )


def _create_form_factor(
    s: sp.Symbol,
    isobar: IsobarNode,
) -> tuple[sp.Expr, dict[sp.Symbol, float]]:
    assert isobar.interaction is not None, "Need LS-couplings"
    inv_mass = _generate_mass_symbol(isobar.parent, s)
    outgoing_state_mass1 = _generate_mass_symbol(_get_particle(isobar.child1), s)
    outgoing_state_mass2 = _generate_mass_symbol(_get_particle(isobar.child2), s)
    meson_radius = _create_meson_radius_symbol(isobar.parent)
    form_factor = FormFactor(
        s=inv_mass**2,
        m1=outgoing_state_mass1,
        m2=outgoing_state_mass2,
        angular_momentum=isobar.interaction.L,
        meson_radius=meson_radius,
    )
    parameter_defaults = {
        meson_radius: 1,
    }
    return form_factor, parameter_defaults


def _generate_mass_symbol(state: State | Particle, s: sp.Symbol) -> sp.Symbol:
    if isinstance(state, State):
        return create_mass_symbol(state)
    return sp.sqrt(s)


def _create_breit_wigner(
    s: sp.Symbol,
    isobar: IsobarNode,
) -> tuple[sp.Expr, dict[sp.Symbol, float]]:
    assert isobar.interaction is not None, "Need LS-couplings"
    outgoing_state_mass1 = create_mass_symbol(isobar.child1)
    outgoing_state_mass2 = create_mass_symbol(isobar.child2)
    angular_momentum = isobar.interaction.L
    res_mass = create_mass_symbol(isobar.parent)
    res_width = sp.Symbol(Rf"\Gamma_{{{isobar.parent.latex}}}", nonnegative=True)
    meson_radius = _create_meson_radius_symbol(isobar.parent)

    breit_wigner_expr = RelativisticBreitWigner(
        s=s,
        mass0=res_mass,
        gamma0=res_width,
        m1=outgoing_state_mass1,
        m2=outgoing_state_mass2,
        angular_momentum=angular_momentum,
        meson_radius=meson_radius,
    )
    parameter_defaults = {
        res_mass: isobar.parent.mass,
        res_width: isobar.parent.width,
        meson_radius: 1,
    }
    return breit_wigner_expr, parameter_defaults


def _create_meson_radius_symbol(isobar: IsobarNode) -> sp.Symbol:
    particle = _get_particle(isobar)
    if isinstance(particle, State):
        if particle.index != 0:
            msg = "Only the initial state has a meson radius"
            raise NotImplementedError(msg)
        return sp.Symbol(R"R_{J/\psi}")
    return sp.Symbol(Rf"R_\mathrm{{{particle.latex}}}")


def create_mass_symbol(particle: IsobarNode | Particle) -> sp.Symbol:
    particle = _get_particle(particle)
    if isinstance(particle, State):
        return sp.Symbol(f"m{particle.index}", nonnegative=True)
    return sp.Symbol(f"m_{{{particle.latex}}}", nonnegative=True)


def _get_mandelstam_s(decay: ThreeBodyDecayChain) -> sp.Symbol:
    s1, s2, s3 = sp.symbols("sigma1:4", nonnegative=True)
    m1, m2, m3 = map(create_mass_symbol, decay.final_state)
    decay_masses = {create_mass_symbol(p) for p in decay.decay_products}
    if decay_masses == {m2, m3}:
        return s1
    if decay_masses == {m1, m3}:
        return s2
    if decay_masses == {m1, m2}:
        return s3
    msg = f"Cannot find Mandelstam variable for {''.join(decay_masses)}"
    raise NotImplementedError(msg)


def _get_particle(isobar: IsobarNode | State) -> State | Particle:
    if isinstance(isobar, IsobarNode):
        return isobar.parent
    return isobar

In [None]:
def prepare_for_phsp(model: AmplitudeModel) -> AmplitudeModel:
    p1, p2, p3 = (FourMomentumSymbol(f"p{i}", shape=[]) for i in (1, 2, 3))
    s1, s2, s3 = sp.symbols("sigma1:4", nonnegative=True)
    mass_definitions = {
        s1: InvariantMassSquared(p2 + p3),
        s2: InvariantMassSquared(p1 + p3),
        s3: InvariantMassSquared(p1 + p2),
        sp.Symbol("m0", nonnegative=True): InvariantMass(p1 + p2 + p3),
        sp.Symbol("m1", nonnegative=True): InvariantMass(p1),
        sp.Symbol("m2", nonnegative=True): InvariantMass(p2),
        sp.Symbol("m3", nonnegative=True): InvariantMass(p3),
    }
    mass_definitions = {k: sp.sympify(v) for k, v in mass_definitions.items()}
    angle_and_mandelstam_definitions = {
        symbol: expr.xreplace(mass_definitions)
        for symbol, expr in model.variables.items()
    }
    angle_and_mandelstam_definitions.update(mass_definitions)
    polarized_intensity = set_initial_state_polarization(
        model.intensity,
        spin_projections=(-1, +1),
    )
    new_parameter_defaults = {
        symbol: v
        for symbol, v in model.parameter_defaults.items()
        if not re.match(r"m[0-3]", symbol.name)
    }
    return attrs.evolve(
        model,
        intensity=polarized_intensity,
        variables=angle_and_mandelstam_definitions,
        parameter_defaults=new_parameter_defaults,
    )


def set_initial_state_polarization(
    intensity: PoolSum, spin_projections: Iterable[sp.Rational | float]
) -> PoolSum:
    helicity_symbol, _ = intensity.indices[0]
    helicity_values = tuple(sp.Rational(i) for i in spin_projections)
    new_indices = (
        (helicity_symbol, helicity_values),
        *intensity.indices[1:],
    )
    return PoolSum(intensity.expression, *new_indices)


@unevaluated
class InvariantMassSquared(sp.Expr):
    momentum: sp.Basic
    _latex_repr_ = R"M^2\left({momentum}\right)"

    def evaluate(self) -> sp.Expr:
        p = self.momentum
        p_xyz = ThreeMomentum(p)
        return Energy(p) ** 2 - EuclideanNorm(p_xyz) ** 2

In [None]:
decay = to_three_body_decay(reaction.transitions, min_ls=False)
builder = DalitzPlotDecompositionBuilder(decay, min_ls=False)
for chain in builder.decay.chains:
    builder.dynamics_choices.register_builder(chain, formulate_breit_wigner_with_ff)
model = builder.formulate(reference_subsystem=2, cleanup_summations=True)
model = prepare_for_phsp(model)
model.intensity.cleanup()

### Generate phase space sample

In [None]:
dpd_transformer = SympyDataTransformer.from_sympy(model.variables, backend="jax")

In [None]:
phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[0].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
rng = TFUniformRealNumberGenerator(seed=0)
phsp = phsp_generator.generate(100_000, rng)
phsp = dpd_transformer(phsp)
phsp = {k: phsp[k] for k in sorted(phsp)}
phsp = {k: v if jnp.iscomplex(v).any() else v.real for k, v in phsp.items()}

In [None]:
phsp

## Compute acceptance matrix