# PC and PV currents

```{autolink-concat}
```

In [None]:
from __future__ import annotations

import itertools
import re
from collections.abc import Iterable
from functools import cache
from pathlib import Path

import cloudpickle
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import sympy as sp
import unicodeitplus
from ampform.io import aslatex
from ampform.sympy._cache import cache_to_disk
from ampform_dpd import AmplitudeModel  # pyright:ignore[reportPrivateUsage]
from ampform_dpd.decay import IsobarNode, ThreeBodyDecayChain, to_particle
from ampform_dpd.io import cached
from IPython.display import HTML, Markdown, Math, display
from matplotlib.patches import Rectangle
from matplotlib.ticker import FuncFormatter
from matplotlib_inline.backend_inline import set_matplotlib_formats
from plotly.subplots import make_subplots
from tensorwaves.interface import ParametrizedFunction
from tqdm.auto import tqdm

from polarimetry.data import create_data_transformer, generate_phasespace_sample
from polarimetry.function import integrate_intensity
from polarimetry.io import mute_jax_warnings
from polarimetry.lhcb import load_model_builder, load_model_parameters
from polarimetry.lhcb.particle import load_particles
from polarimetry.plot import reduce_svg_size, use_mpl_latex_fonts

src = '<script type="text/javascript" async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-MML-AM_SVG"></script>'
display(HTML(src))
mute_jax_warnings()
set_matplotlib_formats("svg")
use_mpl_latex_fonts()

MODEL_FILE = "../../data/model-definitions.yaml"
PARTICLES = load_particles("../../data/particle-definitions.yaml")

## Model formulation

In [None]:
def formulate_model(title: str) -> AmplitudeModel:
    builder = load_model_builder(MODEL_FILE, PARTICLES, title)
    imported_parameters = load_model_parameters(
        MODEL_FILE, builder.decay, title, PARTICLES
    )
    model = builder.formulate(reference_subsystem=1)
    model.parameter_defaults.update(imported_parameters)
    return model


CANO_MODEL = formulate_model("Alternative amplitude model obtained using LS couplings")
HELI_MODEL = formulate_model("Default amplitude model")

In [None]:
@cache_to_disk(
    dependencies=[
        "ampform",
        "ampform_dpd",
        "cloudpickle",
        "jax",
        "lc2pkpi-polarimetry",
        "sympy",
    ],
    dump_function=cloudpickle.dump,
)
def lambdify_ls_and_default_model() -> tuple[
    ParametrizedFunction, ParametrizedFunction
]:
    return lambdify(HELI_MODEL), lambdify(CANO_MODEL)


def lambdify(model: AmplitudeModel) -> sp.Expr:
    intensity_expr = cached.unfold(model)
    pars = model.parameter_defaults
    free_parameters = {s: v for s, v in pars.items() if "production" in str(s)}
    fixed_parameters = {s: v for s, v in pars.items() if s not in free_parameters}
    subs_intensity_expr = cached.xreplace(intensity_expr, fixed_parameters)
    return cached.lambdify(subs_intensity_expr, free_parameters)


HELI_FUNC, CANO_FUNC = lambdify_ls_and_default_model()

## Sort couplings by PC/PV

In [None]:
def is_pc(node: IsobarNode) -> int:
    """Check if the production node in a decay is parity-conserving."""
    parent = to_particle(node.parent)
    child1 = to_particle(node.child1)
    child2 = to_particle(node.child2)
    L = node.interaction.L
    return parent.parity == child1.parity * child2.parity * (-1) ** L


def get_chain_identifier(chain: ThreeBodyDecayChain) -> str:
    resonance_id = chain.resonance.latex
    S = to_frac(str(chain.incoming_ls.S), frac=False, use_sign=False)
    return Rf"{resonance_id} \, {{}}^{{L={chain.incoming_ls.L}}}_{{S={S}}}"


