# Manuallymodelling amplitudes using `sympy`

In this notebook, we formulate the amplitude model for the $\gamma p \to K^+ \pi^0 \Lambda $ symbolically by adapting the model originally for the $\gamma p \to \eta\pi^0 p$ channel example as described in [Reaction and Models](reaction-model.md).

The model we want to implement is

$$
\begin{array}{rcl}
I &=& \left|A^{12} + A^{23} + A^{31}\right|^2 \\
A^{12} &=& \frac{\sum a_m Y_2^m (\Omega_1)}{s_{12}-m^2_{K^{*+}_2}+im_{K^{*+}_2} \Gamma_{K^{*+}_2}} \\
A^{23} &=& \frac{\sum b_m Y_1^m (\Omega_2)}{s_{23}-m^2_{\Sigma^*}+im_{\Sigma^*} \Gamma_{\Sigma^*}} \\
A^{31} &=& \frac{c_0}{s_{31}-m^2_{N^{*+}}+im_{N^{*+}} \Gamma_{N^{*+}}} \,,
\end{array}
$$

where $1=K^+$, $2=\pi^0$, and $3=\Lambda$.

In [None]:
from __future__ import annotations

import logging
import os
import warnings
from collections import defaultdict

import ipywidgets as w
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from ampform.io import aslatex
from ampform.kinematics.angles import Phi, Theta
from ampform.kinematics.lorentz import (
    ArrayMultiplication,
    ArraySize,
    BoostZMatrix,
    Energy,
    EuclideanNorm,
    FourMomentumSymbol,
    RotationYMatrix,
    RotationZMatrix,
    ThreeMomentum,
    three_momentum_norm,
)
from ampform.sympy import unevaluated
from ampform.sympy._array_expressions import ArraySum
from IPython.display import SVG, Image, Latex, display
from tensorwaves.data import (
    SympyDataTransformer,
    TFPhaseSpaceGenerator,
    TFUniformRealNumberGenerator,
)
from tensorwaves.function.sympy import create_parametrized_function

STATIC_PAGE = "EXECUTE_NB" in os.environ

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
logging.disable(logging.WARNING)
warnings.filterwarnings("ignore")

## Model implementation

In [None]:
l_max = 2

### $A^{12}$

In [None]:
s12, m_Kstar2, Gamma_Kstar2, l12 = sp.symbols(r"s_{12} m_{K^*_2} \Gamma_{K^*_2} l_{12}")
theta1, phi1 = sp.symbols("theta_1 phi_1")
a = sp.IndexedBase("a")
m = sp.symbols("m", cls=sp.Idx)
A12 = sp.Sum(a[m] * sp.Ynm(l12, m, theta1, phi1), (m, -l12, l12)) / (
    s12 - m_Kstar2**2 + sp.I * m_Kstar2 * Gamma_Kstar2
)
A12

In [None]:
A12_funcs = [
    sp.lambdify(
        [
            s12,
            *(a[j] for j in range(-l_max, l_max + 1)),
            m_Kstar2,
            Gamma_Kstar2,
            theta1,
            phi1,
        ],
        expr=A12.subs(l12, i).doit().expand(func=True),
    )
    for i in range(l_max + 1)
]
A12_funcs

### $A^{23}$

In [None]:
s23, m_Sigma, Gamma_Sigma, l23 = sp.symbols(
    r"s_{23} m_{\Sigma^{*+}} \Gamma_{\Sigma^{*+}} l_{23}"
)
b = sp.IndexedBase("b")
m = sp.symbols("m", cls=sp.Idx)
theta2, phi2 = sp.symbols("theta_2 phi_2")
A23 = sp.Sum(b[m] * sp.Ynm(l23, m, theta2, phi2), (m, -l23, l23)) / (
    s23 - m_Sigma**2 + sp.I * m_Sigma * Gamma_Sigma
)
A23

In [None]:
A23_funcs = [
    sp.lambdify(
        [
            s23,
            *(b[j] for j in range(-l_max, l_max + 1)),
            m_Sigma,
            Gamma_Sigma,
            theta2,
            phi2,
        ],
        A23.subs(l23, i).doit().expand(func=True),
    )
    for i in range(l_max + 1)
]
A23_funcs

### $A^{31}$

In [None]:
s31, m_Nstar, Gamma_Nstar = sp.symbols(r"s_{31} m_{N^*} \Gamma_{N^*}")
c = sp.IndexedBase("c")
theta3, phi3, l31 = sp.symbols("theta_3 phi_3 l_{31}")
A31 = sp.Sum(c[m] * sp.Ynm(l31, m, theta3, phi3), (m, -l31, l31)) / (
    s31 - m_Nstar**2 + sp.I * m_Nstar * Gamma_Nstar
)
A31

