# Polarization sensitivity in $\Lambda_c \to p \pi K$

_This notebook originates from [ComPWA/compwa-org#129](https://github.com/ComPWA/compwa-org/pull/129) [[TR-018](https://compwa-org--129.org.readthedocs.build/report/018.html)], [`625e50c`](https://github.com/ComPWA/compwa-org/pull/129/commits/94ceb51)._

In [None]:
from __future__ import annotations

import itertools
import json
import re
from typing import Pattern

import matplotlib.pyplot as plt
import numpy as np
import qrules
import sympy as sp
from ampform.sympy import PoolSum, UnevaluatedExpression
from attrs import evolve
from IPython.display import HTML, Math, display
from ipywidgets import Button, Combobox, HBox, HTMLMath, Tab, VBox, interactive_output
from matplotlib import cm
from matplotlib.colors import LogNorm
from qrules.particle import Particle
from symplot import create_slider
from sympy.core.symbol import Str
from sympy.physics.matrices import msigma
from sympy.physics.quantum.spin import Rotation as Wigner
from tensorwaves.data.transform import SympyDataTransformer
from tensorwaves.function import ParametrizedBackendFunction
from tensorwaves.function.sympy import create_function, create_parametrized_function

from polarization.decay import IsobarNode, Resonance, ThreeBodyDecay
from polarization.dynamics import (
    BlattWeisskopf,
    BuggBreitWigner,
    EnergyDependentWidth,
    FlattéSWave,
    Källén,
    P,
    Q,
    RelativisticBreitWigner,
)
from polarization.io import as_latex, from_qrules, load_resonance_definitions
from polarization.spin import (
    create_spin_range,
    filter_parity_violating_ls,
    generate_ls_couplings,
)

PDG = qrules.load_pdg()


def display_latex(obj) -> None:
    latex = as_latex(obj)
    display(Math(latex))


def display_doit(
    expr: UnevaluatedExpression, deep=False, terms_per_line: int | None = None
) -> None:
    if terms_per_line is None:
        latex = as_latex({expr: expr.doit(deep=deep)})
    else:
        latex = sp.multiline_latex(
            lhs=expr,
            rhs=expr.doit(deep=deep),
            terms_per_line=terms_per_line,
            environment="eqnarray",
        )
    display(Math(latex))


# hack for moving Indexed indices below superscript of the base
def _print_Indexed_latex(self, printer, *args):
    base = printer._print(self.base)
    indices = ", ".join(map(printer._print, self.indices))
    return f"{base}_{{{indices}}}"


sp.Indexed._latex = _print_Indexed_latex

## Amplitude model

### Resonances and $LS$-scheme

Particle definitions for $\Lambda_c^+$ and $p, \pi^+, K^-$:

In [None]:
Λc = from_qrules(evolve(PDG["Lambda(c)+"], mass=2.28646))
p = from_qrules(evolve(PDG["p"], mass=0.938272046))
K = from_qrules(evolve(PDG["K-"], mass=0.493677))
π = from_qrules(evolve(PDG["pi+"], mass=0.13957018))

# https://github.com/redeboer/polarization-sensitivity/blob/34f5330/julia/notebooks/model0.jl#L43-L47
Σ = from_qrules(evolve(PDG["Sigma-"], mass=1.18937))  # for Flatté

Resonance definitions as defined in [`data/isobars.json`](../data/isobars.json):

In [None]:
resonances = load_resonance_definitions("../data/isobars.json")
len(resonances)

In [None]:
def create_isobar(resonance: Resonance) -> ThreeBodyDecay:
    if resonance.name.startswith("K"):
        child1, child2, sibling = π, K, p
    elif resonance.name.startswith("L"):
        child1, child2, sibling = p, K, π
    elif resonance.name.startswith("D"):
        child1, child2, sibling = p, π, K
    else:
        raise NotImplementedError
    decay = IsobarNode(
        parent=Λc,
        child1=sibling,
        child2=IsobarNode(
            parent=resonance,
            child1=child1,
            child2=child2,
            interaction=generate_L_min(resonance, child1, child2),
        ),
        interaction=generate_L_min(Λc, sibling, resonance),
    )
    return ThreeBodyDecay(decay)


def generate_L_min(parent: Resonance, child1: Resonance, child2: Resonance) -> int:
    ls = generate_ls_couplings(parent.spin, child1.spin, child2.spin)
    ls = filter_parity_violating_ls(ls, parent.parity, child1.parity, child2.parity)
    return min(ls)


decays = [create_isobar(res) for res in resonances.values()]
Math(as_latex(decays, with_jp=True))

In [None]:
def create_html_table_row(*items, typ="td"):
    items = map(lambda i: f"<{typ}>{i}</{typ}>", items)
    return "<tr>" + "".join(items) + "</tr>\n"


column_names = [
    "resonance",
    R"\(j^P\)",
    R"\(m\) (MeV)",
    R"\(\Gamma_0\) (MeV)",
    R"\(l_R\)",
    R"\(l_{\Lambda_c}^\mathrm{min}\)",
    "lineshape",
]
src = "<table>\n"
src += create_html_table_row(*column_names, typ="th")
for decay in decays:
    child1, child2 = map(as_latex, decay.decay_products)
    src += create_html_table_row(
        Rf"\({decay.resonance.latex} \to" Rf" {child1} {child2}\)",
        Rf"\({as_latex(decay.resonance, only_jp=True)}\)",
        int(1e3 * decay.resonance.mass),
        int(1e3 * decay.resonance.width),
        decay.incoming_ls.L,
        decay.outgoing_ls.L,
        decay.resonance.lineshape,
    )
src += "</table>\n"
HTML(src)

### Aligned amplitude

In [None]:
A_K = sp.IndexedBase(R"A^K")
A_Λ = sp.IndexedBase(R"A^{\Lambda}")
A_Δ = sp.IndexedBase(R"A^{\Delta}")

half = sp.S.Half

ζ_0_11 = sp.Symbol(R"\zeta^0_{1(1)}", real=True)
ζ_0_21 = sp.Symbol(R"\zeta^0_{2(1)}", real=True)
ζ_0_31 = sp.Symbol(R"\zeta^0_{3(1)}", real=True)
ζ_1_11 = sp.Symbol(R"\zeta^1_{1(1)}", real=True)
ζ_1_21 = sp.Symbol(R"\zeta^1_{2(1)}", real=True)
ζ_1_31 = sp.Symbol(R"\zeta^1_{3(1)}", real=True)


def formulate_aligned_amplitude(λ_Λc, λ_p):
    _ν = sp.Symbol(R"\nu^{\prime}", rational=True)
    _λ = sp.Symbol(R"\lambda^{\prime}", rational=True)
    return PoolSum(
        A_K[_ν, _λ] * Wigner.d(half, λ_Λc, _ν, ζ_0_11) * Wigner.d(half, _λ, λ_p, ζ_1_11)
        + A_Λ[_ν, _λ]
        * Wigner.d(half, λ_Λc, _ν, ζ_0_21)
        * Wigner.d(half, _λ, λ_p, ζ_1_21)
        + A_Δ[_ν, _λ]
        * Wigner.d(half, λ_Λc, _ν, ζ_0_31)
        * Wigner.d(half, _λ, λ_p, ζ_1_31),
        (_λ, [-half, +half]),
        (_ν, [-half, +half]),
    )


ν = sp.Symbol("nu")
λ = sp.Symbol("lambda")
formulate_aligned_amplitude(λ_Λc=ν, λ_p=λ)

### Dynamics

In [None]:
z = sp.Symbol("z", positive=True)
L = sp.Symbol("L", integer=True, nonnegative=True)
display_doit(BlattWeisskopf(z, L))

In [None]:
x, y, z = sp.symbols("x:z")
display_doit(Källén(x, y, z))

In [None]:
s, m0, mi, mj, mk = sp.symbols("s m0 m_i:k", nonnegative=True)
display_doit(P(s, mi, mj))
display_doit(Q(s, m0, mk))

In [None]:
R = sp.Symbol("R")
l_R = sp.Symbol("l_R", integer=True, positive=True)
m, Γ0, m1, m2 = sp.symbols("m Γ0 m1 m2", nonnegative=True)
display_doit(EnergyDependentWidth(s, m, Γ0, m1, m2, l_R, R))

In [None]:
l_Λc = sp.Symbol(R"l_{\Lambda_c}", integer=True, positive=True)
display_doit(RelativisticBreitWigner(s, m, Γ0, m1, m2, l_R, l_Λc, R))

In [None]:
m1_1, m2_1, m1_2, m2_2 = sp.symbols("m1_1 m2_1 m1_2 m2_2")
display_doit(FlattéSWave(s, m, Γ0, (m1_1, m2_1), (m1_2, m2_2)))

In [None]:
mKπ, m0, Γ0, mK, mπ, γ = sp.symbols(R"m_{K\pi} m0 Gamma0 m_K m_pi gamma")
bugg = BuggBreitWigner(mKπ**2, m0, Γ0, mK, mπ, γ)
q = P(mKπ**2, mK, mπ)
s_A = sp.Symbol("s_A")
definitions = {
    s_A: mK**2 - mπ**2,
    sp.Symbol(R"\rho_{K\pi}"): 2 * q / mKπ**2,
    q: q.evaluate(),
}
display_latex({bugg: bugg.evaluate().subs({v: k for k, v in definitions.items()})})
display_latex(definitions)

### Sub-system amplitudes

In [None]:
H_prod = sp.IndexedBase(R"\mathcal{H}^\mathrm{production}")
H_dec = sp.IndexedBase(R"\mathcal{H}^\mathrm{decay}")

θ23 = sp.Symbol("theta23", real=True)
θ31 = sp.Symbol("theta31", real=True)
θ12 = sp.Symbol("theta12", real=True)

σ1, σ2, σ3 = sp.symbols("sigma1:4", nonnegative=True)
m1, m2, m3 = sp.symbols(R"m_p m_pi m_K", nonnegative=True)


def formulate_subsystem_amplitude(subsystem: int, λ_Λc, λ_p):
    if subsystem == 1:
        return formulate_K_amplitude(λ_Λc, λ_p, filter_isobars(decays, "K"))
    if subsystem == 2:
        return formulate_Λ_amplitude(λ_Λc, λ_p, filter_isobars(decays, "L"))
    if subsystem == 3:
        return formulate_Δ_amplitude(λ_Λc, λ_p, filter_isobars(decays, "D"))
    raise NotImplementedError(f"No chain implemented for sub-system {subsystem}")


def formulate_K_amplitude(λ_Λc, λ_p, decays: list[ThreeBodyDecay]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ_Λc, τ - λ_p)
                * H_prod[stringify(decay.resonance), τ, λ_p]
                * formulate_dynamics(decay, σ1, m2, m3)
                * (-1) ** (half - λ_p)
                * Wigner.d(sp.Rational(decay.resonance.spin), τ, 0, θ23)
                * H_dec[stringify(decay.resonance), 0, 0],
                (τ, create_spin_range(decay.resonance.spin)),
            )
            for decay in decays
        ]
    )


