# Polarization sensitivity

In [None]:
%%capture
%run ./phase-space.ipynb

```{autolink-concat}
```

In [None]:
# pyright: reportUndefinedVariable=false
from __future__ import annotations

import re
from typing import Pattern

import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from ampform.sympy import PoolSum
from IPython.display import Markdown, Math
from matplotlib import cm
from matplotlib.colors import LogNorm
from sympy.physics.matrices import msigma
from tensorwaves.function import ParametrizedBackendFunction
from tensorwaves.function.sympy import create_parametrized_function
from tqdm.notebook import tqdm

## SymPy expressions

In [None]:
def to_index(helicity):
    """Symbolic conversion of half-value helicities to Pauli matrix indices."""
    return sp.Piecewise(
        (1, sp.LessThan(helicity, 0)),
        (0, True),
    )


ν_prime = sp.Symbol(R"\nu^{\prime}")
polarization_exprs = [
    PoolSum(
        formulate_aligned_amplitude(ν, λ).conjugate()
        * pauli_matrix[to_index(ν), to_index(ν_prime)]
        * formulate_aligned_amplitude(ν_prime, λ),
        (λ, [-half, +half]),
        (ν, [-half, +half]),
        (ν_prime, [-half, +half]),
    )
    / intensity_expr
    for pauli_matrix in map(msigma, [1, 2, 3])
]

In [None]:
%%time
unfolded_polarization_exprs = [
    expr.doit().xreplace(amp_definitions).doit()
    for expr in tqdm(polarization_exprs, desc="Unfolding polarization expressions")
]
unfolded_intensity_expr = intensity_expr.doit().xreplace(amp_definitions).doit()

## Definition of free parameters

In [None]:
free_parameters = {
    symbol: value
    for symbol, value in parameter_defaults.items()
    if symbol in prod_couplings
}
fixed_parameters = {
    symbol: value
    for symbol, value in parameter_defaults.items()
    if symbol not in free_parameters
}
fixed_parameters.update(masses)

In [None]:
subs_polarization_exprs = [
    expr.xreplace(fixed_parameters) for expr in unfolded_polarization_exprs
]
subs_intensity_expr = unfolded_intensity_expr.xreplace(fixed_parameters)

In [None]:
src = "Number of operations after substituting non-production-couplings:\n"
for xyz, (expr, subs_expr) in enumerate(
    zip(unfolded_polarization_exprs, subs_polarization_exprs)
):
    old_n_ops = sp.count_ops(expr)
    new_n_ops = sp.count_ops(subs_expr)
    src += Rf"- $\alpha_{'xyz'[xyz]}$: from {old_n_ops:,} to {new_n_ops:,}" + "\n"
src += (
    Rf"- $I_\mathrm{{tot}}$: from {sp.count_ops(unfolded_intensity_expr):,} to"
    rf" {sp.count_ops(subs_intensity_expr):,}"
)
Markdown(src)

## Polarization distributions

In [None]:
polarization_funcs = [
    create_parametrized_function(
        subs_polarization_exprs[xyz],
        parameters=free_parameters,
        backend="jax",
    )
    for xyz in tqdm(range(3))
]
intensity_func = create_parametrized_function(
    subs_intensity_expr,
    parameters=free_parameters,
    backend="jax",
)

In [None]:
def compute_sub_func(
    func: ParametrizedBackendFunction, input_data, non_zero_couplings: list[str]
) -> None:
    old_parameters = dict(func.parameters)
    pattern = rf"\\mathcal{{H}}.*\[(?!{'|'.join(non_zero_couplings)})"
    set_parameter_to_zero(func, pattern)
    array = func(input_data)
    func.update_parameters(old_parameters)
    return array


def set_parameter_to_zero(
    func: ParametrizedBackendFunction, search_term: Pattern
) -> None:
    new_parameters = dict(func.parameters)
    no_parameters_selected = True
    for par_name in func.parameters:
        if re.match(search_term, par_name) is not None:
            new_parameters[par_name] = 0
            no_parameters_selected = False
    if no_parameters_selected:
        logging.warning(f"All couplings were set to zero for search term {search_term}")
    func.update_parameters(new_parameters)


