# Polarization sensitivity

In [None]:
%%capture
%run ./phase-space.ipynb

```{autolink-concat}
```

In [None]:
# pyright: reportUndefinedVariable=false
from __future__ import annotations

import re
from typing import Pattern

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from ampform.sympy import PoolSum
from IPython.display import Markdown, Math
from matplotlib import cm
from matplotlib.colors import LogNorm
from sympy.physics.matrices import msigma
from tensorwaves.function import ParametrizedBackendFunction
from tensorwaves.function.sympy import create_parametrized_function
from tqdm.notebook import tqdm

from polarization import formulate_polarization
from polarization.io import perform_cached_doit

## SymPy expressions

In [None]:
%%time
polarization_exprs = formulate_polarization(amplitude_builder)
unfolded_polarization_exprs = [
    perform_cached_doit(expr.doit().xreplace(model.amplitudes))
    for expr in tqdm(polarization_exprs, desc="Unfolding polarization expressions")
]
unfolded_intensity_expr = perform_cached_doit(model.full_expression)

## Definition of free parameters

In [None]:
free_parameters = {
    symbol: value
    for symbol, value in model.parameter_defaults.items()
    if symbol in production_couplings
}
fixed_parameters = {
    symbol: value
    for symbol, value in model.parameter_defaults.items()
    if symbol not in free_parameters
}
fixed_parameters.update(masses)

In [None]:
subs_polarization_exprs = [
    expr.xreplace(fixed_parameters) for expr in unfolded_polarization_exprs
]
subs_intensity_expr = unfolded_intensity_expr.xreplace(fixed_parameters)

In [None]:
src = "Number of operations after substituting non-production-couplings:\n"
for xyz, (expr, subs_expr) in enumerate(
    zip(unfolded_polarization_exprs, subs_polarization_exprs)
):
    old_n_ops = sp.count_ops(expr)
    new_n_ops = sp.count_ops(subs_expr)
    src += Rf"- $\alpha_{'xyz'[xyz]}$: from {old_n_ops:,} to {new_n_ops:,}" + "\n"
src += (
    Rf"- $I_\mathrm{{tot}}$: from {sp.count_ops(unfolded_intensity_expr):,} to"
    rf" {sp.count_ops(subs_intensity_expr):,}"
)
Markdown(src)

## Polarization distributions

In [None]:
polarization_funcs = [
    create_parametrized_function(
        subs_polarization_exprs[xyz],
        parameters=free_parameters,
        backend="jax",
    )
    for xyz in tqdm(range(3))
]
intensity_func = create_parametrized_function(
    subs_intensity_expr,
    parameters=free_parameters,
    backend="jax",
)

In [None]:
def compute_sub_func(
    func: ParametrizedBackendFunction, input_data, non_zero_couplings: list[str]
) -> None:
    old_parameters = dict(func.parameters)
    pattern = rf"\\mathcal{{H}}.*\[(?!{'|'.join(non_zero_couplings)})"
    set_parameter_to_zero(func, pattern)
    array = func(input_data)
    func.update_parameters(old_parameters)
    return array


def set_parameter_to_zero(
    func: ParametrizedBackendFunction, search_term: Pattern
) -> None:
    new_parameters = dict(func.parameters)
    no_parameters_selected = True
    for par_name in func.parameters:
        if re.match(search_term, par_name) is not None:
            new_parameters[par_name] = 0
            no_parameters_selected = False
    if no_parameters_selected:
        logging.warning(f"All couplings were set to zero for search term {search_term}")
    func.update_parameters(new_parameters)


def render_mean(array, weights=None, plus=True):
    array = array.real
    if weights is None:
        mean = jnp.nanmean(array)
        std = jnp.nanstd(array)
    else:
        mean = compute_weighted_average(array, weights)
        variance = compute_weighted_average((array - mean) ** 2, weights)
        std = jnp.sqrt(variance)
    mean = f"{mean:.3f}"
    std = f"{std:.3f}"
    if plus and float(mean) > 0:
        mean = f"+{mean}"
    return Rf"{mean} \pm {std}"