def formulate_Λ_amplitude(λ_Λc, λ_p, decays: list[ThreeBodyDecay]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ_Λc, τ)
                * H_prod[stringify(decay.resonance), τ, 0]
                * formulate_dynamics(decay, σ2, m1, m3)
                * Wigner.d(sp.Rational(decay.resonance.spin), τ, -λ_p, θ31)
                * H_dec[stringify(decay.resonance), 0, λ_p]
                * (-1) ** (half - λ_p),
                (τ, create_spin_range(decay.resonance.spin)),
            )
            for decay in decays
        ]
    )


def formulate_Δ_amplitude(λ_Λc, λ_p, decays: list[ThreeBodyDecay]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ_Λc, τ)
                * H_prod[stringify(decay.resonance), τ, 0]
                * formulate_dynamics(decay, σ3, m1, m2)
                * Wigner.d(sp.Rational(decay.resonance.spin), τ, λ_p, θ12)
                * H_dec[stringify(decay.resonance), λ_p, 0],
                (τ, create_spin_range(decay.resonance.spin)),
            )
            for decay in decays
        ]
    )


def formulate_dynamics(decay: ThreeBodyDecay, s, m1, m2):
    lineshape = decay.resonance.lineshape
    if lineshape == "BreitWignerMinL":
        return formulate_breit_wigner(decay, s, m1, m2)
    if lineshape == "BuggBreitWignerMinL":
        return formulate_bugg_breit_wigner(decay, s, m1, m2)
    if lineshape == "Flatte1405":
        return formulate_flatté_1405(decay, s, m1, m2)
    raise NotImplementedError(f'No dynamics implemented for lineshape "{lineshape}"')


