# Polarization sensitivity in $\Lambda_c \to p \pi K$

_This notebook originates from [ComPWA/compwa-org#129](https://github.com/ComPWA/compwa-org/pull/129) [[TR-018](https://compwa-org--129.org.readthedocs.build/report/018.html)], [`625e50c`](https://github.com/ComPWA/compwa-org/pull/129/commits/94ceb51)._

In [None]:
from __future__ import annotations

import itertools
import json
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import qrules
import sympy as sp
from ampform.sympy import PoolSum, UnevaluatedExpression
from IPython.display import HTML, Math, display
from ipywidgets import Button, Combobox, HBox, HTMLMath, Tab, VBox, interactive_output
from matplotlib import cm
from matplotlib.colors import LogNorm
from qrules.particle import Particle
from symplot import create_slider
from sympy.core.symbol import Str
from sympy.physics.matrices import msigma
from sympy.physics.quantum.spin import Rotation as Wigner
from tensorwaves.data.transform import SympyDataTransformer
from tensorwaves.function.sympy import create_function, create_parametrized_function

from polarization.decay import IsobarNode, Resonance
from polarization.dynamics import (
    BlattWeisskopf,
    EnergyDependentWidth,
    Källén,
    P,
    Q,
    RelativisticBreitWigner,
)
from polarization.io import as_latex, from_qrules, load_resonance_definitions
from polarization.spin import (
    create_spin_range,
    filter_parity_violating_ls,
    generate_ls_couplings,
)

PDG = qrules.load_pdg()


def display_latex(obj) -> None:
    latex = as_latex(obj)
    display(Math(latex))


def display_doit(
    expr: UnevaluatedExpression, deep=False, terms_per_line: int | None = None
) -> None:
    if terms_per_line is None:
        latex = as_latex({expr: expr.doit(deep=deep)})
    else:
        latex = sp.multiline_latex(
            lhs=expr,
            rhs=expr.doit(deep=deep),
            terms_per_line=terms_per_line,
            environment="eqnarray",
        )
    display(Math(latex))


# hack for moving Indexed indices below superscript of the base
def _print_Indexed_latex(self, printer, *args):
    base = printer._print(self.base)
    indices = ", ".join(map(printer._print, self.indices))
    return f"{base}_{{{indices}}}"


sp.Indexed._latex = _print_Indexed_latex

## Amplitude model

### Resonances and $LS$-scheme

Particle definitions for $\Lambda_c^+$ and $p, \pi^+, K^-$:

In [None]:
Λc = from_qrules(PDG["Lambda(c)+"])
p = from_qrules(PDG["p"])
K = from_qrules(PDG["K-"])
π = from_qrules(PDG["pi+"])

Resonance definitions as defined in [`data/isobars.json`](../data/isobars.json):

In [None]:
resonances = load_resonance_definitions("../data/isobars.json")
len(resonances)

In [None]:
def create_isobar(resonance: Resonance) -> IsobarNode:
    if resonance.name.startswith("K"):
        child1, child2, sibling = π, K, p
    elif resonance.name.startswith("L"):
        child1, child2, sibling = p, K, π
    elif resonance.name.startswith("D"):
        child1, child2, sibling = p, π, K
    else:
        raise NotImplementedError
    return IsobarNode(
        parent=Λc,
        child1=sibling,
        child2=IsobarNode(
            parent=resonance,
            child1=child1,
            child2=child2,
            interaction=generate_L_min(resonance, child1, child2),
        ),
        interaction=generate_L_min(Λc, sibling, resonance),
    )


def generate_L_min(parent: Resonance, child1: Resonance, child2: Resonance) -> int:
    ls = generate_ls_couplings(parent.spin, child1.spin, child2.spin)
    ls = filter_parity_violating_ls(ls, parent.parity, child1.parity, child2.parity)
    return min(ls)


assert len(resonances) == 12
isobars = [create_isobar(res) for res in resonances.values()]
Math(as_latex(isobars, with_jp=True))

In [None]:
def create_html_table_row(*items, typ="td"):
    items = map(lambda i: f"<{typ}>{i}</{typ}>", items)
    return "<tr>" + "".join(items) + "</tr>\n"


column_names = [
    "resonance",
    R"\(j^P\)",
    R"\(m\) (MeV)",
    R"\(\Gamma_0\) (MeV)",
    R"\(l_R\)",
    R"\(l_{\Lambda_c}^\mathrm{min}\)",
]
src = "<table>\n"
src += create_html_table_row(*column_names, typ="th")
for production in isobars:
    decay = production.child2
    assert isinstance(decay, IsobarNode)
    assert production.interaction is not None
    assert decay.interaction is not None
    resonance = decay.parent
    child1 = as_latex(decay.child1)
    child2 = as_latex(decay.child2)
    src += create_html_table_row(
        Rf"\({resonance.latex} \to" Rf" {child1} {child2}\)",
        as_latex(resonance, only_jp=True),
        resonance.mass,
        resonance.width,
        production.interaction.L,
        decay.interaction.L,
    )
src += "</table>\n"
HTML(src)

### Aligned amplitude

In [None]:
A_K = sp.IndexedBase(R"A^K")
A_Λ = sp.IndexedBase(R"A^{\Lambda}")
A_Δ = sp.IndexedBase(R"A^{\Delta}")

half = sp.S.Half

ζ_0_11 = sp.Symbol(R"\zeta^0_{1(1)}", real=True)
ζ_0_21 = sp.Symbol(R"\zeta^0_{2(1)}", real=True)
ζ_0_31 = sp.Symbol(R"\zeta^0_{3(1)}", real=True)
ζ_1_11 = sp.Symbol(R"\zeta^1_{1(1)}", real=True)
ζ_1_21 = sp.Symbol(R"\zeta^1_{2(1)}", real=True)
ζ_1_31 = sp.Symbol(R"\zeta^1_{3(1)}", real=True)


def formulate_aligned_amplitude(λ_Λc, λ_p):
    _ν = sp.Symbol(R"\nu^{\prime}", rational=True)
    _λ = sp.Symbol(R"\lambda^{\prime}", rational=True)
    return PoolSum(
        A_K[_ν, _λ] * Wigner.d(half, λ_Λc, _ν, ζ_0_11) * Wigner.d(half, _λ, λ_p, ζ_1_11)
        + A_Λ[_ν, _λ]
        * Wigner.d(half, λ_Λc, _ν, ζ_0_21)
        * Wigner.d(half, _λ, λ_p, ζ_1_21)
        + A_Δ[_ν, _λ]
        * Wigner.d(half, λ_Λc, _ν, ζ_0_31)
        * Wigner.d(half, _λ, λ_p, ζ_1_31),
        (_λ, [-half, +half]),
        (_ν, [-half, +half]),
    )


ν = sp.Symbol("nu")
λ = sp.Symbol("lambda")
formulate_aligned_amplitude(λ_Λc=ν, λ_p=λ)

### Dynamics

In [None]:
z = sp.Symbol("z", positive=True)
L = sp.Symbol("L", integer=True, nonnegative=True)
display_doit(BlattWeisskopf(z, L))

In [None]:
x, y, z = sp.symbols("x:z")
display_doit(Källén(x, y, z))

In [None]:
s, m0, mi, mj, mk = sp.symbols("s m0 m_i:k", nonnegative=True)
display_doit(P(s, mi, mj))
display_doit(Q(s, m0, mk))

In [None]:
R = sp.Symbol("R")
parameter_defaults = {
    R: 5,  # GeV^{-1} (length factor)
}
l_R = sp.Symbol("l_R", integer=True, positive=True)
m, Γ0, m1, m2 = sp.symbols("m Γ0 m1 m2", nonnegative=True)
display_doit(EnergyDependentWidth(s, m, Γ0, m1, m2, l_R, R))

In [None]:
l_Λc = sp.Symbol(R"l_{\Lambda_c}", integer=True, positive=True)
display_doit(RelativisticBreitWigner(s, m, Γ0, m1, m2, l_R, l_Λc, R))

### Decay chain amplitudes

In [None]:
H_prod = sp.IndexedBase(R"\mathcal{H}^\mathrm{production}")
H_dec = sp.IndexedBase(R"\mathcal{H}^\mathrm{decay}")

θ23 = sp.Symbol("theta23", real=True)
θ31 = sp.Symbol("theta31", real=True)
θ12 = sp.Symbol("theta12", real=True)

σ1, σ2, σ3 = sp.symbols("sigma1:4", nonnegative=True)
m1, m2, m3 = sp.symbols(R"m_p m_pi m_K", nonnegative=True)


def formulate_K_amplitude(λ_Λc, λ_p, isobars: list[IsobarNode]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ_Λc, τ - λ_p)
                * H_prod[stringify(iso.child2), τ, λ_p]
                * formulate_dynamics(iso, σ1, m2, m3)
                * (-1) ** (half - λ_p)
                * Wigner.d(sp.Rational(iso.child2.parent.spin), τ, 0, θ23)
                * H_dec[stringify(iso.child2), 0, 0],
                (τ, create_spin_range(iso.child2.parent.spin)),
            )
            for iso in isobars
        ]
    )


