# Phase space sample

In [None]:
%%capture
%run ./amplitude-model.ipynb

```{autolink-concat}
```

In [None]:
# pyright: reportUndefinedVariable=false
import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from tensorwaves.data.transform import SympyDataTransformer
from tensorwaves.function.sympy import create_function

from polarization.dynamics import Källén
from polarization.io import display_latex

In [None]:
σ3_expr = m0**2 + m1**2 + m2**2 + m3**2 - σ1 - σ2
display_latex({σ3: σ3_expr})

In [None]:
def kibble_function(σ1, σ2):
    return Källén(
        Källén(σ2, m2**2, m0**2),
        Källén(σ3, m3**2, m0**2),
        Källén(σ1, m1**2, m0**2),
    )


def is_within_phsp(σ1, σ2, non_phsp_value=sp.nan):
    return sp.Piecewise(
        (1, sp.LessThan(kibble_function(σ1, σ2), 0)),
        (non_phsp_value, True),
    )


is_within_phsp(σ1, σ2)

In [None]:
in_phsp_expr = is_within_phsp(σ1, σ2).subs(σ3, σ3_expr).subs(masses).doit()
assert in_phsp_expr.free_symbols == {σ1, σ2}
in_phsp = create_function(in_phsp_expr, backend="numpy")

In [None]:
m0_val, m1_val, m2_val, m3_val = masses.values()
σ1_min = (m2_val + m3_val) ** 2
σ1_max = (m0_val - m1_val) ** 2
σ2_min = (m1_val + m3_val) ** 2
σ2_max = (m0_val - m2_val) ** 2

In [None]:
kinematic_variables = {
    symbol: expression.doit().subs(masses) for symbol, expression in angles.items()
}
kinematic_variables.update({s: s for s in [σ1, σ2, σ3]})  # include identity
transformer = SympyDataTransformer.from_sympy(kinematic_variables, backend="jax")

In [None]:
def generate_uniform_phsp(resolution: int):
    x = np.linspace(σ1_min, σ1_max, num=resolution)
    y = np.linspace(σ2_min, σ2_max, num=resolution)
    compute_third_mandelstam = create_function(σ3_expr.subs(masses), backend="jax")
    X, Y = np.meshgrid(x, y)
    Z = compute_third_mandelstam.function(X, Y)
    σ_arrays = {"sigma1": X, "sigma2": Y, "sigma3": Z}
    data = transformer(σ_arrays)
    phsp = in_phsp(σ_arrays)
    return X, Y, data, phsp

In [None]:
%config InlineBackend.figure_formats = ['svg']

In [None]:
def __plot_phsp():
    X, Y, _, phsp = generate_uniform_phsp(resolution=500)
    phsp = np.nan_to_num(phsp)
    _, ax = plt.subplots(figsize=(4, 4))
    ax.set_xlabel(R"$\sigma_1$")
    ax.set_ylabel(R"$\sigma_2$")
    ax.set_xticks([])
    ax.set_yticks([])
    mesh = ax.contour(X, Y, phsp, colors="black")
    contour = mesh.collections[0]
    contour.set_facecolor("lightgray")
    plt.show()


__plot_phsp()