# Model serialization

## Import model

In [None]:
from __future__ import annotations

import json
from collections import abc
from difflib import get_close_matches
from itertools import product
from typing import Any, Callable, Literal, Sequence, Union
from warnings import warn

import jax.numpy as jnp
import matplotlib.pyplot as plt
import pandas as pd
import sympy as sp
from ampform.dynamics import BlattWeisskopfSquared
from ampform.kinematics.phasespace import Kallen
from ampform.sympy import PoolSum, argument, perform_cached_doit, unevaluated
from attrs import asdict, frozen
from IPython.display import JSON, Markdown, Math
from sympy.functions.special.tensor_functions import KroneckerDelta as δ
from sympy.physics.quantum.cg import CG
from sympy.physics.quantum.spin import Rotation as Wigner
from sympy.printing.latex import LatexPrinter
from tqdm.auto import tqdm

from ampform_dpd import (
    AmplitudeModel,
    DefinedExpression,
    _AlignmentWignerGenerator,
    _generate_amplitude_index_bases,
    create_mass_symbol_mapping,
    formulate_invariants,
)
from ampform_dpd.angles import formulate_scattering_angle
from ampform_dpd.decay import (
    FinalStateID,
    IsobarNode,
    Particle,
    State,
    StateID,
    ThreeBodyDecay,
    ThreeBodyDecayChain,
)
from ampform_dpd.dynamics import BuggBreitWigner, EnergyDependentWidth, FormFactor, P
from ampform_dpd.io import (
    as_markdown_table,
    aslatex,
    perform_cached_lambdify,
    simplify_latex_rendering,
)
from ampform_dpd.spin import create_spin_range

simplify_latex_rendering()

In [None]:
with open("Lc2ppiK.json") as stream:
    MODEL_DEFINITION = json.load(stream)

In [None]:
JSON(MODEL_DEFINITION)

## Construct `ThreeBodyDecay`

In [None]:
def to_decay(
    model: dict, to_latex: Callable[[str], str] | None = None
) -> ThreeBodyDecay:
    initial_state = _get_initial_state(model)
    final_state = _get_final_state(model)
    return ThreeBodyDecay(
        states=_get_states(model),
        chains=sorted({
            to_decay_chain(chain, initial_state, final_state, to_latex)
            for chain in _get_decay_chains(model)
        }),
    )


def _get_decay_chains(model: dict) -> list[dict]:
    distribution_def = _get_distribution_def(model)
    return distribution_def["decay_description"]["chains"]


def _get_distribution_def(model: dict) -> dict:
    distribution_defs = MODEL_DEFINITION["distributions"]
    n_distributions = len(distribution_defs)
    if n_distributions == 0:
        msg = "The serialized model does not have any distributions"
        raise ValueError(msg)
    if n_distributions > 1:
        msg = f"There are {n_distributions} distributions, but expecting one only"
        warn(msg, category=UserWarning)
    return distribution_defs[0]


def to_decay_chain(
    chain_def: dict,
    initial_state: State,
    final_state: dict[FinalStateID, State],
    to_latex: Callable[[str], str] | None = None,
) -> ThreeBodyDecayChain:
    vertices = chain_def["vertices"]
    if to_latex is None:
        to_latex = lambda x: x  # noqa:E731
    resonance = Particle(
        name=chain_def["name"],
        latex=name_to_latex(chain_def["name"]),
        spin=chain_def["propagators"][0]["spin"],
        mass=0,
        width=0,
        parity=None,
    )
    return ThreeBodyDecayChain(
        decay=IsobarNode(
            parent=initial_state,
            child1=IsobarNode(
                parent=resonance,
                child1=final_state[vertices[1]["node"][0]],
                child2=final_state[vertices[1]["node"][1]],
            ),
            child2=final_state[vertices[0]["node"][1]],
        )
    )


def _get_states(model: dict) -> dict[StateID, State]:
    initial_state = _get_initial_state(model)
    final_state = _get_final_state(model)
    return {initial_state.index: initial_state, **final_state}


def _get_initial_state(model: dict) -> dict[StateID, State]:
    distribution_def = _get_distribution_def(model)
    decay_description = distribution_def["decay_description"]
    kinematics = decay_description["kinematics"]
    return dict_to_particle(kinematics["initial_state"])


def _get_final_state(model: dict) -> dict[StateID, State]:
    distribution_def = _get_distribution_def(model)
    decay_description = distribution_def["decay_description"]
    kinematics = decay_description["kinematics"]
    final_state_def = kinematics["final_state"]
    return {p["index"]: dict_to_particle(p) for p in final_state_def}


def dict_to_particle(dct: dict) -> State:
    return State(
        name=dct["name"],
        latex=name_to_latex(dct["name"]),
        mass=dct["mass"],
        width=0,
        spin=dct["spin"],
        parity=None,
        index=dct["index"],
    )

In [None]:
def name_to_latex(name: str) -> str:
    latex = {
        "Lc": R"\Lambda_c^+",
        "pi": R"\pi^+",
        "K": "K^-",
        "p": "p",
    }.get(name)
    if latex is not None:
        return latex
    mass_str = name[1:].strip("(").strip(")")
    subsystem_letter = name[0]
    subsystem = {"D": "D", "K": "K", "L": R"\Lambda"}.get(subsystem_letter)
    if subsystem is None:
        return name
    return f"{subsystem}({mass_str})"


