# Amplitude model

```{autolink-concat}
```

In [None]:
import json
from itertools import product
from typing import Dict, List, Union  # run magic does not support PEP563

import sympy as sp
from ampform.sympy import PoolSum
from IPython.display import Markdown, display
from sympy.core.symbol import Str
from sympy.physics.quantum.spin import Rotation as Wigner

from polarization.decay import Particle, Resonance, ThreeBodyDecay
from polarization.dynamics import BreitWignerMinL, BuggBreitWigner, FlattéSWave, Källén
from polarization.io import as_markdown_table, display_latex
from polarization.lhcb import K, Λc, Σ, load_three_body_decays, p, π
from polarization.spin import create_spin_range


# hack for moving Indexed indices below superscript of the base
def _print_Indexed_latex(self, printer, *args):
    base = printer._print(self.base)
    indices = ", ".join(map(printer._print, self.indices))
    return f"{base}_{{{indices}}}"


sp.Indexed._latex = _print_Indexed_latex

## Resonances and LS-scheme

Particle definitions for $\Lambda_c^+$ and $p, \pi^+, K^-$:

In [None]:
Markdown(as_markdown_table([Λc, p, π, K, Σ]))

Resonance definitions as defined in {download}`data/isobars.json <../data/isobars.json>`:

In [None]:
decays = load_three_body_decays("../data/isobars.json")

In [None]:
src = as_markdown_table(decays)
Markdown(src)

## Amplitude

### Spin-alignment with DPD

In [None]:
A_K = sp.IndexedBase(R"A^K")
A_Λ = sp.IndexedBase(R"A^{\Lambda}")
A_Δ = sp.IndexedBase(R"A^{\Delta}")

half = sp.S.Half

ζ_0_11 = sp.Symbol(R"\zeta^0_{1(1)}", real=True)
ζ_0_21 = sp.Symbol(R"\zeta^0_{2(1)}", real=True)
ζ_0_31 = sp.Symbol(R"\zeta^0_{3(1)}", real=True)
ζ_1_11 = sp.Symbol(R"\zeta^1_{1(1)}", real=True)
ζ_1_21 = sp.Symbol(R"\zeta^1_{2(1)}", real=True)
ζ_1_31 = sp.Symbol(R"\zeta^1_{3(1)}", real=True)


def formulate_aligned_amplitude(λ_Λc, λ_p):
    _ν = sp.Symbol(R"\nu^{\prime}", rational=True)
    _λ = sp.Symbol(R"\lambda^{\prime}", rational=True)
    return PoolSum(
        A_K[_ν, _λ] * Wigner.d(half, λ_Λc, _ν, ζ_0_11) * Wigner.d(half, _λ, λ_p, ζ_1_11)
        + A_Λ[_ν, _λ]
        * Wigner.d(half, λ_Λc, _ν, ζ_0_21)
        * Wigner.d(half, _λ, λ_p, ζ_1_21)
        + A_Δ[_ν, _λ]
        * Wigner.d(half, λ_Λc, _ν, ζ_0_31)
        * Wigner.d(half, _λ, λ_p, ζ_1_31),
        (_λ, [-half, +half]),
        (_ν, [-half, +half]),
    )


ν = sp.Symbol("nu")
λ = sp.Symbol("lambda")
formulate_aligned_amplitude(λ_Λc=ν, λ_p=λ)

### Sub-system amplitudes

:::{warning}

