# Alignment consistency

```{autolink-concat}
```

In [None]:
from __future__ import annotations

import jax.numpy as jnp
import matplotlib.pyplot as plt
import sympy as sp
from numpy.testing import assert_almost_equal
from tensorwaves.data import SympyDataTransformer
from tensorwaves.function.sympy import create_function
from tqdm.auto import tqdm

from polarization.amplitude import (
    AmplitudeModel,
    DalitzPlotDecompositionBuilder,
    simplify_latex_rendering,
)
from polarization.data import create_data_transformer, generate_meshgrid_sample
from polarization.io import display_latex, mute_jax_warnings, perform_cached_doit
from polarization.lhcb import (
    _load_model_parameters,
    flip_production_coupling_signs,
    load_three_body_decays,
)

mute_jax_warnings()
simplify_latex_rendering()

In [None]:
dynamics_configurator = load_three_body_decays("../../data/isobars.json")
decay = dynamics_configurator.decay
amplitude_builder = DalitzPlotDecompositionBuilder(decay)
amplitude_builder.dynamics_choices = dynamics_configurator
imported_parameter_values = _load_model_parameters(
    "../../data/modelparameters.json", decay
)
models = {}
for reference_subsystem in [1, 2, 3]:
    models[reference_subsystem] = amplitude_builder.formulate(
        reference_subsystem, cleanup_summations=True
    )
    models[reference_subsystem].parameter_defaults.update(imported_parameter_values)
models[2] = flip_production_coupling_signs(models[2], subsystem_names=["K", "L"])
models[3] = flip_production_coupling_signs(models[3], subsystem_names=["K", "D"])

In [None]:
display_latex(m.intensity.cleanup() for m in models.values())

See {doc}`/appendix/angles` for the definition of each $\zeta^i_{j(k)}$.

Note that a change in reference sub-system requires the production couplings for certain sub-systems to flip sign:
- **Sub-system 2** as reference system: flip signs of $\mathcal{H}^\mathrm{production}_{K^{**}}$ and  $\mathcal{H}^\mathrm{production}_{L^{**}}$
- **Sub-system 3** as reference system: flip signs of $\mathcal{H}^\mathrm{production}_{K^{**}}$ and  $\mathcal{H}^\mathrm{production}_{D^{**}}$

In [None]:
coupling = [
    symbol
    for symbol in models[1].parameter_defaults
    if str(symbol) == R"\mathcal{H}^\mathrm{production}[K(892), -1, -1/2]"
][0]
assert models[2].parameter_defaults[coupling] == -models[1].parameter_defaults[coupling]
assert models[3].parameter_defaults[coupling] == -models[1].parameter_defaults[coupling]

In [None]:
unfolded_intensity_exprs = {
    reference_subsystem: perform_cached_doit(model.full_expression)
    for reference_subsystem, model in tqdm(models.items())
}

In [None]:
def assert_all_symbols_defined(expr: sp.Expr, model: AmplitudeModel) -> None:
    sigmas = sp.symbols("sigma1:4", nonnegative=True)
    remaining_symbols = expr.xreplace(model.parameter_defaults).free_symbols
    remaining_symbols -= set(model.variables)
    remaining_symbols -= set(sigmas)
    assert not remaining_symbols, remaining_symbols


for reference_subsystem in unfolded_intensity_exprs:
    assert_all_symbols_defined(
        expr=unfolded_intensity_exprs[reference_subsystem],
        model=models[reference_subsystem],
    )

In [None]:
subs_intensity_exprs = {
    reference_subsystem: expr.xreplace(models[reference_subsystem].parameter_defaults)
    for reference_subsystem, expr in unfolded_intensity_exprs.items()
}

In [None]:
intensity_funcs = {
    reference_subsystem: create_function(expr, backend="jax")
    for reference_subsystem, expr in tqdm(subs_intensity_exprs.items())
}

In [None]:
transformer = {}
for reference_subsystem in tqdm([1, 2, 3]):
    model = models[reference_subsystem]
    transformer.update(create_data_transformer(model).functions)
transformer = SympyDataTransformer(transformer)
grid_sample = generate_meshgrid_sample(decay, resolution=200)
grid_sample = transformer(grid_sample)
intensity_grids = {i: func(grid_sample) for i, func in intensity_funcs.items()}

In [None]:
{i: jnp.nansum(grid) for i, grid in intensity_grids.items()}

In [None]:
assert_almost_equal(jnp.nansum(intensity_grids[2] - intensity_grids[1]), 0)
assert_almost_equal(jnp.nansum(intensity_grids[2] - intensity_grids[1]), 0)

In [None]:
%config InlineBackend.figure_formats = ['png']

In [None]:
s1_label = R"$\sigma_1=m^2\left(K\pi\right)$"
s2_label = R"$\sigma_2=m^2\left(pK\right)$"
s3_label = R"$\sigma_3=m^2\left(p\pi\right)$"

X = grid_sample["sigma1"]
Y = grid_sample["sigma2"]

fig, axes = plt.subplots(
    ncols=3,
    figsize=(20, 7),
    tight_layout=True,
    sharey=True,
    gridspec_kw={"width_ratios": [1, 1, 1.2]},
)
fig.suptitle("Intensity distribution")
global_max = max(map(jnp.nanmax, intensity_grids.values()))
axes[0].set_ylabel(s2_label)
for i, ax in enumerate(axes, 1):
    ax.set_title(f"Subsystem {i} as reference")
    ax.set_xlabel(s1_label)
    Z = intensity_grids[i]
    mesh = ax.pcolormesh(X, Y, Z)
    mesh.set_clim(vmax=global_max)
    if ax is axes[-1]:
        fig.colorbar(mesh, ax=ax)
plt.show()