In [None]:
A31_funcs = [
    sp.lambdify(
        [
            s31,
            *(c[j] for j in range(-l_max, l_max + 1)),
            m_Nstar,
            Gamma_Nstar,
            theta3,
            phi3,
        ],
        A31.subs(l31, i).doit().expand(func=True),
    )
    for i in range(l_max + 1)
]
A31_funcs

### $I = |A|^2 = |A^{12}+A^{23}+A^{31}|^2$

In [None]:
intensity_expr = sp.Abs(A12 + A23 + A31) ** 2
intensity_expr

### Phase Space Generation

Mass for $p\gamma$ system

In [None]:
E_lab_gamma = 8.5
m_proton = 0.938
m_0 = np.sqrt(2 * E_lab_gamma * m_proton + m_proton**2)
m_lambda = 1.12
m_k = 0.494
m_pi = 0.135
m_0

In [None]:
rng = TFUniformRealNumberGenerator(seed=0)
phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=m_0,
    final_state_masses={1: m_k, 2: m_pi, 3: m_lambda},
)
phsp_momenta = phsp_generator.generate(500_000, rng)

### Kinematic variables

In [None]:
@unevaluated
class SquaredInvariantMass(sp.Expr):
    momentum: sp.Basic
    _latex_repr_ = "m_{{{momentum}}}^2"

    def evaluate(self) -> sp.Expr:
        p = self.momentum
        p_xyz = ThreeMomentum(p)
        return Energy(p) ** 2 - EuclideanNorm(p_xyz) ** 2


def formulate_helicity_angles(
    pi: FourMomentumSymbol, pj: FourMomentumSymbol
) -> tuple[Theta, Phi]:
    pij = ArraySum(pi, pj)
    beta = three_momentum_norm(pij) / Energy(pij)
    Rz = RotationZMatrix(-Phi(pij), n_events=ArraySize(Phi(pij)))
    Ry = RotationYMatrix(-Theta(pij), n_events=ArraySize(Theta(pij)))
    Bz = BoostZMatrix(beta, n_events=ArraySize(beta))
    pi_boosted = ArrayMultiplication(Bz, Ry, Rz, pi)
    return Theta(pi_boosted), Phi(pi_boosted)

In [None]:
p1 = FourMomentumSymbol("p1", shape=[])
p2 = FourMomentumSymbol("p2", shape=[])
p3 = FourMomentumSymbol("p3", shape=[])
p12 = ArraySum(p1, p2)
p23 = ArraySum(p2, p3)
p31 = ArraySum(p3, p1)

theta1_expr, phi1_expr = formulate_helicity_angles(p1, p2)
theta2_expr, phi2_expr = formulate_helicity_angles(p2, p3)
theta3_expr, phi3_expr = formulate_helicity_angles(p3, p1)

s12_expr = SquaredInvariantMass(p12)
s23_expr = SquaredInvariantMass(p23)
s31_expr = SquaredInvariantMass(p31)

In [None]:
kinematic_variables = {
    theta1: theta1_expr,
    theta2: theta2_expr,
    theta3: theta3_expr,
    phi1: phi1_expr,
    phi2: phi2_expr,
    phi3: phi3_expr,
    s12: s12_expr,
    s23: s23_expr,
    s31: s31_expr,
}

Latex(aslatex(kinematic_variables))

In [None]:
helicity_transformer = SympyDataTransformer.from_sympy(
    kinematic_variables, backend="jax"
)

In [None]:
phsp = helicity_transformer(phsp_momenta)
list(phsp)

### Parameters

In [None]:
a_vals = [0, 1.0, 3.0, 3.5, 2.0]  # Slight adjustment to emphasize higher waves
b_vals = [0, -1.0, 3.5, 0.5, 0]  # Adjust for new final state coupling
c_vals = [0, 0, 3.0, 0, 0]  # Adjust if more s-wave or p-wave is expected

m_Kstar2_val = 1.43  # K*2(1430)
m_Sigma_val = 1.66  # Sigma(1385) Sigma(1660), (1670)
m_Nstar_val = 1.90  # N(1710) N(1720) N(1900)

Gamma_Kstar2_val = 0.1
Gamma_Sigma_val = 0.1
Gamma_Nstar_val = 0.1

l12_val = 2  # I still use 2 assuming K^*_2
l23_val = 1
l31_val = 0

