# LHCb cross-check

In [None]:
from __future__ import annotations

import json
import logging
import re
from itertools import product
from typing import Pattern

import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from ampform.helicity.naming import natural_sorting
from ampform.sympy import PoolSum
from IPython.display import Markdown, Math, display
from matplotlib.colors import LogNorm
from sympy.core.symbol import Str
from sympy.physics.quantum.spin import Rotation as Wigner
from tensorwaves.data.transform import SympyDataTransformer
from tensorwaves.function import ParametrizedBackendFunction
from tensorwaves.function.sympy import create_function, create_parametrized_function
from tqdm.notebook import tqdm

from polarization.decay import Particle, ThreeBodyDecay
from polarization.dynamics import (
    BuggBreitWigner,
    FlattéSWave,
    Källén,
    RelativisticBreitWigner,
)
from polarization.io import as_latex, as_markdown_table, display_latex
from polarization.lhcb import K, Λc, Σ, load_three_body_decays, p, π
from polarization.spin import create_spin_range


# hack for moving Indexed indices below superscript of the base
def _print_Indexed_latex(self, printer, *args):
    base = printer._print(self.base)
    indices = ", ".join(map(printer._print, self.indices))
    return f"{base}_{{{indices}}}"


sp.Indexed._latex = _print_Indexed_latex

## Amplitude model

### Resonances and $LS$-scheme

Particle definitions for $\Lambda_c^+$ and $p, \pi^+, K^-$:

In [None]:
Markdown(as_markdown_table([Λc, p, π, K, Σ]))

Resonance definitions as defined in {download}`data/isobars.json <../data/isobars.json>`:

In [None]:
decays = load_three_body_decays("../data/isobars.json")

In [None]:
def create_markdown_table_row(*items):
    items = map(lambda i: f"{i}", items)
    return "| " + " | ".join(items) + " |\n"


column_names = [
    "resonance",
    R"$j^P$",
    R"$m$ (MeV)",
    R"$\Gamma_0$ (MeV)",
    R"$l_R$",
    R"$l_{\Lambda_c}^\mathrm{min}$",
    "lineshape",
]
src = create_markdown_table_row(*column_names)
src += create_markdown_table_row(*["---" for _ in column_names])
for decay in decays:
    child1, child2 = map(as_latex, decay.decay_products)
    src += create_markdown_table_row(
        Rf"${decay.resonance.latex} \to" Rf" {child1} {child2}$",
        Rf"${as_latex(decay.resonance, only_jp=True)}$",
        int(1e3 * decay.resonance.mass),
        int(1e3 * decay.resonance.width),
        decay.outgoing_ls.L,
        decay.incoming_ls.L,
        decay.resonance.lineshape,
    )
Markdown(src)

### Aligned amplitude

In [None]:
A_K = sp.IndexedBase(R"A^K")
A_Λ = sp.IndexedBase(R"A^{\Lambda}")
A_Δ = sp.IndexedBase(R"A^{\Delta}")

half = sp.S.Half

ζ_0_11 = sp.Symbol(R"\zeta^0_{1(1)}", real=True)
ζ_0_21 = sp.Symbol(R"\zeta^0_{2(1)}", real=True)
ζ_0_31 = sp.Symbol(R"\zeta^0_{3(1)}", real=True)
ζ_1_11 = sp.Symbol(R"\zeta^1_{1(1)}", real=True)
ζ_1_21 = sp.Symbol(R"\zeta^1_{2(1)}", real=True)
ζ_1_31 = sp.Symbol(R"\zeta^1_{3(1)}", real=True)


def formulate_aligned_amplitude(λ_Λc, λ_p):
    _ν = sp.Symbol(R"\nu^{\prime}", rational=True)
    _λ = sp.Symbol(R"\lambda^{\prime}", rational=True)
    return PoolSum(
        A_K[_ν, _λ] * Wigner.d(half, λ_Λc, _ν, ζ_0_11) * Wigner.d(half, _λ, λ_p, ζ_1_11)
        + A_Λ[_ν, _λ]
        * Wigner.d(half, λ_Λc, _ν, ζ_0_21)
        * Wigner.d(half, _λ, λ_p, ζ_1_21)
        + A_Δ[_ν, _λ]
        * Wigner.d(half, λ_Λc, _ν, ζ_0_31)
        * Wigner.d(half, _λ, λ_p, ζ_1_31),
        (_λ, [-half, +half]),
        (_ν, [-half, +half]),
    )