Couplings are remapped from the LHCb paper to Dalitz-Plot Decomposition with [these relations](https://user-images.githubusercontent.com/22725744/165932213-34013235-8464-4018-bd21-3ebde1e86faf.png).

:::

In [None]:
H_prod = sp.IndexedBase(R"\mathcal{H}^\mathrm{production}")
H_dec = sp.IndexedBase(R"\mathcal{H}^\mathrm{decay}")

θ23 = sp.Symbol("theta23", real=True)
θ31 = sp.Symbol("theta31", real=True)
θ12 = sp.Symbol("theta12", real=True)


def to_mass_symbol(particle: Resonance) -> sp.Symbol:
    return sp.Symbol(f"m_{{{particle.latex}}}")


σ1, σ2, σ3 = sp.symbols("sigma1:4", nonnegative=True)
m1, m2, m3 = map(to_mass_symbol, [p, π, K])


def get_mandelstam_s(decay: ThreeBodyDecay) -> sp.Symbol:
    decay_masses = {to_mass_symbol(p) for p in decay.decay_products}
    if decay_masses == {m2, m3}:
        return σ1
    if decay_masses == {m1, m3}:
        return σ2
    if decay_masses == {m1, m2}:
        return σ3
    raise NotImplementedError(
        f"Cannot find Mandelstam variable for {''.join(decay_masses)}"
    )


def formulate_subsystem_amplitude(subsystem: int, λ_Λc, λ_p):
    if subsystem == 1:
        return formulate_K_amplitude(λ_Λc, λ_p, filter_isobars(decays, "K"))
    if subsystem == 2:
        return formulate_Λ_amplitude(λ_Λc, λ_p, filter_isobars(decays, "L"))
    if subsystem == 3:
        return formulate_Δ_amplitude(λ_Λc, λ_p, filter_isobars(decays, "D"))
    raise NotImplementedError(f"No chain implemented for sub-system {subsystem}")


def formulate_K_amplitude(λ_Λc, λ_p, decays: List[ThreeBodyDecay]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ_Λc, τ - λ_p)
                * H_prod[stringify(decay.resonance), τ, -λ_p]
                * formulate_dynamics(decay)
                * (-1) ** (half - λ_p)
                * Wigner.d(sp.Rational(decay.resonance.spin), τ, 0, θ23)
                * H_dec[stringify(decay.resonance), 0, 0],
                (τ, create_spin_range(decay.resonance.spin)),
            )
            for decay in decays
        ]
    )


def formulate_Λ_amplitude(λ_Λc, λ_p, decays: List[ThreeBodyDecay]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ_Λc, τ)
                * H_prod[stringify(decay.resonance), -τ, 0]
                * formulate_dynamics(decay)
                * Wigner.d(sp.Rational(decay.resonance.spin), τ, -λ_p, θ31)
                * H_dec[stringify(decay.resonance), 0, λ_p]
                * (-1) ** (half - λ_p)
                / (-decay.resonance.parity),
                (τ, create_spin_range(decay.resonance.spin)),
            )
            for decay in decays
        ]
    )


def formulate_Δ_amplitude(λ_Λc, λ_p, decays: List[ThreeBodyDecay]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ_Λc, τ)
                * H_prod[stringify(decay.resonance), -τ, 0]
                * formulate_dynamics(decay)
                * Wigner.d(sp.Rational(decay.resonance.spin), τ, λ_p, θ12)
                * H_dec[stringify(decay.resonance), λ_p, 0]
                / (-decay.resonance.parity * (-1) ** (decay.resonance.spin - half)),
                (τ, create_spin_range(decay.resonance.spin)),
            )
            for decay in decays
        ]
    )


def formulate_dynamics(decay: ThreeBodyDecay):
    lineshape = decay.resonance.lineshape
    if lineshape == "BreitWignerMinL":
        return formulate_breit_wigner(decay)
    if lineshape == "BuggBreitWignerMinL":
        return formulate_bugg_breit_wigner(decay)
    if lineshape == "Flatte1405":
        return formulate_flatté_1405(decay)
    raise NotImplementedError(f'No dynamics implemented for lineshape "{lineshape}"')