def formulate_breit_wigner(decay: ThreeBodyDecay, s, m1, m2):
    l_R = sp.Rational(decay.incoming_ls.L)
    l_Λc = sp.Rational(decay.outgoing_ls.L)
    mass = sp.Symbol(f"m_{{{decay.resonance.latex}}}")
    width = sp.Symbol(Rf"\Gamma_{{{decay.resonance.latex}}}")
    safe_update_parameters(mass, decay.resonance.mass)
    safe_update_parameters(width, decay.resonance.width)
    return RelativisticBreitWigner(s, mass, width, m1, m2, l_R, l_Λc, R)


def formulate_bugg_breit_wigner(decay: ThreeBodyDecay, s, m1, m2):
    gamma = sp.Symbol(Rf"\gamma_{{{decay.resonance.latex}}}")
    mass = sp.Symbol(f"m_{{{decay.resonance.latex}}}")
    width = sp.Symbol(Rf"\Gamma_{{{decay.resonance.latex}}}")
    safe_update_parameters(mass, decay.resonance.mass)
    safe_update_parameters(width, decay.resonance.width)
    safe_update_parameters(gamma, 1)
    return BuggBreitWigner(s, mass, width, m1, m2, gamma)


def formulate_flatté_1405(decay: ThreeBodyDecay, s, m1, m2):
    mass = sp.Symbol(f"m_{{{decay.resonance.latex}}}")
    width = sp.Symbol(Rf"\Gamma_{{{decay.resonance.latex}}}")
    mπ = sp.Symbol(f"m_{{{π.latex}}}")
    mΣ = sp.Symbol(f"m_{{{Σ.latex}}}")
    safe_update_parameters(mass, decay.resonance.mass)
    safe_update_parameters(width, decay.resonance.width)
    safe_update_parameters(mπ, π.width)
    safe_update_parameters(mΣ, Σ.width)
    return FlattéSWave(s, mass, width, (m1, m2), (mπ, mΣ))


def safe_update_parameters(parameter: sp.Symbol, value) -> None:
    parameter_defaults[parameter] = parameter_defaults.get(parameter, value)


def stringify(obj) -> Str:
    if isinstance(obj, (Particle, Resonance)):
        return Str(obj.latex)
    return Str(f"{obj}")


def filter_isobars(
    decays: list[ThreeBodyDecay], resonance_pattern: str
) -> list[ThreeBodyDecay]:
    return [dec for dec in decays if dec.resonance.name.startswith(resonance_pattern)]

### Angle definitions

