# Interactive visualization

```{autolink-concat}
```

In [None]:
from __future__ import annotations

import logging
import os
from functools import lru_cache
from textwrap import dedent
from warnings import filterwarnings

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from IPython.display import Markdown, display
from ipywidgets import (
    HTML,
    Button,
    Combobox,
    FloatSlider,
    HBox,
    HTMLMath,
    Tab,
    VBox,
    interactive_output,
)
from matplotlib.colors import LogNorm
from symplot import create_slider
from tensorwaves.interface import DataSample
from tqdm.auto import tqdm

from polarimetry import formulate_polarimetry
from polarimetry.amplitude import simplify_latex_rendering
from polarimetry.data import create_data_transformer, generate_meshgrid_sample
from polarimetry.io import (
    mute_jax_warnings,
    perform_cached_doit,
    perform_cached_lambdify,
)
from polarimetry.lhcb import load_model_builder, load_model_parameters
from polarimetry.lhcb.particle import load_particles
from polarimetry.plot import use_mpl_latex_fonts

filterwarnings("ignore")
logging.getLogger("polarimetry.function").setLevel(logging.INFO)
mute_jax_warnings()
simplify_latex_rendering()

NO_TQDM = "EXECUTE_NB" in os.environ
if NO_TQDM:
    logging.getLogger().setLevel(logging.ERROR)

In [None]:
model_choice = 0
model_file = "../../data/model-definitions.yaml"
PARTICLES = load_particles("../../data/particle-definitions.yaml")
BUILDER = load_model_builder(model_file, PARTICLES, model_id=0)
imported_parameters = load_model_parameters(
    model_file,
    BUILDER.decay,
    model_id=0,
    particle_definitions=PARTICLES,
)
MODEL = BUILDER.formulate(reference_subsystem=1)
MODEL.parameter_defaults.update(imported_parameters)
del model_choice, model_file, imported_parameters

In [None]:
FREE_PARAMETERS = {
    s: value
    for s, value in MODEL.parameter_defaults.items()
    if (isinstance(s, sp.Indexed) and "production" in s.name)
    or s.name.startswith(R"\Gamma_")
    and "Sigma" not in s.name
    or (s.name.startswith("m_") and "(" in s.name)
}
FIXED_PARAMETERS = {
    s: value
    for s, value in MODEL.parameter_defaults.items()
    if s not in FREE_PARAMETERS
}


@lru_cache(maxsize=None)
def unfold_and_substitute(expr: sp.Expr) -> sp.Expr:
    expr = perform_cached_doit(expr)
    expr = perform_cached_doit(expr.xreplace(MODEL.amplitudes))
    expr = expr.xreplace(FIXED_PARAMETERS)
    return expr

In [None]:
POLARIMETRY_EXPRS = tuple(
    unfold_and_substitute(expr)
    for expr in formulate_polarimetry(BUILDER, reference_subsystem=1)
)
INTENSITY_EXPR = unfold_and_substitute(MODEL.intensity)

In [None]:
POLARIMETRY_FUNCS = tuple(
    perform_cached_lambdify(expr, parameters=FREE_PARAMETERS)
    for expr in tqdm(POLARIMETRY_EXPRS, disable=NO_TQDM, leave=False)
)
INTENSITY_FUNC = perform_cached_lambdify(INTENSITY_EXPR, parameters=FREE_PARAMETERS)

In [None]:
def create_grid(resolution: int) -> DataSample:
    sample = generate_meshgrid_sample(MODEL.decay, resolution)
    sample.update(TRANSFORMER(sample))
    return sample


TRANSFORMER = create_data_transformer(MODEL)
MESH_GRID = create_grid(resolution=200)
QUIVER_GRID = create_grid(resolution=35)

