In [None]:
%config InlineBackend.figure_formats = ['svg']
import os

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

```{autolink-concat}
```

# [TR-017] Polarization sensitivity

<!-- cspell:ignore mmikhasenko Remco -->

:::{epigraph}

Mikhail Mikhasenko [@mmikhasenko](https://github.com/mmikhasenko), Remco de Boer [@redeboer](https://github.com/redeboer)

:::



```{warning}
This report is Work-in-Progress.
```

```{autolink-skip}
```

In [None]:
%pip -q install ampform==0.14.0 qrules==0.9.7 sympy==1.10.1 tensorwaves[jax,pwa]==0.4.5

This report is an attempt to formulate [this report](https://www.overleaf.com/7229968911cjshysdbfjtj) [behind login] on polarization sensitivity in $\Lambda_c \to p\pi K$ with [SymPy](https://docs.sympy.org) and [TensorWaves](https://tensorwaves.rtfd.io).

In [None]:
from __future__ import annotations

import itertools
import logging

import graphviz
import matplotlib.pyplot as plt
import numpy as np
import qrules
import sympy as sp
from ampform.sympy import (
    PoolSum,
    UnevaluatedExpression,
    create_expression,
    implement_doit_method,
    make_commutative,
)
from attrs import frozen
from IPython.display import HTML, Math, display
from matplotlib.colors import LogNorm
from qrules.io import asdot
from qrules.particle import Particle, ParticleCollection, create_particle
from sympy.physics.quantum.spin import Rotation as Wigner

LOGGER = logging.getLogger()
LOGGER.setLevel(logging.ERROR)

PDG = qrules.load_pdg()


def display_definitions(definitions: dict[sp.Symbol, sp.Expr]) -> None:
    latex = R"\begin{array}{rcl}" + "\n"
    for symbol, expr in definitions.items():
        symbol = sp.sympify(symbol)
        expr = sp.sympify(expr)
        lhs = sp.latex(symbol)
        rhs = sp.latex(expr)
        latex += Rf"  {lhs} & = & {rhs} \\" + "\n"
    latex += R"\end{array}"
    display(Math(latex))


def display_doit(
    expr: UnevaluatedExpression, deep=False, terms_per_line: int = 10
) -> None:
    latex = sp.multiline_latex(
        lhs=expr,
        rhs=expr.doit(deep=deep),
        terms_per_line=terms_per_line,
        environment="eqnarray",
    )
    display(Math(latex))

## Decay visualization

::::{margin}
:::{tip}
Particle properties of $\Delta^{**}$, $\Lambda^{**}$, and $K^{**}$ are determined here.
:::
::::

In [None]:
p = PDG["p"]
K = PDG["K-"]
π = PDG["pi+"]
Λc = create_particle(
    PDG["Lambda(c)+"],
    name="Λc⁺",
)
K_star = create_particle(
    PDG["K*(892)0"],
    name="K*",
    latex="K^*",
)
Λ_star = create_particle(
    PDG["Lambda(1520)"],
    name="Λ**",
    latex=R"\Lambda^{**}",
)
Δ_star = create_particle(
    PDG["Delta(1232)++"],
    name="Δ**",
    latex=R"\Delta^{**}",
)
particle_db = ParticleCollection({Λc, p, K, π, K_star, Λ_star, Δ_star})

In [None]:
reaction = qrules.generate_transitions(
    initial_state="Λc⁺",
    final_state=["p", "pi+", "K-"],
    particle_db=particle_db,
    formalism="canonical",
)

In [None]:
for g in reaction.transition_groups:
    dot = qrules.io.asdot(g.transitions, collapse_graphs=True, size=3.6)
    graph = graphviz.Source(dot)
    display(graph)

Allowed $LS$-couplings:

In [None]:
def filter_transitions(resonance_name):
    return [
        transition
        for transition in reaction.transitions
        if transition.states[3].particle.name.startswith(resonance_name)
    ]


style = dict(render_node=True, strip_spin=True, size=6)
display(
    graphviz.Source(asdot(filter_transitions("Λ"), **style)),
    graphviz.Source(asdot(filter_transitions("Δ"), **style)),
    graphviz.Source(asdot(filter_transitions("K"), **style)),
)

:::{warning}
Selection rules are not correctly applied here, see [ComPWA/qrules#171](https://github.com/ComPWA/qrules/issues/171).
:::

## Amplitude model

### SymPy implementation of equations

#### Aligned amplitude

In [None]:
A_K = sp.IndexedBase(R"A^K")
A_Λ = sp.IndexedBase(R"A^{\Lambda}")
A_Δ = sp.IndexedBase(R"A^{\Delta}")

half = sp.S.Half

ζ_0_11 = sp.Symbol(R"\zeta^0_{1(1)}", real=True)
ζ_0_21 = sp.Symbol(R"\zeta^0_{2(1)}", real=True)
ζ_0_31 = sp.Symbol(R"\zeta^0_{3(1)}", real=True)
ζ_1_11 = sp.Symbol(R"\zeta^1_{1(1)}", real=True)
ζ_1_21 = sp.Symbol(R"\zeta^1_{2(1)}", real=True)
ζ_1_31 = sp.Symbol(R"\zeta^1_{3(1)}", real=True)


def formulate_aligned_amplitude(λ_Λc, λ_p):
    _ν = sp.Symbol(R"\nu^{\prime}", rational=True)
    _λ = sp.Symbol(R"\lambda^{\prime}", rational=True)
    return PoolSum(
        A_K[_ν, _λ]
        * Wigner.d(half, λ_Λc, _ν, ζ_0_11)
        * Wigner.d(half, _λ, λ_p, ζ_1_11)
        + A_Λ[_ν, _λ]
        * Wigner.d(half, λ_Λc, _ν, ζ_0_21)
        * Wigner.d(half, _λ, λ_p, ζ_1_21)
        + A_Δ[_ν, _λ]
        * Wigner.d(half, λ_Λc, _ν, ζ_0_31)
        * Wigner.d(half, _λ, λ_p, ζ_1_31),
        (_λ, [-half, +half]),
        (_ν, [-half, +half]),
    )


ν = sp.Symbol("nu")
λ = sp.Symbol("lambda")
formulate_aligned_amplitude(λ_Λc=ν, λ_p=λ)

#### Dynamics

In [None]:
@make_commutative
@implement_doit_method
class BlattWeisskopf(UnevaluatedExpression):
    def __new__(cls, z, L, **hints):
        return create_expression(cls, z, L, **hints)

    def evaluate(self):
        z, L = self.args
        cases = {
            0: 1,
            1: 1 / (1 + z**2),
            2: 1 / (9 + 3 * z**2 + z**4),
        }
        return sp.Piecewise(
            *[
                (sp.sqrt(expr), sp.Eq(L, l_val))
                for l_val, expr in cases.items()
            ]
        )

    def _latex(self, printer, *args):
        z, L = map(printer._print, self.args)
        return Rf"F_{{{L}}}\left({z}\right)"


z = sp.Symbol("z", positive=True)
L = sp.Symbol("L", integer=True, nonnegative=True)
latex = sp.multiline_latex(BlattWeisskopf(z, L), BlattWeisskopf(z, L).doit())
Math(latex)

In [None]:
@make_commutative
@implement_doit_method
class Källén(UnevaluatedExpression):
    def __new__(cls, x, y, z, **hints):
        return create_expression(cls, x, y, z, **hints)

    def evaluate(self) -> sp.Expr:
        x, y, z = self.args
        return x**2 + y**2 + z**2 - 2 * x * y - 2 * y * z - 2 * z * x

    def _latex(self, printer, *args):
        x, y, z = map(printer._print, self.args)
        return Rf"\lambda\left({x}, {y}, {z}\right)"


x, y, z = sp.symbols("x:z")
display_doit(Källén(x, y, z))

In [None]:
@make_commutative
@implement_doit_method
class P(UnevaluatedExpression):
    def __new__(cls, s, mi, mj, **hints):
        return create_expression(cls, s, mi, mj, **hints)

    def evaluate(self):
        s, mi, mj = self.args
        return sp.sqrt(Källén(s, mi**2, mj**2)) / (2 * sp.sqrt(s))

    def _latex(self, printer, *args):
        s = printer._print(self.args[0])
        return Rf"p_{{{s}}}"


@make_commutative
@implement_doit_method
class Q(UnevaluatedExpression):
    def __new__(cls, s, m0, mk, **hints):
        return create_expression(cls, s, m0, mk, **hints)

    def evaluate(self):
        s, m0, mk = self.args
        return sp.sqrt(Källén(s, m0**2, mk**2)) / (2 * m0)  # <-- not s!

    def _latex(self, printer, *args):
        s = printer._print(self.args[0])
        return Rf"q_{{{s}}}"


s, m0, mi, mj, mk = sp.symbols("s m0 m_i:k", nonnegative=True)
display_doit(P(s, mi, mj))
display_doit(Q(s, m0, mk))

In [None]:
R = sp.Symbol("R")
parameter_defaults = {
    R: 5,  # GeV^{-1} (length factor)
}


@make_commutative
@implement_doit_method
class EnergyDependentWidth(UnevaluatedExpression):
    def __new__(cls, s, m0, Γ0, m1, m2, L, R):
        return create_expression(cls, s, m0, Γ0, m1, m2, L, R)

    def evaluate(self):
        s, m0, Γ0, m1, m2, L, R = self.args
        p = P(s, m1, m2)
        p0 = P(m0**2, m1, m2)
        ff = BlattWeisskopf(p * R, L) ** 2
        ff0 = BlattWeisskopf(p0 * R, L) ** 2
        return sp.Mul(
            Γ0,
            (p / p0) ** (2 * L + 1),
            m0 / sp.sqrt(s),
            ff / ff0,
            evaluate=False,
        )

    def _latex(self, printer, *args) -> str:
        s = printer._print(self.args[0])
        return Rf"\Gamma\left({s}\right)"


l_R = sp.Symbol("l_R", integer=True, positive=True)
m, Γ0, m1, m2 = sp.symbols("m Γ0 m1 m2", nonnegative=True)
display_doit(EnergyDependentWidth(s, m, Γ0, m1, m2, l_R, R))

In [None]:
@make_commutative
@implement_doit_method
class RelativisticBreitWigner(UnevaluatedExpression):
    def __new__(cls, s, m0, Γ0, m1, m2, l_R, l_Λc, l_Λc_min, R):
        return create_expression(
            cls, s, m0, Γ0, m1, m2, l_R, l_Λc, l_Λc_min, R
        )

    def evaluate(self):
        s, m0, Γ0, m1, m2, l_R, l_Λc, l_Λc_min, R = self.args
        q = Q(s, m1, m2)
        q0 = Q(m0**2, m1, m2)
        p = P(s, m1, m2)
        p0 = P(m0**2, m1, m2)
        width = EnergyDependentWidth(s, m0, Γ0, m1, m2, l_R, R)
        return sp.Mul(
            (q / q0) ** l_Λc,
            BlattWeisskopf(q * R, l_Λc_min) / BlattWeisskopf(q0 * R, l_Λc_min),
            1 / (m0**2 - s - sp.I * m0 * width),
            (p / p0) ** l_R,
            BlattWeisskopf(p * R, l_R) / BlattWeisskopf(p0 * R, l_R),
            evaluate=False,
        )

    def _latex(self, printer, *args) -> str:
        s = printer._print(self.args[0])
        return Rf"\mathcal{{R}}\left({s}\right)"


l_Λc = sp.Symbol(R"l_{\Lambda_c}", integer=True, positive=True)
l_Λc_min = sp.Symbol(
    R"{l^{\mathrm{min}}_{\Lambda_c}}", integer=True, positive=True
)
display_doit(RelativisticBreitWigner(s, m, Γ0, m1, m2, l_R, l_Λc, l_Λc_min, R))

#### Unaligned amplitudes

In [None]:
H_K_prod = sp.IndexedBase(R"\mathcal{H}^{\Lambda_c \to K^{**}p}")
H_K_dec = sp.IndexedBase(R"\mathcal{H}^{K^{**} \to \pi K}")
H_Λ_prod = sp.IndexedBase(R"\mathcal{H}^{\Lambda_c \to \Lambda^{**}\pi}")
H_Λ_dec = sp.IndexedBase(R"\mathcal{H}^{\Lambda^{**} \to K p}")
H_Δ_prod = sp.IndexedBase(R"\mathcal{H}^{\Lambda_c \to \Delta^{**}K}")
H_Δ_dec = sp.IndexedBase(R"\mathcal{H}^{\Delta^{**} \to p \pi}")

θ23 = sp.Symbol("theta23", real=True)
θ31 = sp.Symbol("theta31", real=True)
θ12 = sp.Symbol("theta12", real=True)

σ1, σ2, σ3 = sp.symbols("sigma1:4", nonnegative=True)
m1, m2, m3 = sp.symbols("m1:4", nonnegative=True)


@frozen
class Decay:
    resonance: Particle
    l_R: int
    l_Λc: int
    l_Λc_min: int


def formulate_dynamics(decay: Decay, s, m1, m2):
    l_R = sp.Rational(decay.l_R)
    l_Λc = sp.Rational(decay.l_Λc)
    l_Λc_min = sp.Rational(decay.l_Λc_min)
    mass = sp.Symbol(f"m_{{{decay.resonance.latex}}}")
    width = sp.Symbol(Rf"\Gamma_{{{decay.resonance.latex}}}")
    parameter_defaults[mass] = decay.resonance.mass
    parameter_defaults[width] = decay.resonance.width
    return RelativisticBreitWigner(
        s, mass, width, m1, m2, l_R, l_Λc, l_Λc_min, R
    )


def create_spin_range(j):
    spin_range = np.arange(-float(j), +float(j) + 0.5)
    return list(map(sp.Rational, spin_range))


def formulate_K_amplitude(λ_Λc, λ_p, decays: list[Decay]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ_Λc, τ - λ_p)
                * H_K_prod[τ, λ_p]
                * formulate_dynamics(dec, σ1, m2, m3)
                * (-1) ** (half - λ_p)
                * Wigner.d(sp.Rational(dec.resonance.spin), τ, 0, θ23)
                * H_K_dec[0, 0],
                (τ, create_spin_range(dec.resonance.spin)),
            )
            for dec in decays
        ]
    )