ν = sp.Symbol("nu")
λ = sp.Symbol("lambda")
formulate_aligned_amplitude(λ_Λc=ν, λ_p=λ)

### Sub-system amplitudes

:::{warning}

Couplings are remapped from the LHCb paper to Dalitz-Plot Decomposition with [these relations](https://user-images.githubusercontent.com/22725744/165932213-34013235-8464-4018-bd21-3ebde1e86faf.png).

:::

In [None]:
H_prod = sp.IndexedBase(R"\mathcal{H}^\mathrm{production}")
H_dec = sp.IndexedBase(R"\mathcal{H}^\mathrm{decay}")

θ23 = sp.Symbol("theta23", real=True)
θ31 = sp.Symbol("theta31", real=True)
θ12 = sp.Symbol("theta12", real=True)

σ1, σ2, σ3 = sp.symbols("sigma1:4", nonnegative=True)
m1, m2, m3 = sp.symbols(R"m_p m_pi m_K", nonnegative=True)


def formulate_subsystem_amplitude(subsystem: int, λ_Λc, λ_p):
    if subsystem == 1:
        return formulate_K_amplitude(λ_Λc, λ_p, filter_isobars(decays, "K"))
    if subsystem == 2:
        return formulate_Λ_amplitude(λ_Λc, λ_p, filter_isobars(decays, "L"))
    if subsystem == 3:
        return formulate_Δ_amplitude(λ_Λc, λ_p, filter_isobars(decays, "D"))
    raise NotImplementedError(f"No chain implemented for sub-system {subsystem}")


def formulate_K_amplitude(λ_Λc, λ_p, decays: list[ThreeBodyDecay]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ_Λc, τ - λ_p)
                * H_prod[stringify(decay.resonance), τ, -λ_p]
                * formulate_dynamics(decay, σ1, m2, m3)
                * (-1) ** (half - λ_p)
                * Wigner.d(sp.Rational(decay.resonance.spin), τ, 0, θ23)
                * H_dec[stringify(decay.resonance), 0, 0],
                (τ, create_spin_range(decay.resonance.spin)),
            )
            for decay in decays
        ]
    )


def formulate_Λ_amplitude(λ_Λc, λ_p, decays: list[ThreeBodyDecay]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ_Λc, τ)
                * H_prod[stringify(decay.resonance), -τ, 0]
                * formulate_dynamics(decay, σ2, m1, m3)
                * Wigner.d(sp.Rational(decay.resonance.spin), τ, -λ_p, θ31)
                * H_dec[stringify(decay.resonance), 0, λ_p]
                * (-1) ** (half - λ_p)
                / (-decay.resonance.parity),
                (τ, create_spin_range(decay.resonance.spin)),
            )
            for decay in decays
        ]
    )


def formulate_Δ_amplitude(λ_Λc, λ_p, decays: list[ThreeBodyDecay]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ_Λc, τ)
                * H_prod[stringify(decay.resonance), -τ, 0]
                * formulate_dynamics(decay, σ3, m1, m2)
                * Wigner.d(sp.Rational(decay.resonance.spin), τ, λ_p, θ12)
                * H_dec[stringify(decay.resonance), λ_p, 0]
                / (-decay.resonance.parity * (-1) ** (decay.resonance.spin - half)),
                (τ, create_spin_range(decay.resonance.spin)),
            )
            for decay in decays
        ]
    )


def formulate_dynamics(decay: ThreeBodyDecay, s, m1, m2):
    lineshape = decay.resonance.lineshape
    if lineshape == "BreitWignerMinL":
        return formulate_breit_wigner(decay, s, m1, m2)
    if lineshape == "BuggBreitWignerMinL":
        return formulate_bugg_breit_wigner(decay, s, m1, m2)
    if lineshape == "Flatte1405":
        return formulate_flatté_1405(decay, s, m1, m2)
    raise NotImplementedError(f'No dynamics implemented for lineshape "{lineshape}"')


