# Amplitude model

```{autolink-concat}
```

In [None]:
from itertools import product
from typing import List

import sympy as sp
from ampform.sympy import PoolSum
from IPython.display import Markdown, display
from sympy.core.symbol import Str
from sympy.physics.quantum.spin import Rotation as Wigner

from polarization.amplitude.angles import (
    formulate_scattering_angle,
    formulate_zeta_angle,
)
from polarization.decay import Particle, ThreeBodyDecay, ThreeBodyDecayChain
from polarization.dynamics import BreitWignerMinL, BuggBreitWigner, FlattéSWave
from polarization.io import as_markdown_table, display_latex
from polarization.lhcb import K, Λc, Σ, load_three_body_decays, p, π
from polarization.spin import create_spin_range


# hack for moving Indexed indices below superscript of the base
def _print_Indexed_latex(self, printer, *args):
    base = printer._print(self.base)
    indices = ", ".join(map(printer._print, self.indices))
    return f"{base}_{{{indices}}}"


sp.Indexed._latex = _print_Indexed_latex

## Resonances and LS-scheme

Particle definitions for $\Lambda_c^+$ and $p, \pi^+, K^-$:

In [None]:
Markdown(as_markdown_table([Λc, p, π, K, Σ]))

Resonance definitions as defined in {download}`data/isobars.json <../data/isobars.json>`:

In [None]:
decay = load_three_body_decays("../data/isobars.json")

In [None]:
src = as_markdown_table(decay)
Markdown(src)

## Amplitude

### Spin-alignment with DPD

In [None]:
A1 = sp.IndexedBase(R"A^K")
A2 = sp.IndexedBase(R"A^{\Lambda}")
A3 = sp.IndexedBase(R"A^{\Delta}")

half = sp.S.Half

ζ_0_11, ζ_0_11_expr = formulate_zeta_angle(0, 1, 1)
ζ_0_21, ζ_0_21_expr = formulate_zeta_angle(0, 2, 1)
ζ_0_31, ζ_0_31_expr = formulate_zeta_angle(0, 3, 1)
ζ_1_11, ζ_1_11_expr = formulate_zeta_angle(1, 1, 1)
ζ_1_21, ζ_1_21_expr = formulate_zeta_angle(1, 2, 1)
ζ_1_31, ζ_1_31_expr = formulate_zeta_angle(1, 3, 1)


def formulate_aligned_amplitude(λ0, λ1):
    _λ0 = sp.Symbol(R"\nu^{\prime}", rational=True)
    _λ1 = sp.Symbol(R"\lambda^{\prime}", rational=True)
    return PoolSum(
        A1[_λ0, _λ1] * Wigner.d(half, λ0, _λ0, ζ_0_11) * Wigner.d(half, _λ1, λ1, ζ_1_11)
        + A2[_λ0, _λ1]
        * Wigner.d(half, λ0, _λ0, ζ_0_21)
        * Wigner.d(half, _λ1, λ1, ζ_1_21)
        + A3[_λ0, _λ1]
        * Wigner.d(half, λ0, _λ0, ζ_0_31)
        * Wigner.d(half, _λ1, λ1, ζ_1_31),
        (_λ1, [-half, +half]),
        (_λ0, [-half, +half]),
    )


ν = sp.Symbol("nu")
λ = sp.Symbol("lambda")
formulate_aligned_amplitude(λ0=ν, λ1=λ)

### Sub-system amplitudes

:::{warning}

Couplings are remapped from the LHCb paper to Dalitz-Plot Decomposition with [these relations](https://user-images.githubusercontent.com/22725744/165932213-34013235-8464-4018-bd21-3ebde1e86faf.png).

:::

In [None]:
H_prod = sp.IndexedBase(R"\mathcal{H}^\mathrm{production}")
H_dec = sp.IndexedBase(R"\mathcal{H}^\mathrm{decay}")

θ12, θ12_expr = formulate_scattering_angle(1, 2)
θ23, θ23_expr = formulate_scattering_angle(2, 3)
θ31, θ31_expr = formulate_scattering_angle(3, 1)
particle_to_id = {
    Λc: 0,
    p: 1,
    π: 2,
    K: 3,
}


def to_mass_symbol(particle: Particle) -> sp.Symbol:
    if particle in particle_to_id:
        i = particle_to_id[particle]
        return sp.Symbol(f"m{i}", nonnegative=True)
    return sp.Symbol(f"m_{{{particle.latex}}}", nonnegative=True)


σ1, σ2, σ3 = sp.symbols("sigma1:4", nonnegative=True)
m1, m2, m3 = sp.symbols("m1:4", nonnegative=True)


def get_mandelstam_s(decay: ThreeBodyDecayChain) -> sp.Symbol:
    decay_products = set(decay.decay_products)
    if decay_products == {π, K}:
        return σ1
    if decay_products == {p, K}:
        return σ2
    if decay_products == {p, π}:
        return σ3
    raise NotImplementedError(
        "Cannot find Mandelstam variable for"
        f" {', '.join(map(lambda p: p.name, decay_products))}"
    )


def formulate_subsystem_amplitude(subsystem: int, λ0, λ1):
    if subsystem == 1:
        return formulate_K_amplitude(λ0, λ1, filter_isobars(decay, "K"))
    if subsystem == 2:
        return formulate_Λ_amplitude(λ0, λ1, filter_isobars(decay, "L"))
    if subsystem == 3:
        return formulate_Δ_amplitude(λ0, λ1, filter_isobars(decay, "D"))
    raise NotImplementedError(f"No chain implemented for sub-system {subsystem}")


def formulate_K_amplitude(λ0, λ1, chains: List[ThreeBodyDecayChain]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ0, τ - λ1)
                * H_prod[stringify(chain.resonance), τ, -λ1]
                * formulate_dynamics(chain)
                * (-1) ** (half - λ1)
                * Wigner.d(sp.Rational(chain.resonance.spin), τ, 0, θ23)
                * H_dec[stringify(chain.resonance), 0, 0],
                (τ, create_spin_range(chain.resonance.spin)),
            )
            for chain in chains
        ]
    )


