# Amplitude model with LS-couplings

```{autolink-concat}
```

In [None]:
from __future__ import annotations

import logging
import os

import jax.numpy as jnp
import matplotlib.pyplot as plt
import sympy as sp
from tensorwaves.interface import Function
from tqdm.auto import tqdm

from polarimetry.amplitude import AmplitudeModel, simplify_latex_rendering
from polarimetry.data import create_data_transformer, generate_meshgrid_sample
from polarimetry.io import (
    display_latex,
    mute_jax_warnings,
    perform_cached_doit,
    perform_cached_lambdify,
)
from polarimetry.lhcb import load_model_builder, load_model_parameters
from polarimetry.lhcb.particle import load_particles
from polarimetry.plot import use_mpl_latex_fonts

mute_jax_warnings()
simplify_latex_rendering()
MODEL_FILE = "../../data/model-definitions.yaml"
PARTICLES = load_particles("../../data/particle-definitions.yaml")

NO_TQDM = "EXECUTE_NB" in os.environ
if NO_TQDM:
    logging.getLogger().setLevel(logging.ERROR)
    logging.getLogger("polarimetry.io").setLevel(logging.ERROR)

## Model inspection

In [None]:
def formulate_model(title: str) -> AmplitudeModel:
    builder = load_model_builder(MODEL_FILE, PARTICLES, title)
    imported_parameters = load_model_parameters(
        MODEL_FILE, builder.decay, title, PARTICLES
    )
    model = builder.formulate()
    model.parameter_defaults.update(imported_parameters)
    return model


def simplify_notation(expr: sp.Expr) -> sp.Expr:
    def substitute_node(node):
        if isinstance(node, sp.Indexed):
            if node.indices[2:] == (0, 0):
                return sp.Indexed(node.base, *node.indices[:2])
        return node

    for node in sp.preorder_traversal(expr):
        new_node = substitute_node(node)
        expr = expr.xreplace({node: new_node})
    return expr


LS_MODEL = formulate_model("Alternative amplitude model obtained using LS couplings")
simplify_notation(LS_MODEL.intensity.args[0].args[0].args[0].cleanup())

In [None]:
display_latex({simplify_notation(k): v for k, v in LS_MODEL.amplitudes.items()})

It is asserted that these amplitude expressions to not evaluate to $0$ once the Clebsch-Gordan coefficients are evaluated.

In [None]:
def assert_non_zero_amplitudes(model: AmplitudeModel) -> None:
    for amplitude in tqdm(model.amplitudes.values(), disable=NO_TQDM):
        assert amplitude.doit() != 0


assert_non_zero_amplitudes(LS_MODEL)

:::{seealso}
See {ref}`amplitude-model:Resonances and LS-scheme` for the allowed $LS$-values.
:::

## Distribution

In [None]:
def lambdify(model: AmplitudeModel) -> sp.Expr:
    intensity_expr = unfold_intensity(model)
    subs_intensity_expr = intensity_expr.xreplace(model.parameter_defaults)
    return perform_cached_lambdify(subs_intensity_expr)


def unfold_intensity(model: AmplitudeModel) -> sp.Expr:
    unfolded_intensity = perform_cached_doit(model.intensity)
    return perform_cached_doit(unfolded_intensity.xreplace(model.amplitudes))


NOMINAL_MODEL = formulate_model("Default amplitude model")
NOMINAL_INTENSITY_FUNC = lambdify(NOMINAL_MODEL)
LS_INTENSITY_FUNC = lambdify(LS_MODEL)

In [None]:
GRID = generate_meshgrid_sample(NOMINAL_MODEL.decay, resolution=300)
transformer = create_data_transformer(NOMINAL_MODEL)
GRID.update(transformer(GRID))
del transformer

In [None]:
def compare_2d_distributions() -> None:
    NOMINAL_INTENSITIES = compute_normalized_intensity(NOMINAL_INTENSITY_FUNC)
    LS_INTENSITIES = compute_normalized_intensity(LS_INTENSITY_FUNC)
    max_intensity = max(
        jnp.nanmax(NOMINAL_INTENSITIES),
        jnp.nanmax(LS_INTENSITIES),
    )
    use_mpl_latex_fonts()
    fig, axes = plt.subplots(
        dpi=200,
        figsize=(12, 5),
        ncols=2,
    )
    for ax in axes:
        ax.set_box_aspect(1)
    ax1, ax2 = axes
    ax1.set_title("Nominal model")
    ax2.set_title("LS-model")
    ax1.pcolormesh(
        GRID["sigma1"],
        GRID["sigma2"],
        NOMINAL_INTENSITIES,
        vmax=max_intensity,
    )
    ax2.pcolormesh(
        GRID["sigma1"],
        GRID["sigma2"],
        LS_INTENSITIES,
        vmax=max_intensity,
    )
    plt.show()


def compute_normalized_intensity(func: Function) -> jnp.ndarray:
    intensities = func(GRID)
    integral = jnp.nansum(intensities)
    return intensities / integral


compare_2d_distributions()