# Amplitude model with `sympy`

This section is a follow-up of previous chapter:[Reaction and Models](reaction-model.md), to formulate the amplitude model for the $\gamma p \to \eta\pi^0 p$ channel example symbolically. See **[TR&#8209;033](https://compwa.github.io/report/033)** for a purely numerical tutorial.

The model we want to implement is

$$
\begin{array}{rcl}
I &=& \left|A^{12} + A^{23} + A^{31}\right|^2 \\
A^{12} &=& \frac{\sum a_m Y_2^m (\Omega_1)}{s_{12}-m^2_{a_2}+im_{a_2} \Gamma_{a_2}} \\
A^{23} &=& \frac{\sum b_m Y_1^m (\Omega_2)}{s_{23}-m^2_{\Delta^+}+im_{\Delta^+} \Gamma_{\Delta^+}} \\
A^{31} &=& \frac{c_0}{s_{31}-m^2_{N^*}+im_{N^*} \Gamma_{N^*}} \,,
\end{array}
$$

where $1=\eta$, $2=\pi^0$, and $3=p$.

In [None]:
from __future__ import annotations

import logging
import os
import warnings
from collections import defaultdict

import ipywidgets as w
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from ampform.io import aslatex
from ampform.kinematics.angles import Phi, Theta
from ampform.kinematics.lorentz import (
    ArrayMultiplication,
    ArraySize,
    BoostZMatrix,
    Energy,
    EuclideanNorm,
    FourMomentumSymbol,
    RotationYMatrix,
    RotationZMatrix,
    ThreeMomentum,
    three_momentum_norm,
)
from ampform.sympy import unevaluated
from ampform.sympy._array_expressions import ArraySum
from IPython.display import Image, Latex, display
from tensorwaves.data import (
    SympyDataTransformer,
    TFPhaseSpaceGenerator,
    TFUniformRealNumberGenerator,
)
from tensorwaves.function.sympy import create_parametrized_function

STATIC_PAGE = "EXECUTE_NB" in os.environ

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
logging.disable(logging.WARNING)
warnings.filterwarnings("ignore")

## Model implementation

### $A^{12}$

In [None]:
s12, m_a2, gamma_a2 = sp.symbols(r"s_{12} m_{a_2} \Gamma_{a_2}")
theta1, phi1 = sp.symbols("theta_1 phi_1")
a = sp.IndexedBase("a")
m = sp.symbols("m", cls=sp.Idx)
A12 = sp.Sum(a[m] * sp.Ynm(2, m, theta1, phi1), (m, -2, 2)) / (
    s12 - m_a2**2 + sp.I * m_a2 * gamma_a2
)
A12

In [None]:
sp.Ynm(2, 1, 1, 1).expand(func=False)

In [None]:
sp.Ynm(2, 1, 1, 1).expand(func=True)

In [None]:
A12.free_symbols

In [None]:
A12_func = sp.lambdify(
    [s12, a[-2], a[-1], a[0], a[1], a[2], m_a2, gamma_a2, theta1, phi1],
    A12.doit().expand(func=True),
)
A12_func

### $A^{23}$

In [None]:
s23, m_delta, gamma_delta = sp.symbols(r"s_{23} m_{\Delta^+} \Gamma_{\Delta^+}")
b = sp.IndexedBase("b")
m = sp.symbols("m", cls=sp.Idx)
theta2, phi2 = sp.symbols("theta_2 phi_2")
A23 = sp.Sum(b[m] * sp.Ynm(1, m, theta2, phi2), (m, -1, 1)) / (
    s23 - m_delta**2 + sp.I * m_delta * gamma_delta
)
A23

In [None]:
A23_func = sp.lambdify(
    [s23, b[-1], b[0], b[1], m_delta, gamma_delta, theta2, phi2],
    A23.doit().expand(func=True),
)

### $A^{31}$

In [None]:
s31, m_nstar, gamma_nstar = sp.symbols(r"s_{31} m_{N^*} \Gamma_{N^*}")
c = sp.IndexedBase("c")
A31 = c[0] / (s31 - m_nstar**2 + sp.I * m_nstar * gamma_nstar)
A31

:::{tip}
The expression above is originally deduced from $Y_l^m (\Omega_3)$ when $l=0$ as shown below, so that the numerator is absorbed and becomes $c_0$.
:::

In [None]:
theta3, phi3 = sp.symbols("theta_3 phi_3")
sp.Sum(c[m] * sp.Ynm(0, m, theta3, phi3), (m, 0, 0)) / (
    s31 - m_nstar**2 + sp.I * m_nstar * gamma_nstar
)

In [None]:
A31_func = sp.lambdify([s31, c[0], m_nstar, gamma_nstar], A31)

### $I = |A|^2 = |A^{12}+A^{23}+A^{31}|^2$

In [None]:
intensity_expr = sp.Abs(A12 + A23 + A31) ** 2
intensity_expr

### Phase Space Generation

Mass for $p\gamma$ system

In [None]:
E_lab_gamma = 8.5
m_proton = 0.938
m_0 = np.sqrt(2 * E_lab_gamma * m_proton + m_proton**2)
m_eta = 0.548
m_pi = 0.135
m_0

In [None]:
rng = TFUniformRealNumberGenerator(seed=0)
phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=m_0,
    final_state_masses={1: m_eta, 2: m_pi, 3: m_proton},
)
phsp_momenta = phsp_generator.generate(500_000, rng)

