# Phase space sample

In [None]:
%%capture
%run ./amplitude-model.ipynb

```{autolink-concat}
```

In [None]:
# pyright: reportUndefinedVariable=false
import logging

import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from ampform.kinematics.phasespace import Kallen, is_within_phasespace
from tensorwaves.data import (
    IntensityDistributionGenerator,
    NumpyDomainGenerator,
    NumpyUniformRNG,
)
from tensorwaves.data.transform import SympyDataTransformer
from tensorwaves.function.sympy import create_function

from polarization.io import display_latex

JAX_LOGGER = logging.getLogger("absl")
JAX_LOGGER.setLevel(logging.ERROR)

In [None]:
m0, m1, m2, m3 = sp.symbols("m:4", nonnegative=True)
σ1, σ2, σ3 = sp.symbols("sigma1:4", nonnegative=True)
σ3_expr = m0**2 + m1**2 + m2**2 + m3**2 - σ1 - σ2
display_latex({σ3: σ3_expr})

In [None]:
is_within_phasespace(σ1, σ2, m0, m1, m2, m3)

In [None]:
masses = {
    s: model.parameter_defaults[s]
    for s in sorted(model.parameter_defaults, key=str)
    if s in {m0, m1, m2, m3}
}
display_latex(masses)

In [None]:
in_phsp_expr = is_within_phasespace(σ1, σ2, m0, m1, m2, m3).doit()
in_phsp_expr = in_phsp_expr.subs(σ3, σ3_expr).subs(masses)
assert in_phsp_expr.free_symbols == {σ1, σ2}, in_phsp_expr.free_symbols
in_phsp = create_function(in_phsp_expr, backend="numpy")

In [None]:
m0_val, m1_val, m2_val, m3_val = masses.values()
σ1_min = (m2_val + m3_val) ** 2
σ1_max = (m0_val - m1_val) ** 2
σ2_min = (m1_val + m3_val) ** 2
σ2_max = (m0_val - m2_val) ** 2

In [None]:
kinematic_variables = {
    symbol: expression.doit().subs(masses)
    for symbol, expression in model.variables.items()
}
kinematic_variables.update({s: s for s in [σ1, σ2, σ3]})  # include identity
transformer = SympyDataTransformer.from_sympy(kinematic_variables, backend="jax")

In [None]:
def generate_uniform_phsp(resolution: int):
    x = np.linspace(σ1_min, σ1_max, num=resolution)
    y = np.linspace(σ2_min, σ2_max, num=resolution)
    compute_third_mandelstam = create_function(σ3_expr.subs(masses), backend="jax")
    X, Y = np.meshgrid(x, y)
    Z = compute_third_mandelstam.function(X, Y)
    σ_arrays = {"sigma1": X, "sigma2": Y, "sigma3": Z}
    data = transformer(σ_arrays)
    phsp = in_phsp(σ_arrays)
    return X, Y, data, phsp

In [None]:
def generate_phasespace_sample(n_events: int, seed=None):
    in_phsp_expr = (
        is_within_phasespace(σ1, σ2, m0, m1, m2, m3, outside_value=0)
        .doit()
        .subs(σ3, σ3_expr)
        .subs(masses)
    )
    phsp_filter = create_function(in_phsp_expr, backend="numpy")
    rng = NumpyUniformRNG(seed)
    domain_generator = NumpyDomainGenerator(
        boundaries={
            "sigma1": (σ1_min, σ1_max),
            "sigma2": (σ2_min, σ2_max),
        }
    )
    phsp_generator = IntensityDistributionGenerator(domain_generator, phsp_filter)
    phsp = phsp_generator.generate(n_events, rng)
    compute_third_mandelstam = create_function(σ3_expr.subs(masses), backend="numpy")
    phsp["sigma3"] = compute_third_mandelstam(phsp)
    return phsp

In [None]:
%config InlineBackend.figure_formats = ['svg']

In [None]:
def __plot_phsp():
    X, Y, _, phsp = generate_uniform_phsp(resolution=500)
    phsp = np.nan_to_num(phsp)
    _, ax = plt.subplots(figsize=(4, 4))
    ax.set_xlabel(R"$\sigma_1$")
    ax.set_ylabel(R"$\sigma_2$")
    ax.set_xticks([])
    ax.set_yticks([])
    mesh = ax.contour(X, Y, phsp, colors="black")
    contour = mesh.collections[0]
    contour.set_facecolor("lightgray")
    plt.show()


__plot_phsp()