def formulate_breit_wigner(decay: ThreeBodyDecay):
    s = get_mandelstam_s(decay)
    child1_mass, child2_mass = map(to_mass_symbol, decay.decay_products)
    l_dec = sp.Rational(decay.outgoing_ls.L)
    l_prod = sp.Rational(decay.incoming_ls.L)
    parent_mass = sp.Symbol(f"m_{{{decay.parent.latex}}}")
    spectator_mass = sp.Symbol(f"m_{{{decay.spectator.latex}}}")
    resonance_mass = sp.Symbol(f"m_{{{decay.resonance.latex}}}")
    resonance_width = sp.Symbol(Rf"\Gamma_{{{decay.resonance.latex}}}")
    R_dec = sp.Symbol(R"R_\mathrm{res}")
    R_prod = sp.Symbol(R"R_{\Lambda_c}")
    safe_update_parameters(parent_mass, decay.parent.mass)
    safe_update_parameters(spectator_mass, decay.spectator.mass)
    safe_update_parameters(resonance_mass, decay.resonance.mass)
    safe_update_parameters(resonance_width, decay.resonance.width)
    safe_update_parameters(child1_mass, decay.decay_products[0].mass)
    safe_update_parameters(child2_mass, decay.decay_products[1].mass)
    # https://github.com/redeboer/polarization-sensitivity/pull/11#issuecomment-1128784376
    safe_update_parameters(R_dec, 1.5)
    safe_update_parameters(R_prod, 5)
    return BreitWignerMinL(
        s,
        parent_mass,
        spectator_mass,
        resonance_mass,
        resonance_width,
        child1_mass,
        child2_mass,
        l_dec,
        l_prod,
        R_dec,
        R_prod,
    )


def formulate_bugg_breit_wigner(decay: ThreeBodyDecay):
    decay_masses = set(map(to_mass_symbol, decay.decay_products))
    if decay_masses != {m2, m3}:
        raise ValueError("Bugg Breit-Wigner only defined for K* → Kπ")
    s = get_mandelstam_s(decay)
    gamma = sp.Symbol(Rf"\gamma_{{{decay.resonance.latex}}}")
    mass = sp.Symbol(f"m_{{{decay.resonance.latex}}}")
    width = sp.Symbol(Rf"\Gamma_{{{decay.resonance.latex}}}")
    safe_update_parameters(mass, decay.resonance.mass)
    safe_update_parameters(width, decay.resonance.width)
    safe_update_parameters(m2, π.mass)
    safe_update_parameters(m3, K.mass)
    safe_update_parameters(gamma, 1)
    return BuggBreitWigner(s, mass, width, m3, m2, gamma)  # Adler zero for K minus π


def formulate_flatté_1405(decay: ThreeBodyDecay):
    s = get_mandelstam_s(decay)
    m1, m2 = map(to_mass_symbol, decay.decay_products)
    mass = sp.Symbol(f"m_{{{decay.resonance.latex}}}")
    width = sp.Symbol(Rf"\Gamma_{{{decay.resonance.latex}}}")
    mπ = sp.Symbol(f"m_{{{π.latex}}}")
    mΣ = sp.Symbol(f"m_{{{Σ.latex}}}")
    safe_update_parameters(mass, decay.resonance.mass)
    safe_update_parameters(width, decay.resonance.width)
    safe_update_parameters(m1, decay.decay_products[0].mass)
    safe_update_parameters(m2, decay.decay_products[1].mass)
    safe_update_parameters(mπ, π.mass)
    safe_update_parameters(mΣ, Σ.mass)
    return FlattéSWave(s, mass, width, (m1, m2), (mπ, mΣ))


parameter_defaults = {}


def safe_update_parameters(parameter: sp.Symbol, value) -> None:
    parameter_defaults[parameter] = parameter_defaults.get(parameter, value)


def stringify(obj) -> Str:
    if isinstance(obj, Particle):
        return Str(obj.latex)
    return Str(f"{obj}")


def filter_isobars(
    decays: List[ThreeBodyDecay], resonance_pattern: str
) -> List[ThreeBodyDecay]:
    return [dec for dec in decays if dec.resonance.name.startswith(resonance_pattern)]