def formulate_Λ_amplitude(λ_Λc, λ_p, isobars: list[IsobarNode]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ_Λc, τ)
                * H_prod[stringify(iso.child2), τ, 0]
                * formulate_dynamics(iso, σ2, m1, m3)
                * Wigner.d(sp.Rational(iso.child2.parent.spin), τ, -λ_p, θ31)
                * H_dec[stringify(iso.child2), 0, λ_p]
                * (-1) ** (half - λ_p),
                (τ, create_spin_range(iso.child2.parent.spin)),
            )
            for iso in isobars
        ]
    )


def formulate_Δ_amplitude(λ_Λc, λ_p, isobars: list[IsobarNode]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ_Λc, τ)
                * H_prod[stringify(iso.child2), τ, 0]
                * formulate_dynamics(iso, σ3, m1, m2)
                * Wigner.d(sp.Rational(iso.child2.parent.spin), τ, λ_p, θ12)
                * H_dec[stringify(iso.child2), λ_p, 0],
                (τ, create_spin_range(iso.child2.parent.spin)),
            )
            for iso in isobars
        ]
    )


def formulate_dynamics(isobar: IsobarNode, s, m1, m2):
    production, decay = isobar, isobar.child2
    resonance = decay.parent
    assert production.interaction is not None
    assert decay.interaction is not None
    assert isinstance(resonance, Resonance)
    l_R = sp.Rational(production.interaction.L)
    l_Λc = sp.Rational(decay.interaction.L)
    mass = sp.Symbol(f"m_{{{resonance.latex}}}")
    width = sp.Symbol(Rf"\Gamma_{{{resonance.latex}}}")
    parameter_defaults[mass] = resonance.mass
    parameter_defaults[width] = resonance.width
    return RelativisticBreitWigner(s, mass, width, m1, m2, l_R, l_Λc, R)