def formulate_Λ_amplitude(λ_Λc, λ_p, decays: list[Decay]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ_Λc, τ)
                * H_Λ_prod[τ, 0]
                * formulate_dynamics(dec, σ2, m1, m3)
                * Wigner.d(sp.Rational(dec.resonance.spin), τ, -λ_p, θ31)
                * H_Λ_dec[0, λ_p]
                * (-1) ** (half - λ_p),
                (τ, create_spin_range(dec.resonance.spin)),
            )
            for dec in decays
        ]
    )


def formulate_Δ_amplitude(λ_Λc, λ_p, decays: list[Decay]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(
        *[
            PoolSum(
                sp.KroneckerDelta(λ_Λc, τ)
                * H_Δ_prod[τ, 0]
                * formulate_dynamics(dec, σ3, m1, m2)
                * Wigner.d(sp.Rational(dec.resonance.spin), τ, λ_p, θ12)
                * H_Δ_dec[λ_p, 0],
                (τ, create_spin_range(dec.resonance.spin)),
            )
            for dec in decays
        ]
    )


display(
    formulate_K_amplitude(ν, λ, decays=[Decay(K_star, 1, 0, 0)]),
    formulate_Λ_amplitude(ν, λ, decays=[Decay(Λ_star, 1, 1, 1)]),
    formulate_Δ_amplitude(ν, λ, decays=[Decay(Δ_star, 2, 1, 1)]),
)

#### Angle definitions

Following relations apply:

$$
\begin{eqnarray}
  \zeta^0_{1(1)} &=& \hat{\theta}_{1(1)}^{0} = 0 \\
  \zeta^0_{2(1)} &=& \hat{\theta}_{2(1)} = -\hat{\theta}_{1(2)} \\
  \zeta^0_{3(1)} &=& \hat{\theta}_{3(1)} \\
  \zeta^1_{1(1)} &=& 0 \\
  \zeta^1_{3(1)} &=& -\zeta^1_{1(3)} \\
\end{eqnarray}
$$

The remaining angles $\theta_{12}, \theta_{23}, \theta_{13}$ and $\hat\theta_{1(2)}, \hat\theta_{3(1)}, \zeta^1_{1(3)}$ can be expressed in terms of Mandelstam variables $\sigma_1, \sigma_2, \sigma_3$ using {cite}`mikhasenkoDalitzplotDecompositionThreebody2020`, Appendix A:

In [None]:
m0 = sp.Symbol("m0", nonnegative=True)
angles = {
    θ12: sp.acos(
        (
            2 * σ3 * (σ2 - m3**2 - m1**2)
            - (σ3 + m1**2 - m2**2) * (m0**2 - σ3 - m3**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m3**2, σ3))
            * sp.sqrt(Källén(σ3, m1**2, m2**2))
        )
    ),
    θ23: sp.acos(
        (
            2 * σ1 * (σ3 - m1**2 - m2**2)
            - (σ1 + m2**2 - m3**2) * (m0**2 - σ1 - m1**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(σ1, m2**2, m3**2))
        )
    ),
    θ31: sp.acos(
        (
            2 * σ2 * (σ1 - m2**2 - m3**2)
            - (σ2 + m3**2 - m1**2) * (m0**2 - σ2 - m2**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m2**2, σ2))
            * sp.sqrt(Källén(σ2, m3**2, m1**2))
        )
    ),
    ζ_0_11: sp.S.Zero,  # = \hat\theta^0_{1(1)}
    ζ_0_21: -sp.acos(  # = -\hat\theta^{1(2)}
        (
            (m0**2 + m1**2 - σ1) * (m0**2 + m2**2 - σ2)
            - 2 * m0**2 * (σ3 - m1**2 - m2**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m2**2, σ2))
            * sp.sqrt(Källén(m0**2, σ1, m1**2))
        )
    ),
    ζ_0_31: sp.acos(  # = \hat\theta^{3(1)}
        (
            (m0**2 + m3**2 - σ3) * (m0**2 + m1**2 - σ1)
            - 2 * m0**2 * (σ2 - m3**2 - m1**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(m0**2, σ3, m3**2))
        )
    ),
    ζ_1_11: sp.S.Zero,
    ζ_1_21: sp.acos(
        (
            2 * m1**2 * (σ3 - m0**2 - m3**2)
            + (m0**2 + m1**2 - σ1) * (σ2 - m1**2 - m3**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(σ2, m1**2, m3**2))
        )
    ),
    ζ_1_31: -sp.acos(  # = -\zeta^1_{1(3)}
        (
            2 * m1**2 * (σ2 - m0**2 - m2**2)
            + (m0**2 + m1**2 - σ1) * (σ3 - m1**2 - m2**2)
        )
        / (
            sp.sqrt(Källén(m0**2, m1**2, σ1))
            * sp.sqrt(Källén(σ3, m1**2, m2**2))
        )
    ),
}

