# Serialization

```{autolink-concat}
```

In [None]:
from __future__ import annotations

import json
import logging
import math
import os
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import Markdown, display
from scipy.interpolate import interp2d
from tensorwaves.function.sympy import create_function
from tqdm.auto import tqdm

from polarization import formulate_polarization
from polarization.amplitude import DalitzPlotDecompositionBuilder
from polarization.data import (
    create_data_transformer,
    generate_meshgrid_sample,
    generate_phasespace_sample,
)
from polarization.io import (
    export_polarization_field,
    import_polarization_field,
    mute_jax_warnings,
    perform_cached_doit,
)
from polarization.lhcb import _load_model_parameters, load_three_body_decays
from polarization.plot import use_mpl_latex_fonts

mute_jax_warnings()
logging.getLogger().setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

reference_subsystem = 1
dynamics_configurator = load_three_body_decays("../../data/isobars.json")
decay = dynamics_configurator.decay
amplitude_builder = DalitzPlotDecompositionBuilder(decay)
amplitude_builder.dynamics_choices = dynamics_configurator
model = amplitude_builder.formulate(reference_subsystem)
imported_parameter_values = _load_model_parameters(
    "../../data/modelparameters.json", decay
)
model.parameter_defaults.update(imported_parameter_values)

In [None]:
polarization_exprs = formulate_polarization(amplitude_builder, reference_subsystem)
unfolded_polarization_exprs = [
    perform_cached_doit(expr.doit().xreplace(model.amplitudes))
    for expr in tqdm(polarization_exprs, desc="Unfolding polarization expressions")
]
unfolded_intensity_expr = perform_cached_doit(model.full_expression)

In [None]:
polarization_funcs = [
    create_function(expr.xreplace(model.parameter_defaults), backend="jax")
    for expr in tqdm(unfolded_polarization_exprs)
]
intensity_func = create_function(
    unfolded_intensity_expr.xreplace(model.parameter_defaults),
    backend="jax",
)

In [None]:
resolution = 100
transformer = create_data_transformer(model)
grid_sample = generate_meshgrid_sample(decay, resolution)
grid_sample = transformer(grid_sample)
X = grid_sample["sigma1"]
Y = grid_sample["sigma2"]

In [None]:
polarization_grids = [func(grid_sample).real for func in polarization_funcs]
alpha_x = polarization_grids[0]
df = pd.DataFrame(alpha_x, index=X[0], columns=Y[:, 0])
df.to_json("alpha-x-pandas.json")
df.to_json("alpha-x-pandas-json.zip", compression={"method": "zip"})
df.to_csv("alpha-x-pandas.csv")

df_dict = df.to_dict()
filtered_df_dict = {
    x: {y: v for y, v in row.items() if not math.isnan(v)} for x, row in df_dict.items()
}
with open("alpha-x-python.json", "w") as f:
    json.dump(filtered_df_dict, f)

json_dict = dict(
    x=X[0].tolist(),
    y=Y[:, 0].tolist(),
    z=alpha_x.tolist(),
)
with open("alpha-x-arrays.json", "w") as f:
    json.dump(json_dict, f, separators=(",", ":"))

In [None]:
def render_kilobytes(path, markdown: bool = False) -> str:
    byt = os.path.getsize(path)
    kb = f"{1e-3*byt:.0f}"
    if markdown:
        return f"\n - **{kb} kB**: {{download}}`{path}`"
    return f"\n  {kb:>4s} kB  {path}"


src = f"File sizes for {len(X[0])}x{len(Y[:, 0])} grid:"
markdown = "EXECUTE_NB" in os.environ
src += render_kilobytes("alpha-x-arrays.json", markdown)
src += render_kilobytes("alpha-x-pandas.json", markdown)
src += render_kilobytes("alpha-x-python.json", markdown)
src += render_kilobytes("alpha-x-pandas-json.zip", markdown)
src += render_kilobytes("alpha-x-pandas.csv", markdown)
if markdown:
    display(Markdown(src))
else:
    print(src)

## Exported polarization grids

:::{note}
Decided to use the `alpha-x-arrays.json` format. It can be exported with {func}`.export_polarization_field`.
:::

