# Intensity distribution

```{autolink-concat}
```

In [None]:
from __future__ import annotations

import logging
import re
from itertools import product
from typing import Pattern

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from ampform.helicity.naming import natural_sorting
from IPython.display import Markdown
from matplotlib.collections import LineCollection, PathCollection
from matplotlib.colors import LogNorm
from matplotlib.contour import QuadContourSet
from tensorwaves.function import ParametrizedBackendFunction
from tensorwaves.function.sympy import create_parametrized_function
from tqdm.auto import tqdm

from polarization.amplitude import DalitzPlotDecompositionBuilder
from polarization.data import (
    create_data_transformer,
    generate_meshgrid_sample,
    generate_phasespace_sample,
)
from polarization.decay import Particle
from polarization.function import compute_sub_function
from polarization.io import mute_jax_warnings, perform_cached_doit
from polarization.lhcb import load_model_parameters, load_three_body_decays

mute_jax_warnings()

reference_subsystem = 1
dynamics_configurator = load_three_body_decays("../data/isobars.json")
decay = dynamics_configurator.decay
amplitude_builder = DalitzPlotDecompositionBuilder(decay)
amplitude_builder.dynamics_choices = dynamics_configurator
model = amplitude_builder.formulate(reference_subsystem)
imported_parameter_values = load_model_parameters("../data/modelparameters.json", decay)
model.parameter_defaults.update(imported_parameter_values)

In [None]:
%%time
unfolded_intensity_expr = perform_cached_doit(model.full_expression)

In [None]:
def assert_all_symbols_defined(expr: sp.Expr) -> None:
    sigmas = sp.symbols("sigma1:4", nonnegative=True)
    remaining_symbols = expr.xreplace(model.parameter_defaults).free_symbols
    remaining_symbols -= set(model.variables)
    remaining_symbols -= set(sigmas)
    assert not remaining_symbols, remaining_symbols


assert_all_symbols_defined(unfolded_intensity_expr)
Markdown(
    "The complete intensity expression contains"
    f" **{sp.count_ops(unfolded_intensity_expr):,} mathematical operations**."
)

## Definition of free parameters

In [None]:
free_parameters = {
    symbol: value
    for symbol, value in model.parameter_defaults.items()
    if isinstance(symbol, sp.Indexed)
    if "production" in str(symbol)
}
fixed_parameters = {
    symbol: value
    for symbol, value in model.parameter_defaults.items()
    if symbol not in free_parameters
}
subs_intensity_expr = unfolded_intensity_expr.xreplace(fixed_parameters)

In [None]:
Markdown(
    "After substituting the parameters that are not production couplings, the total"
    " intensity expression contains"
    f" **{sp.count_ops(subs_intensity_expr):,} operations**."
)

## Distribution

In [None]:
intensity_func = create_parametrized_function(
    subs_intensity_expr,
    parameters=free_parameters,
    backend="jax",
)

In [None]:
%config InlineBackend.figure_formats = ['png']

In [None]:
s1_label = R"$\sigma_1=m^2\left(K\pi\right)$"
s2_label = R"$\sigma_2=m^2\left(pK\right)$"
s3_label = R"$\sigma_3=m^2\left(p\pi\right)$"

fig, ax = plt.subplots(
    figsize=(10, 8),
    tight_layout=True,
)
ax.set_title("Intensity distribution")
ax.set_xlabel(s1_label)
ax.set_ylabel(s2_label)

phsp_sample = generate_meshgrid_sample(decay, resolution=1_000)
transformer = create_data_transformer(model)
data_sample = transformer(phsp_sample)
X = phsp_sample["sigma1"]
Y = phsp_sample["sigma2"]
Z = intensity_func(data_sample)
mesh = ax.pcolormesh(X, Y, Z, norm=LogNorm())
fig.colorbar(mesh, ax=ax)
plt.show()

In [None]:
%config InlineBackend.figure_formats = ['svg']