@cache
def to_frac(text: str, frac: bool, use_sign: bool) -> str:
    if "/" in text:
        denominator, nominator = text.split("/")
        denominator = int(denominator)
        nominator = int(nominator)
        sign = ""
        if use_sign:
            sign = "-" if denominator < 0 else "+"
            denominator = f"{abs(denominator)}"
        if frac:
            return Rf"{sign}\frac{{{denominator}}}{{{nominator}}}"
        return Rf"{sign}{denominator}/{nominator}"
    return text


PC_PV_MAPPING: dict[str, bool] = {
    get_chain_identifier(chain): Rf"\texttt{{{is_pc(chain.production_node)}}}"
    for chain in CANO_MODEL.decay.chains
}
Math(aslatex(PC_PV_MAPPING))

In [None]:
def sort_couplings(parameters: Iterable[str], helicity: bool) -> list[str]:
    couplings = (p for p in parameters if "production" in p)
    if helicity:
        return sorted(couplings, key=factorize_latex)
    return sorted(
        couplings,
        key=lambda x: (
            PC_PV_MAPPING[get_partial_wave_latex(x, helicity)],
            factorize_latex(x),
        ),
    )


@cache
def factorize_latex(coupling: str) -> tuple[str, int, str, str]:
    resonance, mass, *indices = re.match(
        r"^.*\[(\\?[DKL][a-z]*)\((\d+)\), ([^,]+), ([^,]+)\]$", coupling
    ).groups()
    return resonance, int(mass), *indices


@cache
def get_partial_wave_latex(coupling: str, helicity: bool) -> str:
    R, mass, *indices = factorize_latex(coupling)
    if helicity:
        latex_indices = (to_frac(x, frac=True, use_sign=True) for x in indices)
        λ_resonance, λ_recoil = latex_indices
        return f"{R}({mass}), {λ_resonance}, {λ_recoil}"
    L, S = (to_frac(x, frac=False, use_sign=False) for x in indices)
    return Rf"{R}({mass}) \, {{}}^{{L={L}}}_{{S={S}}}"

In [None]:
CANO_COUPLINGS = sort_couplings(CANO_FUNC.parameters, helicity=False)
HELI_COUPLINGS = sort_couplings(HELI_FUNC.parameters, helicity=True)

In [None]:
sp.Matrix([
    [sp.Symbol(get_partial_wave_latex(s, helicity=False)) for s in CANO_COUPLINGS],
    [sp.Symbol(get_partial_wave_latex(s, helicity=True)) for s in HELI_COUPLINGS],
]).T

## Compute decay rates

In [None]:
TRANSFORMER = create_data_transformer(HELI_MODEL)
PHSP = generate_phasespace_sample(HELI_MODEL.decay, n_events=100_000, seed=0)
PHSP = TRANSFORMER(PHSP)

In [None]:
def compute_decay_rate_matrix(
    intensity_func: ParametrizedFunction, couplings: list[str]
) -> tuple[list[str], np.ndarray]:
    I_tot = integrate_intensity(intensity_func(PHSP))
    n_couplings = len(couplings)
    matrix = np.zeros((n_couplings, n_couplings))
    original_parameters = dict(intensity_func.parameters)
    cache: dict[tuple[int, int]] = {}
    progress_bar = tqdm(
        desc="Computing decay rates",
        total=n_couplings * (n_couplings + 1) // 2,
    )
    is_helicity = is_helicity_func(intensity_func)
    items = list(enumerate(couplings))
    for (i, ci), (j, cj) in itertools.product(items, items):
        pi = get_partial_wave_name(ci, is_helicity)
        pj = get_partial_wave_name(cj, is_helicity)
        progress_bar.set_postfix_str(Rf"{pi} x {pj}")
        transpose_value = cache.get((j, i))
        if transpose_value is None:
            new_parameters = dict.fromkeys(couplings, 0)
            new_parameters[ci] = original_parameters[ci]
            new_parameters[cj] = original_parameters[cj]
            intensity_func.update_parameters(new_parameters)
            I_sub = integrate_intensity(intensity_func(PHSP))
            intensity_func.update_parameters(original_parameters)
            cache[i, j] = I_sub / I_tot
            matrix[i, j] = cache[i, j]
            progress_bar.update()
        else:
            cache[i, j] = transpose_value
            matrix[i, j] = transpose_value
    progress_bar.close()
    return matrix


