# Polarization sensitivity

```{autolink-concat}
```

In [None]:
from __future__ import annotations

import json
import math
import os
from functools import reduce

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sympy as sp
from IPython.display import Markdown, Math, display
from matplotlib import cm
from matplotlib.axes import Axes
from matplotlib.collections import LineCollection
from matplotlib.colors import LogNorm
from tensorwaves.function.sympy import create_parametrized_function
from tensorwaves.interface import DataSample
from tqdm.auto import tqdm

from polarization import formulate_polarization
from polarization.amplitude import DalitzPlotDecompositionBuilder
from polarization.data import (
    create_data_transformer,
    generate_meshgrid_sample,
    generate_phasespace_sample,
)
from polarization.function import compute_sub_function
from polarization.io import mute_jax_warnings, perform_cached_doit
from polarization.lhcb import _load_model_parameters, load_three_body_decays
from polarization.plot import get_contour_line, stylize_contour

mute_jax_warnings()

reference_subsystem = 1
dynamics_configurator = load_three_body_decays("../data/isobars.json")
decay = dynamics_configurator.decay
amplitude_builder = DalitzPlotDecompositionBuilder(decay)
amplitude_builder.dynamics_choices = dynamics_configurator
model = amplitude_builder.formulate(reference_subsystem)
imported_parameter_values = _load_model_parameters(
    "../data/modelparameters.json", decay
)
model.parameter_defaults.update(imported_parameter_values)

## SymPy expressions

In [None]:
%%time
polarization_exprs = formulate_polarization(amplitude_builder, reference_subsystem)
unfolded_polarization_exprs = [
    perform_cached_doit(expr.doit().xreplace(model.amplitudes))
    for expr in tqdm(polarization_exprs, desc="Unfolding polarization expressions")
]
unfolded_intensity_expr = perform_cached_doit(model.full_expression)

## Definition of free parameters

In [None]:
production_couplings = {
    symbol: value
    for symbol, value in model.parameter_defaults.items()
    if isinstance(symbol, sp.Indexed)
    if "production" in str(symbol)
}
fixed_parameters = {
    symbol: value
    for symbol, value in model.parameter_defaults.items()
    if symbol not in production_couplings
}

In [None]:
subs_polarization_exprs = [
    expr.xreplace(fixed_parameters) for expr in unfolded_polarization_exprs
]
subs_intensity_expr = unfolded_intensity_expr.xreplace(fixed_parameters)

In [None]:
src = "Number of operations after substituting non-production-couplings:\n"
for xyz, (expr, subs_expr) in enumerate(
    zip(unfolded_polarization_exprs, subs_polarization_exprs)
):
    old_n_ops = sp.count_ops(expr)
    new_n_ops = sp.count_ops(subs_expr)
    src += Rf"- $\alpha_{'xyz'[xyz]}$: from {old_n_ops:,} to {new_n_ops:,}" + "\n"
src += (
    Rf"- $I_\mathrm{{tot}}$: from {sp.count_ops(unfolded_intensity_expr):,} to"
    rf" {sp.count_ops(subs_intensity_expr):,}"
)
Markdown(src)

## Polarization distributions

In [None]:
polarization_funcs = [
    create_parametrized_function(
        subs_polarization_exprs[xyz],
        parameters=production_couplings,
        backend="jax",
    )
    for xyz in tqdm(range(3))
]
intensity_func = create_parametrized_function(
    subs_intensity_expr,
    parameters=production_couplings,
    backend="jax",
)

In [None]:
def render_mean(array, weights, plus=True):
    array = array.real
    mean = compute_weighted_average(array, weights)
    variance = compute_weighted_average((array - mean) ** 2, weights)
    std = jnp.sqrt(variance)
    mean = f"{mean:.3f}"
    std = f"{std:.3f}"
    if plus and float(mean) > 0:
        mean = f"+{mean}"
    return Rf"{mean} \pm {std}"


def compute_weighted_average(values, weights):
    return jnp.nansum(values * weights) / jnp.nansum(weights)