def compute_weighted_average(values, weights):
    return jnp.nansum(values * weights) / jnp.nansum(weights)


def create_average_polarization_table(average_with_intensity: bool) -> None:
    latex = R"\begin{array}{cccc}" + "\n"
    latex += R"& \bar{|\alpha|} & \bar\alpha_x & \bar\alpha_y & \bar\alpha_z \\" + "\n"
    if average_with_intensity:
        weights = intensity_func(data_sample)
    else:
        weights = None
    for label, identifier in zip(subsystem_labels, subsystem_identifiers):
        latex += f"  {label} & "
        x, y, z = (
            compute_sub_func(polarization_funcs[xyz], data_sample, [identifier])
            for xyz in range(3)
        )
        alpha_abs = jnp.sqrt(x**2 + y**2 + z**2)
        latex += render_mean(alpha_abs, weights, plus=False) + " & "
        latex += " & ".join(render_mean(i, weights) for i in [x, y, z])
        latex += R" \\" + "\n"
    latex += R"\end{array}"
    display(Math(latex))


X, Y, data_sample, phsp_filter = generate_uniform_phsp(resolution=400)

subsystem_identifiers = ["K", "L", "D"]
subsystem_labels = ["K^{**}", R"\Lambda^{**}", R"\Delta^{**}"]
create_average_polarization_table(average_with_intensity=False)
display(Markdown("Average with intensity weights:"))
create_average_polarization_table(average_with_intensity=True)

In [None]:
%config InlineBackend.figure_formats = ['png']

In [None]:
%%time
nrows = 4
ncols = 5
scale = 3.0
aspect_ratio = 1.15
fig, axes = plt.subplots(
    figsize=scale * np.array([ncols, aspect_ratio * nrows]),
    ncols=ncols,
    nrows=nrows,
    sharex=True,
    sharey=True,
    gridspec_kw={"width_ratios": (ncols - 1) * [1] + [1.24]},
    tight_layout=True,
)

s1_label = R"$\sigma_1=m^2\left(K\pi\right)$"
s2_label = R"$\sigma_2=m^2\left(pK\right)$"
for subsystem in range(nrows):
    for i in range(ncols):
        ax = axes[subsystem, i]
        if i == 0:
            alpha_str = R"I_\mathrm{tot}"
        elif i == 1:
            alpha_str = R"|\alpha|"
        else:
            xyz = i - 2
            alpha_str = Rf"\alpha_{'xyz'[xyz]}"
        title = alpha_str
        if subsystem > 0:
            label = subsystem_labels[subsystem - 1]
            title = Rf"{title}\left({label}\right)"
        ax.set_title(f"${title}$")
        if ax is axes[-1, i]:
            ax.set_xlabel(s1_label)
        if i == 0:
            ax.set_ylabel(s2_label)

intensity_arrays = []
polarization_arrays = []
for subsystem in range(nrows):
    # alpha_xyz distributions
    alpha_xyz_arrays = []
    for i in range(2, ncols):
        xyz = i - 2
        if subsystem == 0:
            z_values = polarization_funcs[xyz](data_sample)
            polarization_arrays.append(z_values)
        else:
            identifier = subsystem_identifiers[subsystem - 1]
            z_values = compute_sub_func(
                polarization_funcs[xyz], data_sample, identifier
            )
        z_values = np.real(z_values)
        alpha_xyz_arrays.append(z_values)
        mesh = axes[subsystem, i].pcolormesh(X, Y, z_values, cmap=cm.coolwarm)
        if xyz == 2:
            fig.colorbar(mesh, ax=axes[subsystem, i])
        mesh.set_clim(vmin=-1, vmax=+1)
    # absolute value of alpha_xyz vector
    alpha_abs = np.sqrt(np.sum(np.array(alpha_xyz_arrays) ** 2, axis=0))
    mesh = axes[subsystem, 1].pcolormesh(X, Y, alpha_abs, cmap=cm.coolwarm)
    mesh.set_clim(vmin=-1, vmax=+1)
    # total intensity
    if subsystem == 0:
        z_values = intensity_func(data_sample)
    else:
        identifier = subsystem_identifiers[subsystem - 1]
        z_values = compute_sub_func(intensity_func, data_sample, identifier)
    intensity_arrays.append(z_values)
    axes[subsystem, 0].pcolormesh(X, Y, z_values, norm=LogNorm())