def stringify(obj) -> Str:
    if isinstance(obj, IsobarNode):
        return Str(obj.parent.latex)
    if isinstance(obj, (Particle, Resonance)):
        return Str(obj.latex)
    return Str(f"{obj}")


def filter_isobars(
    isobars: list[IsobarNode], resonance_pattern: str
) -> list[IsobarNode]:
    return [
        isobar
        for isobar in isobars
        if isobar.child2.parent.name.startswith(resonance_pattern)
    ]


def formulate_subsystem_amplitude(subsystem: int, λ_Λc, λ_p):
    if subsystem == 1:
        return formulate_K_amplitude(λ_Λc, λ_p, filter_isobars(isobars, "K"))
    if subsystem == 2:
        return formulate_Λ_amplitude(λ_Λc, λ_p, filter_isobars(isobars, "L"))
    if subsystem == 3:
        return formulate_Δ_amplitude(λ_Λc, λ_p, filter_isobars(isobars, "D"))
    raise NotImplementedError(f"No chain implemented for sub-system {subsystem}")

### Angle definitions

In [None]:
m0 = sp.Symbol(R"m_{\Lambda_c}", nonnegative=True)
angles = {
    θ12: sp.acos(
        (
            2 * σ3 * (σ2 - m3**2 - m1**2)
            - (σ3 + m1**2 - m2**2) * (m0**2 - σ3 - m3**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m3**2, σ3))
            * sp.sqrt(Källén(σ3, m1**2, m2**2))
        )
    ),
    θ23: sp.acos(
        (
            2 * σ1 * (σ3 - m1**2 - m2**2)
            - (σ1 + m2**2 - m3**2) * (m0**2 - σ1 - m1**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(σ1, m2**2, m3**2))
        )
    ),
    θ31: sp.acos(
        (
            2 * σ2 * (σ1 - m2**2 - m3**2)
            - (σ2 + m3**2 - m1**2) * (m0**2 - σ2 - m2**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m2**2, σ2))
            * sp.sqrt(Källén(σ2, m3**2, m1**2))
        )
    ),
    ζ_0_11: sp.S.Zero,  # = \hat\theta^0_{1(1)}
    ζ_0_21: -sp.acos(  # = -\hat\theta^{1(2)}
        (
            (m0**2 + m1**2 - σ1) * (m0**2 + m2**2 - σ2)
            - 2 * m0**2 * (σ3 - m1**2 - m2**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m2**2, σ2))
            * sp.sqrt(Källén(m0**2, σ1, m1**2))
        )
    ),
    ζ_0_31: sp.acos(  # = \hat\theta^{3(1)}
        (
            (m0**2 + m3**2 - σ3) * (m0**2 + m1**2 - σ1)
            - 2 * m0**2 * (σ2 - m3**2 - m1**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(m0**2, σ3, m3**2))
        )
    ),
    ζ_1_11: sp.S.Zero,
    ζ_1_21: sp.acos(
        (
            2 * m1**2 * (σ3 - m0**2 - m3**2)
            + (m0**2 + m1**2 - σ1) * (σ2 - m1**2 - m3**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(σ2, m1**2, m3**2))
        )
    ),
    ζ_1_31: -sp.acos(  # = -\zeta^1_{1(3)}
        (
            2 * m1**2 * (σ2 - m0**2 - m2**2)
            + (m0**2 + m1**2 - σ1) * (σ3 - m1**2 - m2**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(σ3, m1**2, m2**2))
        )
    ),
}

display_latex(angles)

In [None]:
masses = {
    m0: Λc.mass,
    m1: p.mass,
    m2: π.mass,
    m3: K.mass,
}
display_latex(masses)