def create_average_polarization_table() -> None:
    latex = R"\begin{array}{cccc}" + "\n"
    latex += R"& \bar{|\alpha|} & \bar\alpha_x & \bar\alpha_y & \bar\alpha_z \\" + "\n"
    weights = intensity_func(data_sample)
    for label, identifier in zip(subsystem_labels, subsystem_identifiers):
        latex += f"  {label} & "
        x, y, z = (
            compute_sub_function(polarization_funcs[xyz], data_sample, [identifier])
            for xyz in range(3)
        )
        alpha_abs = jnp.sqrt(x**2 + y**2 + z**2)
        latex += render_mean(alpha_abs, weights, plus=False) + " & "
        latex += " & ".join(render_mean(i, weights) for i in [x, y, z])
        latex += R" \\" + "\n"
    latex += R"\end{array}"
    display(Math(latex))


phsp_sample = generate_meshgrid_sample(decay, resolution=400)
X = phsp_sample["sigma1"]
Y = phsp_sample["sigma2"]
transformer = create_data_transformer(model)
data_sample = transformer(phsp_sample)

subsystem_identifiers = ["K", "L", "D"]
subsystem_labels = ["K^{**}", R"\Lambda^{**}", R"\Delta^{**}"]
create_average_polarization_table()

In [None]:
%config InlineBackend.figure_formats = ['png']

In [None]:
def create_dominant_region_contours(
    decay, data_sample: DataSample, threshold: float
) -> dict[str, jnp.ndarray]:
    I_tot = intensity_func(data_sample)
    resonance_names = [chain.resonance.name for chain in decay.chains]
    region_filters = {}
    progress_bar = tqdm(
        desc="Computing dominant region contours",
        total=len(resonance_names),
    )
    for resonance_name in resonance_names:
        progress_bar.postfix = resonance_name
        regex_filter = resonance_name.replace("(", r"\(").replace(")", r"\)")
        I_sub = compute_sub_function(intensity_func, data_sample, [regex_filter])
        ratio = I_sub / I_tot
        selection = jnp.select(
            [jnp.isnan(ratio), ratio < threshold, True],
            [0, 0, 1],
        )
        progress_bar.update()
        if jnp.all(selection == 0):
            continue
        region_filters[resonance_name] = selection
    contour_arrays = {}
    for contour_level, subsystem in enumerate(["K", "L", "D"], 1):
        contour_array = reduce(
            jnp.bitwise_or,
            (a for k, a in region_filters.items() if k.startswith(subsystem)),
        )
        contour_array *= contour_level
        contour_arrays[subsystem] = contour_array
    return contour_arrays


def indicate_dominant_regions(
    contour_arrays, ax: Axes, selected_subsystems=None
) -> dict[str, LineCollection]:
    if selected_subsystems is None:
        selected_subsystems = {"K", "L", "D"}
    selected_subsystems = set(selected_subsystems)
    colors = dict(K="red", L="blue", D="green")
    labels = dict(K="K^{**}", L=R"\Lambda^{**}", D=R"\Delta^{**}")
    legend_elements = {}
    for subsystem, Z in contour_arrays.items():
        if subsystem not in selected_subsystems:
            continue
        contour_set = ax.contour(X, Y, Z, colors="none")
        stylize_contour(
            contour_set,
            edgecolor=colors[subsystem],
            linewidth=0.5,
            label=f"${labels[subsystem]}$",
        )
        line_collection = get_contour_line(contour_set)
        legend_elements[subsystem] = line_collection
    return legend_elements

In [None]:
%%time
nrows = 4
ncols = 5
scale = 3.0
aspect_ratio = 1.15
fig, axes = plt.subplots(
    figsize=scale * np.array([ncols, aspect_ratio * nrows]),
    ncols=ncols,
    nrows=nrows,
    sharex=True,
    sharey=True,
    gridspec_kw={"width_ratios": (ncols - 1) * [1] + [1.24]},
    tight_layout=True,
)