In [None]:
m0 = sp.Symbol(R"m_{\Lambda_c}", nonnegative=True)
angles = {
    θ12: sp.acos(
        (
            2 * σ3 * (σ2 - m3**2 - m1**2)
            - (σ3 + m1**2 - m2**2) * (m0**2 - σ3 - m3**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m3**2, σ3))
            * sp.sqrt(Källén(σ3, m1**2, m2**2))
        )
    ),
    θ23: sp.acos(
        (
            2 * σ1 * (σ3 - m1**2 - m2**2)
            - (σ1 + m2**2 - m3**2) * (m0**2 - σ1 - m1**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(σ1, m2**2, m3**2))
        )
    ),
    θ31: sp.acos(
        (
            2 * σ2 * (σ1 - m2**2 - m3**2)
            - (σ2 + m3**2 - m1**2) * (m0**2 - σ2 - m2**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m2**2, σ2))
            * sp.sqrt(Källén(σ2, m3**2, m1**2))
        )
    ),
    ζ_0_11: sp.S.Zero,  # = \hat\theta^0_{1(1)}
    ζ_0_21: -sp.acos(  # = -\hat\theta^{1(2)}
        (
            (m0**2 + m1**2 - σ1) * (m0**2 + m2**2 - σ2)
            - 2 * m0**2 * (σ3 - m1**2 - m2**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m2**2, σ2))
            * sp.sqrt(Källén(m0**2, σ1, m1**2))
        )
    ),
    ζ_0_31: sp.acos(  # = \hat\theta^{3(1)}
        (
            (m0**2 + m3**2 - σ3) * (m0**2 + m1**2 - σ1)
            - 2 * m0**2 * (σ2 - m3**2 - m1**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(m0**2, σ3, m3**2))
        )
    ),
    ζ_1_11: sp.S.Zero,
    ζ_1_21: sp.acos(
        (
            2 * m1**2 * (σ3 - m0**2 - m3**2)
            + (m0**2 + m1**2 - σ1) * (σ2 - m1**2 - m3**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(σ2, m1**2, m3**2))
        )
    ),
    ζ_1_31: -sp.acos(  # = -\zeta^1_{1(3)}
        (
            2 * m1**2 * (σ2 - m0**2 - m2**2)
            + (m0**2 + m1**2 - σ1) * (σ3 - m1**2 - m2**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(σ3, m1**2, m2**2))
        )
    ),
}

display_latex(angles)

In [None]:
masses = {
    m0: Λc.mass,
    m1: p.mass,
    m2: π.mass,
    m3: K.mass,
}
display_latex(masses)

### Helicity coupling values

In [None]:
def to_symbol_definitions(
    parameter_dict: dict[str, str]
) -> dict[sp.Basic, complex | float]:
    key_to_val: dict[str, complex | float] = {}
    for key, str_value in parameter_dict.items():
        if key.startswith("Ar"):
            identifier = key[2:]
            str_imag = parameter_dict[f"Ai{identifier}"]
            real = to_float(str_value)
            imag = to_float(str_imag)
            key_to_val[f"A{identifier}"] = complex(real, imag)
        elif key.startswith("Ai"):
            continue
        else:
            key_to_val[key] = to_float(str_value)
    return {to_symbol(key): value for key, value in key_to_val.items()}


def to_float(str_value: str) -> float:
    value, _ = map(float, str_value.split(" ± "))
    return value


def to_symbol(key: str) -> sp.Indexed | sp.Symbol:
    if key.startswith("A"):
        res = stringify(key[1:-1])
        i = int(key[-1])
        if str(res).startswith("L"):
            if i == 1:
                return H_prod[res, +half, 0]
            if i == 2:
                return H_prod[res, -half, 0]
        if str(res).startswith("D"):
            if i == 1:
                return H_prod[res, +half, 0]
            if i == 2:
                return H_prod[res, -half, 0]
        if str(res).startswith("K"):
            if str(res) in {"K(700)", "K(1430)"}:
                if i == 1:
                    return H_prod[res, 0, -half]
                if i == 2:
                    return H_prod[res, 0, +half]
            else:
                if i == 1:
                    return H_prod[res, 0, +half]
                if i == 2:
                    return H_prod[res, +1, +half]
                if i == 3:
                    return H_prod[res, -1, -half]
                if i == 4:
                    return H_prod[res, 0, -half]
    if key.startswith("gamma"):
        res = stringify(key[5:])
        return sp.Symbol(Rf"\gamma_{{{res}}}")
    if key.startswith("M"):
        res = stringify(key[1:])
        return sp.Symbol(Rf"m_{{{res}}}")
    if key.startswith("G"):
        res = stringify(key[1:])
        return sp.Symbol(Rf"\Gamma_{{{res}}}")
    raise NotImplementedError(
        f'Cannot convert key "{key}" in model parameter JSON file to SymPy symbol'
    )


with open("../data/modelparameters.json") as stream:
    data = json.load(stream)
assert len(data["modelstudies"]) == 18

model_number = 0
model_json = data["modelstudies"][model_number]["parameters"]
# https://github.com/redeboer/polarization-sensitivity/blob/34f5330/julia/notebooks/model0.jl#L301-L302
model_json["ArK(892)1"] = "1.0 ± 0.0"
model_json["AiK(892)1"] = "0.0 ± 0.0"

parameter_defaults = to_symbol_definitions(model_json)
parameter_defaults[R] = 5  # GeV^{-1} (length factor)

In [None]:
prod_couplings = {
    key: value
    for key, value in parameter_defaults.items()
    if isinstance(key, sp.Indexed) and key.base == H_prod
}
display_latex(prod_couplings)

In [None]:
dec_couplings = {}
for decay in decays:
    i = stringify(decay.resonance)
    if decay.resonance.name.startswith("K"):
        dec_couplings[H_dec[i, 0, 0]] = 1
    if decay.resonance.name.startswith("L"):
        dec_couplings[H_dec[i, 0, half]] = 1
        dec_couplings[H_dec[i, 0, -half]] = (
            int(decay.resonance.parity)
            * int(K.parity)
            * int(p.parity)
            * (-1) ** (decay.resonance.spin - K.spin - p.spin)
        )
    if decay.resonance.name.startswith("D"):
        dec_couplings[H_dec[i, half, 0]] = 1
        dec_couplings[H_dec[i, -half, 0]] = (
            int(decay.resonance.parity)
            * int(p.parity)
            * int(π.parity)
            * (-1) ** (decay.resonance.spin - p.spin - π.spin)
        )