In [None]:
def create_ui() -> tuple[VBox, dict[str, FloatSlider]]:
    # Slider construction
    sliders = {}
    for symbol, value in FREE_PARAMETERS.items():
        if isinstance(symbol, sp.Indexed):
            real_slider = create_slider(symbol)
            imag_slider = create_slider(symbol)
            sliders[f"{symbol.name}_real"] = real_slider
            sliders[f"{symbol.name}_imag"] = imag_slider
            real_slider.description = "Re"
            imag_slider.description = "Im"
        else:
            slider = create_slider(symbol)
            sliders[symbol.name] = slider
            slider.readout_format = ".3f"
            if symbol.name.startswith("m"):
                slider.description = "m"
            elif symbol.name.startswith(R"\Gamma"):
                slider.description = "Γ"

    # Slider ranges
    σ1_min, σ1_max = get_min_max(1)
    σ2_min, σ2_max = get_min_max(2)
    σ3_min, σ3_max = get_min_max(3)
    for name, slider in sliders.items():
        slider.continuous_update = True
        slider.step = 0.01
        if name.startswith("m_"):
            if "K" in name:
                slider.min = np.sqrt(σ1_min)
                slider.max = np.sqrt(σ1_max)
            elif "L" in name:
                slider.min = np.sqrt(σ2_min)
                slider.max = np.sqrt(σ2_max)
            elif "D" in name:
                slider.min = np.sqrt(σ3_min)
                slider.max = np.sqrt(σ3_max)
        elif name.startswith(R"\Gamma_"):
            slider.min = 0
            slider.max = max(0.5, 2 * slider.value)
        elif "production" in name:
            slider.min = -15
            slider.max = +15

    # Slider values
    def reset_sliders(click_event):
        for symbol, value in FREE_PARAMETERS.items():
            if isinstance(symbol, sp.Indexed):
                set_slider(sliders[symbol.name + "_real"], value.real)
                set_slider(sliders[symbol.name + "_imag"], value.imag)
            else:
                set_slider(sliders[symbol.name], value)

    def set_coupling_to_zero(filter_pattern):
        if isinstance(filter_pattern, Combobox):
            filter_pattern = filter_pattern.value
        for name, slider in sliders.items():
            if "production" not in name:
                continue
            if filter_pattern not in name:
                continue
            set_slider(sliders[name], 0)

    def set_slider(slider, value):
        if "Im" in slider.description:
            value = complex(value).imag
        else:
            value = complex(value).real
        n_decimals = -round(np.log10(slider.step))
        if slider.value != round(value, n_decimals):  # widget performance
            slider.value = value

    reset_sliders(click_event=None)
    reset_button = Button(description="Reset slider values")
    reset_button.on_click(reset_sliders)

    resonances = sorted(
        {c.resonance for c in MODEL.decay.chains},
        key=lambda p: (p.name[0], p.mass),
    )
    filter_button = Combobox(
        placeholder="Enter coupling filter pattern",
        options=[p.name for p in resonances],
        description=R"$\mathcal{H}=0$",
    )
    filter_button.on_submit(set_coupling_to_zero)

    # UI design
    latex = {symbol.name: sp.latex(symbol) for symbol in FREE_PARAMETERS}
    get_subscript = lambda p: Rf"{p.name} \to p K^-" if "1405" in p.name else p.name
    colors = dict(K="red", L="blue", D="green")
    pole_sliders = [
        (
            sliders[f"m_{{{p.name}}}"],
            sliders[Rf"\Gamma_{{{get_subscript(p)}}}"],
            HTML(f'<b><span style="color:{colors[p.name[0]]}">{p.name}</span></b>'),
        )
        for p in resonances
    ]
    coupling_sliders = {
        res.name: (
            [
                s
                for n, s in sliders.items()
                if n.endswith("_real") and res.name[0] in n
            ],
            [
                s
                for n, s in sliders.items()
                if n.endswith("_imag") and res.name[0] in n
            ],
            [
                HTMLMath(f"${latex[n[:-5]]}$")
                for n in sliders
                if n.endswith("_real") and res.name[0] in n
            ],
        )
        for res in resonances
    }
    slider_tabs = Tab(
        children=[
            VBox([HBox([m, w, t]) for m, w, t in pole_sliders]),
            Tab(
                children=[
                    VBox([HBox(s) for s in zip(*pair)])
                    for pair in coupling_sliders.values()
                ],
                titles=tuple(coupling_sliders),
            ),
        ],
        titles=("Masses and widths", "Couplings"),
    )
    ui = VBox([slider_tabs, HBox([reset_button, filter_button])])
    return ui, sliders


