# Wigner-D simplification

In [None]:
import logging

import ampform
import qrules
import sympy as sp
from ampform.helicity.align.dpd import DalitzPlotDecomposition, relabel_edge_ids
from ampform.io import aslatex
from ampform.kinematics.phasespace import Kallen, Kibble
from IPython.display import Math
from symplot import partial_doit

LOGGER = logging.getLogger()
LOGGER.setLevel(logging.ERROR)

In [None]:
reaction = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [-1, +1]),
    final_state=["K0", "Sigma+", "p~"],
    allowed_intermediate_particles=["Sigma(1660)", "N(1650)"],
    allowed_interaction_types=["strong"],
    formalism="helicity",
)

In [None]:
reaction_123 = relabel_edge_ids(reaction)
builder_123 = ampform.get_builder(reaction_123)
builder_123.config.spin_alignment = DalitzPlotDecomposition(reference_subsystem=1)
builder_123.config.scalar_initial_state_mass = True
builder_123.config.stable_final_state_ids = [1, 2, 3]
dpd_model = builder_123.formulate()
dpd_model.intensity

In [None]:
dpd_angles = {
    k: v for k, v in dpd_model.kinematic_variables.items() if "zeta" in str(k)
}
src = aslatex(dpd_angles)
Math(src)

In [None]:
m12, m13, m23 = (
    v for k, v in dpd_model.kinematic_variables.items() if str(k).startswith("m")
)


def simplify_zeta(zeta: sp.Expr):
    numerator, denominator = sp.fraction(sp.together(sp.sin(zeta)) ** 2)
    s = sp.symbols("sigma0:4")
    m = sp.symbols("m_0:4", nonnegative=True)
    unfolded_numerator = partial_doit(numerator, doit_classes=(Kallen,))
    s2_expr = sum(i**2 for i in m) - s[1] - s[3]
    simplified_numerator = (
        unfolded_numerator.xreplace({m12**2: s[1], m13**2: s[2], m23**2: s[3]})
        .subs(s[2], s2_expr)
        .expand()
        .factor()
    )
    simplified_expr = simplified_numerator / denominator
    return simplified_expr.xreplace({s[1]: m12**2, s[2]: m13**2, s[3]: m23**2})

In [None]:
simplify_zeta(list(dpd_angles.values())[0])

In [None]:
simplify_zeta(list(dpd_angles.values())[2])

In [None]:
def simplified_kibble():
    s = sp.symbols("sigma0:4")
    m = sp.symbols("m_0:4", nonnegative=True)
    kibble = Kibble(s[3], s[2], s[1], *m).doit()
    s2_expr = sum(i**2 for i in m) - s[1] - s[3]
    return kibble.subs(s[2], s2_expr).expand().factor()


simplified_kibble()