### Angle definitions

In [None]:
m0 = sp.Symbol(R"m_{\Lambda_c}", nonnegative=True)
angles = {
    θ12: sp.acos(
        (
            2 * σ3 * (σ2 - m3**2 - m1**2)
            - (σ3 + m1**2 - m2**2) * (m0**2 - σ3 - m3**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m3**2, σ3))
            * sp.sqrt(Källén(σ3, m1**2, m2**2))
        )
    ),
    θ23: sp.acos(
        (
            2 * σ1 * (σ3 - m1**2 - m2**2)
            - (σ1 + m2**2 - m3**2) * (m0**2 - σ1 - m1**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(σ1, m2**2, m3**2))
        )
    ),
    θ31: sp.acos(
        (
            2 * σ2 * (σ1 - m2**2 - m3**2)
            - (σ2 + m3**2 - m1**2) * (m0**2 - σ2 - m2**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m2**2, σ2))
            * sp.sqrt(Källén(σ2, m3**2, m1**2))
        )
    ),
    ζ_0_11: sp.S.Zero,  # = \hat\theta^0_{1(1)}
    ζ_0_21: -sp.acos(  # = -\hat\theta^{1(2)}
        (
            (m0**2 + m1**2 - σ1) * (m0**2 + m2**2 - σ2)
            - 2 * m0**2 * (σ3 - m1**2 - m2**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m2**2, σ2))
            * sp.sqrt(Källén(m0**2, σ1, m1**2))
        )
    ),
    ζ_0_31: sp.acos(  # = \hat\theta^{3(1)}
        (
            (m0**2 + m3**2 - σ3) * (m0**2 + m1**2 - σ1)
            - 2 * m0**2 * (σ2 - m3**2 - m1**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(m0**2, σ3, m3**2))
        )
    ),
    ζ_1_11: sp.S.Zero,
    ζ_1_21: sp.acos(
        (
            2 * m1**2 * (σ3 - m0**2 - m3**2)
            + (m0**2 + m1**2 - σ1) * (σ2 - m1**2 - m3**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(σ2, m1**2, m3**2))
        )
    ),
    ζ_1_31: -sp.acos(  # = -\zeta^1_{1(3)}
        (
            2 * m1**2 * (σ2 - m0**2 - m2**2)
            + (m0**2 + m1**2 - σ1) * (σ3 - m1**2 - m2**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(σ3, m1**2, m2**2))
        )
    ),
}

display_latex(angles)

In [None]:
masses = {
    m0: Λc.mass,
    m1: p.mass,
    m2: π.mass,
    m3: K.mass,
}
display_latex(masses)

### Intensity expression

In [None]:
intensity_expr = PoolSum(
    sp.Abs(formulate_aligned_amplitude(ν, λ)) ** 2,
    (λ, [-half, +half]),
    (ν, [-half, +half]),
)
display(intensity_expr)

In [None]:
A = {1: A_K, 2: A_Λ, 3: A_Δ}
amp_definitions = {}
for subsystem in range(1, 4):
    for Λc_heli, p_heli in product([-half, +half], [-half, +half]):
        symbol = A[subsystem][Λc_heli, p_heli]
        expr = formulate_subsystem_amplitude(subsystem, ν, λ)
        amp_definitions[symbol] = expr.subs({ν: Λc_heli, λ: p_heli})
display_latex(amp_definitions)

## Parameter definitions

### Helicity coupling values

#### Production couplings