def is_helicity_func(intensity_func: ParametrizedFunction) -> bool:
    couplings = (p for p in intensity_func.parameters if "production" in p)
    return not any("LS" in p for p in couplings)


@cache
def get_partial_wave_name(coupling: str, helicity: bool) -> str:
    latex = get_partial_wave_latex(coupling, helicity)
    latex = re.sub(r" \\, {}\^{L=([^}]+)}_{S=([^}]+)}", r",\1,\2", latex)
    latex = re.sub(r"\\frac{(\d+)}{(\d+)}", r"\1/\2", latex)
    latex = latex.replace("1/2", "½")
    latex = latex.replace("3/2", "3⁄2")
    return unicodeitplus.replace(latex)

In [None]:
CANO_RATES = compute_decay_rate_matrix(CANO_FUNC, CANO_COUPLINGS)
assert CANO_RATES.shape == (len(CANO_COUPLINGS), len(CANO_COUPLINGS))

In [None]:
HELI_RATES = compute_decay_rate_matrix(HELI_FUNC, HELI_COUPLINGS)
assert HELI_RATES.shape == (len(HELI_COUPLINGS), len(HELI_COUPLINGS))

In [None]:
def compute_interference(rates: jnp.ndarray) -> jnp.ndarray:
    m, n = rates.shape
    assert m == n
    d = rates.diagonal()
    D = d * np.identity(m)
    X = d[None] + d[None].T
    return (rates - X + D) / 2 + D

In [None]:
CANO_INTERFERENCE = compute_interference(CANO_RATES)
HELI_INTERFERENCE = compute_interference(HELI_RATES)

## Visualize decay rates

In [None]:
def visualize_rates(
    labels: list[str], rates: np.ndarray, model_name: str, helicity: bool
) -> Path:
    abs_max = float(jnp.abs(rates).max())
    abs_max = np.floor(abs_max * 5) / 5
    if not helicity:
        rates = jnp.where(jnp.abs(rates) > 1e-7, rates, jnp.nan)
    fig = go.Figure()
    fig.add_trace(
        go.Heatmap(
            x=labels,
            y=labels,
            z=rates,
            colorscale="RdBu_r",
            customdata=100 * rates,
            hovertemplate="%{x}<br>%{y}<br>Decay rate: <b>%{customdata:+.3g}%</b><extra></extra>",
            zmin=-abs_max,
            zmax=+abs_max,
            colorbar=dict(
                title="Decay rate",
                title_side="right",
                tickformat="+.0%",
            ),
        )
    )
    if helicity:
        indicate_subsystem(fig, 0, 7, label="𝐾<sup>*</sup>")
        indicate_subsystem(fig, 8, 13, label="𝛥<sup>*</sup>")
        indicate_subsystem(fig, 14, 25, label="𝛬</sup>*</sup>")
    else:
        indicate_subsystem(fig, 0, 3, label="𝐾<sup>*</sup>")
        indicate_subsystem(fig, 4, 6, label="𝛥<sup>*</sup>")
        indicate_subsystem(fig, 7, 12, label="𝛬</sup>*</sup>")
        indicate_subsystem(fig, 13, 16, label="𝐾<sup>*</sup>", legend=False)
        indicate_subsystem(fig, 17, 19, label="𝛥<sup>*</sup>", legend=False)
        indicate_subsystem(fig, 20, 25, label="𝛬</sup>*</sup>", legend=False)
        fig.add_annotation(
            x=13.5,
            y=6,
            font_size=15,
            showarrow=False,
            text="parity-violating",
            textangle=+90,
        )
        fig.add_annotation(
            x=11.5,
            y=19,
            font_size=15,
            text="parity-conserving",
            showarrow=False,
            textangle=-90,
        )
    fig.update_layout(
        autosize=False,
        height=750,
        legend=dict(font_size=16, yanchor="top", y=1, xanchor="left", x=0.85),
        paper_bgcolor="rgba(0, 0, 0, 0)",
        title=dict(
            text=f"Decay rates of {model_name}",
            xanchor="center",
            x=0.5,
            y=0.89,
        ),
    )
    fig.update_scenes(aspectmode="data")
    fig.update_yaxes(autorange="reversed")
    return fig