parameter_defaults.update(dec_couplings)
display_latex(dec_couplings)

In [None]:
couplings = set(dec_couplings) | set(prod_couplings)
display_latex({k: v for k, v in parameter_defaults.items() if k not in couplings})

### Intensity expression

In [None]:
intensity_expr = PoolSum(
    sp.Abs(formulate_aligned_amplitude(ν, λ)) ** 2,
    (λ, [-half, +half]),
    (ν, [-half, +half]),
)
display(intensity_expr)

In [None]:
A = {1: A_K, 2: A_Λ, 3: A_Δ}
amp_definitions = {}
for subsystem in range(1, 4):
    for Λc_heli, p_heli in itertools.product([-half, +half], [-half, +half]):
        symbol = A[subsystem][Λc_heli, p_heli]
        expr = formulate_subsystem_amplitude(subsystem, ν, λ)
        amp_definitions[symbol] = expr.subs({ν: Λc_heli, λ: p_heli})
display_latex(amp_definitions)

It takes about **one minute** to run this cell, given the resonances as defined in [`data/isobars.json`](../data/isobars.json).

In [None]:
def assert_all_symbols_defined(expr: sp.Expr) -> None:
    remaining_symbols = expr.xreplace(parameter_defaults).free_symbols
    assert remaining_symbols <= set(angles) | set(masses) | {σ1, σ2, σ3}


subs_intensity_expr = intensity_expr.doit().xreplace(amp_definitions).doit()
assert_all_symbols_defined(subs_intensity_expr)
print(f"Intensity expression has {sp.count_ops(subs_intensity_expr):,} operations")

## Polarization sensitivity

In [None]:
def to_index(helicity):
    """Symbolic conversion of half-value helicities to Pauli matrix indices."""
    # https://github.com/ComPWA/compwa-org/pull/129#issuecomment-1096599896
    return sp.Piecewise(
        (1, sp.LessThan(helicity, 0)),
        (0, True),
    )


ν_prime = sp.Symbol(R"\nu^{\prime}")
polarization_expr = [
    PoolSum(
        formulate_aligned_amplitude(ν, λ).conjugate()
        * pauli_matrix[to_index(ν), to_index(ν_prime)]
        * formulate_aligned_amplitude(ν_prime, λ),
        (λ, [-half, +half]),
        (ν, [-half, +half]),
        (ν_prime, [-half, +half]),
    )
    / intensity_expr
    for pauli_matrix in map(msigma, [1, 2, 3])
]

It takes about **three minutes** to run this cell, given the resonances as defined in [`data/isobars.json`](../data/isobars.json).

In [None]:
subs_polarization_expr = []
for xyz, subscript in enumerate("₁₂₃"):
    expr = polarization_expr[xyz].doit().xreplace(amp_definitions).doit()
    assert_all_symbols_defined(expr)
    print(f"Polarization expression α{subscript} has {sp.count_ops(expr):,} operations")
    subs_polarization_expr.append(expr)

## Computations with TensorWaves


### Phase space and helicity angles

In [None]:
σ3_expr = m0**2 + m1**2 + m2**2 + m3**2 - σ1 - σ2
compute_third_mandelstam = create_function(σ3_expr.subs(masses), backend="jax")
display_latex({σ3: σ3_expr})

In [None]:
def kibble_function(σ1, σ2):
    return Källén(
        Källén(σ2, m2**2, m0**2),
        Källén(σ3, m3**2, m0**2),
        Källén(σ1, m1**2, m0**2),
    )


def is_within_phsp(σ1, σ2, non_phsp_value=sp.nan):
    return sp.Piecewise(
        (1, sp.LessThan(kibble_function(σ1, σ2), 0)),
        (non_phsp_value, True),
    )


is_within_phsp(σ1, σ2)

In [None]:
in_phsp_expr = is_within_phsp(σ1, σ2).subs(σ3, σ3_expr).subs(masses).doit()
assert in_phsp_expr.free_symbols == {σ1, σ2}

In [None]:
resolution = 200
m0_val, m1_val, m2_val, m3_val = masses.values()
σ1_min = (m2_val + m3_val) ** 2
σ1_max = (m0_val - m1_val) ** 2
σ2_min = (m1_val + m3_val) ** 2
σ2_max = (m0_val - m2_val) ** 2
x = np.linspace(σ1_min, σ1_max, num=resolution)
y = np.linspace(σ2_min, σ2_max, num=resolution)
X, Y = np.meshgrid(x, y)
Z = compute_third_mandelstam.function(X, Y)
σ_arrays = {"sigma1": X, "sigma2": Y, "sigma3": Z}

in_phsp = create_function(in_phsp_expr, backend="numpy")
phsp = in_phsp(σ_arrays)

In [None]:
kinematic_variables = {
    symbol: expression.doit().subs(masses) for symbol, expression in angles.items()
}
kinematic_variables.update({s: s for s in [σ1, σ2, σ3]})  # include identity
transformer = SympyDataTransformer.from_sympy(kinematic_variables, backend="jax")
kinematic_arrays = transformer(σ_arrays)

