In [None]:
%config InlineBackend.figure_formats = ['svg']
import os

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

```{autolink-concat}
```

# [TR-017] Polarization sensitivity

```{warning}
This report is Work-in-Progress.
```

```{autolink-skip}
```

In [None]:
%pip -q install ampform==0.13.3 qrules==0.9.7 sympy==1.10.1 tensorwaves[jax]==0.4.3

This report is an attempt to formulate [this report](https://www.overleaf.com/7229968911cjshysdbfjtj) [behind login] on polarization sensitivity in $\Lambda_c \to p\pi K$ with [SymPy](https://docs.sympy.org) and [TensorWaves](https://tensorwaves.rtfd.io).

In [None]:
from __future__ import annotations

import itertools
import logging

import graphviz
import qrules
import sympy as sp
from ampform.sympy import PoolSum
from IPython.display import Math, display
from qrules.io import asdot
from qrules.particle import ParticleCollection, create_particle
from symplot import substitute_indexed_symbols
from sympy.physics.quantum.spin import Rotation as Wigner

LOGGER = logging.getLogger()
LOGGER.setLevel(logging.ERROR)

PDG = qrules.load_pdg()


def display_definitions(definitions: dict[sp.Symbol, sp.Expr]) -> None:
    latex = R"\begin{array}{rcl}" + "\n"
    for symbol, expr in definitions.items():
        lhs = sp.latex(symbol)
        rhs = sp.latex(expr)
        latex += Rf"  {lhs} & = & {rhs} \\" + "\n"
    latex += R"\end{array}"
    display(Math(latex))

## Decay visualization

::::{margin}

:::{tip}

Particle properties of $\Delta^{**}$, $\Lambda^{**}$, and $K^{**}$ are determined here.

:::

::::

In [None]:
p = PDG["p"]
K = PDG["K-"]
π = PDG["pi+"]
Λc = create_particle(
    PDG["Lambda(c)+"],
    name="Λc⁺",
)
K_star = create_particle(
    PDG["K*(892)0"],
    name="K*",
    latex="K^*",
)
Λ_star = create_particle(
    PDG["Lambda(1405)"],
    name="Λ**",
    latex=R"\Lambda^{**}",
)
Δ_star = create_particle(
    PDG["Delta(1232)++"],
    name="Δ**",
    latex=R"\Delta^{**}",
)
particle_db = ParticleCollection({Λc, p, K, π, K_star, Λ_star, Δ_star})

In [None]:
reaction = qrules.generate_transitions(
    initial_state="Λc⁺",
    final_state=["p", "pi+", "K-"],
    particle_db=particle_db,
)

In [None]:
for g in reaction.transition_groups:
    dot = qrules.io.asdot(g.transitions, collapse_graphs=True, size=3.6)
    graph = graphviz.Source(dot)
    display(graph)

Allowed $LS$-couplings:

In [None]:
def filter_transitions(resonance_name):
    return [
        transition
        for transition in reaction.transitions
        if transition.states[3].particle.name.startswith(resonance_name)
    ]


style = dict(render_node=True, strip_spin=True, size=6)
display(
    graphviz.Source(asdot(filter_transitions("Λ"), **style)),
    graphviz.Source(asdot(filter_transitions("Δ"), **style)),
    graphviz.Source(asdot(filter_transitions("K"), **style)),
)

## Amplitude model

### SymPy implementation of equations

#### Equation (1)

In [None]:
A_k = sp.IndexedBase(R"A^K")
A_l = sp.IndexedBase(R"A^{\Lambda}")
A_d = sp.IndexedBase(R"A^{\Delta}")

_nu = sp.Symbol(R"\nu^{\prime}", rational=True)
_lambda = sp.Symbol(R"\lambda^{\prime}", rational=True)
half = sp.S.Half

zeta_0_1 = sp.Symbol(R"\zeta^0_{1(1)}", real=True)
zeta_0_2 = sp.Symbol(R"\zeta^0_{2(1)}", real=True)
zeta_0_3 = sp.Symbol(R"\zeta^0_{3(1)}", real=True)
zeta_1_1 = sp.Symbol(R"\zeta^1_{1(1)}", real=True)
zeta_1_2 = sp.Symbol(R"\zeta^1_{2(1)}", real=True)
zeta_1_3 = sp.Symbol(R"\zeta^1_{3(1)}", real=True)


def formulate_aligned_amplitude(Λc_helicity, p_helicity):
    return PoolSum(
        A_k[_nu, _lambda]
        * Wigner.d(half, Λc_helicity, _nu, zeta_0_1)
        * Wigner.d(half, _lambda, p_helicity, zeta_1_1)
        + A_l[_nu, _lambda]
        * Wigner.d(half, Λc_helicity, _nu, zeta_0_1)
        * Wigner.d(half, _lambda, p_helicity, zeta_1_1)
        + A_d[_nu, _lambda]
        * Wigner.d(half, Λc_helicity, _nu, zeta_0_3)
        * Wigner.d(half, _lambda, p_helicity, zeta_1_3),
        (_lambda, [-half, +half]),
        (_nu, [-half, +half]),
    )

In [None]:
nu = sp.Symbol("nu")
lam = sp.Symbol("lambda")
formulate_aligned_amplitude(Λc_helicity=nu, p_helicity=lam)

#### Equations (2-4)

In [None]:
H_K_prod = sp.IndexedBase(R"\mathcal{H}^{\Lambda_c \to K^{**}p}")
H_K_dec = sp.IndexedBase(R"\mathcal{H}^{K^{**} \to \pi K}")
H_Λ_prod = sp.IndexedBase(R"\mathcal{H}^{\Lambda_c \to \Lambda^{**}\pi}")
H_Λ_dec = sp.IndexedBase(R"\mathcal{H}^{\Lambda^{**} \to K p}")
H_Δ_prod = sp.IndexedBase(R"\mathcal{H}^{\Lambda_c \to \Delta^{**}K}")
H_Δ_dec = sp.IndexedBase(R"\mathcal{H}^{\Delta^{**} \to p \pi}")

theta23 = sp.Symbol("theta23", real=True)
theta31 = sp.Symbol("theta31", real=True)
theta12 = sp.Symbol("theta12", real=True)

_j, _tau = sp.symbols("j tau", rational=True)


def formulate_K_amplitude(Λc_helicity, p_helicity, j_values: list):
    j_values = list(map(sp.Rational, j_values))
    return PoolSum(
        sp.KroneckerDelta(Λc_helicity, _tau - p_helicity)
        * H_K_prod[_tau, p_helicity]
        * (-1) ** (half - p_helicity)
        * Wigner.d(_j, p_helicity, 0, theta23)
        * H_K_dec[0, 0],
        (_j, j_values),
        (_tau, [-half, +half]),
    )


def formulate_Λ_amplitude(Λc_helicity, p_helicity, j_values: list):
    j_values = list(map(sp.Rational, j_values))
    return PoolSum(
        sp.KroneckerDelta(Λc_helicity, _tau)
        * H_Λ_prod[_tau, 0]
        * Wigner.d(_j, _tau, -p_helicity, theta31)
        * H_Λ_dec[0, 0]
        * (-1) ** (_j - p_helicity),
        (_j, j_values),
        (_tau, [-half, +half]),
    )


def formulate_Δ_amplitude(Λc_helicity, p_helicity, j_values: list):
    j_values = list(map(sp.Rational, j_values))
    return PoolSum(
        sp.KroneckerDelta(Λc_helicity, _tau)
        * H_Δ_prod[_tau, 0]
        * Wigner.d(_j, _tau, p_helicity, theta12)
        * H_Δ_dec[0, 0],
        (_j, j_values),
        (_tau, [-half, +half]),
    )

In [None]:
display(
    formulate_K_amplitude(nu, lam, j_values=[K_star.spin]),
    formulate_Λ_amplitude(nu, lam, j_values=[Λ_star.spin]),
    formulate_Δ_amplitude(nu, lam, j_values=[Δ_star.spin]),
)

#### Equations (5-7)

In [None]:
# Equations (6-7)
couplings = {
    H_Λ_dec[0, half]: 1,
    H_Δ_dec[half, 0]: 1,
    H_K_dec[0, 0]: 1,
    H_Λ_dec[0, -half]: -Λ_star.parity
    * (-1) ** sp.Rational(Λ_star.spin - 1 / 2),
    H_Δ_dec[-half, 0]: -Δ_star.parity
    * (-1) ** sp.Rational(Δ_star.spin - 1 / 2),
}
# Equation (5)
couplings.update({})  # todo
couplings = {
    substitute_indexed_symbols(s): expr for s, expr in couplings.items()
}
display_definitions(couplings)

:::{todo}

Need to implement the remaining coupling definitions from Equation (5).

:::

### Combining all definitions

Incoherent sum of the amplitudes defined by {ref}`report/017:Equation (1)`:

In [None]:
top_expr = PoolSum(
    sp.Abs(formulate_aligned_amplitude(lam, nu)) ** 2,
    (lam, [-half, +half]),
    (nu, [-half, +half]),
)
top_expr

Remaining {attr}`~sympy.core.basic.Basic.free_symbols` are indeed the specific amplitudes as defined by {ref}`report/017:Equations (2-4)`:

In [None]:
evaluated_top_expr = substitute_indexed_symbols(top_expr.doit())
sp.Matrix(sorted(evaluated_top_expr.free_symbols, key=str)).reshape(4, 4)

The specific amplitudes from {ref}`report/017:Equations (2-4)` need to be formulated for each value of $\nu, \lambda$, so that they can be substituted in the top expression:

In [None]:
amp_definitions = {}
for Λc_heli, p_heli in itertools.product([-half, +half], [-half, +half]):
    symbol = substitute_indexed_symbols(A_d[Λc_heli, p_heli])
    expr = formulate_Δ_amplitude(nu, lam, j_values=[Δ_star.spin])
    amp_definitions[symbol] = expr.subs({nu: Λc_heli, lam: p_heli})
for Λc_heli, p_heli in itertools.product([-half, +half], [-half, +half]):
    symbol = substitute_indexed_symbols(A_l[Λc_heli, p_heli])
    expr = formulate_Λ_amplitude(nu, lam, j_values=[Λ_star.spin])
    amp_definitions[symbol] = expr.subs({nu: Λc_heli, lam: p_heli})
for Λc_heli, p_heli in itertools.product([-half, +half], [-half, +half]):
    symbol = substitute_indexed_symbols(A_k[Λc_heli, p_heli])
    expr = formulate_K_amplitude(nu, lam, j_values=[K_star.spin])
    amp_definitions[symbol] = expr.subs({nu: Λc_heli, lam: p_heli})
display_definitions(amp_definitions)

The remaining {attr}`~sympy.core.basic.Basic.free_symbols` are indeed the couplings as defined by {ref}`report/017:Equations (5-7)`:

In [None]:
amp_definitions_eval = {
    s: substitute_indexed_symbols(expr.doit())
    for s, expr in amp_definitions.items()
}
defined_top_expr = evaluated_top_expr.subs(amp_definitions_eval)
sp.Matrix(sorted(defined_top_expr.free_symbols, key=str)).reshape(4, 3)

In [None]:
amplitude_expr = defined_top_expr.subs(couplings)
sp.Matrix(sorted(amplitude_expr.free_symbols, key=str)).reshape(4, 3)

:::{todo}

Some coupling values are still missing, which leaves some undefined {attr}`~sympy.core.basic.Basic.free_symbols`.

:::

In [None]:
Math(
    sp.multiline_latex(sp.Symbol("A"), amplitude_expr, environment="eqnarray")
)

## Polarization

:::{todo}

Formulate Section 2 with [SymPy](https://docs.sympy.org).

:::

## Computations with TensorWaves


The full [expression tree](https://docs.sympy.org/latest/tutorial/manipulation.html) can be converted to a computational function as follows. Note that _all_ {attr}`~sympy.core.basic.Basic.free_symbols` become arguments to the function:

In [None]:
from tensorwaves.function.sympy import create_function

func = create_function(amplitude_expr, backend="jax")
func.argument_order

Optionally, some symbols {class}`~sympy.core.symbol.Symbol`s can be identified as (fit) parameters:

In [None]:
from tensorwaves.function.sympy import create_parametrized_function

parameter_defaults = {
    zeta_0_1: 0,
    zeta_0_3: 0,
    zeta_1_1: 0,
    zeta_1_3: 0,
}
par_func = create_parametrized_function(
    amplitude_expr,
    parameters=parameter_defaults,
    backend="jax",
)

:::{todo}

Visualize $I(\alpha, \beta, \gamma, m_{K\pi}, m_{K\pi})$ distribution, potentially with [`ipywidgets`](https://ipywidgets.readthedocs.io).

:::