In [None]:
parameters_default = {
    m_Kstar2: m_Kstar2_val,
    m_Sigma: m_Sigma_val,
    m_Nstar: m_Nstar_val,
    Gamma_Kstar2: Gamma_Kstar2_val,
    Gamma_Sigma: Gamma_Sigma_val,
    Gamma_Nstar: Gamma_Nstar_val,
    l12: l12_val,
    l23: l23_val,
    l31: l31_val,
}

a_dict = {a[i]: a_vals[i + l_max] for i in range(-l_max, l_max + 1)}
b_dict = {b[i]: b_vals[i + l_max] for i in range(-l_max, l_max + 1)}
c_dict = {c[i]: c_vals[i + l_max] for i in range(-l_max, l_max + 1)}
parameters_default.update(a_dict)
parameters_default.update(b_dict)
parameters_default.update(c_dict)

Latex(aslatex(parameters_default))

## Visulization

In [None]:
sliders = {}
categorized_sliders_m = defaultdict(list)
categorized_sliders_gamma = defaultdict(list)
categorized_cphi_pair = defaultdict(list)
categorized_sliders_l = defaultdict(list)

for symbol, value in parameters_default.items():
    if symbol.name.startswith(R"\Gamma_{"):
        slider = w.FloatSlider(
            description=Rf"\({sp.latex(symbol)}\)",
            min=0.0,
            max=1.0,
            step=0.01,
            value=value,
            continuous_update=False,
        )
        sliders[symbol.name] = slider
        if symbol.name.startswith(R"\Gamma_{N"):
            categorized_sliders_gamma[0].append(slider)
        elif symbol.name.startswith(R"\Gamma_{\S"):
            categorized_sliders_gamma[1].append(slider)
        elif symbol.name.startswith(R"\Gamma_{K"):
            categorized_sliders_gamma[2].append(slider)

    elif symbol.name.startswith("m_{"):
        slider = w.FloatSlider(
            description=Rf"\({sp.latex(symbol)}\)",
            min=0.63,
            max=4,
            step=0.01,
            value=value,
            continuous_update=False,
        )
        sliders[symbol.name] = slider
        if symbol.name.startswith("m_{N"):
            categorized_sliders_m[0].append(slider)
        elif symbol.name.startswith(R"m_{\S"):
            categorized_sliders_m[1].append(slider)
        elif symbol.name.startswith("m_{K"):
            categorized_sliders_m[2].append(slider)

    elif isinstance(symbol, sp.Indexed):
        c_latex = sp.latex(symbol)
        phi_latex = Rf"\phi_{{{c_latex}}}"
        slider_c = w.FloatSlider(
            description=Rf"\({c_latex}\)",
            min=0,
            max=10,
            step=0.01,
            value=abs(value),
            continuous_update=False,
        )
        slider_phi = w.FloatSlider(
            description=Rf"\({phi_latex}\)",
            min=-np.pi,
            max=+np.pi,
            step=0.01,
            value=np.angle(value),
            continuous_update=False,
        )

        sliders[symbol.name] = slider_c
        sliders[f"phi_{symbol.name}"] = slider_phi
        cphi_hbox = w.HBox([slider_c, slider_phi])
        if symbol.base is a:
            categorized_cphi_pair[2].append(cphi_hbox)
        elif symbol.base is b:
            categorized_cphi_pair[1].append(cphi_hbox)
        elif symbol.base is c:
            categorized_cphi_pair[0].append(cphi_hbox)
        else:
            raise NotImplementedError(symbol.name)

    elif symbol.name.startswith("l_{12}"):
        slider = w.IntSlider(
            value=2,
            min=0,
            max=2,
            step=1,
            description="ℓ₁₂",
            disabled=False,
            continuous_update=False,
            orientation="horizontal",
            readout=True,
            readout_format="d",
        )
        sliders[symbol.name] = slider
        categorized_sliders_l[2].append(slider)
    elif symbol.name.startswith("l_{23}"):
        slider = w.IntSlider(
            value=1,
            min=0,
            max=2,
            step=1,
            description="ℓ₂₃",
            disabled=False,
            continuous_update=False,
            orientation="horizontal",
            readout=True,
            readout_format="d",
        )
        sliders[symbol.name] = slider
        categorized_sliders_l[1].append(slider)
    elif symbol.name.startswith("l_{31}"):
        slider = w.IntSlider(
            value=0,
            min=0,
            max=2,
            step=1,
            description="ℓ₃₁",
            disabled=False,
            continuous_update=False,
            orientation="horizontal",
            readout=True,
            readout_format="d",
        )
        sliders[symbol.name] = slider
        categorized_sliders_l[0].append(slider)
    else:
        raise NotImplementedError(symbol.name)
