# PC and PV currents

```{autolink-concat}
```

In [None]:
from __future__ import annotations

import itertools
import re
from collections.abc import Iterable
from functools import cache
from pathlib import Path

import cloudpickle
import jax.numpy as jnp
import numpy as np
import plotly.graph_objects as go
import sympy as sp
import unicodeitplus
from ampform.io import aslatex
from ampform.sympy._cache import cache_to_disk
from ampform_dpd import AmplitudeModel  # pyright:ignore[reportPrivateUsage]
from ampform_dpd.decay import IsobarNode, ThreeBodyDecayChain, to_particle
from ampform_dpd.io import cached
from IPython.display import HTML, Markdown, Math, display
from tensorwaves.interface import ParametrizedFunction
from tqdm.auto import tqdm

from polarimetry.data import create_data_transformer, generate_phasespace_sample
from polarimetry.function import integrate_intensity
from polarimetry.io import mute_jax_warnings
from polarimetry.lhcb import load_model_builder, load_model_parameters
from polarimetry.lhcb.particle import load_particles

src = '<script type="text/javascript" async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-MML-AM_SVG"></script>'
display(HTML(src))
mute_jax_warnings()

MODEL_FILE = "../../data/model-definitions.yaml"
PARTICLES = load_particles("../../data/particle-definitions.yaml")

## Model formulation

In [None]:
def formulate_model(title: str) -> AmplitudeModel:
    builder = load_model_builder(MODEL_FILE, PARTICLES, title)
    imported_parameters = load_model_parameters(
        MODEL_FILE, builder.decay, title, PARTICLES
    )
    model = builder.formulate()
    model.parameter_defaults.update(imported_parameters)
    return model


CANO_MODEL = formulate_model("Alternative amplitude model obtained using LS couplings")
HELI_MODEL = formulate_model("Default amplitude model")

In [None]:
@cache_to_disk(
    dependencies=[
        "ampform",
        "ampform_dpd",
        "cloudpickle",
        "jax",
        "lc2pkpi-polarimetry",
        "sympy",
    ],
    dump_function=cloudpickle.dump,
)
def lambdify_ls_and_default_model() -> tuple[
    ParametrizedFunction, ParametrizedFunction
]:
    return lambdify(HELI_MODEL), lambdify(CANO_MODEL)


def lambdify(model: AmplitudeModel) -> sp.Expr:
    intensity_expr = cached.unfold(model)
    pars = model.parameter_defaults
    free_parameters = {s: v for s, v in pars.items() if "production" in str(s)}
    fixed_parameters = {s: v for s, v in pars.items() if s not in free_parameters}
    subs_intensity_expr = cached.xreplace(intensity_expr, fixed_parameters)
    return cached.lambdify(subs_intensity_expr, free_parameters)


HELI_FUNC, CANO_FUNC = lambdify_ls_and_default_model()

## Sort couplings by PC/PV

In [None]:
def is_pc(node: IsobarNode) -> int:
    """Check if the production node in a decay is parity-conserving."""
    parent = to_particle(node.parent)
    child1 = to_particle(node.child1)
    child2 = to_particle(node.child2)
    L = node.interaction.L
    return parent.parity == child1.parity * child2.parity * (-1) ** L


def get_chain_identifier(chain: ThreeBodyDecayChain) -> str:
    resonance_id = chain.resonance.latex
    return f"{resonance_id}, L={chain.incoming_ls.L}"


PC_PV_MAPPING: dict[str, bool] = {
    get_chain_identifier(chain): Rf"\texttt{{{is_pc(chain.production_node)}}}"
    for chain in CANO_MODEL.decay.chains
}
Math(aslatex(PC_PV_MAPPING))

In [None]:
def sort_couplings(parameters: Iterable[str], helicity: bool) -> list[str]:
    couplings = (p for p in parameters if "production" in p)
    if helicity:
        return sorted(couplings, key=factorize_latex)
    return sorted(
        couplings,
        key=lambda x: (
            PC_PV_MAPPING[get_partial_wave_latex(x, helicity)],
            factorize_latex(x),
        ),
    )