def formulate_Λ_amplitude(λ0, λ1, chains: List[ThreeBodyDecayChain]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ0, τ)
                * H_prod[stringify(chain.resonance), -τ, 0]
                * formulate_dynamics(chain)
                * Wigner.d(sp.Rational(chain.resonance.spin), τ, -λ1, θ31)
                * H_dec[stringify(chain.resonance), 0, λ1]
                * (-1) ** (half - λ1)
                / (-chain.resonance.parity),
                (τ, create_spin_range(chain.resonance.spin)),
            )
            for chain in chains
        ]
    )


def formulate_Δ_amplitude(λ0, λ1, chains: List[ThreeBodyDecayChain]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ0, τ)
                * H_prod[stringify(chain.resonance), -τ, 0]
                * formulate_dynamics(chain)
                * Wigner.d(sp.Rational(chain.resonance.spin), τ, λ1, θ12)
                * H_dec[stringify(chain.resonance), λ1, 0]
                / (-chain.resonance.parity * (-1) ** (chain.resonance.spin - half)),
                (τ, create_spin_range(chain.resonance.spin)),
            )
            for chain in chains
        ]
    )


def formulate_dynamics(chain: ThreeBodyDecayChain):
    lineshape = chain.resonance.lineshape
    if lineshape == "BreitWignerMinL":
        return formulate_breit_wigner(chain)
    if lineshape == "BuggBreitWignerMinL":
        return formulate_bugg_breit_wigner(chain)
    if lineshape == "Flatte1405":
        return formulate_flatté_1405(chain)
    raise NotImplementedError(f'No dynamics implemented for lineshape "{lineshape}"')


def formulate_breit_wigner(decay: ThreeBodyDecayChain):
    s = get_mandelstam_s(decay)
    child1_mass, child2_mass = map(to_mass_symbol, decay.decay_products)
    l_dec = sp.Rational(decay.outgoing_ls.L)
    l_prod = sp.Rational(decay.incoming_ls.L)
    parent_mass = sp.Symbol(f"m_{{{decay.parent.latex}}}")
    spectator_mass = sp.Symbol(f"m_{{{decay.spectator.latex}}}")
    resonance_mass = sp.Symbol(f"m_{{{decay.resonance.latex}}}")
    resonance_width = sp.Symbol(Rf"\Gamma_{{{decay.resonance.latex}}}")
    R_dec = sp.Symbol(R"R_\mathrm{res}")
    R_prod = sp.Symbol(R"R_{\Lambda_c}")
    safe_update_parameters(parent_mass, decay.parent.mass)
    safe_update_parameters(spectator_mass, decay.spectator.mass)
    safe_update_parameters(resonance_mass, decay.resonance.mass)
    safe_update_parameters(resonance_width, decay.resonance.width)
    safe_update_parameters(child1_mass, decay.decay_products[0].mass)
    safe_update_parameters(child2_mass, decay.decay_products[1].mass)
    # https://github.com/redeboer/polarization-sensitivity/pull/11#issuecomment-1128784376
    safe_update_parameters(R_dec, 1.5)
    safe_update_parameters(R_prod, 5)
    return BreitWignerMinL(
        s,
        parent_mass,
        spectator_mass,
        resonance_mass,
        resonance_width,
        child1_mass,
        child2_mass,
        l_dec,
        l_prod,
        R_dec,
        R_prod,
    )


def formulate_bugg_breit_wigner(decay: ThreeBodyDecayChain):
    if set(decay.decay_products) != {π, K}:
        raise ValueError("Bugg Breit-Wigner only defined for K* → Kπ")
    s = get_mandelstam_s(decay)
    gamma = sp.Symbol(Rf"\gamma_{{{decay.resonance.latex}}}")
    mass = sp.Symbol(f"m_{{{decay.resonance.latex}}}")
    width = sp.Symbol(Rf"\Gamma_{{{decay.resonance.latex}}}")
    safe_update_parameters(mass, decay.resonance.mass)
    safe_update_parameters(width, decay.resonance.width)
    safe_update_parameters(m2, π.mass)
    safe_update_parameters(m3, K.mass)
    safe_update_parameters(gamma, 1)
    return BuggBreitWigner(s, mass, width, m3, m2, gamma)  # Adler zero for K minus π