def formulate_breit_wigner(decay: ThreeBodyDecay, s, m1, m2):
    l_dec = sp.Rational(decay.incoming_ls.L)
    l_prod = sp.Rational(decay.outgoing_ls.L)
    mass = sp.Symbol(f"m_{{{decay.resonance.latex}}}")
    width = sp.Symbol(Rf"\Gamma_{{{decay.resonance.latex}}}")
    R_dec = sp.Symbol(R"R_\mathrm{res}")
    R_prod = sp.Symbol(R"R_{\Lambda_c}")
    safe_update_parameters(mass, decay.resonance.mass)
    safe_update_parameters(width, decay.resonance.width)
    # https://github.com/redeboer/polarization-sensitivity/pull/11#issuecomment-1128784376
    safe_update_parameters(R_dec, 1.5)
    safe_update_parameters(R_prod, 5)
    return RelativisticBreitWigner(s, mass, width, m1, m2, l_dec, l_prod, R_dec, R_prod)


def formulate_bugg_breit_wigner(decay: ThreeBodyDecay, s, m1, m2):
    gamma = sp.Symbol(Rf"\gamma_{{{decay.resonance.latex}}}")
    mass = sp.Symbol(f"m_{{{decay.resonance.latex}}}")
    width = sp.Symbol(Rf"\Gamma_{{{decay.resonance.latex}}}")
    safe_update_parameters(mass, decay.resonance.mass)
    safe_update_parameters(width, decay.resonance.width)
    safe_update_parameters(gamma, 1)
    return BuggBreitWigner(s, mass, width, m1, m2, gamma)


def formulate_flatté_1405(decay: ThreeBodyDecay, s, m1, m2):
    mass = sp.Symbol(f"m_{{{decay.resonance.latex}}}")
    width = sp.Symbol(Rf"\Gamma_{{{decay.resonance.latex}}}")
    mπ = sp.Symbol(f"m_{{{π.latex}}}")
    mΣ = sp.Symbol(f"m_{{{Σ.latex}}}")
    safe_update_parameters(mass, decay.resonance.mass)
    safe_update_parameters(width, decay.resonance.width)
    safe_update_parameters(mπ, π.width)
    safe_update_parameters(mΣ, Σ.width)
    return FlattéSWave(s, mass, width, (m1, m2), (mπ, mΣ))


def safe_update_parameters(parameter: sp.Symbol, value) -> None:
    parameter_defaults[parameter] = parameter_defaults.get(parameter, value)


def stringify(obj) -> Str:
    if isinstance(obj, Particle):
        return Str(obj.latex)
    return Str(f"{obj}")


def filter_isobars(
    decays: list[ThreeBodyDecay], resonance_pattern: str
) -> list[ThreeBodyDecay]:
    return [dec for dec in decays if dec.resonance.name.startswith(resonance_pattern)]

### Angle definitions

In [None]:
m0 = sp.Symbol(R"m_{\Lambda_c}", nonnegative=True)
angles = {
    θ12: sp.acos(
        (
            2 * σ3 * (σ2 - m3**2 - m1**2)
            - (σ3 + m1**2 - m2**2) * (m0**2 - σ3 - m3**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m3**2, σ3))
            * sp.sqrt(Källén(σ3, m1**2, m2**2))
        )
    ),
    θ23: sp.acos(
        (
            2 * σ1 * (σ3 - m1**2 - m2**2)
            - (σ1 + m2**2 - m3**2) * (m0**2 - σ1 - m1**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(σ1, m2**2, m3**2))
        )
    ),
    θ31: sp.acos(
        (
            2 * σ2 * (σ1 - m2**2 - m3**2)
            - (σ2 + m3**2 - m1**2) * (m0**2 - σ2 - m2**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m2**2, σ2))
            * sp.sqrt(Källén(σ2, m3**2, m1**2))
        )
    ),
    ζ_0_11: sp.S.Zero,  # = \hat\theta^0_{1(1)}
    ζ_0_21: -sp.acos(  # = -\hat\theta^{1(2)}
        (
            (m0**2 + m1**2 - σ1) * (m0**2 + m2**2 - σ2)
            - 2 * m0**2 * (σ3 - m1**2 - m2**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m2**2, σ2))
            * sp.sqrt(Källén(m0**2, σ1, m1**2))
        )
    ),
    ζ_0_31: sp.acos(  # = \hat\theta^{3(1)}
        (
            (m0**2 + m3**2 - σ3) * (m0**2 + m1**2 - σ1)
            - 2 * m0**2 * (σ2 - m3**2 - m1**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(m0**2, σ3, m3**2))
        )
    ),
    ζ_1_11: sp.S.Zero,
    ζ_1_21: sp.acos(
        (
            2 * m1**2 * (σ3 - m0**2 - m3**2)
            + (m0**2 + m1**2 - σ1) * (σ2 - m1**2 - m3**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(σ2, m1**2, m3**2))
        )
    ),
    ζ_1_31: -sp.acos(  # = -\zeta^1_{1(3)}
        (
            2 * m1**2 * (σ2 - m0**2 - m2**2)
            + (m0**2 + m1**2 - σ1) * (σ3 - m1**2 - m2**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(σ3, m1**2, m2**2))
        )
    ),
}