@cache
def factorize_latex(coupling: str) -> tuple[str, int, str, str]:
    resonance, mass, *indices = re.match(
        r"^.*\[(\\?[DKL][a-z]*)\((\d+)\), (.+), (.+)\]$", coupling
    ).groups()
    return resonance, int(mass), *indices


@cache
def get_partial_wave_latex(coupling: str, helicity: bool) -> str:
    R, mass, *indices = factorize_latex(coupling)
    latex_indices = map(to_frac, indices)
    if helicity:
        λ_resonance, λ_recoil = latex_indices
        return f"{R}({mass}), {λ_resonance}, {λ_recoil}"
    L, _ = latex_indices
    return f"{R}({mass}), L={L}"


@cache
def to_frac(text: str) -> str:
    if "/" in text:
        denominator, nominator = text.split("/")
        denominator = int(denominator)
        nominator = int(nominator)
        sign = "-" if denominator < 0 else "+"
        return Rf"{sign}\frac{{{abs(denominator)}}}{{{nominator}}}"
    return text

In [None]:
CANO_COUPLINGS = sort_couplings(CANO_FUNC.parameters, helicity=False)
HELI_COUPLINGS = sort_couplings(HELI_FUNC.parameters, helicity=True)

In [None]:
sp.Matrix([
    [sp.Symbol(get_partial_wave_latex(s, helicity=False)) for s in CANO_COUPLINGS],
    [sp.Symbol(get_partial_wave_latex(s, helicity=True)) for s in HELI_COUPLINGS],
]).T

## Compute decay rates

In [None]:
TRANSFORMER = create_data_transformer(HELI_MODEL)
PHSP = generate_phasespace_sample(HELI_MODEL.decay, n_events=100_000, seed=0)
PHSP = TRANSFORMER(PHSP)

In [None]:
def compute_decay_rate_matrix(
    intensity_func: ParametrizedFunction, couplings: list[str]
) -> tuple[list[str], np.ndarray]:
    I_tot = integrate_intensity(intensity_func(PHSP))
    n_couplings = len(couplings)
    matrix = np.zeros((n_couplings, n_couplings))
    original_parameters = dict(intensity_func.parameters)
    cache: dict[tuple[int, int]] = {}
    progress_bar = tqdm(
        desc="Computing decay rates",
        total=n_couplings * (n_couplings + 1) // 2,
    )
    is_helicity = is_helicity_func(intensity_func)
    items = list(enumerate(couplings))
    for (i, ci), (j, cj) in itertools.product(items, items):
        pi = get_partial_wave_name(ci, is_helicity)
        pj = get_partial_wave_name(cj, is_helicity)
        progress_bar.set_postfix_str(Rf"{pi} x {pj}")
        transpose_value = cache.get((j, i))
        if transpose_value is None:
            new_parameters = dict.fromkeys(couplings, 0)
            new_parameters[ci] = original_parameters[ci]
            new_parameters[cj] = original_parameters[cj]
            intensity_func.update_parameters(new_parameters)
            I_sub = integrate_intensity(intensity_func(PHSP))
            intensity_func.update_parameters(original_parameters)
            cache[i, j] = I_sub / I_tot
            matrix[i, j] = cache[i, j]
            progress_bar.update()
        else:
            cache[i, j] = transpose_value
            matrix[i, j] = transpose_value
    progress_bar.close()
    return matrix


def is_helicity_func(intensity_func: ParametrizedFunction) -> bool:
    couplings = (p for p in intensity_func.parameters if "production" in p)
    return not any("LS" in p for p in couplings)


@cache
def get_partial_wave_name(coupling: str, helicity: bool) -> str:
    latex = get_partial_wave_latex(coupling, helicity)
    latex = re.sub(r"\\frac{(\d+)}{(\d+)}", r"\1/\2", latex)
    latex = latex.replace("1/2", "½")
    return unicodeitplus.replace(latex)

In [None]:
CANO_RATES = compute_decay_rate_matrix(CANO_FUNC, CANO_COUPLINGS)
assert CANO_RATES.shape == (len(CANO_COUPLINGS), len(CANO_COUPLINGS))

In [None]:
HELI_RATES = compute_decay_rate_matrix(HELI_FUNC, HELI_COUPLINGS)
assert HELI_RATES.shape == (len(HELI_COUPLINGS), len(HELI_COUPLINGS))