### Helicity coupling values

In [None]:
def to_symbol_definitions(
    parameter_dict: dict[str, str]
) -> dict[sp.Basic, complex | float]:
    key_to_val: dict[str, complex | float] = {}
    for key, str_value in parameter_dict.items():
        if key.startswith("Ar"):
            identifier = key[2:]
            str_imag = parameter_dict[f"Ai{identifier}"]
            real = to_float(str_value)
            imag = to_float(str_imag)
            key_to_val[f"A{identifier}"] = complex(real, imag)
    return {to_symbol(key): value for key, value in key_to_val.items()}


def to_float(str_value: str) -> float:
    value, _ = map(float, str_value.split(" ± "))
    return value


def to_symbol(key: str) -> sp.Indexed | sp.Symbol:
    if key.startswith("A"):
        res = stringify(key[1:-1])
        i = int(key[-1])
        if str(res).startswith("L"):
            if i == 1:
                return H_prod[res, +half, 0]
            if i == 2:
                return H_prod[res, -half, 0]
        if str(res).startswith("D"):
            if i == 1:
                return H_prod[res, +half, 0]
            if i == 2:
                return H_prod[res, -half, 0]
        if str(res).startswith("K"):
            if str(res) in {"K(700)", "K(1430)"}:
                if i == 1:
                    return H_prod[res, 0, -half]
                if i == 2:
                    return H_prod[res, 0, +half]
            else:
                if i == 1:
                    return H_prod[res, 0, +half]
                if i == 2:
                    return H_prod[res, +1, +half]
                if i == 3:
                    return H_prod[res, -1, -half]
                if i == 4:
                    return H_prod[res, 0, -half]
    raise NotImplementedError(
        f'Cannot convert key "{key}" in model parameter JSON file to SymPy symbol'
    )


with open("../data/modelparameters.json") as stream:
    data = json.load(stream)
assert len(data["modelstudies"]) == 18
model_number = 0
imported_parameters = to_symbol_definitions(
    data["modelstudies"][model_number]["parameters"]
)

In [None]:
prod_couplings = {
    key: value for key, value in imported_parameters.items() if key.base == H_prod
}
display_latex(prod_couplings)
couplings = dict(prod_couplings)
parameter_defaults.update(prod_couplings)

In [None]:
dec_couplings = {}
for isobar in isobars:
    res = isobar.child2.parent
    assert isinstance(res, Resonance)
    i = stringify(res)
    if res.name.startswith("K"):
        dec_couplings[H_dec[i, 0, 0]] = 1
    if res.name.startswith("L"):
        dec_couplings[H_dec[i, 0, half]] = 1
        dec_couplings[H_dec[i, 0, -half]] = (
            int(res.parity)
            * int(K.parity)
            * int(p.parity)
            * (-1) ** (res.spin - K.spin - p.spin)
        )
    if res.name.startswith("D"):
        dec_couplings[H_dec[i, half, 0]] = 1
        dec_couplings[H_dec[i, -half, 0]] = (
            int(res.parity)
            * int(p.parity)
            * int(π.parity)
            * (-1) ** (res.spin - p.spin - π.spin)
        )
parameter_defaults.update(dec_couplings)
couplings.update(dec_couplings)
display_latex(dec_couplings)

### Intensity expression

In [None]:
def formulate_intensity(amplitude_builder):
    return PoolSum(
        sp.Abs(amplitude_builder(ν, λ)) ** 2,
        (λ, [-half, +half]),
        (ν, [-half, +half]),
    )


intensity_expressions = {
    0: formulate_intensity(formulate_aligned_amplitude),
    1: formulate_intensity(partial(formulate_subsystem_amplitude, 1)),
    2: formulate_intensity(partial(formulate_subsystem_amplitude, 2)),
    3: formulate_intensity(partial(formulate_subsystem_amplitude, 3)),
}
intensity_expressions[0]

In [None]:
A = {1: A_K, 2: A_Λ, 3: A_Δ}
amp_definitions = {}
for subsystem in range(1, 4):
    for Λc_heli, p_heli in itertools.product([-half, +half], [-half, +half]):
        symbol = A[subsystem][Λc_heli, p_heli]
        expr = formulate_subsystem_amplitude(subsystem, ν, λ)
        amp_definitions[symbol] = expr.subs({ν: Λc_heli, λ: p_heli})
display_latex(amp_definitions)

In [None]:
substituted_intensity_expressions = {}
for subsystem, expr in intensity_expressions.items():
    expr = expr.doit().xreplace(amp_definitions).doit()
    expr = expr.xreplace(angles).doit().xreplace(masses)
    substituted_intensity_expressions[subsystem] = expr
    expr = expr.xreplace(parameter_defaults)
    display(sp.Array(expr.free_symbols))
    if subsystem == 0:
        assert expr.free_symbols == {σ1, σ2, σ3}
    else:
        assert expr.free_symbols < {σ1, σ2, σ3}