tab_contents = []
resonances_name = ["N*   [ΛK⁺ (31)]", "Σ*  [π⁰Λ(23)]", "K*  [K⁺π⁰(12)]"]
for i in range(len(resonances_name)):
    tab_content = w.VBox([
        w.HBox(categorized_sliders_m[i] + categorized_sliders_gamma[i]),
        w.VBox(categorized_cphi_pair[i]),
    ])
    tab_contents.append(tab_content)
UI_lsliders = w.HBox([
    *categorized_sliders_l[0],
    *categorized_sliders_l[1],
    *categorized_sliders_l[2],
])
UI_1 = w.Tab(tab_contents, titles=resonances_name)
UI = w.VBox([UI_lsliders, UI_1])

In [None]:
intensity_funcs = np.array([
    [
        [
            create_parametrized_function(
                expression=intensity_expr.xreplace({l12: i, l23: j, l31: k})
                .doit()
                .expand(func=True),
                parameters=parameters_default,
                backend="jax",
            )
            for i in range(l_max + 1)
        ]
        for j in range(l_max + 1)
    ]
    for k in range(l_max + 1)
])
intensity_funcs.shape

In [None]:
def insert_phi(parameters: dict) -> dict:
    updated_parameters = {}
    for key, value in parameters.items():
        if key.startswith("phi_"):
            continue
        if key.startswith(("a", "b", "c")):
            phi_key = f"phi_{key}"
            if phi_key in parameters:
                phi = parameters[phi_key]
                value *= np.exp(1j * phi)  # noqa:PLW2901
        updated_parameters[key] = value

    return updated_parameters

In [None]:
%matplotlib widget
%config InlineBackend.figure_formats = ['png']
fig_2d, ax_2d = plt.subplots(dpi=150)
ax_2d.set_title("Model-weighted Phase space Dalitz Plot")
ax_2d.set_xlabel(R"$m^2(\Lambda K^+)\;\left[\mathrm{GeV}^2\right]$")
ax_2d.set_ylabel(R"$m^2(K^+ \pi^0)\;\left[\mathrm{GeV}^2\right]$")

mesh = None


def update_histogram(**parameters):
    global mesh

    l12para = parameters["l_{12}"]
    l23para = parameters["l_{23}"]
    l31para = parameters["l_{31}"]
    intensity_func = intensity_funcs[l12para, l23para, l31para]
    parameters = insert_phi(parameters)

    intensity_func.update_parameters(parameters)
    intensity_weights = intensity_func(phsp)
    bin_values, xedges, yedges = jnp.histogram2d(
        phsp["s_{31}"],
        phsp["s_{12}"],
        bins=200,
        weights=intensity_weights,
        density=True,
    )
    bin_values = jnp.where(bin_values < 1e-10, jnp.nan, bin_values)
    x, y = jnp.meshgrid(xedges[:-1], yedges[:-1])
    if mesh is None:
        mesh = ax_2d.pcolormesh(x, y, bin_values.T, cmap="jet", vmax=0.15)
    else:
        mesh.set_array(bin_values.T)
    fig_2d.canvas.draw_idle()


interactive_plot = w.interactive_output(
    update_histogram,
    {**sliders},
)
fig_2d.tight_layout()
fig_2d.colorbar(mesh, ax=ax_2d)

if STATIC_PAGE:
    filename = "dalitz-plot-man.png"
    fig_2d.savefig(filename)
    plt.close(fig_2d)
    display(UI, Image(filename))
else:
    display(UI, interactive_plot)

In [None]:
%matplotlib widget
%config InlineBackend.figure_formats = ['svg']

theta_subtitles = [
    R"$\cos (\theta_{1}^{{12}}) \equiv \cos (\theta_{K^+}^{{K^+ \pi^0}})$",
    R"$\cos (\theta_{2}^{{23}}) \equiv \cos (\theta_{\pi^0}^{{\pi^0 \Lambda}})$",
    R"$\cos (\theta_{3}^{{31}}) \equiv \cos (\theta_{\Lambda}^{{\Lambda K^+}})$",
]
phi_subtitles = [
    R"$\phi_1^{12} \equiv \phi_{K^+}^{{K^+ \pi^0}}$",
    R"$\phi_2^{23} \equiv \phi_{\pi^0}^{{\pi^0 \Lambda}}$",
    R"$\phi_3^{31} \equiv \phi_{\Lambda}^{{\Lambda K^+}}$",
]
mass_subtitles = [
    R"$m_{12} \equiv m_{K^+ \pi^0}$",
    R"$m_{23} \equiv m_{\pi^0 \Lambda}$",
    R"$m_{31} \equiv m_{\Lambda K^+}$",
]

