# D⁰ → K⁰K⁺K⁻

```{autolink-concat}
```

The decay $D^0 \to K^0K^+K^-$ has a spinless initial and final state, which means that there is no need to align spin with Dalitz-plot decomposition. This notebook shows that the model formulated by {mod}`ampform` is the same as that formulated by {doc}`AmpForm-DPD</index>`. To simplify this comparison, we do not define any dynamics.

In [None]:
from __future__ import annotations

import itertools
import logging
import os
from typing import Iterable

import ampform
import graphviz
import jax.numpy as jnp
import matplotlib.pyplot as plt
import qrules
import sympy as sp
from ampform.helicity import HelicityModel
from ampform.kinematics import FourMomentumSymbol, InvariantMass
from ampform.sympy import perform_cached_doit
from IPython.display import SVG, Latex, Markdown, display
from ipywidgets import (
    FloatSlider,
    GridBox,
    HBox,
    Layout,
    SelectMultiple,
    Tab,
    interactive_output,
)
from qrules.transition import ReactionInfo
from tensorwaves.data.phasespace import TFPhaseSpaceGenerator
from tensorwaves.data.rng import TFUniformRealNumberGenerator
from tensorwaves.data.transform import SympyDataTransformer
from tensorwaves.interface import DataSample, ParameterValue, ParametrizedFunction

from ampform_dpd import DalitzPlotDecompositionBuilder, simplify_latex_rendering
from ampform_dpd.decay import (
    IsobarNode,
    Particle,
    ThreeBodyDecay,
    ThreeBodyDecayChain,
)
from ampform_dpd.io import as_markdown_table, aslatex, perform_cached_lambdify
from ampform_dpd.spin import filter_parity_violating_ls, generate_ls_couplings

simplify_latex_rendering()
logging.getLogger("jax").setLevel(logging.ERROR)  # mute JAX
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # mute TF
NO_TQDM = "EXECUTE_NB" in os.environ
if NO_TQDM:
    logging.getLogger("ampform_dpd.io").setLevel(logging.ERROR)

## Decay definition

In [None]:
PDG = qrules.load_pdg()
PARTICLE_DB = {
    p.name: Particle(
        name=p.name,
        latex=p.latex,
        spin=p.spin,
        parity=int(p.parity),
        mass=p.mass,
        width=p.width,
    )
    for p in PDG
    if p.parity is not None
}
D0 = PARTICLE_DB["D0"]
K0 = PARTICLE_DB["K0"]
Km = PARTICLE_DB["K-"]
Kp = PARTICLE_DB["K+"]
PARTICLE_TO_ID = {D0: 0, K0: 1, Km: 2, Kp: 3}
_, *FINAL_STATE = PARTICLE_TO_ID
Markdown(as_markdown_table(list(PARTICLE_TO_ID)))

In [None]:
reaction = qrules.generate_transitions(
    initial_state="D0",
    final_state=[p.name for p in FINAL_STATE],
    allowed_intermediate_particles=["a(0)", "f(0)", "K*"],
    mass_conservation_factor=0.2,
    formalism="helicity",
)
dot = qrules.io.asdot(reaction, collapse_graphs=True)
graphviz.Source(dot)

In [None]:
resonance_names = sorted(reaction.get_intermediate_particles().names)
resonances = [PARTICLE_DB[name] for name in resonance_names]
Markdown(as_markdown_table(resonances))

In [None]:
def load_three_body_decay(
    resonance_names: Iterable[str],
    particle_definitions: dict[str, Particle],
    min_ls: bool = True,
) -> ThreeBodyDecay:
    resonances = [particle_definitions[name] for name in resonance_names]
    chains: list[ThreeBodyDecayChain] = []
    for res in resonances:
        chains.extend(_create_isobar(res, min_ls))
    return ThreeBodyDecay(
        states={state_id: particle for particle, state_id in PARTICLE_TO_ID.items()},
        chains=tuple(chains),
    )