s1_label = R"$\sigma_1=m^2\left(K\pi\right)$"
s2_label = R"$\sigma_2=m^2\left(pK\right)$"
for subsystem in range(nrows):
    for i in range(ncols):
        ax = axes[subsystem, i]
        if i == 0:
            alpha_str = R"I_\mathrm{tot}"
        elif i == 1:
            alpha_str = R"|\alpha|"
        else:
            xyz = i - 2
            alpha_str = Rf"\alpha_{'xyz'[xyz]}"
        title = alpha_str
        if subsystem > 0:
            label = subsystem_labels[subsystem - 1]
            title = Rf"{title}\left({label}\right)"
        ax.set_title(f"${title}$")
        if ax is axes[-1, i]:
            ax.set_xlabel(s1_label)
        if i == 0:
            ax.set_ylabel(s2_label)

intensity_arrays = []
polarization_arrays = []
for subsystem in range(nrows):
    # alpha_xyz distributions
    alpha_xyz_arrays = []
    for i in range(2, ncols):
        xyz = i - 2
        if subsystem == 0:
            z_values = polarization_funcs[xyz](data_sample)
            polarization_arrays.append(z_values)
        else:
            identifier = subsystem_identifiers[subsystem - 1]
            z_values = compute_sub_function(
                polarization_funcs[xyz], data_sample, identifier
            )
        z_values = np.real(z_values)
        alpha_xyz_arrays.append(z_values)
        mesh = axes[subsystem, i].pcolormesh(X, Y, z_values, cmap=cm.coolwarm)
        if xyz == 2:
            fig.colorbar(mesh, ax=axes[subsystem, i])
        mesh.set_clim(vmin=-1, vmax=+1)
    # absolute value of alpha_xyz vector
    alpha_abs = np.sqrt(np.sum(np.array(alpha_xyz_arrays) ** 2, axis=0))
    mesh = axes[subsystem, 1].pcolormesh(X, Y, alpha_abs, cmap=cm.coolwarm)
    mesh.set_clim(vmin=-1, vmax=+1)
    # total intensity
    if subsystem == 0:
        z_values = intensity_func(data_sample)
    else:
        identifier = subsystem_identifiers[subsystem - 1]
        z_values = compute_sub_function(intensity_func, data_sample, identifier)
    intensity_arrays.append(z_values)
    axes[subsystem, 0].pcolormesh(X, Y, z_values, norm=LogNorm())

threshold = 0.7
contour_arrays = create_dominant_region_contours(decay, data_sample, threshold)

for ax in axes[0]:
    legend_elements = indicate_dominant_regions(contour_arrays, ax)
    if ax is axes[0, -1]:
        leg = ax.legend(
            handles=legend_elements.values(),
            title=f">{100*threshold:.0f}%",
        )

for subsystem, ax_row in zip(["K", "L", "D"], axes[1:]):
    for ax in ax_row:
        indicate_dominant_regions(contour_arrays, ax, selected_subsystems=[subsystem])

plt.show()

In [None]:
fig, axes = plt.subplots(
    figsize=(13, 5),
    ncols=3,
    gridspec_kw={"width_ratios": [1, 1, 1.2]},
    sharey=True,
    tight_layout=True,
)
axes[0].set_ylabel(s2_label)
I_times_alpha = jnp.array(
    [array * intensity_arrays[0] for array in polarization_arrays]
)
global_min_max = float(jnp.nanmax(jnp.abs(I_times_alpha)))
for ax, z_values, xyz in zip(axes, I_times_alpha, "xyz"):
    ax.set_title(Rf"$\alpha_{xyz} \cdot I$")
    ax.set_xlabel(s1_label)
    mesh = ax.pcolormesh(X, Y, np.real(z_values), cmap=cm.RdYlGn_r)
    mesh.set_clim(vmin=-global_min_max, vmax=global_min_max)
    if ax is axes[-1]:
        fig.colorbar(mesh, ax=ax, pad=0.02)
plt.show()

In [None]:
%config InlineBackend.figure_formats = ['svg']