def indicate_subsystem(
    fig: go.Figure,
    idx1: int,
    idx2: int,
    label: str,
    legend: bool = True,
    col: int | None = None,
) -> None:
    colors = {
        "𝐾<sup>*</sup>": ("red", "rgba(255,0,0,0.1)"),
        "𝛬</sup>*</sup>": ("green", "rgba(0,255,0,0.1)"),
        "𝛥<sup>*</sup>": ("blue", "rgba(0,0,255,0.1)"),
    }
    linecolor, fillcolor = colors[label]
    left = idx1 - 0.5
    right = idx2 + 0.5
    kwargs = dict(
        **dict(x0=left, x1=right, y0=left, y1=right),
        fillcolor=fillcolor,
        line=dict(color=linecolor, width=1),
        name=label,
        opacity=0.3,
        showlegend=legend,
        type="rect",
    )
    if col is not None:
        kwargs["col"] = col
        kwargs["row"] = 1
    return fig.add_shape(**kwargs)

In [None]:
CANO_LABELS = [get_partial_wave_name(p, helicity=False) for p in CANO_COUPLINGS]
fig = visualize_rates(
    CANO_LABELS, CANO_INTERFERENCE, "LS model (canonical basis)", helicity=False
)
CANONICAL_PATH = Path("../_static/images/decay-rates-default-canonical.svg")
fig.write_image(CANONICAL_PATH)
fig

In [None]:
HELI_LABELS = [get_partial_wave_name(p, helicity=True) for p in HELI_COUPLINGS]
fig = visualize_rates(
    HELI_LABELS, HELI_INTERFERENCE, "default model (helicity basis)", helicity=True
)
HELICITY_PATH = Path("../_static/images/decay-rates-default-helicity.svg")
fig.write_image(HELICITY_PATH)
fig

In [None]:
n_pv = 13
rates = CANO_INTERFERENCE
pv_rates = rates[:n_pv, :n_pv]
pc_rates = rates[n_pv:, n_pv:]
abs_max = float(jnp.abs(rates).max())
abs_max = np.floor(abs_max * 5) / 5

fig = make_subplots(
    cols=2,
    horizontal_spacing=0.18,
    rows=1,
    subplot_titles=(
        "parity-violating",
        "parity-conserving",
    ),
)
heatmap_kwargs = dict(
    colorscale="RdBu_r",
    hovertemplate="%{x}<br>%{y}<br>Decay rate: <b>%{customdata:+.3g}%</b><extra></extra>",
    zmin=-abs_max,
    zmax=+abs_max,
)
fig.add_trace(
    go.Heatmap(
        x=CANO_LABELS[:n_pv],
        y=CANO_LABELS[:n_pv],
        z=pv_rates,
        customdata=100 * pv_rates,
        showscale=False,
        **heatmap_kwargs,
    ),
    col=1,
    row=1,
)
fig.add_trace(
    go.Heatmap(
        x=CANO_LABELS[n_pv:],
        y=CANO_LABELS[n_pv:],
        z=pc_rates,
        customdata=100 * pc_rates,
        colorbar=dict(
            title="Decay rate",
            title_side="right",
            tickformat="+.0%",
            thickness=12,
            xpad=1,
        ),
        **heatmap_kwargs,
    ),
    col=2,
    row=1,
)
indicate_subsystem(fig, 0, 3, label="𝐾<sup>*</sup>", col=1)
indicate_subsystem(fig, 4, 6, label="𝛥<sup>*</sup>", col=1)
indicate_subsystem(fig, 7, 12, label="𝛬</sup>*</sup>", col=1)
indicate_subsystem(fig, 0, 3, label="𝐾<sup>*</sup>", col=2, legend=False)
indicate_subsystem(fig, 4, 6, label="𝛥<sup>*</sup>", col=2, legend=False)
indicate_subsystem(fig, 7, 12, label="𝛬</sup>*</sup>", col=2, legend=False)
fig.update_layout(
    autosize=False,
    height=420,
    legend=dict(font_size=16, yanchor="top", y=1, xanchor="left", x=0.85),
    paper_bgcolor="rgba(0, 0, 0, 0)",
    title=dict(
        text="Decay rates of default model",
        xanchor="center",
        x=0.5,
        y=0.87,
    ),
)
fig.update_yaxes(autorange="reversed")
CANO_PATH_SPLIT = Path("../_static/images/decay-rates-default-canonical-split.svg")
fig.write_image(CANO_PATH_SPLIT)
fig.show()