def _create_isobar(resonance: Particle, min_ls: bool) -> list[ThreeBodyDecayChain]:
    if resonance.name.endswith("-"):
        child1, child2, spectator = K0, Km, Kp
    elif resonance.name.endswith("+"):
        child1, child2, spectator = Kp, K0, Km
    else:
        child1, child2, spectator = Kp, Km, K0
    prod_ls_couplings = _generate_ls(D0, resonance, spectator, conserve_parity=False)
    dec_ls_couplings = _generate_ls(resonance, child1, child2, conserve_parity=True)
    if min_ls:
        decay = IsobarNode(
            parent=D0,
            child1=IsobarNode(
                parent=resonance,
                child1=child1,
                child2=child2,
                interaction=min(dec_ls_couplings),
            ),
            child2=spectator,
            interaction=min(prod_ls_couplings),
        )
        return [ThreeBodyDecayChain(decay)]
    chains = []
    for dec_ls, prod_ls in itertools.product(dec_ls_couplings, prod_ls_couplings):
        decay = IsobarNode(
            parent=D0,
            child1=IsobarNode(
                parent=resonance,
                child1=child1,
                child2=child2,
                interaction=dec_ls,
            ),
            child2=spectator,
            interaction=prod_ls,
        )
        chains.append(ThreeBodyDecayChain(decay))
    return chains


def _generate_ls(
    parent: Particle, child1: Particle, child2: Particle, conserve_parity: bool
) -> list[tuple[int, sp.Rational]]:
    ls = generate_ls_couplings(parent.spin, child1.spin, child2.spin)
    if conserve_parity:
        return filter_parity_violating_ls(
            ls, parent.parity, child1.parity, child2.parity
        )
    return ls


DECAY = load_three_body_decay(
    resonance_names,
    particle_definitions=PARTICLE_DB,
    min_ls=True,
)
Latex(aslatex(DECAY, with_jp=True))

## Model formulation

### DPD model

Note that, as opposed to {ref}`Λc⁺ → pπ⁺K⁻<lc2pkpi:Model formulation>` and {ref}`J/ψ → K⁰Σ⁺p̅<jpsi2ksp:Model formulation>`, there are no Wigner-$d$ functions, because the final state is spinless.

In [None]:
model_builder = DalitzPlotDecompositionBuilder(DECAY, min_ls=True)
dpd_model = model_builder.formulate(reference_subsystem=1)
dpd_model.intensity

In [None]:
Latex(aslatex(dpd_model.amplitudes))

There is an isobar Wigner-$d$ function, which takes the following helicity angles as argument:

In [None]:
Latex(aslatex(dpd_model.variables))

In [None]:
masses = {
    sp.Symbol("m0", nonnegative=True): round(D0.mass, 7),
    sp.Symbol("m1", nonnegative=True): round(K0.mass, 7),
    sp.Symbol("m2", nonnegative=True): round(Km.mass, 7),
    sp.Symbol("m3", nonnegative=True): round(Kp.mass, 7),
}
dpd_model.parameter_defaults.update(masses)
Latex(aslatex(masses))

### AmpForm model