display_latex(angles)

In [None]:
masses = {
    m0: Λc.mass,
    m1: p.mass,
    m2: π.mass,
    m3: K.mass,
}
display_latex(masses)

### Helicity coupling values

In [None]:
def to_symbol_definitions(
    parameter_dict: dict[str, str]
) -> dict[sp.Basic, complex | float]:
    key_to_val: dict[str, complex | float] = {}
    for key, str_value in parameter_dict.items():
        if key.startswith("Ar"):
            identifier = key[2:]
            str_imag = parameter_dict[f"Ai{identifier}"]
            real = to_float(str_value)
            imag = to_float(str_imag)
            key_to_val[f"A{identifier}"] = complex(real, imag)
        elif key.startswith("Ai"):
            continue
        else:
            key_to_val[key] = to_float(str_value)
    return {to_symbol(key): value for key, value in key_to_val.items()}


def to_float(str_value: str) -> float:
    value, _ = map(float, str_value.split(" ± "))
    return value


def to_symbol(key: str) -> sp.Indexed | sp.Symbol:
    if key.startswith("A"):
        res = stringify(key[1:-1])
        i = int(key[-1])
        if str(res).startswith("L"):
            if i == 1:
                return H_prod[res, +half, 0]
            if i == 2:
                return H_prod[res, -half, 0]
        if str(res).startswith("D"):
            if i == 1:
                return H_prod[res, +half, 0]
            if i == 2:
                return H_prod[res, -half, 0]
        if str(res).startswith("K"):
            if str(res) in {"K(700)", "K(1430)"}:
                if i == 1:
                    return H_prod[res, 0, -half]
                if i == 2:
                    return H_prod[res, 0, +half]
            else:
                if i == 1:
                    return H_prod[res, 0, +half]
                if i == 2:
                    return H_prod[res, -1, +half]
                if i == 3:
                    return H_prod[res, +1, -half]
                if i == 4:
                    return H_prod[res, 0, -half]
    if key.startswith("gamma"):
        res = stringify(key[5:])
        return sp.Symbol(Rf"\gamma_{{{res}}}")
    if key.startswith("M"):
        res = stringify(key[1:])
        return sp.Symbol(Rf"m_{{{res}}}")
    if key.startswith("G"):
        res = stringify(key[1:])
        return sp.Symbol(Rf"\Gamma_{{{res}}}")
    raise NotImplementedError(
        f'Cannot convert key "{key}" in model parameter JSON file to SymPy symbol'
    )


with open("../data/modelparameters.json") as stream:
    data = json.load(stream)
assert len(data["modelstudies"]) == 18

model_number = 0
model_json = data["modelstudies"][model_number]["parameters"]
# https://github.com/redeboer/polarization-sensitivity/blob/34f5330/julia/notebooks/model0.jl#L301-L302
model_json["ArK(892)1"] = "1.0 ± 0.0"
model_json["AiK(892)1"] = "0.0 ± 0.0"

parameter_defaults = to_symbol_definitions(model_json)

In [None]:
prod_couplings = {
    key: value
    for key, value in parameter_defaults.items()
    if isinstance(key, sp.Indexed) and key.base == H_prod
}
display_latex(prod_couplings)