In [None]:
def set_parameter_to_zero(
    func: ParametrizedBackendFunction, search_term: Pattern
) -> None:
    new_parameters = dict(func.parameters)
    no_parameters_selected = True
    for par_name in func.parameters:
        if re.match(search_term, par_name) is not None:
            new_parameters[par_name] = 0
            no_parameters_selected = False
    if no_parameters_selected:
        logging.warning(f"All couplings were set to zero for search term {search_term}")
    func.update_parameters(new_parameters)


def set_ylim_to_zero(ax):
    _, y_max = ax.get_ylim()
    ax.set_ylim(0, y_max)


fig, (ax1, ax2) = plt.subplots(
    ncols=2,
    figsize=(12, 5),
    tight_layout=True,
)
ax1.set_xlabel(s1_label)
ax2.set_xlabel(s2_label)

subsystem_identifiers = ["K", "L", "D"]
subsystem_labels = ["K^{**}", R"\Lambda^{**}", R"\Delta^{**}"]
intensity_array = intensity_func(data_sample)
x, y = X[0], Y[:, 0]
ax1.fill(x, np.nansum(intensity_array, axis=0), alpha=0.3)
ax2.fill(y, np.nansum(intensity_array, axis=1), alpha=0.3)

original_parameters = dict(intensity_func.parameters)
for label, identifier in zip(subsystem_labels, subsystem_identifiers):
    label = f"${label}$"
    intensity_array = compute_sub_function(intensity_func, data_sample, [identifier])
    ax1.plot(x, np.nansum(intensity_array, axis=0), label=label)
    ax2.plot(y, np.nansum(intensity_array, axis=1), label=label)
    intensity_func.update_parameters(original_parameters)
set_ylim_to_zero(ax1)
set_ylim_to_zero(ax2)
ax2.legend()
plt.show()

In [None]:
def stylize_contour(
    contour_set: QuadContourSet,
    *,
    color=None,
    label: str | None = None,
    linestyle: str | None = None,
    linewidth: float | None = None,
) -> None:
    contour_line: PathCollection = contour_set.collections[0]
    if color is not None:
        contour_line.set_edgecolor(color)
    if label is not None:
        contour_line.set_label(label)
    if linestyle is not None:
        contour_line.set_linestyle(linestyle)
    if linewidth is not None:
        contour_line.set_linewidth(linewidth)


def get_contour_line(contour_set: QuadContourSet) -> LineCollection:
    (line_collection, *_), _ = contour_set.legend_elements()
    return line_collection


threshold = 0.5
percentage = int(100 * threshold)
I_tot = intensity_func(data_sample)

fig, ax = plt.subplots(figsize=(7, 7), sharey=True, tight_layout=True)
ax.set_ylabel(s2_label)
ax.set_xlabel(s1_label)
ax.set_title(Rf"Regions where the resonance has a decay ratio of $\geq {percentage}$%")

phsp_region = jnp.select(
    [I_tot > 0, True],
    (1, 0),
)
contour_set = ax.contour(X, Y, phsp_region, colors="none")
stylize_contour(contour_set, color="black", linewidth=0.2)

resonances_names = [c.resonance.name for c in decay.chains]
contour_levels = [i for i, _ in enumerate(resonances_names, 1)]
colors = [plt.cm.rainbow(x) for x in np.linspace(0, 1, len(resonances_names))]
items = list(zip(contour_levels, resonances_names, colors))  # tqdm requires len
legend_elements = []
for res_id, resonance, color in tqdm(items):
    regex_filter = resonance.replace("(", r"\(").replace(")", r"\)")
    I_sub = compute_sub_function(intensity_func, data_sample, [regex_filter])
    ratio = I_sub / I_tot
    selection = jnp.select(
        [jnp.isnan(ratio), ratio < threshold, True],
        [0, 0, res_id],
    )
    if jnp.all(selection == 0):
        continue
    contour_set = ax.contour(X, Y, selection, colors="none")
    contour_set.set_clim(vmin=1, vmax=len(decay.chains))
    stylize_contour(contour_set, label=resonance, color=color)
    line_collection = get_contour_line(contour_set)
    legend_elements.append(line_collection)