In [None]:
def indicate_subsystem_mpl(ax, start, end, label, color):
    x0 = start - 0.5
    y0 = start - 0.5
    w = end - start + 1
    h = end - start + 1
    rect = Rectangle(
        (x0, y0),
        w,
        h,
        edgecolor=color,
        linewidth=1,
        fill=False,
        label=label,
    )
    ax.add_patch(rect)
    return rect


plt.rc("font", size=13)
fig, axes = plt.subplots(
    constrained_layout=True,
    dpi=300,
    figsize=(10.5, 4.6),
    ncols=2,
)
fig.patch.set_facecolor("none")
for ax in axes:
    for spine in ax.spines.values():
        spine.set_visible(False)
    ax.patch.set_facecolor("none")
    ax.tick_params(axis="both", which="both", length=-4)
    ax.xaxis.set_tick_params(pad=7)
    ax.yaxis.set_tick_params(pad=7)
ax1, ax2 = axes

abs_max = 0.15
vmin, vmax = -abs_max, +abs_max
cmap = "RdBu_r"

CANO_LABELS_LATEX = [
    f"${get_partial_wave_latex(p, helicity=False)}$".replace(
        R"\Delta", R"\mathit{\Delta}"
    ).replace(R"\Lambda", R"\mathit{\Lambda}")
    for p in CANO_COUPLINGS
]
im1 = ax1.imshow(pv_rates, cmap=cmap, vmin=vmin, vmax=vmax)
im2 = ax2.imshow(pc_rates, cmap=cmap, vmin=vmin, vmax=vmax)
ax1.set_xticks(np.arange(pv_rates.shape[1]))
ax1.set_yticks(np.arange(pv_rates.shape[0]))
ax1.set_xticklabels(CANO_LABELS_LATEX[:n_pv], rotation=90, ha="center")
ax1.set_yticklabels(CANO_LABELS_LATEX[:n_pv])
ax2.set_xticks(np.arange(pc_rates.shape[1]))
ax2.set_yticks(np.arange(pc_rates.shape[0]))
ax2.set_xticklabels(CANO_LABELS_LATEX[n_pv:], rotation=90, ha="center")
ax2.set_yticklabels(CANO_LABELS_LATEX[n_pv:])

indicate_subsystem_mpl(ax1, 0, 3, R"$K^*$", "C0")
indicate_subsystem_mpl(ax1, 4, 6, R"$\mathit{\Delta}^*$", "C2")
indicate_subsystem_mpl(ax1, 7, 12, R"$\mathit{\Lambda}^*$", "C1")
handles = [
    indicate_subsystem_mpl(ax2, 0, 3, R"$K^*$", "C0"),
    indicate_subsystem_mpl(ax2, 4, 6, R"$\mathit{\Delta}^*$", "C2"),
    indicate_subsystem_mpl(ax2, 7, 12, R"$\mathit{\Lambda}^*$", "C1"),
]
ax2.legend(
    bbox_to_anchor=(0.7, 1.01),
    frameon=False,
    handles=handles,
    handlelength=1.1,
    handletextpad=0.4,
    loc="upper left",
)