### Definition of free parameters

In [None]:
free_parameters = {
    symbol: value
    for symbol, value in parameter_defaults.items()
    if symbol.name.startswith("m_")
    or symbol.name.startswith(R"\Gamma_")
    or symbol in prod_couplings
}
fixed_parameters = {
    symbol: value
    for symbol, value in parameter_defaults.items()
    if symbol not in free_parameters
}
fixed_parameters.update(masses)

### Intensity distribution

In [None]:
intensity_func = create_parametrized_function(
    subs_intensity_expr.xreplace(fixed_parameters),
    parameters=free_parameters,
    backend="jax",
)

In [None]:
%config InlineBackend.figure_formats = ['png']

In [None]:
s1_label = R"$\sigma_1=m^2\left(K\pi\right)$"
s2_label = R"$\sigma_2=m^2\left(pK\right)$"
s3_label = R"$\sigma_3=m^2\left(p\pi\right)$"

fig, ax = plt.subplots(
    figsize=(10, 8),
    tight_layout=True,
)
ax.set_title("Intensity distribution")
ax.set_xlabel(s1_label)
ax.set_ylabel(s2_label)

mesh = ax.pcolormesh(
    X,
    Y,
    phsp * intensity_func(kinematic_arrays),
    norm=LogNorm(),
)
fig.colorbar(mesh, ax=ax)
plt.show()

In [None]:
%config InlineBackend.figure_formats = ['svg']

In [None]:
def compute_sub_func(
    func: ParametrizedBackendFunction, input_data, non_zero_couplings: list[str]
) -> None:
    old_parameters = dict(func.parameters)
    pattern = rf"\\mathcal{{H}}.*\[(?!{'|'.join(non_zero_couplings)})"
    set_parameter_to_zero(func, pattern)
    array = func(input_data)
    func.update_parameters(old_parameters)
    return array


def set_parameter_to_zero(
    func: ParametrizedBackendFunction, search_term: Pattern
) -> None:
    new_parameters = dict(func.parameters)
    for par_name in func.parameters:
        if re.match(search_term, par_name) is not None:
            new_parameters[par_name] = 0
    func.update_parameters(new_parameters)


def set_ylim_to_zero(ax):
    _, y_max = ax.get_ylim()
    ax.set_ylim(0, y_max)


fig, (ax1, ax2) = plt.subplots(
    ncols=2,
    figsize=(12, 5),
    tight_layout=True,
)
ax1.set_xlabel(s1_label)
ax2.set_xlabel(s2_label)

subsystem_identifiers = ["K", "L", "D"]
subsystem_labels = ["K^{**}", R"\Lambda^{**}", R"\Delta^{**}"]
intensity_array = intensity_func(kinematic_arrays)
ax1.fill(x, np.nansum(intensity_array, axis=0), alpha=0.3)
ax2.fill(y, np.nansum(intensity_array, axis=1), alpha=0.3)

original_parameters = dict(intensity_func.parameters)
for label, identifier in zip(subsystem_labels, subsystem_identifiers):
    label = f"${label}$"
    intensity_array = compute_sub_func(intensity_func, kinematic_arrays, [identifier])
    ax1.plot(x, np.nansum(intensity_array, axis=0), label=label)
    ax2.plot(y, np.nansum(intensity_array, axis=1), label=label)
    intensity_func.update_parameters(original_parameters)
set_ylim_to_zero(ax1)
set_ylim_to_zero(ax2)
ax1.legend()
plt.show()

### Fit fractions

In [None]:
def sub_intensity(data, non_zero_couplings: list[str]):
    intensity_array = compute_sub_func(intensity_func, data, non_zero_couplings)
    return integrate_intensity(intensity_array)


def integrate_intensity(intensities):
    return np.nansum(intensities) / len(intensities)


I_tot = integrate_intensity(intensity_func(kinematic_arrays))
np.testing.assert_allclose(
    I_tot,
    sub_intensity(kinematic_arrays, ["K", "L", "D"]),
)

In [None]:
def interference_intensity(
    data,
    chain1: list[str],
    chain2: list[str],
) -> float:
    I_interference = sub_intensity(data, chain1 + chain2)
    I_chain1 = sub_intensity(data, chain1)
    I_chain2 = sub_intensity(data, chain2)
    return I_interference - I_chain1 - I_chain2


I_K = sub_intensity(kinematic_arrays, non_zero_couplings=["K"])
I_Λ = sub_intensity(kinematic_arrays, non_zero_couplings=["L"])
I_Δ = sub_intensity(kinematic_arrays, non_zero_couplings=["D"])
I_ΛΔ = interference_intensity(kinematic_arrays, ["L"], ["D"])
I_KΔ = interference_intensity(kinematic_arrays, ["K"], ["D"])
I_KΛ = interference_intensity(kinematic_arrays, ["K"], ["L"])
np.testing.assert_allclose(I_tot, I_K + I_Λ + I_Δ + I_ΛΔ + I_KΔ + I_KΛ)

In [None]:
def render_resonance_row(resonance_identifier: str):
    filtered_resonances = [
        decay.resonance
        for decay in decays
        if resonance_identifier in decay.resonance.name
    ]
    rows = []
    for res in filtered_resonances:
        pattern = res.latex.replace("(", r"\(").replace(")", r"\)")
        I_sub = sub_intensity(kinematic_arrays, [pattern])
        row = (
            Rf"\color{{gray}}{{{res.latex}}}",
            Rf"\color{{gray}}{{{I_sub/I_tot:.3f}}}",
        )
        rows.append(row)
    if len(rows) > 1:
        return rows
    return []