In [None]:
def to_symbol_definitions(
    parameter_dict: Dict[str, str]
) -> Dict[sp.Basic, Union[complex, float]]:
    key_to_val: Dict[str, Union[complex, float]] = {}
    for key, str_value in parameter_dict.items():
        if key.startswith("Ar"):
            identifier = key[2:]
            str_imag = parameter_dict[f"Ai{identifier}"]
            real = to_float(str_value)
            imag = to_float(str_imag)
            key_to_val[f"A{identifier}"] = complex(real, imag)
        elif key.startswith("Ai"):
            continue
        else:
            key_to_val[key] = to_float(str_value)
    return {to_symbol(key): value for key, value in key_to_val.items()}


def to_float(str_value: str) -> float:
    value, _ = map(float, str_value.split(" ± "))
    return value


def to_symbol(key: str) -> Union[sp.Indexed, sp.Symbol]:
    if key.startswith("A"):
        res = stringify(key[1:-1])
        i = int(key[-1])
        if str(res).startswith("L"):
            if i == 1:
                return H_prod[res, +half, 0]
            if i == 2:
                return H_prod[res, -half, 0]
        if str(res).startswith("D"):
            if i == 1:
                return H_prod[res, +half, 0]
            if i == 2:
                return H_prod[res, -half, 0]
        if str(res).startswith("K"):
            if str(res) in {"K(700)", "K(1430)"}:
                if i == 1:
                    return H_prod[res, 0, -half]
                if i == 2:
                    return H_prod[res, 0, +half]
            else:
                if i == 1:
                    return H_prod[res, 0, +half]
                if i == 2:
                    return H_prod[res, -1, +half]
                if i == 3:
                    return H_prod[res, +1, -half]
                if i == 4:
                    return H_prod[res, 0, -half]
    if key.startswith("gamma"):
        res = stringify(key[5:])
        return sp.Symbol(Rf"\gamma_{{{res}}}")
    if key.startswith("M"):
        res = stringify(key[1:])
        return sp.Symbol(Rf"m_{{{res}}}")
    if key.startswith("G"):
        res = stringify(key[1:])
        return sp.Symbol(Rf"\Gamma_{{{res}}}")
    raise NotImplementedError(
        f'Cannot convert key "{key}" in model parameter JSON file to SymPy symbol'
    )


with open("../data/modelparameters.json") as stream:
    data = json.load(stream)
assert len(data["modelstudies"]) == 18

model_number = 0
model_json = data["modelstudies"][model_number]["parameters"]
# https://github.com/redeboer/polarization-sensitivity/blob/34f5330/julia/notebooks/model0.jl#L301-L302
model_json["ArK(892)1"] = "1.0 ± 0.0"
model_json["AiK(892)1"] = "0.0 ± 0.0"

parameter_defaults.update(to_symbol_definitions(model_json))

In [None]:
prod_couplings = {
    key: value
    for key, value in parameter_defaults.items()
    if isinstance(key, sp.Indexed) and key.base == H_prod
}
display_latex(prod_couplings)

#### Decay couplings

In [None]:
dec_couplings = {}
for decay in decays:
    i = stringify(decay.resonance)
    if decay.resonance.name.startswith("K"):
        dec_couplings[H_dec[i, 0, 0]] = 1
    if decay.resonance.name.startswith("L"):
        dec_couplings[H_dec[i, 0, half]] = 1
        dec_couplings[H_dec[i, 0, -half]] = (
            int(decay.resonance.parity)
            * int(K.parity)
            * int(p.parity)
            * (-1) ** (decay.resonance.spin - K.spin - p.spin)
        )
    if decay.resonance.name.startswith("D"):
        dec_couplings[H_dec[i, half, 0]] = 1
        dec_couplings[H_dec[i, -half, 0]] = (
            int(decay.resonance.parity)
            * int(p.parity)
            * int(π.parity)
            * (-1) ** (decay.resonance.spin - p.spin - π.spin)
        )
parameter_defaults.update(dec_couplings)
display_latex(dec_couplings)

### Non-coupling parameters

In [None]:
couplings = set(dec_couplings) | set(prod_couplings)
display_latex({k: v for k, v in parameter_defaults.items() if k not in couplings})