cbar = fig.colorbar(im2, ax=axes, fraction=0.028, pad=0.01)
cbar.formatter = FuncFormatter(lambda x, _: Rf"${100 * x:+.0f}\%$" if x else R"$0\%$")
cbar.set_label("Decay rate", rotation=90, labelpad=-2)

CANO_PATH_MPL = "../_static/images/decay-rates-default-canonical-split-matplotlib.svg"
fig.savefig(CANO_PATH_MPL, bbox_inches="tight")
plt.show()

In [None]:
fig1 = go.Figure(
    data=go.Heatmap(
        x=CANO_LABELS[:n_pv],
        y=CANO_LABELS[:n_pv],
        z=pv_rates,
        customdata=100 * pv_rates,
        showscale=False,
        **heatmap_kwargs,
    )
)
indicate_subsystem(fig1, 0, 3, label="𝐾<sup>*</sup>")
indicate_subsystem(fig1, 4, 6, label="𝛥<sup>*</sup>")
indicate_subsystem(fig1, 7, 12, label="𝛬</sup>*</sup>")

fig2 = go.Figure(
    data=go.Heatmap(
        x=CANO_LABELS[n_pv:],
        y=CANO_LABELS[n_pv:],
        z=pc_rates,
        customdata=100 * pc_rates,
        colorbar=dict(
            tickformat="+.0%",
            thickness=12,
            xpad=2,
        ),
        **heatmap_kwargs,
    )
)
indicate_subsystem(fig2, 0, 2, label="𝐾<sup>*</sup>", legend=False)
indicate_subsystem(fig2, 3, 5, label="𝛥<sup>*</sup>", legend=False)
indicate_subsystem(fig2, 6, 11, label="𝛬</sup>*</sup>", legend=False)

fig1.update_layout(height=335, width=400)
fig2.update_layout(height=335, width=400)
for fig in [fig1, fig2]:
    fig.update_layout(
        autosize=False,
        legend=dict(font_size=16, yanchor="top", y=1, xanchor="left", x=0.7),
        margin=dict(t=10),
        paper_bgcolor="rgba(0, 0, 0, 0)",
        plot_bgcolor="rgba(0, 0, 0, 0)",
        yaxis_scaleanchor="x",
    )
    fig.update_yaxes(autorange="reversed")

CANO_PATH_PV = Path("../_static/images/decay-rates-default-canonical-pv.svg")
CANO_PATH_PC = Path("../_static/images/decay-rates-default-canonical-pc.svg")
fig1.write_image(CANO_PATH_PV)
fig2.write_image(CANO_PATH_PC)

In [None]:
for path in [
    CANONICAL_PATH,
    CANO_PATH_SPLIT,
    CANO_PATH_PC,
    CANO_PATH_PV,
    HELICITY_PATH,
]:
    reduce_svg_size(path)

Markdown(f"""
SVG files can be downloaded here:
- [`{CANONICAL_PATH.name}`]({CANONICAL_PATH})
  - [`{CANO_PATH_SPLIT.name}`]({CANO_PATH_SPLIT})
  - [`{CANO_PATH_PC.name}`]({CANO_PATH_PC})
  - [`{CANO_PATH_PV.name}`]({CANO_PATH_PV})
- [`{HELICITY_PATH.name}`]({HELICITY_PATH})
""")

In [None]:
pc_current = np.tril(pc_rates).sum()
pv_current = np.tril(pv_rates).sum()
np.testing.assert_allclose(
    pc_current + pv_current,
    np.tril(rates).sum(),
    atol=1e-15,
    rtol=1e-15,
)
Markdown(f"""
:::{{tip}}
- Parity-conserving (PC) current: **{pc_current:.1%}**
- Parity-violating (PV) current: **{pv_current:.1%}**
:::
""")