def get_min_max(sigma: int) -> tuple[float, float]:
    array = MESH_GRID[f"sigma{sigma}"]
    return max(np.nanmin(array), 0), np.nanmax(array)


UI, SLIDERS = create_ui()

In [None]:
%config InlineBackend.figure_formats = ['png']
%matplotlib widget


def create_interactive_plot() -> None:
    plt.rcdefaults()
    use_mpl_latex_fonts()
    plt.rc("font", size=17)
    fig, axes = plt.subplots(
        figsize=(12, 6.2),
        ncols=2,
        sharey=True,
    )
    ax1, ax2 = axes
    ax1.set_title("Intensity distribution")
    ax2.set_title("Polarimeter vector field")
    ax1.set_xlabel(R"$m^2(K^- \pi^+)$")
    ax2.set_xlabel(R"$m^2(K^- \pi^+), \alpha_x$")
    ax1.set_ylabel(R"$m^2(p K^-), \alpha_x$")
    for ax in axes:
        ax.set_box_aspect(1)
    fig.canvas.toolbar_visible = False
    fig.canvas.header_visible = False
    fig.canvas.footer_visible = False

    mesh = None
    quiver = None
    intensity_bar = None

    def plot3(**kwargs):
        nonlocal quiver, mesh, intensity_bar
        kwargs = to_complex_kwargs(**kwargs)
        for func in list(POLARIMETRY_FUNCS) + [INTENSITY_FUNC]:
            func.update_parameters(kwargs)
        intensities = INTENSITY_FUNC(MESH_GRID)
        αx, αy, αz = tuple(func(QUIVER_GRID).real for func in POLARIMETRY_FUNCS)
        abs_α = jnp.sqrt(αx**2 + αy**2 + αz**2)
        if mesh is None:
            mesh = ax1.pcolormesh(
                MESH_GRID["sigma1"],
                MESH_GRID["sigma2"],
                intensities,
                cmap=plt.cm.YlOrRd,
                norm=LogNorm(),
            )
            intensity_bar = fig.colorbar(mesh, ax=ax1, pad=0.01, fraction=0.0473)
            intensity_bar.ax.set_ylabel("normalized intensity (a.u.)")
        else:
            y_min, y_max = intensity_bar.ax.get_ylim()
            y_max = np.nanmax(intensities)
            mesh.set_array(intensities)
            mesh.set_clim(vmax=y_max)
            intensity_bar.ax.set_ylim(y_min, y_max)
        if quiver is None:
            quiver = ax2.quiver(
                QUIVER_GRID["sigma1"],
                QUIVER_GRID["sigma2"],
                αz,
                αx,
                abs_α,
                cmap=plt.cm.viridis_r,
                clim=(0, 1),
            )
            c_bar = fig.colorbar(quiver, ax=ax2, pad=0.01, fraction=0.0473)
            c_bar.ax.set_ylabel(R"$\left|\vec\alpha\right|$")
        else:
            quiver.set_UVC(αz, αx, abs_α)
        fig.canvas.draw_idle()

    def to_complex_kwargs(**kwargs):
        complex_valued_kwargs = {}
        for key, value in dict(kwargs).items():
            if key.endswith("real"):
                symbol_name = key[:-5]
                imag = kwargs[f"{symbol_name}_imag"]
                complex_valued_kwargs[symbol_name] = complex(value, imag)
            elif key.endswith("imag"):
                continue
            else:
                complex_valued_kwargs[key] = value
        return complex_valued_kwargs

    output = interactive_output(plot3, controls=SLIDERS)
    fig.tight_layout()
    if NO_TQDM:
        export_file = "interactive-plot.png"
        fig.savefig(export_file, dpi=200)
        src = f"""
        :::{{tip}}
        Run this notebook in Jupyter to modify parameters interactively!
        :::
        
        :::{{container}} full-width
        ![]({export_file})
        :::
        """
        src = dedent(src)
        display(Markdown(src))
    else:
        display(UI, output)


create_interactive_plot()