def formulate_flatté_1405(decay: ThreeBodyDecayChain):
    s = get_mandelstam_s(decay)
    m1, m2 = map(to_mass_symbol, decay.decay_products)
    mass = sp.Symbol(f"m_{{{decay.resonance.latex}}}")
    width = sp.Symbol(Rf"\Gamma_{{{decay.resonance.latex}}}")
    mπ = sp.Symbol(f"m_{{{π.latex}}}")
    mΣ = sp.Symbol(f"m_{{{Σ.latex}}}")
    safe_update_parameters(mass, decay.resonance.mass)
    safe_update_parameters(width, decay.resonance.width)
    safe_update_parameters(m1, decay.decay_products[0].mass)
    safe_update_parameters(m2, decay.decay_products[1].mass)
    safe_update_parameters(mπ, π.mass)
    safe_update_parameters(mΣ, Σ.mass)
    return FlattéSWave(s, mass, width, (m1, m2), (mπ, mΣ))


parameter_defaults = {}


def safe_update_parameters(parameter: sp.Symbol, value) -> None:
    parameter_defaults[parameter] = parameter_defaults.get(parameter, value)


def stringify(obj) -> Str:
    if isinstance(obj, Particle):
        return Str(obj.latex)
    return Str(f"{obj}")


def filter_isobars(
    decay: ThreeBodyDecay, resonance_pattern: str
) -> List[ThreeBodyDecayChain]:
    return [
        chain
        for chain in decay.chains
        if chain.resonance.name.startswith(resonance_pattern)
    ]

### Angle definitions

In [None]:
m0 = sp.Symbol("m0", nonnegative=True)
angles = {
    θ12: θ12_expr,
    θ23: θ23_expr,
    θ31: θ31_expr,
    ζ_0_11: ζ_0_11_expr,
    ζ_0_21: ζ_0_21_expr,
    ζ_0_31: ζ_0_31_expr,
    ζ_1_11: ζ_1_11_expr,
    ζ_1_21: ζ_1_21_expr,
    ζ_1_31: ζ_1_31_expr,
}
display_latex(angles)

In [None]:
masses = {
    m0: Λc.mass,
    m1: p.mass,
    m2: π.mass,
    m3: K.mass,
}
display_latex(masses)

### Intensity expression

In [None]:
intensity_expr = PoolSum(
    sp.Abs(formulate_aligned_amplitude(ν, λ)) ** 2,
    (λ, [-half, +half]),
    (ν, [-half, +half]),
)
display(intensity_expr)

In [None]:
A = {1: A1, 2: A2, 3: A3}
amp_definitions = {}
for subsystem in range(1, 4):
    for Λc_heli, p_heli in product([-half, +half], [-half, +half]):
        symbol = A[subsystem][Λc_heli, p_heli]
        expr = formulate_subsystem_amplitude(subsystem, ν, λ)
        amp_definitions[symbol] = expr.subs({ν: Λc_heli, λ: p_heli})
display_latex(amp_definitions)

## Parameter definitions

In [None]:
from polarization.lhcb import load_model_parameters

imported_parameter_values = load_model_parameters(
    "../data/modelparameters.json", model_number=0
)
parameter_defaults.update(imported_parameter_values)

### Helicity coupling values

#### Production couplings

In [None]:
prod_couplings = {
    key: value
    for key, value in parameter_defaults.items()
    if isinstance(key, sp.Indexed) and key.base == H_prod
}
display_latex(prod_couplings)

#### Decay couplings

In [None]:
dec_couplings = {}
for chain in decay.chains:
    i = stringify(chain.resonance)
    if chain.resonance.name.startswith("K"):
        dec_couplings[H_dec[i, 0, 0]] = 1
    if chain.resonance.name.startswith("L"):
        dec_couplings[H_dec[i, 0, half]] = 1
        dec_couplings[H_dec[i, 0, -half]] = (
            int(chain.resonance.parity)
            * int(K.parity)
            * int(p.parity)
            * (-1) ** (chain.resonance.spin - K.spin - p.spin)
        )
    if chain.resonance.name.startswith("D"):
        dec_couplings[H_dec[i, half, 0]] = 1
        dec_couplings[H_dec[i, -half, 0]] = (
            int(chain.resonance.parity)
            * int(p.parity)
            * int(π.parity)
            * (-1) ** (chain.resonance.spin - p.spin - π.spin)
        )
parameter_defaults.update(dec_couplings)
display_latex(dec_couplings)

### Non-coupling parameters

In [None]:
couplings = set(dec_couplings) | set(prod_couplings)
display_latex({k: v for k, v in parameter_defaults.items() if k not in couplings})