AmpForm does not formulate alignment Wigner-$D$ functions. For the case of this spinless final state, this means the intensity is the same as that of the [](#dpd-model).

In [None]:
model_builder = ampform.get_builder(reaction)
model_builder.use_helicity_couplings = False
ampform_model = model_builder.formulate()
ampform_model.intensity

In [None]:
Latex(aslatex(ampform_model.amplitudes))

In [None]:
Latex(aslatex(ampform_model.kinematic_variables))

## Phase space sample

In [None]:
p1, p2, p3 = tuple(FourMomentumSymbol(f"p{i}", shape=[]) for i in (0, 1, 2))
s1, s2, s3 = sp.symbols("sigma1:4", nonnegative=True)
mass_definitions = {
    s1: InvariantMass(p2 + p3) ** 2,
    s2: InvariantMass(p1 + p3) ** 2,
    s3: InvariantMass(p1 + p2) ** 2,
    sp.Symbol("m_01", nonnegative=True): InvariantMass(p1 + p2),
    sp.Symbol("m_02", nonnegative=True): InvariantMass(p1 + p3),
    sp.Symbol("m_12", nonnegative=True): InvariantMass(p2 + p3),
}
dpd_variables = {
    sp.Symbol(f"m{i}", nonnegative=True): sp.Float(p.mass)
    for i, p in enumerate(PARTICLE_TO_ID)
}
for symbol, expr in dpd_model.variables.items():
    expr = expr.doit().xreplace(mass_definitions).xreplace(dpd_variables)
    dpd_variables[symbol] = expr
dpd_transformer = SympyDataTransformer.from_sympy(dpd_variables, backend="jax")

ampform_transformer = SympyDataTransformer.from_sympy(
    ampform_model.kinematic_variables, backend="jax"
)

In [None]:
def generate_phase_space(
    reaction: ReactionInfo, size: int
) -> dict[str, jnp.ndarray]:
    rng = TFUniformRealNumberGenerator(seed=0)
    phsp_generator = TFPhaseSpaceGenerator(
        initial_state_mass=reaction.initial_state[-1].mass,
        final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
    )
    return phsp_generator.generate(size, rng)


phsp = generate_phase_space(ampform_model.reaction_info, size=100_000)
ampform_phsp = ampform_transformer(phsp)
dpd_phsp = dpd_transformer(phsp)

## Convert to numerical functions

In [None]:
def unfold_intensity(model: HelicityModel) -> sp.Expr:
    unfolded_intensity = perform_cached_doit(model.intensity)
    unfolded_amplitudes = {
        symbol: perform_cached_doit(expr)
        for symbol, expr in model.amplitudes.items()
    }
    return unfolded_intensity.xreplace(unfolded_amplitudes)


ampform_intensity_expr = unfold_intensity(ampform_model)
dpd_intensity_expr = unfold_intensity(dpd_model)

In [None]:
ampform_func = perform_cached_lambdify(
    ampform_intensity_expr,
    parameters=ampform_model.parameter_defaults,
)
dpd_func = perform_cached_lambdify(
    dpd_intensity_expr,
    parameters=dpd_model.parameter_defaults,
)

## Visualization

In [None]:
def compute_sub_intensities(
    func: ParametrizedFunction, phsp: DataSample, resonance_name: str
) -> jnp.ndarray:
    original_parameters = dict(func.parameters)
    _set_couplings_to_zero(func, resonance_name)
    intensity_array = func(phsp)
    func.update_parameters(original_parameters)
    return intensity_array


def _set_couplings_to_zero(
    func: ParametrizedFunction, resonance_names: list[str]
) -> None:
    couplings_to_zero = {
        key: value if any(r in key for r in resonance_names) else 0
        for key, value in _get_couplings(func).items()
    }
    func.update_parameters(couplings_to_zero)


def _get_couplings(func: ParametrizedFunction) -> dict[str, ParameterValue]:
    return {
        key: value
        for key, value in func.parameters.items()
        if key.startswith("C") or "production" in key
    }

In [None]:
def to_unicode(particle: Particle) -> str:
    unicode = particle.name
    unicode = unicode.replace("Sigma", "Σ")
    unicode = unicode.replace("~", "")
    unicode = unicode.replace("Σ", "~Σ")
    unicode = unicode.replace("+", "⁺")
    unicode = unicode.replace("-", "⁻")
    unicode = unicode.replace("(0)", "₀")
    unicode = unicode.replace(")0", ")⁰")
    return unicode


resonances = ampform_model.reaction_info.get_intermediate_particles()
resonances = sorted(resonances, key=lambda p: (p.charge, p.mass, p.name))
resonance_selector = SelectMultiple(
    description="Resonance",
    options={to_unicode(p): p.latex for p in resonances},
    value=[resonances[0].latex, resonances[1].latex],
    layout=Layout(
        height=f"{14 * (len(resonances)+1)}pt",
        width="auto",
    ),
)

all_parameters = {k: v for k, v in ampform_model.parameter_defaults.items()}
all_parameters.update({k: v for k, v in dpd_model.parameter_defaults.items()})
sliders = {
    str(symbol): FloatSlider(
        description=Rf"\({sp.latex(symbol)}\)",
        value=complex(value).real,
        min=0,
        max=2.0 * complex(value).real,
        continuous_update=False,
        step=0.01,
        layout=Layout(
            width="16cm",
        ),
        style={
            "description_width": "8cm",
        },
    )
    for symbol, value in all_parameters.items()
}

ui = HBox(
    [
        resonance_selector,
        Tab(
            children=[
                GridBox([sliders[key] for key in sorted(sliders) if p.latex in key])
                for p in resonances
            ],
            titles=[to_unicode(p) for p in resonances],
        ),
    ]
)

In [None]:
%matplotlib widget
plt.rc("font", size=12)
fig, axes = plt.subplots(ncols=3, figsize=(16, 5), sharey=True)
ax1, ax2, ax3 = axes
final_state = ampform_model.reaction_info.final_state
ax1.set_ylabel("Intensity (a.u.)")
ax1.set_xlabel(Rf"$m({FINAL_STATE[1].latex}, {FINAL_STATE[2].latex})$")
ax2.set_xlabel(Rf"$m({FINAL_STATE[0].latex}, {FINAL_STATE[2].latex})$")
ax3.set_xlabel(Rf"$m({FINAL_STATE[0].latex}, {FINAL_STATE[1].latex})$")
fig.tight_layout()

lines = None


def plot_contributions(**kwargs) -> None:
    kwargs.pop("resonance_selector")
    resonance_names = list(resonance_selector.value)
    dpd_pars = {k: v for k, v in kwargs.items() if k in dpd_func.parameters}
    ampform_pars = {k: v for k, v in kwargs.items() if k in ampform_func.parameters}
    ampform_func.update_parameters(ampform_pars)
    dpd_func.update_parameters(dpd_pars)
    ampform_intensities = compute_sub_intensities(
        ampform_func, ampform_phsp, resonance_names
    )
    dpd_intensities = compute_sub_intensities(dpd_func, dpd_phsp, resonance_names)

    amp_values_s1, edges_s1 = jnp.histogram(
        ampform_phsp["m_12"].real,
        bins=50,
        weights=ampform_intensities,
    )
    dpd_values_s1, _ = jnp.histogram(
        ampform_phsp["m_12"].real,
        bins=edges_s1,
        weights=dpd_intensities,
    )

    amp_values_s2, edges_s2 = jnp.histogram(
        ampform_phsp["m_02"].real,
        bins=50,
        weights=ampform_intensities,
    )
    dpd_values_s2, _ = jnp.histogram(
        ampform_phsp["m_02"].real,
        bins=edges_s2,
        weights=dpd_intensities,
    )

    amp_values_s3, edges_s3 = jnp.histogram(
        ampform_phsp["m_01"].real,
        bins=50,
        weights=ampform_intensities,
    )
    dpd_values_s3, _ = jnp.histogram(
        ampform_phsp["m_01"].real,
        bins=edges_s3,
        weights=dpd_intensities,
    )

    global lines
    amp_kwargs = dict(color="r", label="ampform", linestyle="solid")
    dpd_kwargs = dict(color="blue", label="dpd", linestyle="dotted")
    if lines is None:
        x1 = (edges_s1[:-1] + edges_s1[1:]) / 2
        x2 = (edges_s2[:-1] + edges_s2[1:]) / 2
        x3 = (edges_s3[:-1] + edges_s3[1:]) / 2
        lines = [
            ax1.step(x1, amp_values_s1, **amp_kwargs)[0],
            ax1.step(x1, dpd_values_s1, **dpd_kwargs)[0],
            ax2.step(x2, amp_values_s2, **amp_kwargs)[0],
            ax2.step(x2, dpd_values_s2, **dpd_kwargs)[0],
            ax3.step(x3, amp_values_s3, **amp_kwargs)[0],
            ax3.step(x3, dpd_values_s3, **dpd_kwargs)[0],
        ]
        ax1.legend(loc="upper right")
    else:
        lines[0].set_ydata(amp_values_s1)
        lines[1].set_ydata(dpd_values_s1)
        lines[2].set_ydata(amp_values_s2)
        lines[3].set_ydata(dpd_values_s2)
        lines[4].set_ydata(amp_values_s3)
        lines[5].set_ydata(dpd_values_s3)
    y_max = max(
        jnp.nanmax(amp_values_s1),
        jnp.nanmax(dpd_values_s1),
        jnp.nanmax(amp_values_s2),
        jnp.nanmax(dpd_values_s2),
        jnp.nanmax(amp_values_s3),
        jnp.nanmax(dpd_values_s3),
    )
    ax1.set_ylim(0, 1.05 * y_max)
    ax2.set_ylim(0, 1.05 * y_max)
    ax3.set_ylim(0, 1.05 * y_max)

    fig.canvas.draw_idle()

    def get_symbol_values(
        expr: sp.Expr, parameters: dict[str, ParameterValue]
    ) -> dict[sp.Symbol, sp.Rational]:
        parameters = {
            key: value if any(r in key for r in resonance_names) else 0
            for key, value in parameters.items()
        }
        return {
            s: sp.Rational(parameters[s.name])
            for s in expr.free_symbols
            if s.name in parameters
        }

    ampform_symbols = get_symbol_values(ampform_intensity_expr, ampform_pars)
    dpd_symbols = get_symbol_values(dpd_intensity_expr, dpd_pars)
    src = Rf"""
    \begin{{eqnarray}}
      \text{{AmpForm:}} && {sp.latex(ampform_intensity_expr.xreplace(ampform_symbols))} \\
      \text{{DPD:}}     && {sp.latex(dpd_intensity_expr.xreplace(dpd_symbols))} \\
    \end{{eqnarray}}
    """
    display(Latex(src))


output = interactive_output(
    plot_contributions,
    controls={**sliders, "resonance_selector": resonance_selector},
)
display(output, ui)

In [None]:
if NO_TQDM:
    filename = "d0-to-kkk-comparison.svg"
    plt.savefig(filename)
    display(SVG(filename))