display_definitions(angles)

where $m_0$ is the mass of the initial state $\Lambda_c$ and $m_1, m_2, m_3$ are the masses of $p, \pi, K$, respectively:

In [None]:
masses = {
    m0: Λc.mass,
    m1: p.mass,
    m2: π.mass,
    m3: K.mass,
}
display_definitions(masses)

### Combining all definitions

#### Intensity expression

Incoherent sum of the amplitudes defined by {ref}`report/017:Aligned amplitude`:

In [None]:
top_expr = PoolSum(
    sp.Abs(formulate_aligned_amplitude(λ, ν)) ** 2,
    (λ, [-half, +half]),
    (ν, [-half, +half]),
)
top_expr

Remaining {attr}`~sympy.core.basic.Basic.free_symbols` are indeed the specific amplitudes as defined by {ref}`report/017:Unaligned amplitudes`:

The specific amplitudes from {ref}`report/017:Unaligned amplitudes` need to be formulated for each value of $\nu, \lambda$, so that they can be substituted in the top expression:

In [None]:
decay_chains = []
amp_definitions = {}

# K**
decays = [
    Decay(K_star, l_R=1, l_Λc=0, l_Λc_min=0),
]
decay_chains.extend(decays)
for Λc_heli, p_heli in itertools.product([-half, +half], [-half, +half]):
    symbol = A_K[Λc_heli, p_heli]
    expr = formulate_K_amplitude(ν, λ, decays)
    amp_definitions[symbol] = expr.subs({ν: Λc_heli, λ: p_heli})