In [None]:
def compute_interference(rates: jnp.ndarray) -> jnp.ndarray:
    m, n = rates.shape
    assert m == n
    d = rates.diagonal()
    D = d * np.identity(m)
    X = d[None] + d[None].T
    return (rates - X + D) / 2 + D

In [None]:
CANO_INTERFERENCE = compute_interference(CANO_RATES)
HELI_INTERFERENCE = compute_interference(HELI_RATES)

## Visualize decay rates

In [None]:
def visualize_rates(
    labels: list[str], rates: np.ndarray, model_name: str, helicity: bool
) -> Path:
    rates_percentage = 100 * rates
    abs_max = float(jnp.abs(rates_percentage).max())
    if not helicity:
        rates_percentage = jnp.where(
            jnp.abs(rates_percentage) > 1e-7, rates_percentage, jnp.nan
        )
    fig = go.Figure()
    fig.add_trace(
        go.Heatmap(
            x=labels,
            y=labels,
            z=rates_percentage,
            colorscale="RdBu_r",
            customdata=rates_percentage,
            hovertemplate="%{x}<br>%{y}<br>Decay rate: <b>%{customdata:.3g}%</b><extra></extra>",
            zmin=-abs_max,
            zmax=+abs_max,
        )
    )
    if helicity:
        indicate_subsystem(fig, 0, 7, label="K*")
        indicate_subsystem(fig, 8, 13, label="Δ*")
        indicate_subsystem(fig, 14, 25, label="Λ*")
    else:
        indicate_subsystem(fig, 0, 3, label="K*")
        indicate_subsystem(fig, 4, 6, label="Δ*")
        indicate_subsystem(fig, 7, 12, label="Λ*")
        indicate_subsystem(fig, 13, 15, label="K*", legend=False)
        indicate_subsystem(fig, 16, 18, label="Δ*", legend=False)
        indicate_subsystem(fig, 19, 24, label="Λ*", legend=False)
    fig.update_layout(
        autosize=False,
        height=750,
        legend=dict(font_size=16, yanchor="top", y=1, xanchor="left", x=0.85),
        paper_bgcolor="rgba(0, 0, 0, 0)",
        plot_bgcolor="rgba(0, 0, 0, 0)",
        title=dict(
            text=f"Decay rates of {model_name}",
            xanchor="center",
            x=0.5,
            y=0.89,
        ),
    )
    fig.update_scenes(aspectmode="data")
    fig.update_yaxes(autorange="reversed")
    return fig


def indicate_subsystem(
    fig: go.Figure, idx1: int, idx2: int, label: str, legend: bool = True
) -> None:
    colors = {
        "K*": ("red", "rgba(255,0,0,0.1)"),
        "Λ*": ("green", "rgba(0,255,0,0.1)"),
        "Δ*": ("blue", "rgba(0,0,255,0.1)"),
    }
    linecolor, fillcolor = colors[label]
    left = idx1 - 0.5
    right = idx2 + 0.5
    return fig.add_shape(
        **dict(x0=left, x1=right, y0=left, y1=right),
        fillcolor=fillcolor,
        line=dict(color=linecolor, width=1),
        name=label,
        opacity=0.3,
        showlegend=legend,
        type="rect",
    )

In [None]:
CANO_LABELS = [get_partial_wave_name(p, helicity=False) for p in CANO_COUPLINGS]
fig = visualize_rates(CANO_LABELS, CANO_INTERFERENCE, "LS model", helicity=False)
CANONICAL_PATH = Path("../_static/images/decay-rates-default-helicity.svg")
fig.write_image(CANONICAL_PATH)
fig

In [None]:
HELI_LABELS = [get_partial_wave_name(p, helicity=True) for p in HELI_COUPLINGS]
fig = visualize_rates(HELI_LABELS, HELI_INTERFERENCE, "default model", helicity=True)
HELICITY_PATH = Path("../_static/images/decay-rates-default-helicity.svg")
fig.write_image(HELICITY_PATH)
fig

In [None]:
Markdown(f"""
SVG files can be downloaded here:
- [`{HELICITY_PATH.name}`]({HELICITY_PATH})
- [`{CANONICAL_PATH.name}`]({CANONICAL_PATH})
""")