# PC and PV currents

```{autolink-concat}
```

In [None]:
from __future__ import annotations

import itertools
import re
from pathlib import Path

import cloudpickle
import numpy as np
import plotly.graph_objects as go
import sympy as sp
from ampform.sympy._cache import cache_to_disk
from ampform_dpd import AmplitudeModel  # pyright:ignore[reportPrivateUsage]
from ampform_dpd.io import cached
from IPython.display import Markdown
from tensorwaves.interface import ParametrizedFunction
from tqdm.auto import tqdm

from polarimetry.data import create_data_transformer, generate_phasespace_sample
from polarimetry.function import integrate_intensity
from polarimetry.io import mute_jax_warnings
from polarimetry.lhcb import load_model_builder, load_model_parameters
from polarimetry.lhcb.particle import load_particles

mute_jax_warnings()

MODEL_FILE = "../../data/model-definitions.yaml"
PARTICLES = load_particles("../../data/particle-definitions.yaml")

## Model formulation

In [None]:
def formulate_model(title: str) -> AmplitudeModel:
    builder = load_model_builder(MODEL_FILE, PARTICLES, title)
    imported_parameters = load_model_parameters(
        MODEL_FILE, builder.decay, title, PARTICLES
    )
    model = builder.formulate()
    model.parameter_defaults.update(imported_parameters)
    return model


DEFAULT_MODEL = formulate_model("Default amplitude model")
LS_MODEL = formulate_model("Alternative amplitude model obtained using LS couplings")

In [None]:
@cache_to_disk(
    dependencies=[
        "ampform",
        "ampform_dpd",
        "cloudpickle",
        "jax",
        "lc2pkpi-polarimetry",
        "sympy",
    ],
    dump_function=cloudpickle.dump,
)
def lambdify_ls_and_default_model() -> tuple[
    ParametrizedFunction, ParametrizedFunction
]:
    return lambdify(DEFAULT_MODEL), lambdify(LS_MODEL)


def lambdify(model: AmplitudeModel) -> sp.Expr:
    intensity_expr = cached.unfold(model)
    pars = model.parameter_defaults
    free_parameters = {s: v for s, v in pars.items() if "production" in str(s)}
    fixed_parameters = {s: v for s, v in pars.items() if s not in free_parameters}
    subs_intensity_expr = cached.xreplace(intensity_expr, fixed_parameters)
    return cached.lambdify(subs_intensity_expr, free_parameters)


DEFAULT_INTENSITY_FUNC, LS_INTENSITY_FUNC = lambdify_ls_and_default_model()

## Compute decay rates

In [None]:
TRANSFORMER = create_data_transformer(DEFAULT_MODEL)
PHSP = generate_phasespace_sample(DEFAULT_MODEL.decay, n_events=100_000, seed=0)
PHSP = TRANSFORMER(PHSP)

In [None]:
def compute_decay_rate_matrix(
    intensity_func: ParametrizedFunction,
) -> tuple[list[str], np.ndarray]:
    I_tot = integrate_intensity(intensity_func(PHSP))
    couplings = sorted(p for p in intensity_func.parameters if "product" in p)
    n_couplings = len(couplings)
    matrix = np.zeros((n_couplings, n_couplings))
    original_parameters = dict(intensity_func.parameters)
    cache: dict[tuple[int, int]] = {}
    progress_bar = tqdm(
        desc="Computing decay rates",
        total=n_couplings * (n_couplings + 1) // 2,
    )
    items = list(enumerate(couplings))
    for (i, ci), (j, cj) in itertools.product(items, items):
        pi = _simplify_name(ci)
        pj = _simplify_name(cj)
        progress_bar.set_postfix_str(Rf"{pi} x {pj}")
        transpose_value = cache.get((j, i))
        if transpose_value is None:
            new_parameters = dict.fromkeys(couplings, 0)
            new_parameters[ci] = original_parameters[ci]
            new_parameters[cj] = original_parameters[cj]
            intensity_func.update_parameters(new_parameters)
            LS_I_sub = integrate_intensity(intensity_func(PHSP))
            intensity_func.update_parameters(original_parameters)
            cache[i, j] = LS_I_sub / I_tot
            matrix[i, j] = cache[i, j]
            progress_bar.update()
        else:
            cache[i, j] = transpose_value
            matrix[i, j] = transpose_value
    progress_bar.close()
    return couplings, matrix


def _simplify_name(latex: str) -> str:
    name = "".join(re.match(r"^.*\[\\?([DKL])[a-z]*(.+)\]$", latex).groups())
    return name.replace("D", "Δ").replace("L", "Λ").replace("1/2", "½")

In [None]:
default_couplings, default_rates = compute_decay_rate_matrix(DEFAULT_INTENSITY_FUNC)

In [None]:
ls_couplings, ls_rates = compute_decay_rate_matrix(LS_INTENSITY_FUNC)

## Visualize decay rates

In [None]:
def visualize_rates(couplings: list[str], rates: np.ndarray, model_name: str) -> Path:
    categories = np.full((len(couplings), len(couplings)), "")
    for (i, ci), (j, cj) in itertools.product(
        enumerate(couplings),
        enumerate(couplings),
    ):
        subsystem_i = _simplify_name(ci)[0]
        subsystem_j = _simplify_name(cj)[0]
        if subsystem_i == subsystem_j:
            categories[i, j] = subsystem_i.replace("Δ", "D").replace("Λ", "L")
    labels = [_simplify_name(c) for c in couplings]
    category_colors = {
        "K": "Reds",
        "L": "Blues",
        "D": "Greens",
        "": "Greys",
    }
    fig = go.Figure()
    for cat, colorscale in category_colors.items():
        mask = categories == cat
        heatmap_values = np.where(mask, rates, np.nan)
        fig.add_trace(
            go.Heatmap(
                x=labels,
                y=labels,
                z=heatmap_values,
                colorscale=colorscale,
                customdata=100 * rates,
                showscale=False,
            )
        )
    fig.update_layout(
        autosize=False,
        height=650,
        paper_bgcolor="rgba(0, 0, 0, 0)",
        plot_bgcolor="rgba(0, 0, 0, 0)",
        title=dict(
            text=f"Decay rates of {model_name}",
            xanchor="center",
            x=0.5,
            y=0.89,
        ),
    )
    fig.update_scenes(aspectmode="data")
    fig.update_traces(
        hovertemplate="%{x}<br>%{y}<br>Decay rate: <b>%{customdata:.3g}%</b><extra></extra>"
    )
    file_id = model_name.replace(" ", "-").lower()
    file_path = f"../_static/images/pc-pv-rates-{file_id}.svg"
    fig.write_image(file_path)
    fig.show()
    return Path(file_path)

In [None]:
def_file = visualize_rates(default_couplings, default_rates, "default model")
ls_file = visualize_rates(ls_couplings, ls_rates, "LS model")

In [None]:
Markdown(f"""
SVG files can be downloaded here:
- [`{def_file.name}`]({def_file})
- [`{ls_file.name}`]({ls_file})
""")