In [None]:
%config InlineBackend.figure_formats = ['svg']
import os

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

```{autolink-concat}
```

# [TR-014] Amplitude model with sum notation

In [None]:
%pip install -q ampform==0.12.3 sympy==1.9

In [None]:
import itertools
import logging
from typing import Any, Iterable, List, Sequence, Tuple, Union

import ampform
import graphviz
import qrules
import sympy as sp
from ampform.sympy import (
    UnevaluatedExpression,
    create_expression,
    implement_doit_method,
)
from IPython.display import Math, display
from sympy.physics.quantum.spin import Rotation as Wigner
from sympy.printing.latex import LatexPrinter
from sympy.printing.precedence import PRECEDENCE

LOGGER = logging.getLogger()
LOGGER.setLevel(logging.ERROR)

## Problem description

[ampform#213](https://github.com/ComPWA/ampform/pull/213) implements spin alignment, which results in large sum combinatorics for all helicity combinations. The result is an amplitude model expression that is too large to be rendered.

To some extend, this is already the case with the current implementation of the 'standard' helicity formalism. Forgetting about dynamics, many of the terms in the total intensity expression differ only by a different helicity of the final and initial state.

In [None]:
reaction = qrules.generate_transitions(
    initial_state="Lambda(c)+",
    final_state=["K-", "p", "pi+"],
    formalism="helicity",
    allowed_intermediate_particles=["Delta(1600)++"],
)

In [None]:
display(
    *map(
        lambda t: graphviz.Source(qrules.io.asdot(t, size=3)),
        reaction.transitions,
    )
)

In [None]:
builder = ampform.get_builder(reaction)
model = builder.formulate()

In [None]:
model = builder.formulate()
full_expression = model.expression.subs(model.parameter_defaults)
full_expression = sp.nsimplify(full_expression)
latex = sp.multiline_latex(sp.Symbol("I"), full_expression)
Math(latex)

## Simplified notation with `PoolSum`

The definition of a special class that mimics {class}`sympy.Sum <sympy.concrete.summations.Sum>` may help to simplify this the total intensity model.

In [None]:
@implement_doit_method
class PoolSum(UnevaluatedExpression):
    precedence = PRECEDENCE["Mul"]

    def __new__(
        cls,
        expression: sp.Expr,
        *indices: Tuple[sp.Symbol, Iterable[sp.Float]],
        **hints: Any,
    ) -> "PoolSum":
        indices = tuple((s, tuple(v)) for s, v in indices)
        return create_expression(cls, expression, *indices, **hints)

    @property
    def expression(self) -> sp.Expr:
        return self.args[0]

    @property
    def indices(self) -> List[Tuple[sp.Symbol, Tuple[sp.Float, ...]]]:
        return self.args[1:]

    def evaluate(self) -> sp.Expr:
        indices = dict(self.indices)
        return sp.Add(
            *[
                self.expression.subs(zip(indices, combi))
                for combi in itertools.product(*indices.values())
            ]
        )

    def _latex(self, printer: LatexPrinter, *args: Any) -> str:
        indices = dict(self.indices)
        sum_symbols: List[str] = []
        for idx, values in indices.items():
            sum_symbols.append(_render_sum_symbol(printer, idx, values))
        expression = printer._print(self.expression)
        return R" ".join(sum_symbols) + f"{{{expression}}}"

    def cleanup(self) -> Union[sp.Expr, "PoolSum"]:
        """Remove redundant summations, like indices with one or no value.

        >>> x, i = sp.symbols("x i")
        >>> PoolSum(x**i, (i, [0, 1, 2])).cleanup().doit()
        x**2 + x + 1
        >>> PoolSum(x, (i, [0, 1, 2])).cleanup()
        x
        >>> PoolSum(x).cleanup()
        x
        >>> PoolSum(x**i, (i, [0])).cleanup()
        1
        """
        substitutions = {}
        new_indices = []
        for idx, values in self.indices:
            if idx not in self.expression.free_symbols:
                continue
            if len(values) == 0:
                continue
            if len(values) == 1:
                substitutions[idx] = values[0]
            else:
                new_indices.append((idx, values))
        new_expression = self.expression.xreplace(substitutions)
        if len(new_indices) == 0:
            return new_expression
        return PoolSum(new_expression, *new_indices)


def _render_sum_symbol(
    printer: LatexPrinter, idx: sp.Symbol, values: Sequence[float]
) -> str:
    if len(values) == 0:
        return ""
    idx = printer._print(idx)
    if len(values) == 1:
        value = values[0]
        return Rf"\sum_{{{idx}={value}}}"
    if _is_regular_series(values):
        sorted_values = sorted(values)
        first_value = sorted_values[0]
        last_value = sorted_values[-1]
        return Rf"\sum_{{{idx}={first_value}}}^{{{last_value}}}"
    idx_values = ",".join(map(printer._print, values))
    return Rf"\sum_{{{idx}\in\left\{{{idx_values}\right\}}}}"


def _is_regular_series(values: Sequence[float]) -> bool:
    if len(values) <= 1:
        return False
    sorted_values = sorted(values)
    for val, next_val in zip(sorted_values, sorted_values[1:]):
        difference = float(next_val - val)
        if difference != 1.0:
            return False
    return True

In [None]:
half = sp.S.Half
phi, theta = sp.symbols("phi theta")
spin_parent = sp.Symbol(R"s_{\Lambda_c}")
spin_resonance = sp.Symbol(R"s_\Delta")
lambda_parent = sp.Symbol(R"\lambda_{\Lambda_c}")
lambda_resonance = sp.Symbol(R"\lambda_\Delta")
lambda_p = sp.Symbol(R"\lambda_p")
lambda_k = sp.Symbol(R"\lambda_K")
lambda_pi = sp.Symbol(R"\lambda_\pi")
sum_expr = PoolSum(
    Wigner.D(
        spin_parent, lambda_parent, lambda_k - lambda_resonance, phi, theta, 0
    )
    * Wigner.D(
        spin_resonance, lambda_resonance, lambda_p - lambda_pi, phi, theta, 0
    ),
    (lambda_parent, [-half, +half]),
    (lambda_resonance, [-half, +half]),
    (lambda_p, [-half, +half]),
    (lambda_pi, [0]),
    (lambda_k, [0]),
    (spin_parent, [half]),
    (spin_resonance, [3 * half]),
)
display(
    sum_expr.expression,
    sum_expr.cleanup(),
    sum_expr.doit(),
)