leg = plt.legend(handles=legend_elements)
plt.show()

## Fit fractions

In [None]:
integration_sample = generate_phasespace_sample(decay, n_events=100_000, seed=0)
integration_sample = transformer(integration_sample)

In [None]:
def sub_intensity(data, non_zero_couplings: list[str]):
    intensity_array = compute_sub_function(intensity_func, data, non_zero_couplings)
    return integrate_intensity(intensity_array)


def integrate_intensity(intensities) -> float:
    flattened_intensities = intensities.flatten()
    non_nan_intensities = flattened_intensities[~jnp.isnan(flattened_intensities)]
    return float(jnp.sum(non_nan_intensities) / len(non_nan_intensities))


I_tot = integrate_intensity(intensity_func(integration_sample))

In [None]:
np.testing.assert_allclose(
    I_tot,
    sub_intensity(integration_sample, ["K", "L", "D"]),
)

In [None]:
def interference_intensity(data, chain1: list[str], chain2: list[str]) -> float:
    I_interference = sub_intensity(data, chain1 + chain2)
    I_chain1 = sub_intensity(data, chain1)
    I_chain2 = sub_intensity(data, chain2)
    return I_interference - I_chain1 - I_chain2


I_K = sub_intensity(integration_sample, non_zero_couplings=["K"])
I_Λ = sub_intensity(integration_sample, non_zero_couplings=["L"])
I_Δ = sub_intensity(integration_sample, non_zero_couplings=["D"])
I_ΛΔ = interference_intensity(integration_sample, ["L"], ["D"])
I_KΔ = interference_intensity(integration_sample, ["K"], ["D"])
I_KΛ = interference_intensity(integration_sample, ["K"], ["L"])
np.testing.assert_allclose(I_tot, I_K + I_Λ + I_Δ + I_ΛΔ + I_KΔ + I_KΛ)

In [None]:
def to_regex(text: str) -> str:
    text = text.replace("(", r"\(")
    text = text.replace(")", r"\)")
    return text


def sort_resonances(resonance: Particle):
    KDL = {"L": 1, "D": 2, "K": 3}
    return KDL[resonance.name[0]], natural_sorting(resonance.name)


resonances = sorted(
    (chain.resonance for chain in decay.chains),
    key=sort_resonances,
    reverse=True,
)
n_resonances = len(resonances)
decay_rates = np.zeros(shape=(n_resonances, n_resonances))
combinations = list(product(enumerate(resonances), enumerate(resonances)))
progress_bar = tqdm(
    desc="Calculating rate matrix",
    total=(len(combinations) + n_resonances) // 2,
)
for (i, resonance1), (j, resonance2) in combinations:
    if j < i:
        continue
    progress_bar.postfix = f"{resonance1.name} × {resonance2.name}"
    res1 = to_regex(resonance1.name)
    res2 = to_regex(resonance2.name)
    if res1 == res2:
        I_sub = sub_intensity(integration_sample, non_zero_couplings=[res1])
    else:
        I_sub = interference_intensity(integration_sample, [res1], [res2])
    decay_rates[i, j] = I_sub / I_tot
    if i != j:
        decay_rates[j, i] = decay_rates[i, j]
    progress_bar.update()
progress_bar.close()

In [None]:
vmax = np.max(np.abs(decay_rates))
fig, ax = plt.subplots(figsize=(9, 9))
ax.set_title("Rate matrix for isobars (%)")
ax.matshow(np.rot90(decay_rates).T, cmap=plt.cm.coolwarm, vmin=-vmax, vmax=+vmax)

resonance_names = [p.name for p in resonances]
ax.set_xticks(range(n_resonances))
ax.set_xticklabels(reversed(resonance_names))
ax.set_yticks(range(n_resonances))
ax.set_yticklabels(resonance_names)
for i in range(n_resonances):
    for j in range(n_resonances):
        if j < i:
            continue
        rate = decay_rates[i, j]
        ax.text(n_resonances - j - 1, i, f"{100 * rate:.2f}", va="center", ha="center")
fig.tight_layout()
plt.show()