# J/ψ → K⁰Σ⁺p̅

```{autolink-concat}
```

In [None]:
# pyright: reportPrivateUsage=false
# pyright: reportUnusedExpression=false
from __future__ import annotations

import itertools
import logging
import os
from typing import Iterable

import jax.numpy as jnp
import matplotlib.pyplot as plt
import qrules
import sympy as sp
from ampform.dynamics import EnergyDependentWidth, formulate_form_factor
from ampform.kinematics.phasespace import compute_third_mandelstam
from ampform.sympy import (
    UnevaluatedExpression,
    create_expression,
    implement_doit_method,
    make_commutative,
)
from IPython.display import Latex, Markdown
from tensorwaves.data.transform import SympyDataTransformer
from tensorwaves.function.sympy import create_function

from ampform_dpd import (
    DalitzPlotDecompositionBuilder,
    _get_particle,
    simplify_latex_rendering,
)
from ampform_dpd.decay import (
    IsobarNode,
    Particle,
    ThreeBodyDecay,
    ThreeBodyDecayChain,
)
from ampform_dpd.io import as_markdown_table, aslatex, perform_cached_doit
from ampform_dpd.spin import filter_parity_violating_ls, generate_ls_couplings

simplify_latex_rendering()
logging.getLogger("absl").setLevel(logging.ERROR)  # mute JAX
NO_TQDM = "EXECUTE_NB" in os.environ
if NO_TQDM:
    logging.getLogger("ampform_dpd.io").setLevel(logging.ERROR)

## Decay definition

