In [None]:
%config InlineBackend.figure_formats = ['svg']
import os

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

```{autolink-concat}
```

# [TR-014] Amplitude model with sum notation

In [None]:
%pip install -q "ampform[viz]@git+https://github.com/ComPWA/ampform@3ed3ed5" sympy==1.9

In [None]:
import itertools
import logging
from typing import Any, Iterable, List, Sequence, Tuple, Union

import ampform
import graphviz
import qrules
import sympy as sp
from ampform.sympy import (
    UnevaluatedExpression,
    create_expression,
    implement_doit_method,
)
from IPython.display import Math, display
from sympy.physics.quantum.spin import Rotation as Wigner
from sympy.printing.latex import LatexPrinter
from sympy.printing.precedence import PRECEDENCE

LOGGER = logging.getLogger()
LOGGER.setLevel(logging.ERROR)

## Problem description

[ampform#213](https://github.com/ComPWA/ampform/pull/213) implements spin alignment, which results in large sum combinatorics for all helicity combinations. The result is an amplitude model expression that is too large to be rendered.

To some extend, this is already the case with the current implementation of the 'standard' helicity formalism. Forgetting about dynamics, many of the terms in the total intensity expression differ only by a different helicity of the final and initial state.

In [None]:
reaction = qrules.generate_transitions(
    initial_state="Lambda(c)+",
    final_state=["K-", "p", "pi+"],
    formalism="helicity",
    allowed_intermediate_particles=["Delta(1600)++"],
)

In [None]:
display(
    *map(
        lambda t: graphviz.Source(qrules.io.asdot(t, size=3)),
        reaction.transitions,
    )
)

In [None]:
builder = ampform.get_builder(reaction)
model = builder.formulate()

In [None]:
model = builder.formulate()
full_expression = model.expression.subs(model.parameter_defaults)
full_expression = sp.nsimplify(full_expression)
I = sp.Symbol("I")  # noqa: E741
latex = sp.multiline_latex(I, full_expression)
Math(latex)

## Simplified notation with `PoolSum`

The definition of a special class that mimics {class}`sympy.Sum <sympy.concrete.summations.Sum>` may help to simplify this the total intensity model.

In [None]:
@implement_doit_method
class PoolSum(UnevaluatedExpression):
    precedence = PRECEDENCE["Mul"]

    def __new__(
        cls,
        expression: sp.Expr,
        *indices: Tuple[sp.Symbol, Iterable[sp.Float]],
        **hints: Any,
    ) -> "PoolSum":
        indices = tuple((s, tuple(v)) for s, v in indices)
        return create_expression(cls, expression, *indices, **hints)

    @property
    def expression(self) -> sp.Expr:
        return self.args[0]

    @property
    def indices(self) -> List[Tuple[sp.Symbol, Tuple[sp.Float, ...]]]:
        return self.args[1:]

    def evaluate(self) -> sp.Expr:
        indices = dict(self.indices)
        return sp.Add(
            *[
                self.expression.subs(zip(indices, combi))
                for combi in itertools.product(*indices.values())
            ]
        )

    def _latex(self, printer: LatexPrinter, *args: Any) -> str:
        indices = dict(self.indices)
        sum_symbols: List[str] = []
        for idx, values in indices.items():
            sum_symbols.append(_render_sum_symbol(printer, idx, values))
        expression = printer._print(self.expression)
        return R" ".join(sum_symbols) + f"{{{expression}}}"

    def cleanup(self) -> Union[sp.Expr, "PoolSum"]:
        """Remove redundant summations, like indices with one or no value.

        >>> x, i = sp.symbols("x i")
        >>> PoolSum(x**i, (i, [0, 1, 2])).cleanup().doit()
        x**2 + x + 1
        >>> PoolSum(x, (i, [0, 1, 2])).cleanup()
        x
        >>> PoolSum(x).cleanup()
        x
        >>> PoolSum(x**i, (i, [0])).cleanup()
        1
        """
        substitutions = {}
        new_indices = []
        for idx, values in self.indices:
            if idx not in self.expression.free_symbols:
                continue
            if len(values) == 0:
                continue
            if len(values) == 1:
                substitutions[idx] = values[0]
            else:
                new_indices.append((idx, values))
        new_expression = self.expression.xreplace(substitutions)
        if len(new_indices) == 0:
            return new_expression
        return PoolSum(new_expression, *new_indices)


def _render_sum_symbol(
    printer: LatexPrinter, idx: sp.Symbol, values: Sequence[float]
) -> str:
    if len(values) == 0:
        return ""
    idx = printer._print(idx)
    if len(values) == 1:
        value = values[0]
        return Rf"\sum_{{{idx}={value}}}"
    if _is_regular_series(values):
        sorted_values = sorted(values)
        first_value = sorted_values[0]
        last_value = sorted_values[-1]
        return Rf"\sum_{{{idx}={first_value}}}^{{{last_value}}}"
    idx_values = ",".join(map(printer._print, values))
    return Rf"\sum_{{{idx}\in\left\{{{idx_values}\right\}}}}"


def _is_regular_series(values: Sequence[float]) -> bool:
    if len(values) <= 1:
        return False
    sorted_values = sorted(values)
    for val, next_val in zip(sorted_values, sorted_values[1:]):
        difference = float(next_val - val)
        if difference != 1.0:
            return False
    return True

In [None]:
half = sp.S.Half

spin_parent = sp.Symbol(R"s_{\Lambda_c}", real=True)
spin_resonance = sp.Symbol(R"s_\Delta", real=True)

phi_12, theta_12 = sp.symbols("phi_12 theta_12", real=True)
phi_1_12, theta_1_12 = sp.symbols(R"phi_1^12 theta_1^12", real=True)

lambda_parent = sp.Symbol(R"\lambda_{\Lambda_c}", real=True)
lambda_resonance = sp.Symbol(R"\lambda_\Delta", real=True)
lambda_p = sp.Symbol(R"\lambda_p", real=True)
lambda_k = sp.Symbol(R"\lambda_K", real=True)
lambda_pi = sp.Symbol(R"\lambda_\pi", real=True)
sum_expr = sp.Subs(
    PoolSum(
        sp.Abs(
            PoolSum(
                Wigner.D(
                    spin_parent,
                    lambda_parent,
                    lambda_k - lambda_resonance,
                    phi_12,
                    theta_12,
                    0,
                )
                * Wigner.D(
                    spin_resonance,
                    lambda_resonance,
                    lambda_p - lambda_pi,
                    phi_1_12,
                    theta_1_12,
                    0,
                ),
                (lambda_resonance, (-half, +half)),
            )
        )
        ** 2,
        (lambda_parent, (-half, +half)),
        (lambda_p, (-half, +half)),
        (lambda_pi, (0,)),
        (lambda_k, (0,)),
    ),
    (spin_parent, spin_resonance),
    (half, 3 * half),
)
display(
    sum_expr,
    sum_expr.expr.cleanup(),
    Math(sp.multiline_latex(I, sum_expr.doit(deep=False))),
    sum_expr.doit(deep=False).doit(deep=True),
)

In [None]:
assert sum_expr.doit(deep=False).doit() == full_expression.doit()

## Generic solution

The current implementation of the [`HelicityAmplitudeBuilder`](https://ampform.readthedocs.io/en/0.12.3/api/ampform.helicity.html#ampform.helicity.HelicityAmplitudeBuilder) has to be changed quite a bit to produce an amplitude model with `PoolSum`s. First of all, we have to introduce special {class}`~sympy.core.symbol.Symbol`s for the helicities, $\lambda_i$, with $i$ the state ID (taking a sum of attached final state IDs in case of a resonance ID).[^spin-alignment] Next, [`formulate_wigner_d()`](https://ampform.readthedocs.io/en/0.12.3/api/ampform.helicity.html#ampform.helicity.formulate_wigner_d) has to be modified to insert these {class}`~sympy.core.symbol.Symbol`s into the {class}`sympy.physics.quantum.spin.WignerD`:

[^spin-alignment]: When introducing spin alignment ([ampform#213](https://github.com/ComPWA/ampform/pull/213)), we have to distinguish the helicity symbols between different topologies.

In [None]:
from ampform.helicity import _generate_kinematic_variables
from ampform.helicity.decay import TwoBodyDecay, determine_attached_final_state
from qrules.topology import Topology
from qrules.transition import StateTransition


def formulate_wigner_d(transition: StateTransition, node_id: int) -> sp.Expr:
    from sympy.physics.quantum.spin import Rotation as Wigner

    decay = TwoBodyDecay.from_transition(transition, node_id)
    topology = transition.topology
    parent_helicity = create_helicity_symbol(topology, decay.parent.id)
    child1_helicity = create_helicity_symbol(topology, decay.children[0].id)
    child2_helicity = create_helicity_symbol(topology, decay.children[1].id)
    _, phi, theta = _generate_kinematic_variables(transition, node_id)
    return Wigner.D(
        j=sp.Rational(decay.parent.particle.spin),
        m=parent_helicity,
        mp=child1_helicity - child2_helicity,
        alpha=-phi,
        beta=theta,
        gamma=0,
    )


def create_helicity_symbol(topology: Topology, state_id: int) -> sp.Symbol:
    if state_id in topology.incoming_edge_ids:
        suffix = ""
    else:
        suffix = "_" + get_resonance_identifier(topology, state_id)
    return sp.Symbol(f"lambda{suffix}", rational=True)


def get_resonance_identifier(topology: Topology, state_id: int) -> str:
    attached_final_state_ids = determine_attached_final_state(
        topology, state_id
    )
    return "".join(map(str, attached_final_state_ids))

In [None]:
wigner_functions = {
    sp.Mul(
        *[
            formulate_wigner_d(transition, node_id)
            for node_id in transition.topology.nodes
        ]
    )
    for transition in reaction.transitions
}
display(*wigner_functions)

We also have to collect the allowed helicity values for each of these helicity symbols.

In [None]:
from collections import defaultdict
from typing import Dict, Literal, Set

from qrules import ReactionInfo


def get_helicities(
    reaction: ReactionInfo, which: Literal["inner", "outer"]
) -> Dict[int, Set[sp.Rational]]:
    helicities = defaultdict(set)
    for transition in reaction.transitions:
        for state_id, state in transition.states.items():
            intermediate_edge = state_id in transition.intermediate_states
            if which == "inner" and not intermediate_edge:
                continue
            if which == "outer" and intermediate_edge:
                continue
            helicity = sp.Rational(state.spin_projection)
            symbol = create_helicity_symbol(transition.topology, state_id)
            helicities[symbol].add(helicity)
    return dict(helicities)

In [None]:
inner_helicities = get_helicities(reaction, which="inner")
outer_helicities = get_helicities(reaction, which="outer")
display(inner_helicities, outer_helicities)

These collected helicity values can then be combined with the Wigner-$D$ expressions through a `PoolSum`:

In [None]:
def formulate_intensity(reaction: ReactionInfo):
    wigner_functions = {
        sp.Mul(
            *[
                formulate_wigner_d(transition, node_id)
                for node_id in transition.topology.nodes
            ]
        )
        for transition in reaction.transitions
    }
    inner_helicities = get_helicities(reaction, which="inner")
    outer_helicities = get_helicities(reaction, which="outer")
    return PoolSum(
        sp.Abs(
            PoolSum(
                sp.Add(*wigner_functions),
                *inner_helicities.items(),
            )
        )
        ** 2,
        *outer_helicities.items(),
    )


formulate_intensity(reaction)

Note how this approach also works in case there are two decay topologies:

In [None]:
reaction_two_resonances = qrules.generate_transitions(
    initial_state="Lambda(c)+",
    final_state=["K-", "p", "pi+"],
    formalism="helicity",
    allowed_intermediate_particles=["Lambda(1405)", "Delta(1600)++"],
)
assert len(reaction_two_resonances.transition_groups) == 2
dot = qrules.io.asdot(reaction, collapse_graphs=True)
display(
    *[
        graphviz.Source(qrules.io.asdot(g, collapse_graphs=True, size=4))
        for g in reaction_two_resonances.transition_groups
    ]
)

In [None]:
formulate_intensity(reaction_two_resonances).cleanup()