## Polarization sensitivity

### Total polarization sensitivity

In [None]:
def to_index(helicity):
    """Symbolic conversion of half-value helicities to Pauli matrix indices."""
    # https://github.com/ComPWA/compwa-org/pull/129#issuecomment-1096599896
    return sp.Piecewise(
        (1, sp.LessThan(helicity, 0)),
        (0, True),
    )


ν_prime = sp.Symbol(R"\nu^{\prime}")
total_polarization = sp.Array(
    PoolSum(
        formulate_aligned_amplitude(ν, λ).conjugate()
        * msigma(i)[to_index(ν), to_index(ν_prime)]
        * formulate_aligned_amplitude(ν_prime, λ),
        (λ, [-half, +half]),
        (ν, [-half, +half]),
        (ν_prime, [-half, +half]),
    )
    / intensity_expressions[0]
    for i in [1, 2, 3]
)

### Polarization sensitivity per chain

In [None]:
polarization_expressions = {
    0: total_polarization,
}
for subsystem in range(1, 4):
    polarization_expressions[subsystem] = sp.Array(
        PoolSum(
            formulate_subsystem_amplitude(subsystem, ν, λ).conjugate()
            * msigma(i)[to_index(ν), to_index(ν_prime)]
            * formulate_subsystem_amplitude(subsystem, ν_prime, λ),
            (λ, [-half, +half]),
            (ν, [-half, +half]),
            (ν_prime, [-half, +half]),
        )
        / intensity_expressions[subsystem]
        for i in [1, 2, 3]
    )

In [None]:
substituted_polarization_expressions = {}
for subsystem, expr in polarization_expressions.items():
    expr = expr.doit().xreplace(amp_definitions).doit()
    expr = expr.xreplace(angles).doit().xreplace(masses)
    substituted_polarization_expressions[subsystem] = expr
    expr = expr.xreplace(parameter_defaults)
    if subsystem == 0:
        assert expr.free_symbols == {σ1, σ2, σ3}
    else:
        assert expr.free_symbols < {σ1, σ2, σ3}

## Computations with TensorWaves


### Conversion to computational backend

In [None]:
free_parameters = {
    symbol: value
    for symbol, value in parameter_defaults.items()
    if symbol.name.startswith("m_")
    or symbol.name.startswith(R"\Gamma_")
    or symbol in prod_couplings
}
fixed_parameters = {
    symbol: value
    for symbol, value in parameter_defaults.items()
    if symbol not in free_parameters
}

In [None]:
intensity_functions = {
    subsystem: create_parametrized_function(
        expr.subs(fixed_parameters),
        parameters=free_parameters,
        backend="jax",
    )
    for subsystem, expr in substituted_intensity_expressions.items()
}

In [None]:
polarization_functions = {
    subsystem: [
        create_parametrized_function(
            expr[i].subs(fixed_parameters),
            parameters=free_parameters,
            backend="jax",
        )
        for i in range(3)
    ]
    for subsystem, expr in substituted_polarization_expressions.items()
}

### Phase space

In [None]:
computed_σ3 = m0**2 + m1**2 + m2**2 + m3**2 - σ1 - σ2
compute_third_mandelstam = create_function(computed_σ3.subs(masses), backend="jax")
display_latex({σ3: computed_σ3})

In [None]:
def kibble_function(σ1, σ2):
    return Källén(
        Källén(σ2, m2**2, m0**2),
        Källén(σ3, m3**2, m0**2),
        Källén(σ1, m1**2, m0**2),
    )


def is_within_phsp(σ1, σ2, non_phsp_value=sp.nan):
    return sp.Piecewise(
        (1, sp.LessThan(kibble_function(σ1, σ2), 0)),
        (non_phsp_value, True),
    )


is_within_phsp(σ1, σ2)

In [None]:
in_phsp_expr = is_within_phsp(σ1, σ2).subs(σ3, computed_σ3).subs(masses).doit()
in_phsp_expr.free_symbols

In [None]:
resolution = 200
m0_val, m1_val, m2_val, m3_val = masses.values()
σ1_min = (m2_val + m3_val) ** 2
σ1_max = (m0_val - m1_val) ** 2
σ2_min = (m1_val + m3_val) ** 2
σ2_max = (m0_val - m2_val) ** 2
x = np.linspace(σ1_min, σ1_max, num=resolution)
y = np.linspace(σ2_min, σ2_max, num=resolution)
X, Y = np.meshgrid(x, y)
Z = compute_third_mandelstam.function(X, Y)
σ_arrays = {"sigma1": X, "sigma2": Y, "sigma3": Z}

in_phsp = create_function(in_phsp_expr, backend="numpy")
phsp = in_phsp(σ_arrays)