In [None]:
dec_couplings = {}
for decay in decays:
    i = stringify(decay.resonance)
    if decay.resonance.name.startswith("K"):
        dec_couplings[H_dec[i, 0, 0]] = 1
    if decay.resonance.name.startswith("L"):
        dec_couplings[H_dec[i, 0, half]] = 1
        dec_couplings[H_dec[i, 0, -half]] = (
            int(decay.resonance.parity)
            * int(K.parity)
            * int(p.parity)
            * (-1) ** (decay.resonance.spin - K.spin - p.spin)
        )
    if decay.resonance.name.startswith("D"):
        dec_couplings[H_dec[i, half, 0]] = 1
        dec_couplings[H_dec[i, -half, 0]] = (
            int(decay.resonance.parity)
            * int(p.parity)
            * int(π.parity)
            * (-1) ** (decay.resonance.spin - p.spin - π.spin)
        )
parameter_defaults.update(dec_couplings)
display_latex(dec_couplings)

### Intensity expression

In [None]:
intensity_expr = PoolSum(
    sp.Abs(formulate_aligned_amplitude(ν, λ)) ** 2,
    (λ, [-half, +half]),
    (ν, [-half, +half]),
)
display(intensity_expr)

In [None]:
A = {1: A_K, 2: A_Λ, 3: A_Δ}
amp_definitions = {}
for subsystem in range(1, 4):
    for Λc_heli, p_heli in product([-half, +half], [-half, +half]):
        symbol = A[subsystem][Λc_heli, p_heli]
        expr = formulate_subsystem_amplitude(subsystem, ν, λ)
        amp_definitions[symbol] = expr.subs({ν: Λc_heli, λ: p_heli})
display_latex(amp_definitions)

It takes about **one minute** to run the following cell, given the resonances as defined in {download}`data/isobars.json <../data/isobars.json>`.

In [None]:
def assert_all_symbols_defined(expr: sp.Expr) -> None:
    remaining_symbols = expr.xreplace(parameter_defaults).free_symbols
    assert remaining_symbols <= set(angles) | set(masses) | {σ1, σ2, σ3}


subs_intensity_expr = intensity_expr.doit().xreplace(amp_definitions).doit()
assert_all_symbols_defined(subs_intensity_expr)
print(f"Intensity expression has {sp.count_ops(subs_intensity_expr):,} operations")

### Non-coupling parameter definitions

In [None]:
couplings = set(dec_couplings) | set(prod_couplings)
display_latex({k: v for k, v in parameter_defaults.items() if k not in couplings})

## Cross-check in LHCb paper

In [None]:
with open("../data/crosscheck.json") as stream:
    crosscheck_data = json.load(stream)

In [None]:
from IPython.display import JSON

JSON(crosscheck_data)

### Lineshape comparison

In [None]:
variables = {var: value for var, value in crosscheck_data["mainvars"].items()}
substitutions = {
    σ1: variables["m2kpi"],
    σ2: variables["m2pk"],
}
substitutions.update(masses)
substitutions.update(parameter_defaults)

In [None]:
K892_decay = next(filter(lambda d: d.resonance.name == "K(892)", decays))
L1405_decay = next(filter(lambda d: d.resonance.name == "L(1405)", decays))
L1690_decay = next(filter(lambda d: d.resonance.name == "L(1690)", decays))
Math(as_latex([K892_decay, L1405_decay, L1690_decay]))

In [None]:
crosscheck_data["lineshapes"]

In [None]:
formulate_dynamics(K892_decay, σ1, m2, m3).doit().subs(substitutions).n()

In [None]:
formulate_dynamics(L1405_decay, σ2, m1, m3).doit().subs(substitutions).n()

In [None]:
formulate_dynamics(L1690_decay, σ2, m1, m3).doit().subs(substitutions).n()

## Computations with TensorWaves


### Phase space and helicity angles

In [None]:
σ3_expr = m0**2 + m1**2 + m2**2 + m3**2 - σ1 - σ2
compute_third_mandelstam = create_function(σ3_expr.subs(masses), backend="jax")
display_latex({σ3: σ3_expr})

In [None]:
def kibble_function(σ1, σ2):
    return Källén(
        Källén(σ2, m2**2, m0**2),
        Källén(σ3, m3**2, m0**2),
        Källén(σ1, m1**2, m0**2),
    )


def is_within_phsp(σ1, σ2, non_phsp_value=sp.nan):
    return sp.Piecewise(
        (1, sp.LessThan(kibble_function(σ1, σ2), 0)),
        (non_phsp_value, True),
    )