# Lambda**
decays = [
    Decay(Λ_star, l_R=1, l_Λc=1, l_Λc_min=1),
]
decay_chains.extend(decays)
for Λc_heli, p_heli in itertools.product([-half, +half], [-half, +half]):
    symbol = A_Λ[Λc_heli, p_heli]
    expr = formulate_Λ_amplitude(ν, λ, decays)
    amp_definitions[symbol] = expr.subs({ν: Λc_heli, λ: p_heli})

# Delta**
decays = [
    Decay(Δ_star, l_R=2, l_Λc=1, l_Λc_min=1),
]
decay_chains.extend(decays)
for Λc_heli, p_heli in itertools.product([-half, +half], [-half, +half]):
    symbol = A_Δ[Λc_heli, p_heli]
    expr = formulate_Δ_amplitude(ν, λ, decays)
    amp_definitions[symbol] = expr.subs({ν: Λc_heli, λ: p_heli})

display_definitions(amp_definitions)

In [None]:
def jp(particle: Particle):
    p = "+" if particle.parity > 0 else "-"
    j = sp.Rational(particle.spin)
    return Rf"\({j}^{p}\)"


def create_row(*items, typ="td"):
    items = map(lambda i: f"<{typ}>{i}</{typ}>", items)
    return "<tr>" + "".join(items) + "</tr>\n"