In [None]:
kinematic_variables = {
    symbol: expression.doit().subs(masses).subs(fixed_parameters)
    for symbol, expression in angles.items()
}
kinematic_variables.update({s: s for s in [σ1, σ2, σ3]})  # include identity
transformer = SympyDataTransformer.from_sympy(kinematic_variables, backend="jax")
kinematic_arrays = transformer(σ_arrays)

### Intensity distribution

In [None]:
intensities = {
    subsystem: func(kinematic_arrays) for subsystem, func in intensity_functions.items()
}

In [None]:
%config InlineBackend.figure_formats = ['png']

In [None]:
s1_label = R"$\sigma_1=m^2\left(K\pi\right)$"
s2_label = R"$\sigma_2=m^2\left(pK\right)$"
s3_label = R"$\sigma_3=m^2\left(p\pi\right)$"

fig, ax = plt.subplots(
    figsize=(10, 8),
    tight_layout=True,
)
ax.set_title("Intensity distribution")
ax.set_xlabel(s1_label)
ax.set_ylabel(s2_label)

mesh = ax.pcolormesh(X, Y, phsp * intensities[0], norm=LogNorm())
fig.colorbar(mesh, ax=ax)
plt.show()

In [None]:
%config InlineBackend.figure_formats = ['svg']

In [None]:
def set_ylim_to_zero(ax):
    _, y_max = ax.get_ylim()
    ax.set_ylim(0, y_max)


fig, (ax1, ax2) = plt.subplots(
    ncols=2,
    figsize=(12, 5),
    tight_layout=True,
)
ax1.set_xlabel(s1_label)
ax2.set_xlabel(s2_label)

subsystem_names = {
    1: "K^{**}",
    2: R"\Lambda^{**}",
    3: R"\Delta^{**}",
}
ax1.fill(x, np.nansum(intensities[0], axis=0), alpha=0.3)
ax2.fill(y, np.nansum(intensities[0], axis=1), alpha=0.3)
for subsystem in range(1, 4):
    label = subsystem_names[subsystem]
    label = f"${label}$"
    ax1.plot(x, np.nansum(intensities[subsystem], axis=0), label=label)
    ax2.plot(y, np.nansum(intensities[subsystem], axis=1), label=label)
set_ylim_to_zero(ax1)
set_ylim_to_zero(ax2)
ax1.legend()
plt.show()

### Fit fractions

In [None]:
def sub_intensity(data, non_zero_couplings: list[str]):
    func = intensity_functions[0]
    new_parameters = dict(func.parameters)
    for par_name in new_parameters:
        if not par_name.startswith(R"\mathcal{H}"):
            continue
        if any(map(lambda s: s in par_name, non_zero_couplings)):
            continue
        new_parameters[par_name] = 0
    old_parameters = dict(func.parameters)
    func.update_parameters(new_parameters)
    intensities = func(data)
    func.update_parameters(old_parameters)
    return integrate_intensity(intensities)


def integrate_intensity(intensities):
    return np.nansum(intensities) / len(intensities)


I_tot = integrate_intensity(intensity_functions[0](kinematic_arrays))
np.testing.assert_allclose(
    I_tot,
    sub_intensity(kinematic_arrays, ["K", "L", "D"]),
)

In [None]:
def interference_intensity(
    data,
    chain1: list[str],
    chain2: list[str],
):
    I_interference = sub_intensity(data, chain1 + chain2)
    I_chain1 = sub_intensity(data, chain1)
    I_chain2 = sub_intensity(data, chain2)
    return I_interference - I_chain1 - I_chain2


I_K = sub_intensity(kinematic_arrays, non_zero_couplings=["K"])
I_Λ = sub_intensity(kinematic_arrays, non_zero_couplings=["L"])
I_Δ = sub_intensity(kinematic_arrays, non_zero_couplings=["D"])
I_ΛΔ = interference_intensity(kinematic_arrays, ["L"], ["D"])
I_KΔ = interference_intensity(kinematic_arrays, ["K"], ["D"])
I_KΛ = interference_intensity(kinematic_arrays, ["K"], ["L"])
np.testing.assert_allclose(I_tot, I_K + I_Λ + I_Δ + I_ΛΔ + I_KΔ + I_KΛ)

In [None]:
def render_resonance_row(resonance_identifier: str):
    resonances: list[Resonance] = [
        iso.child2.parent
        for iso in isobars
        if iso.child2.parent.name.startswith(resonance_identifier)
    ]
    rows = [
        (
            Rf"\color{{gray}}{{{p.latex}}}",
            Rf"\color{{gray}}{{{sub_intensity(kinematic_arrays, [p.name])/I_tot:.3f}}}",
        )
        for p in resonances
    ]
    if len(rows) > 1:
        return rows
    return []