### Kinematic variables

In [None]:
@unevaluated
class SquaredInvariantMass(sp.Expr):
    momentum: sp.Basic
    _latex_repr_ = "m_{{{momentum}}}^2"

    def evaluate(self) -> sp.Expr:
        p = self.momentum
        p_xyz = ThreeMomentum(p)
        return Energy(p) ** 2 - EuclideanNorm(p_xyz) ** 2


def formulate_helicity_angles(
    pi: FourMomentumSymbol, pj: FourMomentumSymbol
) -> tuple[Theta, Phi]:
    pij = ArraySum(pi, pj)
    beta = three_momentum_norm(pij) / Energy(pij)
    Rz = RotationZMatrix(-Phi(pij), n_events=ArraySize(Phi(pij)))
    Ry = RotationYMatrix(-Theta(pij), n_events=ArraySize(Theta(pij)))
    Bz = BoostZMatrix(beta, n_events=ArraySize(beta))
    pi_boosted = ArrayMultiplication(Bz, Ry, Rz, pi)
    return Theta(pi_boosted), Phi(pi_boosted)

In [None]:
p1 = FourMomentumSymbol("p1", shape=[])
p2 = FourMomentumSymbol("p2", shape=[])
p3 = FourMomentumSymbol("p3", shape=[])
p12 = ArraySum(p1, p2)
p23 = ArraySum(p2, p3)
p31 = ArraySum(p3, p1)

theta1_expr, phi1_expr = formulate_helicity_angles(p1, p2)
theta2_expr, phi2_expr = formulate_helicity_angles(p2, p3)
theta3_expr, phi3_expr = formulate_helicity_angles(p3, p1)

s12_expr = SquaredInvariantMass(p12)
s23_expr = SquaredInvariantMass(p23)
s31_expr = SquaredInvariantMass(p31)

In [None]:
kinematic_variables = {
    theta1: theta1_expr,
    theta2: theta2_expr,
    theta3: theta3_expr,
    phi1: phi1_expr,
    phi2: phi2_expr,
    phi3: phi3_expr,
    s12: s12_expr,
    s23: s23_expr,
    s31: s31_expr,
}

Latex(aslatex(kinematic_variables))

In [None]:
helicity_transformer = SympyDataTransformer.from_sympy(
    kinematic_variables, backend="jax"
)

In [None]:
phsp = helicity_transformer(phsp_momenta)
list(phsp)

### Parameters