In [None]:
export_polarization_field(
    sigma1=X[0],
    sigma2=Y[:, 0],
    alpha_x=polarization_grids[0],
    alpha_y=polarization_grids[1],
    alpha_z=polarization_grids[2],
    intensity=intensity_func(grid_sample),
    filename="polarizations-model-0.json",
)

In [None]:
if "EXECUTE_NB" in os.environ:
    src = (
        "Polarization grid can be downloaded here:"
        " {download}`polarizations-model-0.json`."
    )
    display(Markdown(src))

### Import and interpolate

In [None]:
imported_resolution = 100
steps = int(resolution / imported_resolution)
field_definition = import_polarization_field("polarizations-model-0.json", steps)
imported_sigma1 = field_definition["m^2_Kpi"]
imported_sigma2 = field_definition["m^2_pK"]
imported_arrays = (
    field_definition["intensity"],
    field_definition["alpha_x"],
    field_definition["alpha_y"],
    field_definition["alpha_z"],
)

In [None]:
xx, yy = np.meshgrid(imported_sigma1, imported_sigma2)
interpolated_funcs = tuple(
    interp2d(xx, yy, np.nan_to_num(zz), kind="linear") for zz in tqdm(imported_arrays)
)

:::{warning}
{obj}`scipy.interpolate.interp2d` in combination is {mod}`jax.numpy` is slower than using {mod}`numpy`. Also note that {obj}`~numpy.nan` values have to be replaced with `0.0` using {func}`numpy.nan_to_num`.
:::

In [None]:
n_points = 10
mini_sample = generate_phasespace_sample(model.decay, n_points, seed=0)
mini_sample = transformer(mini_sample)
src = r"""
|   |   | $I$ | $\alpha_x$ | $\alpha_y$ | $\alpha_z$ |
|---|--:|:---:|:----------:|:----------:|:----------:|
""".strip()
for i in range(n_points):
    phsp_point = {
        k: float(v[i]) if len(v.shape) else float(v) for k, v in mini_sample.items()
    }
    x = phsp_point["sigma1"]
    y = phsp_point["sigma2"]
    computed_values = tuple(float(func(x, y)) for func in interpolated_funcs)
    actual_values = tuple(
        float(func(phsp_point).real) for func in [intensity_func, *polarization_funcs]
    )
    src += "\n"
    src += f"| **point {i}** | **interpolated** |"
    for val in computed_values:
        src += f"{val:.2f} |"
    src += "\n"
    src += "|  | **actual** |"
    for val in actual_values:
        src += f"{val:.2f} |"
    src += "\n"
    src += "|  | **difference** |"
    for actual, computed in zip(actual_values, computed_values):
        diff = actual - computed
        percent = 100 * abs(diff / actual)
        diff_str = f"{percent:.0f}%"
        if percent > 5:
            diff_str = f'<span style="color:red">{diff_str}</span>'
        src += f"{diff_str} |"
Markdown(src)

In [None]:
use_mpl_latex_fonts()
plt.rc("font", size=18)
fig, axes = plt.subplots(dpi=200, figsize=(15, 5.2), ncols=4, sharey=True)
fig.suptitle("Locations of random points", y=0.92)
axes[0].set_ylabel(R"$\sigma_2 = m^2_{pK}$ [GeV$/c^2$]")
for i, ax in enumerate(axes):
    if ax is axes[0]:
        func = intensity_func
        title = "$I$"
        cmap = plt.cm.viridis
    else:
        xyz = i - 1
        func = polarization_funcs[xyz]
        title = Rf"$\alpha_{'xyz'[xyz]}$"
        cmap = plt.cm.coolwarm
    ax.set_title(title)
    Z = func(grid_sample).real
    ax.set_xlabel(R"$\sigma_1 = m^2_{K\pi}$ [GeV$/c^2$]")
    mesh = ax.pcolormesh(X, Y, Z, cmap=cmap)
    if ax is not axes[0]:
        mesh.set_clim(-1, +1)
    for i in range(n_points):
        x = mini_sample["sigma1"][i]
        y = mini_sample["sigma2"][i]
        ax.text(x, y, f"{i}", c="red", fontsize=10)
fig.tight_layout()
plt.show()