rows = [
    R"\hline",
    ("K^{**}", f"{I_K/I_tot:.3f}"),
    *render_resonance_row("L"),
    (R"\Lambda^{**}", f"{I_Λ/I_tot:.3f}"),
    *render_resonance_row("D"),
    (R"\Delta^{**}", f"{I_Δ/I_tot:.3f}"),
    *render_resonance_row("K"),
    (R"\Delta/\Lambda", f"{I_ΛΔ/I_tot:.3f}"),
    (R"K/\Delta", f"{I_KΔ/I_tot:.3f}"),
    (R"K/\Lambda", f"{I_KΛ/I_tot:.3f}"),
    R"\hline",
    (
        R"\mathrm{total}",
        f"{(I_K + I_Λ + I_Δ + I_ΛΔ + I_KΔ + I_KΛ) /I_tot:.3f}",
    ),
]

latex = R"\begin{array}{crr}" + "\n"
latex += R"& I_\mathrm{sub}\,/\,I \\" + "\n"
for row in rows:
    if row == R"\hline":
        latex += R"\hline"
    else:
        latex += "  " + " & ".join(row) + R" \\" + "\n"
latex += R"\end{array}"
Math(latex)

### Polarization distributions

In [None]:
polarization_values = {
    subsystem: [func[i](kinematic_arrays) for i in range(3)]
    for subsystem, func in polarization_functions.items()
}
for subsystem in range(4):
    for array in polarization_values[subsystem]:
        assert np.nanmax(array.imag) < 1e-10

In [None]:
def render_mean(array, plus=True):
    array = array.real
    mean = f"{np.nanmean(array):.3f}"
    std = f"{np.nanstd(array):.3f}"
    if plus and float(mean) > 0:
        mean = f"+{mean}"
    return Rf"{mean} \pm {std}"


latex = R"\begin{array}{cccc}" + "\n"
latex += R"& \bar{|\alpha|} & \bar\alpha_x & \bar\alpha_y & \bar\alpha_z \\" + "\n"
for subsystem, label in subsystem_names.items():
    latex += f"  {label} & "
    x, y, z = polarization_values[subsystem]
    latex += render_mean(np.sqrt(x**2 + y**2 + z**2), plus=False) + " & "
    latex += " & ".join(map(render_mean, [x, y, z]))
    latex += R" \\" + "\n"
latex += R"\end{array}"
Math(latex)

In [None]:
%config InlineBackend.figure_formats = ['png']
%matplotlib widget

#### Slider definitions

In [None]:
# Slider construction
sliders = {}
for symbol, value in free_parameters.items():
    if symbol.name.startswith(R"\mathcal{H}"):
        real_slider = create_slider(symbol)
        imag_slider = create_slider(symbol)
        sliders[f"{symbol.name}_real"] = real_slider
        sliders[f"{symbol.name}_imag"] = imag_slider
        real_slider.description = R"\(\mathrm{Re}\)"
        imag_slider.description = R"\(\mathrm{Im}\)"
    else:
        slider = create_slider(symbol)
        sliders[symbol.name] = slider

# Slider ranges
σ3_max = (m0_val - m3_val) ** 2
σ3_min = (m1_val + m2_val) ** 2

for name, slider in sliders.items():
    slider.continuous_update = False
    slider.step = 0.01
    if name.startswith("m_"):
        if "K" in name:
            slider.min = np.sqrt(σ1_min)
            slider.max = np.sqrt(σ1_max)
        elif R"\Lambda" in name:
            slider.min = np.sqrt(σ2_min)
            slider.max = np.sqrt(σ2_max)
        elif R"\Delta" in name:
            slider.min = np.sqrt(σ3_min)
            slider.max = np.sqrt(σ3_max)
    elif name.startswith(R"\Gamma_"):
        slider.min = 0
        slider.max = max(0.5, 2 * slider.value)
    elif name.startswith(R"\mathcal{H}"):
        slider.min = -15
        slider.max = +15


# Slider values
def reset_sliders(click_event):
    for symbol, value in free_parameters.items():
        if symbol.name.startswith(R"\mathcal{H}"):
            set_slider(sliders[symbol.name + "_real"], value)
            set_slider(sliders[symbol.name + "_imag"], value)
        else:
            set_slider(sliders[symbol.name], value)


def set_coupling_to_zero(filter_pattern):
    if isinstance(filter_pattern, Combobox):
        filter_pattern = filter_pattern.value
    for name, slider in sliders.items():
        if not name.startswith(R"\mathcal{H}"):
            continue
        if filter_pattern not in name:
            continue
        set_slider(slider, 0)


def set_slider(slider, value):
    if slider.description == R"\(\mathrm{Im}\)":
        value = complex(value).imag
    else:
        value = complex(value).real
    n_decimals = -round(np.log10(slider.step))
    if slider.value != round(value, n_decimals):  # widget performance
        slider.value = value


reset_sliders(click_event=None)
reset_button = Button(description="Reset slider values")
reset_button.on_click(reset_sliders)

all_resonances = [prod.child2.parent for prod in isobars]
filter_button = Combobox(
    placeholder="Enter coupling filter pattern",
    options=all_resonances,
    description=R"$\mathcal{H}=0$",
)
filter_button.on_submit(set_coupling_to_zero)