rows = [
    R"\hline",
    ("K^{**}", f"{I_K/I_tot:.3f}"),
    *render_resonance_row("K"),
    (R"\Lambda^{**}", f"{I_Λ/I_tot:.3f}"),
    *render_resonance_row("L"),
    (R"\Delta^{**}", f"{I_Δ/I_tot:.3f}"),
    *render_resonance_row("D"),
    (R"\Delta/\Lambda", f"{I_ΛΔ/I_tot:.3f}"),
    (R"K/\Delta", f"{I_KΔ/I_tot:.3f}"),
    (R"K/\Lambda", f"{I_KΛ/I_tot:.3f}"),
    R"\hline",
    (
        R"\mathrm{total}",
        f"{(I_K + I_Λ + I_Δ + I_ΛΔ + I_KΔ + I_KΛ) /I_tot:.3f}",
    ),
]

latex = R"\begin{array}{crr}" + "\n"
latex += R"& I_\mathrm{sub}\,/\,I \\" + "\n"
for row in rows:
    if row == R"\hline":
        latex += R"\hline"
    else:
        latex += "  " + " & ".join(row) + R" \\" + "\n"
latex += R"\end{array}"
Math(latex)

### Polarization distributions

In [None]:
polarization_func = [
    create_parametrized_function(
        subs_polarization_expr[xyz].xreplace(fixed_parameters),
        parameters=free_parameters,
        backend="jax",
    )
    for xyz in range(3)
]

In [None]:
def render_mean(array, plus=True):
    array = array.real
    mean = f"{np.nanmean(array):.3f}"
    std = f"{np.nanstd(array):.3f}"
    if plus and float(mean) > 0:
        mean = f"+{mean}"
    return Rf"{mean} \pm {std}"


latex = R"\begin{array}{cccc}" + "\n"
latex += R"& \bar{|\alpha|} & \bar\alpha_x & \bar\alpha_y & \bar\alpha_z \\" + "\n"
for label, identifier in zip(subsystem_labels, subsystem_identifiers):
    latex += f"  {label} & "
    x, y, z = (
        compute_sub_func(polarization_func[xyz], kinematic_arrays, [identifier])
        for xyz in range(3)
    )
    latex += render_mean(np.sqrt(x**2 + y**2 + z**2), plus=False) + " & "
    latex += " & ".join(map(render_mean, [x, y, z]))
    latex += R" \\" + "\n"
latex += R"\end{array}"
Math(latex)

In [None]:
%config InlineBackend.figure_formats = ['png']
%matplotlib widget

#### Slider definitions

In [None]:
# Slider construction
sliders = {}
for symbol, value in free_parameters.items():
    if symbol.name.startswith(R"\mathcal{H}"):
        real_slider = create_slider(symbol)
        imag_slider = create_slider(symbol)
        sliders[f"{symbol.name}_real"] = real_slider
        sliders[f"{symbol.name}_imag"] = imag_slider
        real_slider.description = R"\(\mathrm{Re}\)"
        imag_slider.description = R"\(\mathrm{Im}\)"
    else:
        slider = create_slider(symbol)
        sliders[symbol.name] = slider

# Slider ranges
σ3_max = (m0_val - m3_val) ** 2
σ3_min = (m1_val + m2_val) ** 2

for name, slider in sliders.items():
    slider.continuous_update = False
    slider.step = 0.01
    if name.startswith("m_"):
        if "K" in name:
            min_, max_ = np.sqrt(σ1_min), np.sqrt(σ1_max)
        elif "L" in name:
            min_, max_ = np.sqrt(σ2_min), np.sqrt(σ2_max)
        elif "D" in name:
            min_, max_ = np.sqrt(σ3_min), np.sqrt(σ3_max)
        diff = max_ - min_
        slider.min = min_ - 0.2 * diff
        slider.max = max_ + 0.2 * diff
    elif name.startswith(R"\Gamma_"):
        slider.min = 0
        slider.max = max(0.7, 2 * slider.value)
    elif name.startswith(R"\mathcal{H}"):
        slider.min = -15
        slider.max = +15


# Slider values
def reset_sliders(click_event):
    for symbol, value in free_parameters.items():
        if symbol.name.startswith(R"\mathcal{H}"):
            set_slider(sliders[symbol.name + "_real"], value)
            set_slider(sliders[symbol.name + "_imag"], value)
        else:
            set_slider(sliders[symbol.name], value)


def set_coupling_sliders_to_zero(filter_pattern):
    if isinstance(filter_pattern, Combobox):
        filter_pattern = filter_pattern.value
    for name, slider in sliders.items():
        if not name.startswith(R"\mathcal{H}"):
            continue
        if filter_pattern not in name:
            continue
        set_slider(slider, 0)


def set_slider(slider, value):
    if slider.description == R"\(\mathrm{Im}\)":
        value = complex(value).imag
    else:
        value = complex(value).real
    n_decimals = -round(np.log10(slider.step))
    if slider.value != round(value, n_decimals):  # widget performance
        slider.value = value


reset_sliders(click_event=None)
reset_button = Button(description="Reset slider values")
reset_button.on_click(reset_sliders)

