# J/ψ → K⁰ Σ⁺ p̅

```{autolink-concat}
```

In [None]:
from __future__ import annotations

import logging
import os
import warnings
from typing import TYPE_CHECKING

import graphviz
import jax.numpy as jnp
import matplotlib.pyplot as plt
import qrules
import sympy as sp
from ampform.dynamics import BlattWeisskopfSquared, EnergyDependentWidth
from ampform.sympy import perform_cached_doit
from IPython.display import Latex, Markdown
from tensorwaves.data.transform import SympyDataTransformer
from tqdm.auto import tqdm

from ampform_dpd import DalitzPlotDecompositionBuilder, get_particle
from ampform_dpd.adapter.qrules import normalize_state_ids, to_three_body_decay
from ampform_dpd.decay import IsobarNode, Particle, ThreeBodyDecayChain
from ampform_dpd.dynamics import FormFactor, RelativisticBreitWigner
from ampform_dpd.io import (
    as_markdown_table,
    aslatex,
    perform_cached_lambdify,
    simplify_latex_rendering,
)

simplify_latex_rendering()
logging.getLogger("absl").setLevel(logging.ERROR)  # mute JAX
warnings.simplefilter("ignore")

NO_TQDM = "EXECUTE_NB" in os.environ
if NO_TQDM:
    logging.getLogger("ampform_dpd.io").setLevel(logging.ERROR)
if TYPE_CHECKING:
    from tensorwaves.interface import DataSample, ParametrizedFunction

## Decay definition