# UI design
latex = {symbol.name: sp.latex(symbol) for symbol in free_parameters}
mass_sliders = [sliders[n] for n in sliders if n.startswith("m_")]
width_sliders = [sliders[n] for n in sliders if n.startswith(R"\Gamma_")]
coupling_sliders = {}
for isobar in isobars:
    res = isobar.child2.parent
    coupling_sliders[res.name] = (
        [s for n, s in sliders.items() if n.endswith("_real") and res.latex in n],
        [s for n, s in sliders.items() if n.endswith("_imag") and res.latex in n],
        [
            HTMLMath(f"${latex[n[:-5]]}$")
            for n in sliders
            if n.endswith("_real") and res.latex in n
        ],
    )
slider_tabs = Tab(
    children=[
        Tab(
            children=[
                VBox([HBox(s) for s in zip(*pair)])
                for pair in coupling_sliders.values()
            ],
            _titles={i: label for i, label in enumerate(coupling_sliders)},
        ),
        VBox([HBox([r, i]) for r, i in zip(mass_sliders, width_sliders)]),
    ],
    _titles=dict(enumerate(["Couplings", "Masses and widths"])),
)
ui = VBox([slider_tabs, HBox([reset_button, filter_button])])

#### Interactive plot

In [None]:
nrows = 3
ncols = 5
scale = 2.6
aspect_ratio = 1.15
fig, axes = plt.subplots(
    figsize=scale * np.array([ncols, aspect_ratio * nrows]),
    ncols=ncols,
    nrows=nrows,
    sharex=True,
    sharey=True,
    gridspec_kw={"width_ratios": (ncols - 1) * [1] + [1.24]},
    tight_layout=True,
)
fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False

for subsystem in range(nrows):
    for i in range(ncols):
        ax = axes[subsystem, i]
        if i == 0:
            alpha_str = R"I_\mathrm{tot}"
        elif i == 1:
            alpha_str = R"|\alpha|"
        else:
            alpha_str = Rf"\alpha_{'xyz'[i-2]}"
        title = alpha_str
        if subsystem > 0:
            title = Rf"{title}\left({subsystem_names[subsystem]}\right)"
        ax.set_title(f"${title}$")
        if ax is axes[-1, i]:
            ax.set_xlabel(s1_label)
        if i == 0:
            ax.set_ylabel(s2_label)

color_mesh = np.full([nrows, ncols], None)


def plot3(**kwargs):
    global color_mesh
    kwargs = to_complex_kwargs(**kwargs)
    for subsystem in range(nrows):
        # alpha_xyz distributions
        alpha_xyz_arrays = []
        for i in range(2, ncols):
            xyz = i - 2
            func = polarization_functions[subsystem][xyz]
            func.update_parameters(kwargs)
            z_values = func(kinematic_arrays)
            z_values = np.real(z_values)
            alpha_xyz_arrays.append(z_values)
            ax = axes[subsystem, i]
            if color_mesh[subsystem, i] is None:
                color_mesh[subsystem, i] = ax.pcolormesh(
                    X, Y, z_values, cmap=cm.coolwarm
                )
                if ax is axes[subsystem, -1]:
                    fig.colorbar(color_mesh[subsystem, i], ax=ax)
            else:
                color_mesh[subsystem, i].set_array(z_values)
            color_mesh[subsystem, i].set_clim(vmin=-1, vmax=+1)
        # absolute value of alpha_xyz vector
        i = 1
        alpha_abs = np.sqrt(np.sum(np.array(alpha_xyz_arrays) ** 2, axis=0))
        ax = axes[subsystem, i]
        if color_mesh[subsystem, i] is None:
            color_mesh[subsystem, i] = ax.pcolormesh(X, Y, alpha_abs, cmap=cm.coolwarm)
        else:
            color_mesh[subsystem, i].set_array(alpha_abs)
        color_mesh[subsystem, i].set_clim(vmin=-1, vmax=+1)
        # total intensity
        i = 0
        func = intensity_functions[subsystem]
        func.update_parameters(kwargs)
        z_values = func(kinematic_arrays)
        ax = axes[subsystem, i]
        if color_mesh[subsystem, i] is None:
            color_mesh[subsystem, i] = ax.pcolormesh(X, Y, z_values)
        else:
            color_mesh[subsystem, i].set_array(z_values)
    fig.canvas.draw()


def to_complex_kwargs(**kwargs):
    complex_valued_kwargs = {}
    for key, value in dict(kwargs).items():
        if key.endswith("real"):
            symbol_name = key[:-5]
            imag = kwargs[f"{symbol_name}_imag"]
            complex_valued_kwargs[symbol_name] = complex(value, imag)
        elif key.endswith("imag"):
            continue
        else:
            complex_valued_kwargs[key] = value
    return complex_valued_kwargs


output = interactive_output(plot3, controls=sliders)
display(ui, output)