fig, (theta_axes, phi_axes, mass_axes) = plt.subplots(figsize=(12, 8), ncols=3, nrows=3)
fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False

for i, ax in enumerate(theta_axes):
    ax.set_title(theta_subtitles[i])
    ax.set_xticks([-1, 0, 1])

for i, ax in enumerate(phi_axes):
    ax.set_title(phi_subtitles[i])
    ax.set_xticks([-np.pi, 0, np.pi])
    ax.set_xticklabels([R"-$\pi$", 0, R"$\pi$"])

for i, ax in enumerate(mass_axes):
    ax.set_title(mass_subtitles[i])
    ax.set_xlabel("Mass [GeV]")

LINES = 3 * [None]
THETA_LINES = 3 * [None]
PHI_LINES = 3 * [None]
RESONANCE_LINES = defaultdict(dict)
RESONANCES_MASS_NAME = [m_Kstar2, m_Sigma, m_Nstar]


def update_plot(**parameters):  # noqa: C901, PLR0912, PLR0914
    l12para = parameters["l_{12}"]
    l23para = parameters["l_{23}"]
    l31para = parameters["l_{31}"]
    intensity_func = intensity_funcs[l12para, l23para, l31para]
    parameters = insert_phi(parameters)
    intensity_func.update_parameters(parameters)
    intensities = intensity_func(phsp)
    max_value_theta = 0.0
    max_value_phi = 0.0
    max_value_mass = 0.0
    theta_keys = ["theta_1", "theta_2", "theta_3"]
    phi_keys = ["phi_1", "phi_2", "phi_3"]
    m2_keys = ["s_{12}", "s_{23}", "s_{31}"]
    line_labels = [R"$m_{K^{*+}_2}$", R"$m_{\Sigma^*}$", R"$m_{N^*}$"]
    line_colors = ["r", "g", "b"]
    plot_style = dict(bins=120, weights=intensities, density=True)

    for i, ax in enumerate(mass_axes):
        bin_values, bin_edges = jnp.histogram(np.sqrt(phsp[m2_keys[i]]), **plot_style)
        max_value_mass = max(max_value_mass, bin_values.max())

        if LINES[i] is None:
            LINES[i] = ax.step(bin_edges[:-1], bin_values, alpha=0.5)[0]
        else:
            LINES[i].set_ydata(bin_values)

        symbol_key = sp.latex([m_Kstar2, m_Sigma, m_Nstar][i])
        val = parameters[symbol_key]
        resonance_line = RESONANCE_LINES[i].get(symbol_key)
        if resonance_line is None:
            RESONANCE_LINES[i][symbol_key] = ax.axvline(
                val, color=line_colors[i], linestyle="--", label=line_labels[i]
            )
        else:
            resonance_line.set_xdata([val, val])

    for i, ax in enumerate(theta_axes):
        bin_values, bin_edges = jnp.histogram(np.cos(phsp[theta_keys[i]]), **plot_style)
        max_value_theta = max(max_value_theta, bin_values.max())
        if THETA_LINES[i] is None:
            THETA_LINES[i] = ax.step(bin_edges[:-1], bin_values, alpha=0.5)[0]
        else:
            THETA_LINES[i].set_ydata(bin_values)

    for i, ax in enumerate(phi_axes):
        bin_values, bin_edges = jnp.histogram(phsp[phi_keys[i]], **plot_style)
        max_value_phi = max(max_value_phi, bin_values.max())
        if PHI_LINES[i] is None:
            PHI_LINES[i] = ax.step(bin_edges[:-1], bin_values, alpha=0.5)[0]
        else:
            PHI_LINES[i].set_ydata(bin_values)

    for ax in mass_axes:
        ax.set_ylim(0, max_value_mass * 1.1)
        ax.legend()

    for ax in theta_axes:
        ax.set_ylim(0, max_value_theta * 1.1)

    for ax in phi_axes:
        ax.set_ylim(0, max_value_phi * 1.1)


interactive_plot = w.interactive_output(update_plot, sliders)
fig.tight_layout()

if STATIC_PAGE:
    filename = "1d-histograms-man.svg"
    fig.savefig(filename)
    plt.close(fig)
    display(UI, SVG(filename))
else:
    display(UI, interactive_plot)