all_resonances = [decay.resonance for decay in decays]
filter_button = Combobox(
    placeholder="Enter coupling filter pattern",
    options=[p.latex for p in all_resonances],
    description=R"$\mathcal{H}=0$",
)
filter_button.on_submit(set_coupling_sliders_to_zero)

# UI design
latex = {symbol.name: sp.latex(symbol) for symbol in free_parameters}
mass_sliders = [sliders[n] for n in sliders if n.startswith("m_")]
width_sliders = [sliders[n] for n in sliders if n.startswith(R"\Gamma_")]
coupling_sliders = {}
for decay in decays:
    res = decay.resonance
    coupling_sliders[res.name] = (
        [s for n, s in sliders.items() if n.endswith("_real") and res.latex in n],
        [s for n, s in sliders.items() if n.endswith("_imag") and res.latex in n],
        [
            HTMLMath(f"${latex[n[:-5]]}$")
            for n in sliders
            if n.endswith("_real") and res.latex in n
        ],
    )
slider_tabs = Tab(
    children=[
        Tab(
            children=[
                VBox([HBox(s) for s in zip(*pair)])
                for pair in coupling_sliders.values()
            ],
            _titles={i: label for i, label in enumerate(coupling_sliders)},
        ),
        VBox([HBox([r, i]) for r, i in zip(mass_sliders, width_sliders)]),
    ],
    _titles=dict(enumerate(["Couplings", "Masses and widths"])),
)
ui = VBox([slider_tabs, HBox([reset_button, filter_button])])

#### Interactive plot

In [None]:
nrows = 4
ncols = 5
scale = 2.6
aspect_ratio = 1.15
fig, axes = plt.subplots(
    figsize=scale * np.array([ncols, aspect_ratio * nrows]),
    ncols=ncols,
    nrows=nrows,
    sharex=True,
    sharey=True,
    gridspec_kw={"width_ratios": (ncols - 1) * [1] + [1.24]},
    tight_layout=True,
)
fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False

for subsystem in range(nrows):
    for i in range(ncols):
        ax = axes[subsystem, i]
        if i == 0:
            alpha_str = R"I_\mathrm{tot}"
        elif i == 1:
            alpha_str = R"|\alpha|"
        else:
            xyz = i - 2
            alpha_str = Rf"\alpha_{'xyz'[xyz]}"
        title = alpha_str
        if subsystem > 0:
            label = subsystem_labels[subsystem - 1]
            title = Rf"{title}\left({label}\right)"
        ax.set_title(f"${title}$")
        if ax is axes[-1, i]:
            ax.set_xlabel(s1_label)
        if i == 0:
            ax.set_ylabel(s2_label)

color_mesh = np.full([nrows, ncols], None)


def plot3(**kwargs):
    global color_mesh
    kwargs = to_complex_kwargs(**kwargs)
    for subsystem in range(nrows):
        # alpha_xyz distributions
        alpha_xyz_arrays = []
        for i in range(2, ncols):
            xyz = i - 2
            if subsystem == 0:
                z_values = polarization_func[xyz](kinematic_arrays)
            else:
                identifier = subsystem_identifiers[subsystem - 1]
                z_values = compute_sub_func(
                    polarization_func[xyz], kinematic_arrays, identifier
                )
            z_values = np.real(z_values)
            alpha_xyz_arrays.append(z_values)
            ax = axes[subsystem, i]
            if color_mesh[subsystem, i] is None:
                color_mesh[subsystem, i] = ax.pcolormesh(
                    X, Y, z_values, cmap=cm.coolwarm
                )
                if ax is axes[subsystem, -1]:
                    fig.colorbar(color_mesh[subsystem, i], ax=ax)
            else:
                color_mesh[subsystem, i].set_array(z_values)
            color_mesh[subsystem, i].set_clim(vmin=-1, vmax=+1)
        # absolute value of alpha_xyz vector
        i = 1
        alpha_abs = np.sqrt(np.sum(np.array(alpha_xyz_arrays) ** 2, axis=0))
        ax = axes[subsystem, i]
        if color_mesh[subsystem, i] is None:
            color_mesh[subsystem, i] = ax.pcolormesh(X, Y, alpha_abs, cmap=cm.coolwarm)
        else:
            color_mesh[subsystem, i].set_array(alpha_abs)
        color_mesh[subsystem, i].set_clim(vmin=-1, vmax=+1)
        # total intensity
        i = 0
        if subsystem == 0:
            z_values = intensity_func(kinematic_arrays)
        else:
            identifier = subsystem_identifiers[subsystem - 1]
            z_values = compute_sub_func(intensity_func, kinematic_arrays, identifier)
        ax = axes[subsystem, i]
        if color_mesh[subsystem, i] is None:
            color_mesh[subsystem, i] = ax.pcolormesh(X, Y, z_values, norm=LogNorm())
        else:
            color_mesh[subsystem, i].set_array(z_values)
    fig.canvas.draw()


def to_complex_kwargs(**kwargs):
    complex_valued_kwargs = {}
    for key, value in dict(kwargs).items():
        if key.endswith("real"):
            symbol_name = key[:-5]
            imag = kwargs[f"{symbol_name}_imag"]
            complex_valued_kwargs[symbol_name] = complex(value, imag)
        elif key.endswith("imag"):
            continue
        else:
            complex_valued_kwargs[key] = value
    return complex_valued_kwargs


output = interactive_output(plot3, controls=sliders)
display(ui, output)