def render_mean(array, plus=True):
    array = array.real
    mean = f"{np.nanmean(array):.3f}"
    std = f"{np.nanstd(array):.3f}"
    if plus and float(mean) > 0:
        mean = f"+{mean}"
    return Rf"{mean} \pm {std}"


X, Y, data_sample, phsp_filter = generate_uniform_phsp(resolution=400)

subsystem_identifiers = ["K", "L", "D"]
subsystem_labels = ["K^{**}", R"\Lambda^{**}", R"\Delta^{**}"]
latex = R"\begin{array}{cccc}" + "\n"
latex += R"& \bar{|\alpha|} & \bar\alpha_x & \bar\alpha_y & \bar\alpha_z \\" + "\n"
for label, identifier in zip(subsystem_labels, subsystem_identifiers):
    latex += f"  {label} & "
    x, y, z = (
        compute_sub_func(polarization_funcs[xyz], data_sample, [identifier])
        for xyz in range(3)
    )
    latex += render_mean(np.sqrt(x**2 + y**2 + z**2), plus=False) + " & "
    latex += " & ".join(map(render_mean, [x, y, z]))
    latex += R" \\" + "\n"
latex += R"\end{array}"
Math(latex)

In [None]:
%config InlineBackend.figure_formats = ['png']

In [None]:
%%time
nrows = 4
ncols = 5
scale = 3.0
aspect_ratio = 1.15
fig, axes = plt.subplots(
    figsize=scale * np.array([ncols, aspect_ratio * nrows]),
    ncols=ncols,
    nrows=nrows,
    sharex=True,
    sharey=True,
    gridspec_kw={"width_ratios": (ncols - 1) * [1] + [1.24]},
    tight_layout=True,
)

s1_label = R"$\sigma_1=m^2\left(K\pi\right)$"
s2_label = R"$\sigma_2=m^2\left(pK\right)$"
for subsystem in range(nrows):
    for i in range(ncols):
        ax = axes[subsystem, i]
        if i == 0:
            alpha_str = R"I_\mathrm{tot}"
        elif i == 1:
            alpha_str = R"|\alpha|"
        else:
            xyz = i - 2
            alpha_str = Rf"\alpha_{'xyz'[xyz]}"
        title = alpha_str
        if subsystem > 0:
            label = subsystem_labels[subsystem - 1]
            title = Rf"{title}\left({label}\right)"
        ax.set_title(f"${title}$")
        if ax is axes[-1, i]:
            ax.set_xlabel(s1_label)
        if i == 0:
            ax.set_ylabel(s2_label)

for subsystem in range(nrows):
    # alpha_xyz distributions
    alpha_xyz_arrays = []
    for i in range(2, ncols):
        xyz = i - 2
        if subsystem == 0:
            z_values = polarization_funcs[xyz](data_sample)
        else:
            identifier = subsystem_identifiers[subsystem - 1]
            z_values = compute_sub_func(
                polarization_funcs[xyz], data_sample, identifier
            )
        z_values = np.real(z_values)
        alpha_xyz_arrays.append(z_values)
        mesh = axes[subsystem, i].pcolormesh(X, Y, z_values, cmap=cm.coolwarm)
        if xyz == 2:
            fig.colorbar(mesh, ax=axes[subsystem, i])
        mesh.set_clim(vmin=-1, vmax=+1)
    # absolute value of alpha_xyz vector
    alpha_abs = np.sqrt(np.sum(np.array(alpha_xyz_arrays) ** 2, axis=0))
    mesh = axes[subsystem, 1].pcolormesh(X, Y, alpha_abs, cmap=cm.coolwarm)
    mesh.set_clim(vmin=-1, vmax=+1)
    # total intensity
    if subsystem == 0:
        z_values = intensity_func(data_sample)
    else:
        identifier = subsystem_identifiers[subsystem - 1]
        z_values = compute_sub_func(intensity_func, data_sample, identifier)
    axes[subsystem, 0].pcolormesh(X, Y, z_values, norm=LogNorm())
plt.show()