column_names = [
    "resonance",
    R"\(j^P\)",
    R"\(m\) (MeV)",
    R"\(\Gamma_0\) (MeV)",
    R"\(l_R\)",
    R"\(l^\mathrm{min}_{\Lambda_c}\)",
]
src = "<table>\n"
src += create_row(*column_names, typ="th")
for dec in decay_chains:
    src += create_row(
        rf"\({dec.resonance.latex}\)",
        jp(dec.resonance),
        int(1e3 * dec.resonance.mass),
        int(1e3 * dec.resonance.width),
        dec.l_R,
        dec.l_Λc_min,
    )
src += "</table>\n"
HTML(src)

#### Helicity coupling values

In [None]:
dec_couplings = {
    # Equations (6-7)
    H_Λ_dec[0, half]: 1,
    H_Δ_dec[half, 0]: 1,
    H_K_dec[0, 0]: 1,
    H_Λ_dec[0, -half]: int(-Λ_star.parity) * (-1) ** (Λ_star.spin - 1 / 2),
    H_Δ_dec[-half, 0]: int(-Δ_star.parity) * (-1) ** (Δ_star.spin - 1 / 2),
}
parameter_defaults.update(dec_couplings)
display_definitions(dec_couplings)

In [None]:
prod_couplings = {
    H_K_prod[0, -half]: 1,
    H_K_prod[-1, -half]: 1 - 1j,
    H_K_prod[+1, +half]: -3 - 3j,
    H_K_prod[0, +half]: -1 - 4j,
    H_Δ_prod[+half, 0]: -7 + 3j,
    H_Δ_prod[-half, 0]: -13 + 5j,
    H_Λ_prod[+half, 0]: 1,
    H_Λ_prod[-half, 0]: 2j,
}
display_definitions(prod_couplings)
couplings = dict(dec_couplings)
couplings.update(prod_couplings)
parameter_defaults.update(prod_couplings)