We follow [this example](https://qrules.readthedocs.io/en/0.9.7/usage.html#investigate-intermediate-resonances), which was generated with QRules, and leave out the $K$-resonances and the resonances that lie far outside of phase space:

![](https://qrules.readthedocs.io/en/0.9.7/_images/usage_9_0.svg)

:::{warning}
In the above figure, the final states are labeled `0`, `1`, `2`, but in the DPD formalism, the final states are labeled `1`, `2`, `3`.
:::

In [None]:
PDG = qrules.load_pdg()
PARTICLE_DB = {
    p.name: Particle(
        name=p.name,
        latex=p.latex,
        spin=p.spin,
        parity=int(p.parity),
        mass=p.mass,
        width=p.width,
    )
    for p in PDG
    if p.parity is not None
}
Jpsi = PARTICLE_DB["J/psi(1S)"]
K = PARTICLE_DB["K0"]
Σ = PARTICLE_DB["Sigma+"]
pbar = PARTICLE_DB["p~"]
PARTICLE_TO_ID = {Jpsi: 0, K: 1, Σ: 2, pbar: 3}
Markdown(as_markdown_table(list(PARTICLE_TO_ID)))

In [None]:
resonance_names = [
    "Sigma(1660)~-",
    "Sigma(1670)~-",
    "Sigma(1750)~-",
    "Sigma(1775)~-",
    "Sigma(1910)~-",
    "N(1675)+",
    "N(1700)+",
    "N(1710)+",
    "N(1720)+",
]
resonances = [PARTICLE_DB[name] for name in resonance_names]
Markdown(as_markdown_table(resonances))

In [None]:
def load_three_body_decay(
    resonance_names: Iterable[str],
    particle_definitions: dict[str, Particle],
    min_ls: bool = True,
) -> ThreeBodyDecay:
    resonances = [particle_definitions[name] for name in resonance_names]
    chains: list[ThreeBodyDecayChain] = []
    for res in resonances:
        chains.extend(_create_isobar(res, min_ls))
    return ThreeBodyDecay(
        states={state_id: particle for particle, state_id in PARTICLE_TO_ID.items()},
        chains=tuple(chains),
    )


def _create_isobar(resonance: Particle, min_ls: bool) -> list[ThreeBodyDecayChain]:
    if resonance.name.startswith("Sigma"):
        child1, child2, spectator = pbar, K, Σ
    elif resonance.name.startswith("N"):
        child1, child2, spectator = K, Σ, pbar
    elif resonance.name.startswith("K"):
        child1, child2, spectator = Σ, pbar, K
    else:
        raise NotImplementedError
    prod_ls_couplings = _generate_ls(
        Jpsi, resonance, spectator, conserve_parity=False
    )
    dec_ls_couplings = _generate_ls(resonance, child1, child2, conserve_parity=True)
    if min_ls:
        decay = IsobarNode(
            parent=Jpsi,
            child1=IsobarNode(
                parent=resonance,
                child1=child1,
                child2=child2,
                interaction=min(dec_ls_couplings),
            ),
            child2=spectator,
            interaction=min(prod_ls_couplings),
        )
        return [ThreeBodyDecayChain(decay)]
    chains = []
    for dec_ls, prod_ls in itertools.product(dec_ls_couplings, prod_ls_couplings):
        decay = IsobarNode(
            parent=Jpsi,
            child1=IsobarNode(
                parent=resonance,
                child1=child1,
                child2=child2,
                interaction=dec_ls,
            ),
            child2=spectator,
            interaction=prod_ls,
        )
        chains.append(ThreeBodyDecayChain(decay))
    return chains


def _generate_ls(
    parent: Particle, child1: Particle, child2: Particle, conserve_parity: bool
) -> list[tuple[int, sp.Rational]]:
    ls = generate_ls_couplings(parent.spin, child1.spin, child2.spin)
    if conserve_parity:
        return filter_parity_violating_ls(
            ls, parent.parity, child1.parity, child2.parity
        )
    return ls


DECAY = load_three_body_decay(
    resonance_names,
    particle_definitions=PARTICLE_DB,
    min_ls=True,
)
Latex(aslatex(DECAY, with_jp=True))

## Lineshapes for dynamics

:::{note}
As opposed to [AmpForm](https://ampform.rtfd.io), AmpForm-DPD defines dynamics over the **entire decay chain**, not a single isobar node. The dynamics classes and the corresponding builders would have to be extended to implement other dynamics lineshapes.
:::

In the following, we define the **relativistic Breit-Wigner function** as:

In [None]:
@make_commutative
@implement_doit_method
class RelativisticBreitWigner(UnevaluatedExpression):
    def __new__(cls, s, mass0, gamma0, m1, m2, angular_momentum, meson_radius):
        return create_expression(
            cls, s, mass0, gamma0, m1, m2, angular_momentum, meson_radius
        )

    def evaluate(self):
        s, m0, w0, m1, m2, angular_momentum, meson_radius = self.args
        width = EnergyDependentWidth(
            s=s,
            mass0=m0,
            gamma0=w0,
            m_a=m1,
            m_b=m2,
            angular_momentum=angular_momentum,
            meson_radius=meson_radius,
            name=Rf"\Gamma_{{{sp.latex(angular_momentum)}}}",
        )
        return (m0 * w0) / (m0**2 - s - width * m0 * sp.I)

    def _latex(self, printer, *args) -> str:
        s, m0, w0, _, _, L, *_ = map(printer._print, self.args)
        return Rf"\mathcal{{R}}_{{{L}}}\left({s}, {m0}, {w0}\right)"


bw = RelativisticBreitWigner(*sp.symbols("s m0 Gamma0 m1 m2 L R"))
Latex(aslatex({bw: bw.doit(deep=False)}))

with $\Gamma_0(s)$ a {class}`~ampform.dynamics.EnergyDependentWidth`, and we define the **form factor** as:

In [None]:
@make_commutative
@implement_doit_method
class FormFactor(UnevaluatedExpression):
    def __new__(cls, s, m1, m2, angular_momentum, meson_radius):
        return create_expression(cls, s, m1, m2, angular_momentum, meson_radius)

    def evaluate(self):
        s, m1, m2, angular_momentum, meson_radius = self.args
        return formulate_form_factor(
            s=s,
            m_a=m1,
            m_b=m2,
            angular_momentum=angular_momentum,
            meson_radius=meson_radius,
        )

    def _latex(self, printer, *args) -> str:
        s, m1, m2, L, *_ = map(printer._print, self.args)
        return Rf"\mathcal{{F}}_{{{L}}}\left({s}, {m1}, {m1}\right)"


ff = FormFactor(*sp.symbols("s m1 m2 L R"))
Latex(aslatex({ff: ff.doit(deep=False)}))

Here, $B_L^2$ is a {class}`~ampform.dynamics.BlattWeisskopfSquared`.

In [None]:
def formulate_breit_wigner_with_ff(
    decay_chain: ThreeBodyDecayChain,
) -> tuple[sp.Expr, dict[sp.Symbol, float]]:
    production_node = decay_chain.decay
    assert isinstance(
        production_node.child1, IsobarNode
    ), "Not a 3-body isobar decay"
    decay_node = production_node.child1

    s = _get_mandelstam_s(decay_chain)
    parameter_defaults = {}
    production_ff, new_pars = _create_form_factor(s, production_node)
    parameter_defaults.update(new_pars)
    decay_ff, new_pars = _create_form_factor(s, decay_node)
    parameter_defaults.update(new_pars)
    breit_wigner, new_pars = _create_breit_wigner(s, decay_node)
    parameter_defaults.update(new_pars)
    return (
        production_ff * decay_ff * breit_wigner,
        parameter_defaults,
    )


def _create_form_factor(
    s: sp.Symbol, isobar: IsobarNode
) -> tuple[sp.Expr, dict[sp.Symbol, float]]:
    assert isobar.interaction is not None, "Need LS-couplings"
    inv_mass = _create_mass_symbol(isobar.parent)
    outgoing_state_mass1 = _create_mass_symbol(isobar.child1)
    outgoing_state_mass2 = _create_mass_symbol(isobar.child2)
    meson_radius = _create_meson_radius_symbol(isobar.parent)
    form_factor = FormFactor(
        s=inv_mass**2,
        m1=outgoing_state_mass1,
        m2=outgoing_state_mass2,
        angular_momentum=isobar.interaction.L,
        meson_radius=meson_radius,
    )
    parameter_defaults = {
        meson_radius: 1,
    }
    return form_factor, parameter_defaults


def _create_breit_wigner(
    s: sp.Symbol, isobar: IsobarNode
) -> tuple[sp.Expr, dict[sp.Symbol, float]]:
    assert isobar.interaction is not None, "Need LS-couplings"
    outgoing_state_mass1 = _create_mass_symbol(isobar.child1)
    outgoing_state_mass2 = _create_mass_symbol(isobar.child2)
    angular_momentum = isobar.interaction.L
    res_mass = _create_mass_symbol(isobar.parent)
    res_width = sp.Symbol(Rf"\Gamma_{{{isobar.parent.latex}}}", nonnegative=True)
    meson_radius = _create_meson_radius_symbol(isobar.parent)

    breit_wigner_expr = RelativisticBreitWigner(
        s=s,
        mass0=res_mass,
        gamma0=res_width,
        m1=outgoing_state_mass1,
        m2=outgoing_state_mass2,
        angular_momentum=angular_momentum,
        meson_radius=meson_radius,
    )
    parameter_defaults = {
        res_mass: isobar.parent.mass,
        res_width: isobar.parent.width,
        meson_radius: 1,
    }
    return breit_wigner_expr, parameter_defaults


def _create_meson_radius_symbol(isobar: IsobarNode) -> sp.Symbol:
    if _get_particle(isobar) is Jpsi:
        return sp.Symbol(R"R_{J/\psi}")
    return sp.Symbol(R"R_\mathrm{res}")


def _create_mass_symbol(particle: IsobarNode | Particle) -> sp.Symbol:
    particle = _get_particle(particle)
    state_id = PARTICLE_TO_ID.get(particle)
    if state_id is not None:
        return sp.Symbol(f"m{state_id}", nonnegative=True)
    return sp.Symbol(f"m_{{{particle.latex}}}", nonnegative=True)


def _get_mandelstam_s(decay: ThreeBodyDecayChain) -> sp.Symbol:
    s1, s2, s3 = sp.symbols("sigma1:4", nonnegative=True)
    m1, m2, m3 = map(_create_mass_symbol, [K, Σ, pbar])
    decay_masses = {_create_mass_symbol(p) for p in decay.decay_products}
    if decay_masses == {m2, m3}:
        return s1
    if decay_masses == {m1, m3}:
        return s2
    if decay_masses == {m1, m2}:
        return s3
    raise NotImplementedError(
        f"Cannot find Mandelstam variable for {''.join(decay_masses)}"
    )

## Model formulation

The total, aligned intensity expression looks as follows:

In [None]:
model_builder = DalitzPlotDecompositionBuilder(DECAY, min_ls=True)
for chain in model_builder.decay.chains:
    model_builder.dynamics_choices.register_builder(
        chain, formulate_breit_wigner_with_ff
    )
model = model_builder.formulate(reference_subsystem=1)
model.intensity

where the angles can be computed from initial and final state masses $m_0$, $m_1$, $m_2$, and $m_3$:

In [None]:
Latex(aslatex(model.variables))

In [None]:
masses = {
    _create_mass_symbol(Jpsi): Jpsi.mass,
    _create_mass_symbol(K): K.mass,
    _create_mass_symbol(Σ): Σ.mass,
    _create_mass_symbol(pbar): pbar.mass,
}
model.parameter_defaults.update(masses)
Latex(aslatex(masses))

Each **unaligned** amplitude is defined as follows:

In [None]:
Latex(aslatex(model.amplitudes))

## Preparing for input data

The {meth}`~sympy.core.basic.Basic.doit` operation can be cached to disk with {func}`.perform_cached_doit`. We do this twice, once for the unfolding of the {attr}`~.AmplitudeModel.intensity` expression and second for the substitution and unfolding of the {attr}`~.AmplitudeModel.amplitudes`. Note that we could also have unfolded the intensity and substituted the amplitudes with {attr}`~.AmplitudeModel.full_expression`, but then the unfolded {attr}`~.AmplitudeModel.intensity` expression is not cached.

In [None]:
unfolded_intensity_expr = perform_cached_doit(model.intensity)
full_intensity_expr = perform_cached_doit(
    unfolded_intensity_expr.xreplace(model.amplitudes)
)

We set each helicity coupling to $1$, so that each each parameter {class}`~sympy.core.symbol.Symbol` in the expression has a definition:

In [None]:
couplings = {
    s: 1
    for s in full_intensity_expr.free_symbols
    if "production" in str(s) or "decay" in str(s)
}
model.parameter_defaults.update(couplings)

With this, the remaining {class}`~sympy.core.symbol.Symbol`s in the full expression are kinematic variables.

In [None]:
sp.Array(full_intensity_expr.free_symbols - set(model.parameter_defaults))

The $\theta$ and $\zeta$ angles are defined by the {attr}`~.AmplitudeModel.variables` attribute (they are shown under {ref}`jpsi2ksp:Model formulation`). Those definitions allow us to create a converter that computes kinematic variables from masses and Mandelstam variables:

In [None]:
masses_to_angles = SympyDataTransformer.from_sympy(model.variables, backend="jax")
masses_to_angles.functions

## Dalitz plot

The data input for this data transformer can be several things. One can compute them from a (generated) data sample of four-momenta. Or one can compute them for a Dalitz plane. We do the latter in this section.

First, the data transformer defined above expects values for the masses. We have already defined these values above, but we need to convert them from {mod}`sympy` objects to numerical data:

In [None]:
dalitz_data = {str(s): float(v) for s, v in masses.items()}

Next, we define a grid of data points over Mandelstam (Dalitz) variables $\sigma_2=m_{13}, \sigma_3=m_{12}$:

In [None]:
resolution = 500
X, Y = jnp.meshgrid(
    jnp.linspace(1.66**2, 2.18**2, num=resolution),
    jnp.linspace(1.4**2, 1.93**2, num=resolution),
)
dalitz_data["sigma3"] = X
dalitz_data["sigma2"] = Y

The remaining Mandelstam variable can be expressed in terms of the others as follows:

In [None]:
s1, s2, s3 = sp.symbols("sigma1:4", nonnegative=True)
m0, m1, m2, m3 = sorted(masses, key=str)
s1_expr = compute_third_mandelstam(s3, s2, m0, m1, m2, m3)
Latex(aslatex({s1: s1_expr}))

That completes the data sample over which we want to evaluate the intensity model defined above:

In [None]:
sigma1_func = create_function(s1_expr, backend="jax")
dalitz_data["sigma1"] = sigma1_func(dalitz_data)
dalitz_data

We can now extend the sample with angle definitions so that we have a data sample over which the intensity can be evaluated.

In [None]:
angle_data = masses_to_angles(dalitz_data)
dalitz_data.update(angle_data)

In [None]:
for k, v in dalitz_data.items():
    assert not jnp.all(jnp.isnan(v)), f"All values for {k} are NaN"

In [None]:
intensity_func = create_function(
    full_intensity_expr.xreplace(model.parameter_defaults),
    backend="jax",
)

In [None]:
%config InlineBackend.figure_formats = ['png']

In [None]:
plt.rc("font", size=18)
intensities = intensity_func(dalitz_data)
normalized_intensities = intensities / jnp.nansum(intensities)

fig, ax = plt.subplots(figsize=(14, 10))
mesh = ax.pcolormesh(X, Y, normalized_intensities)
ax.set_aspect("equal")
c_bar = plt.colorbar(mesh, ax=ax, pad=0.01)
c_bar.ax.set_ylabel("Normalized intensity (a.u.)")
ax.set_xlabel(R"$\sigma_3 = M^2\left(K^0\Sigma^+\right)$")
ax.set_ylabel(R"$\sigma_2 = M^2\left(K^0\bar{p}\right)$")
plt.show()

In [None]:
%config InlineBackend.figure_formats = ['svg']

In [None]:
plt.rc("font", size=16)
fig, axes = plt.subplots(figsize=(18, 6), ncols=2, sharey=True)
fig.subplots_adjust(wspace=0.02)
ax1, ax2 = axes
ax1.fill_between(jnp.sqrt(X[0]), jnp.nansum(normalized_intensities, axis=0))
ax2.fill_between(jnp.sqrt(Y[:, 0]), jnp.nansum(normalized_intensities, axis=1))
for ax in axes:
    _, y_max = ax.get_ylim()
    ax.set_ylim(0, y_max)
    ax.autoscale(enable=False, axis="x")
ax1.set_ylabel("Normalized intensity (a.u.)")
ax1.set_xlabel(R"$M\left(K^0\Sigma^+\right)$")
ax2.set_xlabel(R"$M\left(K^0\bar{p}\right)$")
i1, i2 = 0, 0
for chain in model.decay.chains:
    resonance = chain.resonance
    decay_product = set(chain.decay_products)
    if decay_product == {K, Σ}:
        ax = ax1
        i1 += 1
        i = i1
    elif decay_product == {K, pbar}:
        ax = ax2
        i2 += 1
        i = i2
    else:
        continue
    ax.axvline(resonance.mass, label=f"${resonance.latex}$", c=f"C{i}", ls="dashed")
for ax in axes:
    ax.legend(fontsize=12)
plt.show()