DECAY = to_decay(MODEL_DEFINITION, to_latex=name_to_latex)

In [None]:
Math(aslatex(DECAY))

In [None]:
Markdown(as_markdown_table(DECAY))

## Dynamics

:::{seealso} [RUB-EP1/amplitude-serialization#22](https://github.com/RUB-EP1/amplitude-serialization/issues/22)
:::

### Function look-up mechanism

In [None]:
function_defs = MODEL_DEFINITION["functions"]
{f["type"] for f in function_defs}

In [None]:
CHAIN_DEFINITIONS = _get_decay_chains(MODEL_DEFINITION)
CHAIN_DEFINITIONS[2]

In [None]:
def get_function_definition(function_name: str, model: dict) -> dict:
    function_definitions = model["functions"]
    for function_def in function_definitions:
        if function_def["name"] == function_name:
            return function_def
    existing_names = {f["name"] for f in function_definitions}
    msg = f"Could not find function with name {function_name!r}."
    candidates = get_close_matches(function_name, existing_names)
    if candidates:
        msg += f" Did you mean any of these? {', '.join(sorted(candidates))}"
    raise KeyError(msg)

In [None]:
get_function_definition("BlattWeiskopf", MODEL_DEFINITION)

<!-- cspell:ignore Weiskopf -->
### Vertices

#### Blatt-Weisskopf form factor

In [None]:
get_function_definition("BlattWeisskopf_resonance_l1", MODEL_DEFINITION)

In [None]:
get_function_definition("BlattWeisskopf_resonance_l2", MODEL_DEFINITION)

In [None]:
def _node_to_mass(node_item: int | Sequence[int]) -> sp.Symbol:
    if isinstance(node_item, int):
        return sp.Symbol(f"m{node_item}", nonnegative=True)
    if (
        isinstance(node_item, abc.Sequence)
        and all(isinstance(i, int) for i in node_item)
        and len(node_item) == 2
    ):
        k, *_ = {1, 2, 3} - set(node_item)
        return sp.Symbol(f"sigma{k}", nonnegative=True)
    msg = f"Cannot create mass symbol for node {node_item}"
    raise NotImplementedError(msg)


sp.Tuple(_node_to_mass([2, 3]), _node_to_mass(1))

In [None]:
def _node_to_mandelstam(node: Topology) -> sp.Symbol:
    if all(isinstance(i, int) for i in node):
        return _node_to_mass(node)
    return _node_to_mass(0)


sp.Tuple(_node_to_mandelstam([2, 3]), _node_to_mandelstam([[2, 3], 1]))

In [None]:
def formulate_form_factor(vertex: dict, model: dict) -> DefinedExpression:
    function_name = vertex.get("formfactor")
    if not function_name:
        return DefinedExpression()
    function_definition = get_function_definition(function_name, model)
    function_type = function_definition["type"]
    if function_type == "BlattWeisskopf":
        node = vertex["node"]
        s = _node_to_mandelstam(node)
        m1, m2 = (_node_to_mass(i) for i in node)
        if all(isinstance(i, int) for i in node):
            meson_radius = sp.Symbol(R"R_\mathrm{res}")
        else:
            initial_state = _get_initial_state(model)
            meson_radius = sp.Symbol(f"R_{{{initial_state.latex}}}")
        angular_momentum = int(function_definition["l"])
        return DefinedExpression(
            expression=FormFactor(s, m1, m2, angular_momentum, meson_radius),
            definitions={
                meson_radius: function_definition["radius"],
            },
        )
    msg = f"No form factor implementation for {function_name!r}"
    raise NotImplementedError(msg)


CHAIN_2 = CHAIN_DEFINITIONS[2]
Math(aslatex(formulate_form_factor(CHAIN_2["vertices"][0], MODEL_DEFINITION)))

In [None]:
s, m1, m2, d, L, z = sp.symbols("s m1 m2 d L z")
exprs = [
    FormFactor(s, m1, m2, L, d),
    BlattWeisskopfSquared(z, L),
]
Math(aslatex({e: e.evaluate() for e in exprs}))

### Propagators

#### Breit-Wigner

In [None]:
@unevaluated
class BreitWigner(sp.Expr):
    s: Any
    mass: Any
    width: Any
    m1: Any = 0
    m2: Any = 0
    angular_momentum: Any = 0
    meson_radius: Any = 1

    def evaluate(self):
        width = self.energy_dependent_width()
        expr = SimpleBreitWigner(self.s, self.mass, width)
        if self.angular_momentum == 0 and self.m1 == 0 and self.m2 == 0:
            return expr.evaluate()
        return expr

    def energy_dependent_width(self) -> sp.Expr:
        s, m0, Γ0, m1, m2, L, d = self.args
        if L == 0 and m1 == 0 and m2 == 0:
            return Γ0
        return EnergyDependentWidth(s, m0, Γ0, m1, m2, L, d)

    def _latex_repr_(self, printer: LatexPrinter, *args) -> str:
        s = printer._print(self.s)
        function_symbol = R"\mathcal{R}^\mathrm{BW}"
        mass = printer._print(self.mass)
        width = printer._print(self.width)
        arg = Rf"\left({s}; {mass}, {width}\right)"
        L = printer._print(self.angular_momentum)
        if isinstance(self.angular_momentum, sp.Integer):
            return Rf"{function_symbol}_{{L={L}}}{arg}"
        return Rf"{function_symbol}_{{{L}}}{arg}"


@unevaluated
class SimpleBreitWigner(sp.Expr):
    s: Any
    mass: Any
    width: Any
    _latex_repr_ = R"\mathcal{{R}}^\mathrm{{BW}}\left({s}; {mass}, {width}\right)"

    def evaluate(self):
        s, m0, Γ0 = self.args
        return 1 / (m0**2 - s - m0 * Γ0 * 1j)


x, y, z = sp.symbols("x:z")
s, m0, Γ0, m1, m2, L, d = sp.symbols("s m0 Gamma0 m1 m2 L R")
exprs = [
    BreitWigner(s, m0, Γ0, m1, m2, L, d),
    SimpleBreitWigner(s, m0, Γ0),
    EnergyDependentWidth(s, m0, Γ0, m1, m2, L, d),
    FormFactor(s, m1, m2, L, d),
    P(s, m1, m2),
    Kallen(x, y, z),
]
Math(aslatex({e: e.doit(deep=False) for e in exprs}))

In [None]:
CHAIN_DEFINITIONS[20]

In [None]:
get_function_definition("K892_BW", MODEL_DEFINITION)

In [None]:
def _formulate_breit_wigner(
    propagator: dict, function_definition: dict, resonance: str, **kwargs
) -> DefinedExpression:
    node = propagator["node"]
    i, j = node
    s = _node_to_mandelstam(node)
    mass = sp.Symbol(f"m_{{{resonance}}}")
    width = sp.Symbol(Rf"\Gamma_{{{resonance}}}")
    m1 = _node_to_mass(i)
    m2 = _node_to_mass(j)
    angular_momentum = int(function_definition["l"])
    d = sp.Symbol(R"R_\mathrm{res}")
    return DefinedExpression(
        expression=BreitWigner(s, mass, width, m1, m2, angular_momentum, d),
        definitions={
            mass: function_definition["mass"],
            width: function_definition["width"],
            m1: function_definition["ma"],
            m2: function_definition["mb"],
            d: function_definition["d"],
        },
    )


CHAIN_20 = CHAIN_DEFINITIONS[20]
K892_BW = _formulate_breit_wigner(
    propagator=CHAIN_20["propagators"][0],
    function_definition=get_function_definition("K892_BW", MODEL_DEFINITION),
    resonance=name_to_latex(CHAIN_20["name"]),
)
Math(aslatex(K892_BW))

#### Multi-channel Breit-Wigner

In [None]:
@unevaluated
class MultichannelBreitWigner(sp.Expr):
    s: Any
    mass: Any
    channels: list[ChannelArguments] = argument(sympify=False)

    def evaluate(self):
        s = self.s
        m0 = self.mass
        width = sum(channel.formulate_width(s, m0) for channel in self.channels)
        return BreitWigner(s, m0, width)

    def _latex_repr_(self, printer: LatexPrinter, *args) -> str:
        latex = R"\mathcal{R}^\mathrm{BW}_\mathrm{multi}\left("
        latex += printer._print(self.s) + "; "
        latex += ", ".join(printer._print(channel.width) for channel in self.channels)
        latex += R"\right)"
        return latex


@frozen
class ChannelArguments:
    width: Any
    m1: Any = 0
    m2: Any = 0
    angular_momentum: Any = 0
    meson_radius: Any = 1

    def __attrs_post_init__(self) -> None:
        for name, value in asdict(self).items():
            object.__setattr__(self, name, sp.sympify(value))

    def formulate_width(self, s: Any, m0: Any) -> sp.Expr:
        Γ0 = self.width
        m1 = self.m1
        m2 = self.m2
        L = self.angular_momentum
        R = self.meson_radius
        ff = FormFactor(s, m1, m2, L, R) ** 2
        return Γ0 * m0 / sp.sqrt(s) * ff


x, y, z = sp.symbols("x:z")
s, m0, Γ0, m1, m2, L, d = sp.symbols("s m0 Gamma0 m1 m2 L R")
channels = [
    ChannelArguments(
        sp.Symbol(f"Gamma{i}"),
        sp.Symbol(f"m_{{a,{i}}}"),
        sp.Symbol(f"m_{{b,{i}}}"),
        sp.Symbol(f"L{i}"),
        d,
    )
    for i in [1, 2]
]
exprs = [
    MultichannelBreitWigner(s, m0, channels),
    BreitWigner(s, m0, Γ0, m1, m2, L, d),
    BreitWigner(s, m0, Γ0),
    EnergyDependentWidth(s, m0, Γ0, m1, m2, L, d),
    FormFactor(s, m1, m2, L, d),
    P(s, m1, m2),
    Kallen(x, y, z),
]
Math(aslatex({e: e.doit(deep=False) for e in exprs}))

In [None]:
CHAIN_DEFINITIONS[0]

In [None]:
get_function_definition("L1405_Flatte", MODEL_DEFINITION)

In [None]:
def _formulate_multichannel_breit_wigner(
    propagator: dict, function_definition: dict, resonance: str, **kwargs
) -> DefinedExpression:
    channel_definitions = function_definition["channels"]
    if len(channel_definitions) < 2:
        msg = "Need at least two channels for a multi-channel Breit-Wigner"
        raise NotImplementedError(msg)
    node = propagator["node"]
    i, j = node
    s = _node_to_mandelstam(node)
    resonance_latex = name_to_latex(resonance)
    mass = sp.Symbol(f"m_{{{resonance}}}")
    width = sp.Symbol(Rf"\Gamma_{{{resonance}}}")
    m1 = _node_to_mass(i)
    m2 = _node_to_mass(j)
    angular_momentum = int(channel_definitions[0]["l"])
    d = sp.Symbol(f"R_{{{resonance_latex}}}")
    channels = [ChannelArguments(width, m1, m2, angular_momentum, d)]
    parameter_defaults = {
        mass: function_definition["mass"],
        width: channel_definitions[0]["gsq"],
        m1: channel_definitions[0]["ma"],
        m2: channel_definitions[0]["mb"],
        d: channel_definitions[0]["d"],
    }
    for i, channel_definition in enumerate(channel_definitions[1:], 2):
        Γi = sp.Symbol(
            Rf"\Gamma_{{{resonance_latex}}}^\text{{ch. {i}}}", nonnegative=True
        )
        mi1 = sp.Symbol(f"m_{{a,{i}}}", nonnegative=True)
        mi2 = sp.Symbol(f"m_{{b,{i}}}", nonnegative=True)
        angular_momentum = int(channel_definition["l"])
        channels.append(ChannelArguments(Γi, mi1, mi2, angular_momentum, d))
        parameter_defaults.update({
            mi1: channel_definition["ma"],
            mi2: channel_definition["mb"],
        })
    return DefinedExpression(
        expression=MultichannelBreitWigner(s, mass, channels),
        definitions=parameter_defaults,
    )


CHAIN_0 = CHAIN_DEFINITIONS[0]
L1405_Flatte = _formulate_multichannel_breit_wigner(
    propagator=CHAIN_0["propagators"][0],
    function_definition=get_function_definition("L1405_Flatte", MODEL_DEFINITION),
    resonance=name_to_latex(CHAIN_0["name"]),
)
Math(aslatex(L1405_Flatte))

#### Breit-Wigner with exponential

In [None]:
s, m0, Γ0, m1, m2, γ = sp.symbols("s m0 Gamma0 m1 m2 gamma")
expr = BuggBreitWigner(s, m0, Γ0, m1, m2, γ)
Math(aslatex({expr: expr.doit(deep=False)}))

In [None]:
CHAIN_DEFINITIONS[18]

In [None]:
get_function_definition("K700_BuggBW", MODEL_DEFINITION)

In [None]:
def _formulate_bugg_breit_wigner(
    propagator: dict, function_definition: dict, resonance: str, model: dict, **kwargs
) -> DefinedExpression:
    node = propagator["node"]
    i, j = node
    s = _node_to_mandelstam(node)
    mass = sp.Symbol(f"m_{{{resonance}}}", nonnegative=True)
    width = sp.Symbol(Rf"\Gamma_{{{resonance}}}", nonnegative=True)
    γ = sp.Symbol(Rf"\gamma_{{{resonance}}}")
    m1 = _node_to_mass(i)
    m2 = _node_to_mass(j)
    final_state = _get_final_state(model)
    return DefinedExpression(
        expression=BuggBreitWigner(s, mass, width, m1, m2, γ),
        definitions={
            mass: function_definition["mass"],
            width: function_definition["width"],
            m1: final_state[i].mass,
            m2: final_state[j].mass,
            γ: function_definition["slope"],
        },
    )


CHAIN_18 = CHAIN_DEFINITIONS[18]
K700_BuggBW = _formulate_bugg_breit_wigner(
    propagator=CHAIN_0["propagators"][0],
    function_definition=get_function_definition("K700_BuggBW", MODEL_DEFINITION),
    resonance=name_to_latex(CHAIN_0["name"]),
    model=MODEL_DEFINITION,
)
Math(aslatex(K700_BuggBW))

#### General propagator dynamics builder

In [None]:
def formulate_dynamics(chain_definition: dict, model: dict) -> DefinedExpression:
    expr = DefinedExpression()
    for propagator in chain_definition["propagators"]:
        parametrization = propagator["parametrization"]
        function_definition = get_function_definition(parametrization, model)
        function_type = function_definition["type"]
        if function_type == "BreitWigner":
            dynamics_builder = _formulate_breit_wigner
        elif function_type == "MultichannelBreitWigner":
            dynamics_builder = _formulate_multichannel_breit_wigner
        elif function_type == "BreitWignerWidthExpLikeBugg":
            dynamics_builder = _formulate_bugg_breit_wigner
        else:
            msg = f"No dynamics implementation for function type {function_type!r}"
            raise NotImplementedError(msg)
        expr *= dynamics_builder(
            propagator,
            function_definition=get_function_definition(parametrization, model),
            resonance=name_to_latex(chain_definition["name"]),
            model=model,
        )
    return expr

In [None]:
Math(aslatex(formulate_dynamics(CHAIN_DEFINITIONS[0], MODEL_DEFINITION)))

In [None]:
Math(aslatex(formulate_dynamics(CHAIN_DEFINITIONS[18], MODEL_DEFINITION)))

In [None]:
Math(aslatex(formulate_dynamics(CHAIN_DEFINITIONS[20], MODEL_DEFINITION)))

## Construct `AmplitudeModel`

### Unpolarized intensity

In [None]:
Topology = list[Union[int, "Topology"]]


def get_reference_subsystem(model: dict) -> FinalStateID:
    topology = get_reference_topology(model)
    return get_spectator_id(topology)


def get_spectator_id(topology: Topology) -> FinalStateID:
    spectator_candidates = {i for i in topology if isinstance(i, int)}
    if len(spectator_candidates) != 1:
        msg = f"Reference topology {topology} seems not to be a three-body decay"
        raise ValueError(msg)
    return next(iter(spectator_candidates))


def get_reference_topology(model: dict) -> Topology:
    distribution_def = _get_distribution_def(model)
    return distribution_def["decay_description"]["reference_topology"]


get_reference_subsystem(MODEL_DEFINITION)

In [None]:
def get_existing_subsystem_ids(model: dict) -> list[FinalStateID]:
    distribution_def = _get_distribution_def(model)
    chain_defs = distribution_def["decay_description"]["chains"]
    subsystem_ids = {get_spectator_id(c["topology"]) for c in chain_defs}
    return sorted(subsystem_ids)


get_existing_subsystem_ids(MODEL_DEFINITION)

In [None]:
def formulate_aligned_amplitude(
    model: dict,
    λ0: sp.Rational | sp.Symbol,
    λ1: sp.Rational | sp.Symbol,
    λ2: sp.Rational | sp.Symbol,
    λ3: sp.Rational | sp.Symbol,
) -> tuple[PoolSum, dict[sp.Symbol, sp.Expr]]:
    reference_subsystem = get_reference_subsystem(model)
    wigner_generator = _AlignmentWignerGenerator(reference_subsystem)
    _λ0, _λ1, _λ2, _λ3 = sp.symbols(R"\lambda_(:4)^{\prime}", rational=True)
    distribution_def = _get_distribution_def(model)
    decay_description = distribution_def["decay_description"]
    states = _get_states(decay_description["kinematics"])
    j0, j1, j2, j3 = (states[i].spin for i in sorted(states))
    A = _generate_amplitude_index_bases()
    amp_expr = PoolSum(
        sum(
            A[k][_λ0, _λ1, _λ2, _λ3]
            * wigner_generator(j0, λ0, _λ0, rotated_state=0, aligned_subsystem=k)
            * wigner_generator(j1, _λ1, λ1, rotated_state=1, aligned_subsystem=k)
            * wigner_generator(j2, _λ2, λ2, rotated_state=2, aligned_subsystem=k)
            * wigner_generator(j3, _λ3, λ3, rotated_state=3, aligned_subsystem=k)
            for k in get_existing_subsystem_ids(model)
        ),
        (_λ0, create_spin_range(j0)),
        (_λ1, create_spin_range(j1)),
        (_λ2, create_spin_range(j2)),
        (_λ3, create_spin_range(j3)),
    )
    return amp_expr, wigner_generator.angle_definitions


λ0, λ1, λ2, λ3 = sp.symbols("lambda(:4)", rational=True)
amplitude_expr, _ = formulate_aligned_amplitude(MODEL_DEFINITION, λ0, λ1, λ2, λ3)
amplitude_expr

### Amplitude for the decay chain

#### Helicity recouplings

In [None]:
@unevaluated
class HelicityRecoupling(sp.Expr):
    λa: sp.Rational | sp.Symbol
    λb: sp.Rational | sp.Symbol
    λa0: sp.Rational | sp.Symbol
    λb0: sp.Rational | sp.Symbol
    _latex_repr_ = R"\mathcal{{H}}^\text{{helicity}}\left({λa},{λb}|{λa0},{λb0}\right)"

    def evaluate(self) -> sp.Expr:
        λa, λb, λa0, λb0 = self.args
        return δ(λa, λa0) * δ(λb, λb0)


@unevaluated
class ParityRecoupling(sp.Expr):
    λa: sp.Rational | sp.Symbol
    λb: sp.Rational | sp.Symbol
    λa0: sp.Rational | sp.Symbol
    λb0: sp.Rational | sp.Symbol
    f: sp.Integer | sp.Symbol
    _latex_repr_ = (
        R"\mathcal{{H}}^\text{{parity}}\left({λa},{λb}|{λa0},{λb0},{f}\right)"
    )

    def evaluate(self) -> sp.Expr:
        λa, λb, λa0, λb0, f = self.args
        return δ(λa, λa0) * δ(λb, λb0) + f * δ(λa, -λa0) * δ(λb, -λb0)


@unevaluated
class LSRecoupling(sp.Expr):
    λa: sp.Rational | sp.Symbol
    λb: sp.Rational | sp.Symbol
    l: sp.Integer | sp.Symbol
    s: sp.Rational | sp.Symbol
    ja: sp.Rational | sp.Symbol
    jb: sp.Rational | sp.Symbol
    j: sp.Rational | sp.Symbol
    _latex_repr_ = (
        R"\mathcal{{H}}^\text{{parity}}\left({λa},{λb}|{l},{s},{ja},{jb},{j}\right)"
    )

    def evaluate(self) -> sp.Expr:
        λa, λb, l, s, ja, jb, j = self.args
        return (
            sp.sqrt((2 * l + 1) / (2 * j + 1))
            * CG(ja, λa, jb, -λb, s, λa - λb)
            * CG(l, 0, s, λa - λb, j, λa - λb)
        )

In [None]:
λa = sp.Symbol(R"\lambda_a", rational=True)
λb = sp.Symbol(R"\lambda_b", rational=True)
λa0 = sp.Symbol(R"\lambda_a^0", rational=True)
λb0 = sp.Symbol(R"\lambda_b^0", rational=True)
f = sp.Symbol("f", integer=True)
l = sp.Symbol("l", integer=True, nonnegative=True)
s = sp.Symbol("s", nonnegative=True, rational=True)
ja = sp.Symbol("j_a", nonnegative=True, rational=True)
jb = sp.Symbol("j_b", nonnegative=True, rational=True)
j = sp.Symbol("j", nonnegative=True, rational=True)
exprs = [
    HelicityRecoupling(λa, λb, λa0, λb0),
    ParityRecoupling(λa, λb, λa0, λb0, f),
    LSRecoupling(λa, λb, l, s, ja, jb, j),
]
Math(aslatex({e: e.doit(deep=False) for e in exprs}))

#### Recoupling deserialization

In [None]:
CHAIN_DEFINITIONS[0]

In [None]:
def _formulate_recoupling(model: dict, chain_idx: int, vertex_idx: int) -> sp.Expr:
    chain_definition = _get_decay_chains(model)[chain_idx]
    vertex_definitions = chain_definition["vertices"]
    if len(vertex_definitions) != 2:
        msg = f"Not a three-body decay: there are {len(vertex_definitions)} vertices"
        raise ValueError(msg)
    if vertex_idx not in {0, 1}:
        msg = f"Vertex index out of range. Can either be 0 or 1, not {vertex_idx}."
        raise ValueError(msg)
    vertex = chain_definition["vertices"][vertex_idx]
    vertex_type = vertex["type"]
    node = vertex["node"]
    λa, λb = map(_get_helicity_symbol, node)
    if vertex_type in {"helicity", "parity"}:
        λa0, λb0 = (sp.Rational(v) for v in vertex["helicities"])
        if vertex_type == "parity":
            f = _sign_to_value(vertex.get("parity_factor", "+"))
            return ParityRecoupling(λa, λb, λa0, λb0, f)
        return HelicityRecoupling(λa, λb, λa0, λb0)
    if vertex_type == "ls":
        l = int(vertex["l"])
        s = sp.Rational(vertex["s"])
        ja, jb = _get_child_spins(model, chain_idx, vertex_idx)
        j = _get_parent_spin(vertex)
        return LSRecoupling(λa, λb, l, s, ja, jb, j)
    msg = f"No implementation for vertex of type {vertex_type!r}"
    raise NotImplementedError(msg)


def _sign_to_value(sign: Literal["", "-", "+"]) -> Literal[0, -1, 1]:
    stripped_sign = sign.strip()
    if stripped_sign == "-":
        return -1
    if not stripped_sign:
        return 0
    if stripped_sign == "+":
        return +1
    msg = f"Cannot convert {sign!r} to value"
    raise NotImplementedError(msg)


def _get_parent_spin(
    model: dict, chain_idx: int, vertex_idx: int
) -> tuple[sp.Rational, sp.Rational]:
    chain_definition = _get_decay_chains(model)[chain_idx]
    vertex = chain_definition["vertices"]
    node = vertex["node"]
    initial_state = _get_initial_state(model)
    if all(isinstance(i, int) for i in node):
        return __get_propagator_spin(vertex)
    return initial_state.spin


def _get_child_spins(
    model: dict, chain_idx: int, vertex_idx: int
) -> tuple[sp.Rational, sp.Rational]:
    chain_definition = _get_decay_chains(model)[chain_idx]
    vertex = chain_definition["vertices"]
    node = vertex["node"]
    final_state = _get_final_state(model)
    spins = []
    for node_item in node:
        if isinstance(node_item, int):
            spins.append(sp.Rational(final_state[node_item]))
        else:
            spins.append(__get_propagator_spin(vertex["propagators"]))
    return tuple(spins)


def __get_propagator_spin(vertex: dict) -> sp.Rational:
    propagators = vertex["propagators"]
    if len(propagators) != 1:
        msg = f"There are {len(propagators)} propagators, not a three-body decay"
        raise ValueError(msg)
    return sp.Rational(propagators[0]["spin"])


def _get_helicity_symbol(node: int | Topology) -> sp.Symbol:
    if isinstance(node, int):
        return sp.Symbol(f"lambda{node}", rational=True)
    return sp.Symbol(R"\lambda_R", rational=True)

In [None]:
recouplings = [
    _formulate_recoupling(MODEL_DEFINITION, chain_idx=0, vertex_idx=i) for i in range(2)
]
Math(aslatex({e: e.doit(deep=False) for e in recouplings}))

#### Chain amplitudes

In [None]:
def _get_final_state_helicities(
    chain_definition: dict,
) -> dict[FinalStateID, sp.Rational]:
    vertices = chain_definition["vertices"]
    helicities: dict[FinalStateID, sp.Rational] = {}
    for v in vertices:
        for helicity, node in zip(v["helicities"], v["node"]):
            if not isinstance(node, int):
                continue
            helicities[node] = sp.Rational(helicity)
    return {i: helicities[i] for i in sorted(helicities)}


_get_final_state_helicities(CHAIN_DEFINITIONS[0])

In [None]:
def _get_resonance_helicity(
    chain_definition: dict,
) -> tuple[tuple[FinalStateID, FinalStateID], sp.Rational]:
    vertices = chain_definition["vertices"]
    for vertex in vertices:
        node = vertex["node"]
        if all(isinstance(i, int) for i in node):
            continue
        helicities = vertex["helicities"]
        for helicity, sub_node in zip(helicities, node):
            if isinstance(sub_node, abc.Sequence) and len(sub_node) == 2:
                return tuple(sub_node), sp.Rational(helicity)
    msg = "Could not find a resonance node"
    raise ValueError(msg)


_get_resonance_helicity(CHAIN_DEFINITIONS[0])

In [None]:
def _get_spectator_helicity(chain_definition: dict) -> tuple[int, sp.Rational]:
    vertices = chain_definition["vertices"]
    for vertex in vertices:
        node = vertex["node"]
        if all(isinstance(i, int) for i in node):
            continue
        helicities = vertex["helicities"]
        for helicity, state_id in zip(helicities, node):
            if isinstance(state_id, int):
                return state_id, sp.Rational(helicity)
    msg = "Could not find vertex that contains a spectator"
    raise ValueError(msg)


_get_spectator_helicity(CHAIN_DEFINITIONS[0])

In [None]:
def _get_decay_product_helicities(
    chain_definition: dict,
) -> tuple[tuple[int, sp.Rational], tuple[int, sp.Rational]]:
    vertices = chain_definition["vertices"]
    for v in vertices:
        node = v["node"]
        if all(isinstance(i, int) for i in node):
            helicities = v["helicities"]
            return tuple((i, sp.Rational(λ)) for i, λ in zip(node, helicities))
    msg = "Could not fine a helicity for any resonance node"
    raise ValueError(msg)


_get_decay_product_helicities(CHAIN_DEFINITIONS[0])

In [None]:
def _get_weight(chain_definition: dict) -> tuple[sp.Symbol, complex | float]:
    value: complex | float
    value = complex(str(chain_definition["weight"]).replace(" ", "").replace("i", "j"))
    if not value.imag:
        value = value.real
    resonance_latex = name_to_latex(chain_definition["name"])
    _, resonance_helicity = _get_resonance_helicity(chain_definition)
    c = sp.IndexedBase(f"c^{{{resonance_latex}[{resonance_helicity}]}}")
    λ1, λ2, λ3 = _get_final_state_helicities(chain_definition).values()
    symbol = c[λ1, λ2, λ3]
    return symbol, value


Math(aslatex(dict([_get_weight(CHAIN_DEFINITIONS[0])])))

In [None]:
def formulate_chain_amplitude(
    λ0: sp.Rational,
    λ1: sp.Rational,
    λ2: sp.Rational,
    λ3: sp.Rational,
    model: dict,
    chain_idx: int,
) -> sp.Expr:
    chain_defs = _get_decay_chains(model)
    chain_definition = chain_defs[chain_idx]
    # -----------------------
    dynamics = formulate_dynamics(chain_definition, model)
    for vertex in chain_definition["vertices"]:
        dynamics *= formulate_form_factor(vertex, model)
    # -----------------------
    weight, weight_val = _get_weight(chain_definition)
    # -----------------------
    (i, λi_val), (j, λj_val) = _get_decay_product_helicities(chain_definition)
    θij, θij_expr = formulate_scattering_angle(i, j)
    jR = sp.Rational(chain_definition["propagators"][0]["spin"])
    R_node, λR_val = _get_resonance_helicity(chain_definition)
    λR = _get_helicity_symbol(R_node)
    # -----------------------
    A = _generate_amplitude_index_bases()
    subsystem_id = get_spectator_id(chain_definition["topology"])
    h_prod = _formulate_recoupling(chain_definition, chain_idx, vertex_idx=0)
    h_dec = _formulate_recoupling(chain_definition, chain_idx, vertex_idx=1)
    amplitude_expression = (
        weight
        * h_prod
        * h_dec
        * Wigner.d(jR, λR, λi_val - λj_val, θij)
        * dynamics.expression
    )
    amplitude_expression = amplitude_expression.subs({λR: λR_val})
    amplitude_symbol = A[subsystem_id][λ0, λ1, λ2, λ3]
    return {
        amplitude_symbol: amplitude_expression,
        weight: weight_val,
        **dynamics.definitions,
        θij: θij_expr,
    }

In [None]:
definitions = formulate_chain_amplitude(λ0, λ1, λ2, λ3, MODEL_DEFINITION, chain_idx=0)
Math(aslatex(definitions))

### Full amplitude model

In [None]:
def formulate(model: dict, cleanup_summations: bool = False) -> AmplitudeModel:
    states = _get_states(model)
    helicity_symbols = sp.symbols("lambda(:4)", rational=True)
    allowed_helicities = {
        symbol: create_spin_range(states[i].spin)
        for i, symbol in enumerate(helicity_symbols)
    }
    amplitude_definitions = {}
    angle_definitions = {}
    parameter_defaults = {}
    n_chains = len(_get_decay_chains(model))
    helicity_values: tuple[sp.Rational, sp.Rational, sp.Rational, sp.Rational]
    for helicity_values in product(*allowed_helicities.values()):
        for chain_idx in range(n_chains):
            amp_defs = formulate_chain_amplitude(*helicity_values, model, chain_idx)
            (amp_symbol, amp_expr), *parameters, (θij, θij_expr) = amp_defs.items()
            helicity_substitutions = dict(zip(helicity_symbols, helicity_values))
            amplitude_definitions[amp_symbol] = amp_expr.subs(helicity_substitutions)
            angle_definitions[θij] = θij_expr
            parameter_defaults.update(dict(parameters))
    aligned_amp, zeta_defs = formulate_aligned_amplitude(model, *helicity_symbols)
    angle_definitions.update(zeta_defs)
    decay = to_decay(model)
    masses = create_mass_symbol_mapping(decay)
    parameter_defaults.update(masses)
    if cleanup_summations:
        aligned_amp = aligned_amp.cleanup()  # type:ignore[assignment]
    intensity = PoolSum(
        sp.Abs(aligned_amp) ** 2,
        *allowed_helicities.items(),
    )
    if cleanup_summations:
        intensity = intensity.cleanup()  # type:ignore[assignment]
    return AmplitudeModel(
        decay=decay,
        intensity=intensity,
        amplitudes=amplitude_definitions,
        variables=angle_definitions,
        parameter_defaults=parameter_defaults,
        masses=masses,
        invariants=formulate_invariants(decay),
    )

## Symbolic result

In [None]:
MODEL = formulate(MODEL_DEFINITION, cleanup_summations=True)
MODEL.intensity

In [None]:
Math(aslatex(MODEL.amplitudes))

In [None]:
Math(aslatex(MODEL.variables))

In [None]:
Math(aslatex({**MODEL.invariants, **MODEL.masses}))

In [None]:
intensity_expr = MODEL.full_expression.xreplace(MODEL.variables)
intensity_expr = intensity_expr.xreplace(MODEL.parameter_defaults)

In [None]:
free_symbols = intensity_expr.free_symbols
assert len(free_symbols) == 3
assert str(sorted(free_symbols, key=str)) == "[sigma1, sigma2, sigma3]"

## Numeric results

In [None]:
intensity_funcs = {}
for s, s_expr in tqdm(MODEL.invariants.items()):
    k = int(str(s)[-1])
    s_expr = s_expr.xreplace(MODEL.masses).doit()
    expr = perform_cached_doit(intensity_expr.xreplace({s: s_expr}))
    func = perform_cached_lambdify(expr, backend="jax")
    assert len(func.argument_order) == 2
    intensity_funcs[k] = func

### Validation

In [None]:
checksums = {
    misc_key: {checksum["name"]: checksum["value"] for checksum in misc_value}
    for misc_key, misc_value in MODEL_DEFINITION["misc"].items()
    if "checksum" in misc_key
}
checksums

In [None]:
checksum_points = {
    point["name"]: {par["name"]: par["value"] for par in point["parameters"]}
    for point in MODEL_DEFINITION["parameter_points"]
}
checksum_points

In [None]:
array = []
for distribution_name, checksum in checksums.items():
    for point_name, expected in checksum.items():
        parameters = checksum_points[point_name]
        s1 = parameters["m_31_2"] ** 2
        s2 = parameters["m_31"] ** 2
        computed = intensity_funcs[3]({"sigma1": s1, "sigma2": s2})
        status = "🟢" if computed == expected else "🔴"
        array.append((distribution_name, point_name, computed, expected, status))
pd.DataFrame(array, columns=["Distribution", "Point", "Computed", "Expected", "Status"])

### Dalitz plot

In [None]:
i, j = (2, 1)
k, *_ = {1, 2, 3} - {i, j}
σk, σk_expr = list(MODEL.invariants.items())[k - 1]
Math(aslatex({σk: σk_expr}))

In [None]:
resolution = 1_000
m = sorted(MODEL.masses, key=str)
x_min = float(((m[j] + m[k]) ** 2).xreplace(MODEL.masses))
x_max = float(((m[0] - m[i]) ** 2).xreplace(MODEL.masses))
y_min = float(((m[i] + m[k]) ** 2).xreplace(MODEL.masses))
y_max = float(((m[0] - m[j]) ** 2).xreplace(MODEL.masses))
x_diff = x_max - x_min
y_diff = y_max - y_min
x_min -= 0.05 * x_diff
x_max += 0.05 * x_diff
y_min -= 0.05 * y_diff
y_max += 0.05 * y_diff
X, Y = jnp.meshgrid(
    jnp.linspace(x_min, x_max, num=resolution),
    jnp.linspace(y_min, y_max, num=resolution),
)
dalitz_data = {
    f"sigma{i}": X,
    f"sigma{j}": Y,
}

In [None]:
intensities = intensity_funcs[k](dalitz_data)

In [None]:
assert not jnp.all(jnp.isnan(intensities)), "All intensities are NaN"

In [None]:
def get_decay_products(subsystem_id: int) -> tuple[State, State]:
    return tuple(s for s in DECAY.final_state.values() if s.index != subsystem_id)


plt.rc("font", size=18)
I_tot = jnp.nansum(intensities)
normalized_intensities = intensities / I_tot

fig, ax = plt.subplots(figsize=(14, 10))
mesh = ax.pcolormesh(X, Y, normalized_intensities)
ax.set_aspect("equal")
c_bar = plt.colorbar(mesh, ax=ax, pad=0.01)
c_bar.ax.set_ylabel("Normalized intensity (a.u.)")
sigma_labels = {
    i: Rf"$\sigma_{i} = M^2\left({' '.join(p.latex for p in get_decay_products(i))}\right)$"
    for i in (1, 2, 3)
}
ax.set_xlabel(sigma_labels[i])
ax.set_ylabel(sigma_labels[j])
plt.show()