After substituting symbols in the expression with the collected `parameter_defaults`, the remaining {attr}`~sympy.core.basic.Basic.free_symbols`  are indeed the Mandelstam variables and final state `masses`:

In [None]:
evaluated_top_expr = top_expr.doit()
amplitude_expr = evaluated_top_expr.xreplace(amp_definitions).doit()
amplitude_expr = amplitude_expr.xreplace(amp_definitions).doit()
substituted_amp = amplitude_expr.xreplace(angles).xreplace(parameter_defaults)
substituted_amp.free_symbols

In [None]:
assert len(substituted_amp.free_symbols) == 7

## Polarization

:::{todo}

Formulate Section 2 with [SymPy](https://docs.sympy.org).

:::

## Computations with TensorWaves


### Conversion to computational backend

The full [expression tree](https://docs.sympy.org/latest/tutorial/manipulation.html) can be converted to a computational, _parametrized_ function as follows. Note that identify all coupling symbols are interpreted as parameters. The remaining symbols (the angles) become arguments to the function.

In [None]:
from tensorwaves.function.sympy import create_parametrized_function

func = create_parametrized_function(
    amplitude_expr.subs(masses),  # hardcode final state masses
    parameters=parameter_defaults,
    backend="jax",
)

Generate phase space sample for $\Lambda_c \to p \pi K$:

In [None]:
from tensorwaves.data import (
    TFPhaseSpaceGenerator,
    TFUniformRealNumberGenerator,
)

rng = TFUniformRealNumberGenerator(seed=0)
phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=masses[m0],
    final_state_masses={i: masses[m] for i, m in enumerate([m1, m2, m3], 1)},
)
phsp = phsp_generator.generate(1_000_000, rng)

Values for the angles will be computed form the Mandelstam values with a data transformer for the symbolic angle definitions:

In [None]:
from tensorwaves.data.transform import SympyDataTransformer

kinematic_variables = {
    symbol: expression.doit().subs(masses)
    for symbol, expression in angles.items()
}
kinematic_variables.update({s: s for s in [σ1, σ2, σ3]})  # include identity
transformer = SympyDataTransformer.from_sympy(
    kinematic_variables, backend="numpy"
)

The three Mandelstam variables $\sigma_1, \sigma_2, \sigma_3$ can be computed from the phase space sample using the data transformer:

In [None]:
def compute_mass_squared(array):
    energy = array[:, 0]
    p_vec = array[:, 1:]
    return energy**2 - np.sum(p_vec**2, axis=1)


data = {
    σ1.name: compute_mass_squared(phsp["p2"] + phsp["p3"]),
    σ2.name: compute_mass_squared(phsp["p3"] + phsp["p1"]),
    σ3.name: compute_mass_squared(phsp["p1"] + phsp["p2"]),
}
transformed_data = transformer(data)
transformed_data

### Intensity distributions

Finally, all intensities can be computed as follows:

```{autolink-skip}
```

In [None]:
%%time
transformed_data = transformer(data)
intensities = func(transformed_data)
intensities

In [None]:
s1_label = R"$\sigma_1=m^2\left(K\pi\right)$"
s2_label = R"$\sigma_2=m^2\left(pK\right)$"
s3_label = R"$\sigma_3=m^2\left(p\pi\right)$"

fig, ax = plt.subplots(figsize=(8, 6.5))
ax.set_title("Intensity distribution")
ax.set_xlabel(s1_label)
ax.set_ylabel(s2_label)
h = ax.hist2d(
    data["sigma1"],
    data["sigma2"],
    weights=np.array(intensities),
    bins=100,
    norm=LogNorm(),
)
fig.colorbar(h[3])

fig.tight_layout()
plt.show()

In [None]:
fig, axes = plt.subplots(figsize=(12, 4), ncols=3)
ax1, ax2, ax3 = axes
hist_style = dict(bins=100, weights=np.array(intensities), histtype="step")
line_style = dict(c="red", linestyle="dotted")

ax1.set_xlabel(s1_label)
ax2.set_xlabel(s2_label)
ax3.set_xlabel(s3_label)
ax1.hist(data["sigma1"], **hist_style)
ax1.axvline(K_star.mass**2, label=f"$m_{{{K_star.latex}}}$", **line_style)
ax2.hist(data["sigma2"], **hist_style)
ax2.axvline(Λ_star.mass**2, label=f"$m_{{{Λ_star.latex}}}$", **line_style)
ax3.hist(data["sigma3"], **hist_style)
ax3.axvline(Δ_star.mass**2, label=f"$m_{{{Δ_star.latex}}}$", **line_style)
for ax in axes:
    ax.legend()

fig.tight_layout()
plt.show()

### Polarization distributions

:::{todo}

Visualize $\vec{\alpha}(\alpha, \beta, \gamma, m_{K\pi}, m_{K\pi})$ with [`ipywidgets`](https://ipywidgets.readthedocs.io).

:::