plt.show()

In [None]:
fig, axes = plt.subplots(
    figsize=(13, 5),
    ncols=3,
    gridspec_kw={"width_ratios": [1, 1, 1.2]},
    sharey=True,
    tight_layout=True,
)
axes[0].set_ylabel(s2_label)
I_times_alpha = jnp.array(
    [array * intensity_arrays[0] for array in polarization_arrays]
)
global_min_max = float(jnp.nanmax(jnp.abs(I_times_alpha)))
for ax, z_values, xyz in zip(axes, I_times_alpha, "xyz"):
    ax.set_title(Rf"$\alpha_{xyz} \cdot I$")
    ax.set_xlabel(s1_label)
    mesh = ax.pcolormesh(X, Y, np.real(z_values), cmap=cm.RdYlGn_r)
    mesh.set_clim(vmin=-global_min_max, vmax=global_min_max)
color_bar = fig.colorbar(mesh, ax=ax, pad=0.02)
plt.show()

In [None]:
%config InlineBackend.figure_formats = ['svg']

In [None]:
fig, ax = plt.subplots(figsize=(9, 8), tight_layout=True)
ax.set_title(R"Total polarization field $\vec{\alpha}$")
ax.set_xlabel(s1_label + R",$\quad\vec{\alpha}_x$")
ax.set_ylabel(s2_label + R",$\quad\vec{\alpha}_y$")
strides = 10
mesh = ax.quiver(
    X[::strides, ::strides],
    Y[::strides, ::strides],
    np.real(polarization_arrays[0][::strides, ::strides]),
    np.real(polarization_arrays[1][::strides, ::strides]),
    np.real(polarization_arrays[2][::strides, ::strides]),
    cmap=cm.coolwarm,
)
mesh.set_clim(vmin=-1, vmax=+1)
color_bar = fig.colorbar(mesh, ax=ax, pad=0.01)
color_bar.ax.set_yticks([-1, 0, 1])
color_bar.ax.set_yticklabels(["-1", "0", "+1"])
color_bar.set_label(R"$\vec{\alpha}_z$")
plt.show()

## Benchmarking

:::{tip}
Compare with Julia results as reported in [redeboer/polarization-sensitivity#27](https://github.com/redeboer/polarization-sensitivity/issues/27).
:::

In [None]:
phsp = generate_phasespace_sample(100_000, seed=0)
%timeit transformer(phsp)
phsp = transformer(phsp)

In [None]:
random_point = {k: v[0] if len(v.shape) > 0 else v for k, v in phsp.items()}
%timeit intensity_func(random_point)

In [None]:
%%timeit
polarization_funcs[0](random_point)
polarization_funcs[1](random_point)
polarization_funcs[2](random_point)

In [None]:
X54, Y54, data_sample54, _ = generate_uniform_phsp(resolution=54)
%timeit intensity_func(data_sample54)

In [None]:
%%timeit
polarization_funcs[0](data_sample54)
polarization_funcs[1](data_sample54)
polarization_funcs[2](data_sample54)

In [None]:
%config InlineBackend.figure_formats = ['png']

In [None]:
intensity54 = intensity_func(data_sample54)
fig, ax = plt.subplots(figsize=(4, 4))
ax.pcolormesh(X54, Y54, intensity54)
plt.show()

In [None]:
%%timeit
polarization_funcs[0](data_sample54)
polarization_funcs[1](data_sample54)
polarization_funcs[2](data_sample54)

In [None]:
%%timeit
intensity_func(phsp)

In [None]:
%%timeit
polarization_funcs[0](phsp)
polarization_funcs[1](phsp)
polarization_funcs[2](phsp)