In [None]:
a_vals = [0, 0.5, 3.5, 4, 2.5]
b_vals = [-1.5, 4, 0.5]
c0_val = 2.5
m_a2_val = 1.32
m_delta_val = 1.54
m_nstar_val = 1.87
gamma_a2_val = 0.1
gamma_delta_val = 0.1
gamma_nstar_val = 0.1

parameters_default = {
    a[-2]: a_vals[0],
    a[-1]: a_vals[1],
    a[0]: a_vals[2],
    a[1]: a_vals[3],
    a[2]: a_vals[4],
    b[-1]: b_vals[0],
    b[0]: b_vals[1],
    b[1]: b_vals[2],
    c[0]: c0_val,
    m_a2: m_a2_val,
    m_delta: m_delta_val,
    m_nstar: m_nstar_val,
    gamma_a2: gamma_a2_val,
    gamma_delta: gamma_delta_val,
    gamma_nstar: gamma_nstar_val,
}

Latex(aslatex(parameters_default))

:::{note}
The mass and width of resonances are customsed to make the resonance bands in a better visible form.
:::

In [None]:
isinstance(a[-2], sp.Indexed)

In [None]:
isinstance(m_a2, sp.Indexed)

In [None]:
type(m_a2).__mro__

In [None]:
type(a[-2]).__mro__

## Visualization

### Model components

In [None]:
phi = np.pi / 4
theta = np.pi / 4
s_values = np.linspace(0, 5, num=500)
A12_values = A12_func(s_values, *a_vals, m_a2_val, gamma_a2_val, theta, phi)
A23_values = A23_func(s_values, *b_vals, m_delta_val, gamma_delta_val, theta, phi)
A31_values = A31_func(s_values, c0_val, m_nstar_val, gamma_nstar_val)

In [None]:
%config InlineBackend.figure_formats = ['svg']

In [None]:
fig, axes = plt.subplots(figsize=(10, 9), nrows=3)
angle_projection_text = R"$\phi=\frac{\pi}{4},\theta=\frac{\pi}{4}$"
for i, A_values in enumerate([A12_values, A23_values, A31_values]):
    ax = axes[i]
    recoil_id = i + 1
    decay_products = sorted({1, 2, 3} - {recoil_id})
    subsystem = "".join(map(str, decay_products))
    ax.plot(s_values, A_values.imag, label="Imaginary Part", c="blue", ls="-")
    ax.plot(s_values, A_values.real, label="Real Part", c="red", ls="--")
    ax.plot(
        s_values,
        np.abs(A_values) ** 2,
        label=f"Intensity $I^{{{subsystem}}}=|A^{{{subsystem}}}|^2$",
        c="g",
        ls="-.",
    )
    ax.plot(
        s_values,
        np.angle(A_values),
        label="Phase (angle)",
        c="m",
        ls=":",
    )
    ax.set_title(Rf"Components of $A^{{{subsystem}}}$ vs s ({angle_projection_text})")
    ax.set_xlabel(f"$s_{{{subsystem}}}$")
    ax.set_ylabel(Rf"$A^{{{subsystem}}}$ components")
    ax.legend()
fig.tight_layout()
plt.show()

Unitarity is preserved in each of the subsystems (assuming fixed $\phi,\theta$), because we assume there is only one resonance in the subsystem.

In [None]:
fig, axes = plt.subplots(figsize=(12, 4), ncols=3)
for i, A_values in enumerate([A12_values, A23_values, A31_values]):
    ax = axes[i]
    recoil_id = i + 1
    decay_products = sorted({1, 2, 3} - {recoil_id})
    subsystem = "".join(map(str, decay_products))
    ax.set_xlabel(Rf"$\mathrm{{Re}}\,A^{{{subsystem}}}$")
    ax.set_ylabel(Rf"$\mathrm{{Im}}\,A^{{{subsystem}}}$")
    ax.set_title(Rf"$A^{{{subsystem}}}$ ({angle_projection_text})")
    ax.plot(A_values.real, A_values.imag)
fig.suptitle("Argand diagrams")
fig.tight_layout()
plt.show(fig)