is_within_phsp(σ1, σ2)

In [None]:
in_phsp_expr = is_within_phsp(σ1, σ2).subs(σ3, σ3_expr).subs(masses).doit()
assert in_phsp_expr.free_symbols == {σ1, σ2}

In [None]:
resolution = 200
m0_val, m1_val, m2_val, m3_val = masses.values()
σ1_min = (m2_val + m3_val) ** 2
σ1_max = (m0_val - m1_val) ** 2
σ2_min = (m1_val + m3_val) ** 2
σ2_max = (m0_val - m2_val) ** 2
x = np.linspace(σ1_min, σ1_max, num=resolution)
y = np.linspace(σ2_min, σ2_max, num=resolution)
X, Y = np.meshgrid(x, y)
Z = compute_third_mandelstam.function(X, Y)
σ_arrays = {"sigma1": X, "sigma2": Y, "sigma3": Z}

in_phsp = create_function(in_phsp_expr, backend="numpy")
phsp = in_phsp(σ_arrays)

In [None]:
kinematic_variables = {
    symbol: expression.doit().subs(masses) for symbol, expression in angles.items()
}
kinematic_variables.update({s: s for s in [σ1, σ2, σ3]})  # include identity
transformer = SympyDataTransformer.from_sympy(kinematic_variables, backend="jax")
kinematic_arrays = transformer(σ_arrays)

### Definition of free parameters

In [None]:
free_parameters = {
    symbol: value
    for symbol, value in parameter_defaults.items()
    if symbol.name.startswith("m_")
    or symbol.name.startswith(R"\Gamma_")
    or symbol in prod_couplings
}
fixed_parameters = {
    symbol: value
    for symbol, value in parameter_defaults.items()
    if symbol not in free_parameters
}
fixed_parameters.update(masses)

### Intensity distribution

In [None]:
intensity_func = create_parametrized_function(
    subs_intensity_expr.xreplace(fixed_parameters),
    parameters=free_parameters,
    backend="jax",
)

In [None]:
%config InlineBackend.figure_formats = ['png']

In [None]:
s1_label = R"$\sigma_1=m^2\left(K\pi\right)$"
s2_label = R"$\sigma_2=m^2\left(pK\right)$"
s3_label = R"$\sigma_3=m^2\left(p\pi\right)$"

fig, ax = plt.subplots(
    figsize=(10, 8),
    tight_layout=True,
)
ax.set_title("Intensity distribution")
ax.set_xlabel(s1_label)
ax.set_ylabel(s2_label)

mesh = ax.pcolormesh(
    X,
    Y,
    phsp * intensity_func(kinematic_arrays),
    norm=LogNorm(),
)
fig.colorbar(mesh, ax=ax)
plt.show()

In [None]:
%config InlineBackend.figure_formats = ['svg']

In [None]:
def compute_sub_func(
    func: ParametrizedBackendFunction, input_data, non_zero_couplings: list[str]
) -> None:
    old_parameters = dict(func.parameters)
    pattern = rf"\\mathcal{{H}}.*\[(?!{'|'.join(non_zero_couplings)})"
    set_parameter_to_zero(func, pattern)
    array = func(input_data)
    func.update_parameters(old_parameters)
    return array


def set_parameter_to_zero(
    func: ParametrizedBackendFunction, search_term: Pattern
) -> None:
    new_parameters = dict(func.parameters)
    no_parameters_selected = True
    for par_name in func.parameters:
        if re.match(search_term, par_name) is not None:
            new_parameters[par_name] = 0
            no_parameters_selected = False
    if no_parameters_selected:
        logging.warning(f"All couplings were set to zero for search term {search_term}")
    func.update_parameters(new_parameters)


def set_ylim_to_zero(ax):
    _, y_max = ax.get_ylim()
    ax.set_ylim(0, y_max)


fig, (ax1, ax2) = plt.subplots(
    ncols=2,
    figsize=(12, 5),
    tight_layout=True,
)
ax1.set_xlabel(s1_label)
ax2.set_xlabel(s2_label)

subsystem_identifiers = ["K", "L", "D"]
subsystem_labels = ["K^{**}", R"\Lambda^{**}", R"\Delta^{**}"]
intensity_array = intensity_func(kinematic_arrays)
ax1.fill(x, np.nansum(intensity_array, axis=0), alpha=0.3)
ax2.fill(y, np.nansum(intensity_array, axis=1), alpha=0.3)