In [None]:
polarization_arrays = jnp.array(polarization_arrays)
alpha_abs = jnp.sqrt(jnp.sum(polarization_arrays**2, axis=0))
fig, ax = plt.subplots(figsize=(8, 7), tight_layout=True)
ax.set_title(R"Total polarization field $\vec{\alpha}$")
ax.set_xlabel(s1_label + R",$\quad\vec{\alpha}_x$")
ax.set_ylabel(s2_label + R",$\quad\vec{\alpha}_z$")
strides = 12
mesh = ax.quiver(
    X[::strides, ::strides],
    Y[::strides, ::strides],
    np.real(polarization_arrays[0][::strides, ::strides]),
    np.real(polarization_arrays[3][::strides, ::strides]),
    np.real(alpha_abs[::strides, ::strides]),
    cmap=cm.RdYlGn_r,
)
mesh.set_clim(vmin=0, vmax=+1)
color_bar = fig.colorbar(mesh, ax=ax, pad=0.01)
color_bar.set_label(R"$\left|\vec{\alpha}\right|$")
plt.show()

## Serialization

In [None]:
array = polarization_funcs[0](data_sample).real
df = pd.DataFrame(array, index=X[0], columns=Y[:, 0])
df.to_json("alpha-x-pandas.json")
df.to_json("alpha-x-pandas-json.zip", compression={"method": "zip"})
df.to_csv("alpha-x-pandas.csv")

df_dict = df.to_dict()
filtered_df_dict = {
    x: {y: v for y, v in row.items() if not math.isnan(v)} for x, row in df_dict.items()
}
with open("alpha-x-python.json", "w") as f:
    json.dump(filtered_df_dict, f, separators=(",", ":"))

In [None]:
def render_kilobytes(path, markdown: bool = False) -> str:
    byt = os.path.getsize(path)
    if markdown:
        return f"\n - **{1e-6*byt:.2f} MB**: {{download}}`{path}`"
    return f"\n  {1e-6*byt:5.2f} MB  {path}"


src = f"File sizes for {len(X[0])}x{len(Y[:, 0])} grid:"
markdown = "EXECUTE_NB" in os.environ
src += render_kilobytes("alpha-x-pandas.json", markdown)
src += render_kilobytes("alpha-x-python.json", markdown)
src += render_kilobytes("alpha-x-pandas-json.zip", markdown)
src += render_kilobytes("alpha-x-pandas.csv", markdown)
if markdown:
    display(Markdown(src))
else:
    print(src)

## Benchmarking

:::{tip}
Compare with Julia results as reported in [redeboer/polarization-sensitivity#27](https://github.com/redeboer/polarization-sensitivity/issues/27).
:::

In [None]:
phsp = generate_phasespace_sample(decay, 100_000, seed=0)
%timeit transformer(phsp)
phsp = transformer(phsp)

In [None]:
random_point = {k: v[0] if len(v.shape) > 0 else v for k, v in phsp.items()}
%timeit intensity_func(random_point)

In [None]:
%%timeit
polarization_funcs[0](random_point)
polarization_funcs[1](random_point)
polarization_funcs[2](random_point)

In [None]:
phsp_meshgrid54 = generate_meshgrid_sample(decay, resolution=54)
data_sample54 = transformer(phsp_meshgrid54)
X54 = phsp_meshgrid54["sigma1"]
Y54 = phsp_meshgrid54["sigma2"]
%timeit intensity_func(data_sample54)

In [None]:
%%timeit
polarization_funcs[0](data_sample54)
polarization_funcs[1](data_sample54)
polarization_funcs[2](data_sample54)

In [None]:
%config InlineBackend.figure_formats = ['png']

In [None]:
intensity54 = intensity_func(data_sample54)
fig, ax = plt.subplots(figsize=(4, 4))
ax.pcolormesh(X54, Y54, intensity54)
plt.show()

In [None]:
%%timeit
intensity_func(phsp)

In [None]:
%%timeit
polarization_funcs[0](phsp)
polarization_funcs[1](phsp)
polarization_funcs[2](phsp)