### Dalitz Plot

In [None]:
sliders = {}
categorized_sliders_m = defaultdict(list)
categorized_sliders_gamma = defaultdict(list)
categorized_cphi_pair = defaultdict(list)

for symbol, value in parameters_default.items():
    if symbol.name.startswith(R"\Gamma_{"):
        slider = w.FloatSlider(
            description=Rf"\({sp.latex(symbol)}\)",
            min=0.0,
            max=1.0,
            step=0.01,
            value=value,
            continuous_update=False,
        )
        sliders[symbol.name] = slider
        if symbol.name.startswith(R"\Gamma_{N"):
            categorized_sliders_gamma[0].append(slider)
        elif symbol.name.startswith(R"\Gamma_{\D"):
            categorized_sliders_gamma[1].append(slider)
        elif symbol.name.startswith(R"\Gamma_{a"):
            categorized_sliders_gamma[2].append(slider)

    elif symbol.name.startswith("m_{"):
        slider = w.FloatSlider(
            description=Rf"\({sp.latex(symbol)}\)",
            min=0.63,
            max=4,
            step=0.01,
            value=value,
            continuous_update=False,
        )
        sliders[symbol.name] = slider
        if symbol.name.startswith("m_{N"):
            categorized_sliders_m[0].append(slider)
        elif symbol.name.startswith(R"m_{\D"):
            categorized_sliders_m[1].append(slider)
        elif symbol.name.startswith("m_{a"):
            categorized_sliders_m[2].append(slider)

    else:
        c_latex = sp.latex(symbol)
        phi_latex = Rf"\phi_{{{c_latex}}}"
        slider_c = w.FloatSlider(
            description=Rf"\({c_latex}\)",
            min=0,
            max=10,
            step=0.01,
            value=abs(value),
            continuous_update=False,
        )
        slider_phi = w.FloatSlider(
            description=Rf"\({phi_latex}\)",
            min=-np.pi,
            max=+np.pi,
            step=0.01,
            value=np.angle(value),
            continuous_update=False,
        )

        sliders[symbol.name] = slider_c
        sliders[f"phi_{symbol.name}"] = slider_phi
        cphi_hbox = w.HBox([slider_c, slider_phi])
        if symbol.base is a:
            categorized_cphi_pair[2].append(cphi_hbox)
        elif symbol.base is b:
            categorized_cphi_pair[1].append(cphi_hbox)
        elif symbol.base is c:
            categorized_cphi_pair[0].append(cphi_hbox)
        else:
            raise NotImplementedError(symbol.name)

tab_contents = []
resonances_name = ["N*", "Δ*", "a₂*"]
for i in range(len(resonances_name)):
    tab_content = w.VBox([
        w.HBox(categorized_sliders_m[i] + categorized_sliders_gamma[i]),
        w.VBox(categorized_cphi_pair[i]),
    ])
    tab_contents.append(tab_content)
UI = w.Tab(tab_contents, titles=resonances_name)
UI

In [None]:
unfolded_expression = intensity_expr.doit().expand(func=True)
intensity_func = create_parametrized_function(
    expression=unfolded_expression,
    parameters=parameters_default,
    backend="jax",
)

In [None]:
intensities = intensity_func(phsp)

In [None]:
def insert_phi(parameters: dict) -> dict:
    updated_parameters = {}
    for key, value in parameters.items():
        if key.startswith("phi_"):
            continue
        if key.startswith(("a", "b", "c")):
            phi_key = f"phi_{key}"
            if phi_key in parameters:
                phi = parameters[phi_key]
                value *= np.exp(1j * phi)  # noqa:PLW2901
        updated_parameters[key] = value

    return updated_parameters

In [None]:
%matplotlib widget
%config InlineBackend.figure_formats = ['png']
fig_2d, ax_2d = plt.subplots(dpi=200)
ax_2d.set_title("Model-weighted Phase space Dalitz Plot")
ax_2d.set_xlabel(R"$m^2(\eta \pi^0)\;\left[\mathrm{GeV}^2\right]$")
ax_2d.set_ylabel(R"$m^2(\pi^0 p)\;\left[\mathrm{GeV}^2\right]$")