We follow [this example](https://qrules.readthedocs.io/en/0.10.1/usage.html#investigate-intermediate-resonances), which was generated with QRules, and leave out the $K$-resonances and the resonances that lie far outside of phase space.

In [None]:
REACTION = qrules.generate_transitions(
    initial_state="J/psi(1S)",
    final_state=["K0", "Sigma+", "p~"],
    allowed_interaction_types="strong",
    formalism="canonical-helicity",
    mass_conservation_factor=0.05,
)
REACTION = normalize_state_ids(REACTION)
dot = qrules.io.asdot(REACTION, collapse_graphs=True)
graphviz.Source(dot)

In [None]:
DECAY = to_three_body_decay(REACTION.transitions, min_ls=True)
Markdown(as_markdown_table([DECAY.initial_state, *DECAY.final_state.values()]))

In [None]:
resonances = sorted(
    {t.resonance for t in DECAY.chains},
    key=lambda p: (p.name[0], p.mass),
)
resonance_names = [p.name for p in resonances]
Markdown(as_markdown_table(resonances))

In [None]:
Latex(aslatex(DECAY, with_jp=True))

## Lineshapes for dynamics

:::{note}
As opposed to [AmpForm](https://ampform.rtfd.io), AmpForm-DPD defines dynamics over the **entire decay chain**, not a single isobar node. The dynamics classes and the corresponding builders would have to be extended to implement other dynamics lineshapes.
:::

In [None]:
s, m0, w0, m1, m2, L, R, z = sp.symbols("s m0 Gamma0 m1 m2 L R z")
exprs = [
    RelativisticBreitWigner(s, m0, w0, m1, m2, L, R),
    EnergyDependentWidth(s, m0, w0, m1, m2, L, R),
    FormFactor(s, m1, m2, L, R),
    BlattWeisskopfSquared(z, L),
]
Latex(aslatex({e: e.doit(deep=False) for e in exprs}))

In [None]:
def formulate_breit_wigner_with_ff(
    decay_chain: ThreeBodyDecayChain,
) -> tuple[sp.Expr, dict[sp.Symbol, float]]:
    production_node = decay_chain.decay
    assert isinstance(production_node.child1, IsobarNode), "Not a 3-body isobar decay"
    decay_node = production_node.child1
    s = _get_mandelstam_s(decay_chain)
    parameter_defaults = {}
    production_ff, new_pars = _create_form_factor(s, production_node)
    parameter_defaults.update(new_pars)
    decay_ff, new_pars = _create_form_factor(s, decay_node)
    parameter_defaults.update(new_pars)
    breit_wigner, new_pars = _create_breit_wigner(s, decay_node)
    parameter_defaults.update(new_pars)
    return (
        production_ff * decay_ff * breit_wigner,
        parameter_defaults,
    )


def _create_form_factor(
    s: sp.Symbol, isobar: IsobarNode
) -> tuple[sp.Expr, dict[sp.Symbol, float]]:
    assert isobar.interaction is not None, "Need LS-couplings"
    if isobar.parent.name == "J/psi(1S)":
        inv_mass = sp.Symbol("m0", nonnegative=True)
    else:
        inv_mass = _get_mandelstam_s(isobar)
    outgoing_state_mass1 = _create_mass_symbol(isobar.child1)
    outgoing_state_mass2 = _create_mass_symbol(isobar.child2)
    meson_radius = _create_meson_radius_symbol(isobar.parent)
    form_factor = FormFactor(
        s=inv_mass**2,
        m1=outgoing_state_mass1,
        m2=outgoing_state_mass2,
        angular_momentum=isobar.interaction.L,
        meson_radius=meson_radius,
    )
    parameter_defaults = {
        meson_radius: 1,
        outgoing_state_mass1: get_particle(isobar.child1).mass,
        outgoing_state_mass2: get_particle(isobar.child2).mass,
    }
    if not inv_mass.name.startswith("s"):
        parameter_defaults[inv_mass] = get_particle(isobar).mass
    return form_factor, parameter_defaults


def _create_breit_wigner(
    s: sp.Symbol, isobar: IsobarNode
) -> tuple[sp.Expr, dict[sp.Symbol, float]]:
    assert isobar.interaction is not None, "Need LS-couplings"
    outgoing_state_mass1 = _create_mass_symbol(isobar.child1)
    outgoing_state_mass2 = _create_mass_symbol(isobar.child2)
    angular_momentum = isobar.interaction.L
    res_mass = _create_mass_symbol(isobar.parent)
    res_width = sp.Symbol(Rf"\Gamma_{{{isobar.parent.latex}}}", nonnegative=True)
    meson_radius = _create_meson_radius_symbol(isobar.parent)

    breit_wigner_expr = RelativisticBreitWigner(
        s=s,
        mass0=res_mass,
        gamma0=res_width,
        m1=outgoing_state_mass1,
        m2=outgoing_state_mass2,
        angular_momentum=angular_momentum,
        meson_radius=meson_radius,
    )
    parameter_defaults = {
        res_mass: isobar.parent.mass,
        res_width: isobar.parent.width,
        meson_radius: 1,
    }
    return breit_wigner_expr, parameter_defaults


def _create_meson_radius_symbol(isobar: IsobarNode) -> sp.Symbol:
    if get_particle(isobar).name == "J/psi(1S)":
        return sp.Symbol(R"R_{J/\psi}")
    return sp.Symbol(R"R_\mathrm{res}")


def _create_mass_symbol(particle: IsobarNode | Particle) -> sp.Symbol:
    particle = get_particle(particle)
    return sp.Symbol(f"m_{{{particle.latex}}}", nonnegative=True)


def _get_mandelstam_s(decay: ThreeBodyDecayChain | IsobarNode) -> sp.Symbol:
    s1, s2, s3 = sp.symbols("sigma1:4", nonnegative=True)
    decay_products = {p.name for p in _get_decay_products(decay)}
    if decay_products == {"Sigma+", "p~"}:
        return s1
    if decay_products == {"K0", "p~"}:
        return s2
    if decay_products == {"K0", "Sigma+"}:
        return s3
    msg = f"Cannot find Mandelstam variable for {', '.join(decay_products)}"
    raise NotImplementedError(msg)


def _get_decay_products(
    decay: ThreeBodyDecayChain | IsobarNode,
) -> tuple[Particle, Particle]:
    if isinstance(decay, ThreeBodyDecayChain):
        return decay.decay_products
    return decay.children

## Model formulation

The total, aligned intensity expression looks as follows:

In [None]:
model_builder = DalitzPlotDecompositionBuilder(DECAY, min_ls=False)
for chain in model_builder.decay.chains:
    model_builder.dynamics_choices.register_builder(
        chain, formulate_breit_wigner_with_ff
    )
model = model_builder.formulate(reference_subsystem=1)
model.intensity

where the angles can be computed from initial and final state masses $m_0$, $m_1$, $m_2$, and $m_3$:

In [None]:
Latex(aslatex(model.variables))

Each **unaligned** amplitude is defined as follows:

In [None]:
Latex(aslatex({k: v for k, v in model.amplitudes.items() if v}))

## Preparing for input data

The {meth}`~sympy.core.basic.Basic.doit` operation can be cached to disk with {func}`~ampform.sympy.perform_cached_doit`. We do this twice, once for the unfolding of the {attr}`~.AmplitudeModel.intensity` expression and second for the substitution and unfolding of the {attr}`~.AmplitudeModel.amplitudes`. Note that we could also have unfolded the intensity and substituted the amplitudes with {attr}`~.AmplitudeModel.full_expression`, but then the unfolded {attr}`~.AmplitudeModel.intensity` expression is not cached.

In [None]:
unfolded_intensity_expr = perform_cached_doit(model.intensity)
full_intensity_expr = perform_cached_doit(
    unfolded_intensity_expr.xreplace(model.amplitudes)
)

With this, the remaining {class}`~sympy.core.symbol.Symbol`s in the full expression are kinematic variables.[^1]

[^1]: Yes, there are still $\mathcal{H}^\mathrm{production}$ and $\mathcal{H}^\mathrm{decay}$, but these are the {attr}`~sympy.tensor.indexed.Indexed.base`s of the {class}`~sympy.tensor.indexed.Indexed` coupling symbols. They should **NOT** be substituted.

In [None]:
sp.Array(
    sorted(full_intensity_expr.free_symbols - set(model.parameter_defaults), key=str)
)

The $\theta$ and $\zeta$ angles are defined by the {attr}`~.AmplitudeModel.variables` attribute (they are shown under {ref}`jpsi2ksp:Model formulation`). Those definitions allow us to create a converter that computes kinematic variables from masses and Mandelstam variables:

In [None]:
masses_to_angles = SympyDataTransformer.from_sympy(model.variables, backend="jax")
masses_to_angles.functions

## Dalitz plot

The data input for this data transformer can be several things. One can compute them from a (generated) data sample of four-momenta. Or one can compute them for a Dalitz plane. We do the latter in this section.

First, the data transformer defined above expects values for the masses. We have already defined these values above, but we need to convert them from {mod}`sympy` objects to numerical data:

In [None]:
dalitz_data = {str(s): float(v) for s, v in model.masses.items()}

Next, we define a grid of data points over Mandelstam (Dalitz) variables $\sigma_2=m_{13}, \sigma_3=m_{12}$:

In [None]:
resolution = 500
X, Y = jnp.meshgrid(
    jnp.linspace(1.66**2, 2.18**2, num=resolution),
    jnp.linspace(1.4**2, 1.93**2, num=resolution),
)
dalitz_data["sigma3"] = X
dalitz_data["sigma2"] = Y

The remaining Mandelstam variable can be expressed in terms of the others as follows:

In [None]:
(s1, s1_expr), *_ = model.invariants.items()
Latex(aslatex({s1: s1_expr}))

That completes the data sample over which we want to evaluate the intensity model defined above:

In [None]:
sigma1_func = perform_cached_lambdify(s1_expr, backend="jax")
dalitz_data["sigma1"] = sigma1_func(dalitz_data)
dalitz_data

We can now extend the sample with angle definitions so that we have a data sample over which the intensity can be evaluated.

In [None]:
angle_data = masses_to_angles(dalitz_data)
dalitz_data.update(angle_data)

In [None]:
for k, v in dalitz_data.items():
    assert not jnp.all(jnp.isnan(v)), f"All values for {k} are NaN"

In [None]:
free_parameters = {
    k: v
    for k, v in model.parameter_defaults.items()
    if isinstance(k, sp.Indexed)
    if "production" in str(k) or "decay" in str(k)
}
fixed_parameters = {
    k: v for k, v in model.parameter_defaults.items() if k not in free_parameters
}
intensity_func = perform_cached_lambdify(
    full_intensity_expr.xreplace(fixed_parameters),
    parameters=free_parameters,
    backend="jax",
)

In [None]:
%config InlineBackend.figure_formats = ['png']

In [None]:
plt.rc("font", size=18)
intensities = intensity_func(dalitz_data)
I_tot = jnp.nansum(intensities)
normalized_intensities = intensities / I_tot
assert not jnp.all(jnp.isnan(normalized_intensities)), "All intensities are NaN"

fig, ax = plt.subplots(figsize=(14, 10))
mesh = ax.pcolormesh(X, Y, normalized_intensities)
ax.set_aspect("equal")
c_bar = plt.colorbar(mesh, ax=ax, pad=0.01)
c_bar.ax.set_ylabel("Normalized intensity (a.u.)")
ax.set_xlabel(R"$\sigma_3 = M^2\left(K^0\Sigma^+\right)$")
ax.set_ylabel(R"$\sigma_2 = M^2\left(K^0\bar{p}\right)$")
plt.show()

In [None]:
%config InlineBackend.figure_formats = ['svg']

In [None]:
def compute_sub_intensity(
    func: ParametrizedFunction, phsp: DataSample, resonance_latex: str
) -> jnp.ndarray:
    original_parameters = dict(func.parameters)
    zero_parameters = {
        k: 0
        for k, v in func.parameters.items()
        if R"\mathcal{H}" in k
        if resonance_latex not in k
    }
    func.update_parameters(zero_parameters)
    intensities = func(phsp)
    func.update_parameters(original_parameters)
    return intensities


plt.rc("font", size=16)
fig, axes = plt.subplots(figsize=(18, 6), ncols=2, sharey=True)
fig.subplots_adjust(wspace=0.02)
ax1, ax2 = axes
x = jnp.sqrt(X[0])
y = jnp.sqrt(Y[:, 0])
ax1.fill_between(x, jnp.nansum(normalized_intensities, axis=0), alpha=0.5)
ax2.fill_between(y, jnp.nansum(normalized_intensities, axis=1), alpha=0.5)
for ax in axes:
    _, y_max = ax.get_ylim()
    ax.set_ylim(0, y_max)
    ax.autoscale(enable=False, axis="x")
ax1.set_ylabel("Normalized intensity (a.u.)")
ax1.set_xlabel(R"$M\left(K^0\Sigma^+\right)$")
ax2.set_xlabel(R"$M\left(K^0\bar{p}\right)$")
i1, i2 = 0, 0
for chain in tqdm(model.decay.chains, disable=NO_TQDM):
    resonance = chain.resonance
    decay_product = {p.name for p in chain.decay_products}
    if decay_product == {"K0", "Sigma+"}:
        ax = ax1
        i1 += 1
        i = i1
        projection_axis = 0
        x_data = x
    elif decay_product == {"K0", "p~"}:
        ax = ax2
        i2 += 1
        i = i2
        projection_axis = 1
        x_data = y
    else:
        continue
    sub_intensities = compute_sub_intensity(
        intensity_func, dalitz_data, resonance.latex
    )
    ax.plot(
        x_data, jnp.nansum(sub_intensities / I_tot, axis=projection_axis), c=f"C{i}"
    )
    ax.axvline(resonance.mass, label=f"${resonance.latex}$", c=f"C{i}", ls="dashed")
for ax in axes:
    ax.legend(fontsize=12)
plt.show()