original_parameters = dict(intensity_func.parameters)
for label, identifier in zip(subsystem_labels, subsystem_identifiers):
    label = f"${label}$"
    intensity_array = compute_sub_func(intensity_func, kinematic_arrays, [identifier])
    ax1.plot(x, np.nansum(intensity_array, axis=0), label=label)
    ax2.plot(y, np.nansum(intensity_array, axis=1), label=label)
    intensity_func.update_parameters(original_parameters)
set_ylim_to_zero(ax1)
set_ylim_to_zero(ax2)
ax2.legend()
plt.show()

### Fit fractions

In [None]:
def sub_intensity(data, non_zero_couplings: list[str]):
    intensity_array = compute_sub_func(intensity_func, data, non_zero_couplings)
    return integrate_intensity(intensity_array)


def integrate_intensity(intensities):
    return np.nansum(intensities) / len(intensities)


I_tot = integrate_intensity(intensity_func(kinematic_arrays))
np.testing.assert_allclose(
    I_tot,
    sub_intensity(kinematic_arrays, ["K", "L", "D"]),
)

In [None]:
def interference_intensity(
    data,
    chain1: list[str],
    chain2: list[str],
) -> float:
    I_interference = sub_intensity(data, chain1 + chain2)
    I_chain1 = sub_intensity(data, chain1)
    I_chain2 = sub_intensity(data, chain2)
    return I_interference - I_chain1 - I_chain2


I_K = sub_intensity(kinematic_arrays, non_zero_couplings=["K"])
I_Λ = sub_intensity(kinematic_arrays, non_zero_couplings=["L"])
I_Δ = sub_intensity(kinematic_arrays, non_zero_couplings=["D"])
I_ΛΔ = interference_intensity(kinematic_arrays, ["L"], ["D"])
I_KΔ = interference_intensity(kinematic_arrays, ["K"], ["D"])
I_KΛ = interference_intensity(kinematic_arrays, ["K"], ["L"])
np.testing.assert_allclose(I_tot, I_K + I_Λ + I_Δ + I_ΛΔ + I_KΔ + I_KΛ)

In [None]:
def to_regex(text: str) -> str:
    text = text.replace("(", r"\(")
    text = text.replace(")", r"\)")
    return text


resonances = sorted(
    (d.resonance for d in decays),
    key=lambda p: natural_sorting(p.name),
    reverse=True,
)
n_resonances = len(resonances)
decay_rates = np.zeros(shape=(n_resonances, n_resonances))
combinations = list(product(enumerate(resonances), enumerate(resonances)))
progress_bar = tqdm(
    desc="Calculating rate matrix",
    total=(len(combinations) + n_resonances) // 2,
)
for (i, resonance1), (j, resonance2) in combinations:
    if j < i:
        continue
    progress_bar.postfix = f"{resonance1.name} × {resonance2.name}"
    res1 = to_regex(resonance1.name)
    res2 = to_regex(resonance2.name)
    if res1 == res2:
        I_sub = sub_intensity(kinematic_arrays, non_zero_couplings=[res1])
    else:
        I_sub = interference_intensity(kinematic_arrays, [res1], [res2])
    decay_rates[i, j] = I_sub / I_tot
    if i != j:
        decay_rates[j, i] = decay_rates[i, j]
    progress_bar.update()
progress_bar.close()

In [None]:
vmax = np.max(decay_rates)
fig, ax = plt.subplots(figsize=(9, 9))
ax.set_title("Rate matrix for isobars (%%)")
ax.matshow(np.rot90(decay_rates).T, cmap=plt.cm.coolwarm, vmin=-vmax, vmax=+vmax)

resonance_names = [p.name for p in resonances]
ax.set_xticks(range(n_resonances))
ax.set_xticklabels(reversed(resonance_names))
ax.set_yticks(range(n_resonances))
ax.set_yticklabels(resonance_names)
for i in range(n_resonances):
    for j in range(n_resonances):
        if j < i:
            continue
        rate = decay_rates[i, j]
        ax.text(n_resonances - j - 1, i, f"{100 * rate:.2f}", va="center", ha="center")
fig.tight_layout()
plt.show()