mesh = None


def update_histogram(**parameters):
    global mesh
    parameters = insert_phi(parameters)
    intensity_func.update_parameters(parameters)
    intensity_weights = intensity_func(phsp)
    bin_values, xedges, yedges = jnp.histogram2d(
        phsp["s_{12}"],
        phsp["s_{23}"],
        bins=200,
        weights=intensity_weights,
        density=True,
    )
    bin_values = jnp.where(bin_values < 1e-6, jnp.nan, bin_values)
    x, y = jnp.meshgrid(xedges[:-1], yedges[:-1])
    if mesh is None:
        mesh = ax_2d.pcolormesh(x, y, bin_values.T, cmap="jet", vmax=0.15)
    else:
        mesh.set_array(bin_values.T)
    fig_2d.canvas.draw_idle()


interactive_plot = w.interactive_output(update_histogram, sliders)
fig_2d.tight_layout()
fig_2d.colorbar(mesh, ax=ax_2d)

if STATIC_PAGE:
    filename = "dalitz-plot.png"
    fig_2d.savefig(filename)
    plt.close(fig_2d)
    display(UI, Image(filename))
else:
    display(UI, interactive_plot)

### Projection

In [None]:
%config InlineBackend.figure_formats = ['svg']

theta_subtitles = [
    R"$\cos (\theta_{1}^{{12}}) \equiv \cos (\theta_{\eta}^{{\eta \pi^0}})$",
    R"$\cos (\theta_{2}^{{23}}) \equiv \cos (\theta_{\pi^0}^{{\pi^0 p}})$",
    R"$\cos (\theta_{3}^{{31}}) \equiv \cos (\theta_{p}^{{p \eta}})$",
]
phi_subtitles = [
    R"$\phi_1^{12} \equiv \phi_{\eta}^{{\eta \pi^0}}$",
    R"$\phi_2^{23} \equiv \phi_{\pi^0}^{{\pi^0 p}}$",
    R"$\phi_3^{31} \equiv \phi_{p}^{{p \eta}}$",
]
mass_subtitles = [
    R"$m_{12} \equiv m_{\eta \pi^0}$",
    R"$m_{23} \equiv m_{\pi^0 p}$",
    R"$m_{31} \equiv m_{p \eta}$",
]

fig, (theta_ax, phi_ax, mass_ax) = plt.subplots(figsize=(12, 8), ncols=3, nrows=3)
for i, ax1 in enumerate(theta_ax, 1):
    ax1.set_title(theta_subtitles[i - 1])
    ax1.set_xticks([-1, 0, 1])

for i, ax2 in enumerate(phi_ax, 1):
    ax2.set_title(phi_subtitles[i - 1])
    ax2.set_xticks([-np.pi, 0, np.pi])
    ax2.set_xticklabels([R"-$\pi$", 0, R"$\pi$"])

for i, ax3 in enumerate(mass_ax, 1):
    ax3.set_title(mass_subtitles[i - 1])

plot_style = dict(bins=100, weights=intensities, density=True)

theta_ax[0].hist(np.cos(phsp["theta_1"]), **plot_style)
theta_ax[1].hist(np.cos(phsp["theta_2"]), **plot_style)
theta_ax[2].hist(np.cos(phsp["theta_3"]), **plot_style)
phi_ax[0].hist(phsp["phi_1"], **plot_style)
phi_ax[1].hist(phsp["phi_2"], **plot_style)
phi_ax[2].hist(phsp["phi_3"], **plot_style)
mass_ax[0].hist(np.sqrt(phsp["s_{12}"]), **plot_style)
mass_ax[1].hist(np.sqrt(phsp["s_{23}"]), **plot_style)
mass_ax[2].hist(np.sqrt(phsp["s_{31}"]), **plot_style)

fig.tight_layout()
plt.show()