```{autolink-concat}
```

::::{margin}
:::{card} Symbolic expressions and model serialization
TR-024
^^^
<!--- cspell:ignore COMAP dodgerblue dummify fillcolor indianred srepr --->
Investigation into dumping SymPy expressions to human-readable format for model preservation. The notebook was motivated by the [COMAP-V workshop on analysis preservation](https://indico.cern.ch/event/1348003/). See also SymPy [printing](https://docs.sympy.org/latest/modules/printing.html), [parsing](https://docs.sympy.org/latest/modules/parsing.html), and [expression manipulation](https://docs.sympy.org/latest/tutorials/intro-tutorial/manipulation.html).
+++
🚧&nbsp;[polarimetry#319](https://github.com/ComPWA/polarimetry/pull/319)
:::
::::

# Symbolic model serialization

In [None]:
from pathlib import Path
from textwrap import shorten

import graphviz
import polarimetry
import sympy as sp
from ampform.io import aslatex
from ampform.sympy import unevaluated
from IPython.display import Markdown, Math
from polarimetry.amplitude import simplify_latex_rendering
from polarimetry.io import perform_cached_doit
from polarimetry.lhcb import load_model
from polarimetry.lhcb.particle import load_particles
from sympy.printing.mathml import MathMLPresentationPrinter

simplify_latex_rendering()

## Expression trees

SymPy expressions are built up from symbols and mathematical operations as follows:

In [None]:
x, y, z = sp.symbols("x y z")
expression = sp.sin(x * y) / 2 - x**2 + 1 / z
expression

In the back, SymPy represents these expressions as **trees**. There are a few ways to visualize this for this particular example:

In [None]:
sp.printing.tree.print_tree(expression, assumptions=False)

In [None]:
src = sp.dotprint(
    expression,
    styles=[
        (sp.Number, {"color": "grey", "fontcolor": "grey"}),
        (sp.Symbol, {"color": "royalblue", "fontcolor": "royalblue"}),
    ],
)
graphviz.Source(src)

Expression trees are powerful, because we can use them as templates for any human-readable presentation we are interested in. In fact, the LaTeX representation that we saw when constructing the expression was generated by SymPy's LaTeX printer.

In [None]:
src = sp.latex(expression)
Markdown(f"```latex\n{src}\n```")

:::{hint} SymPy expressions can serve as a template for generating code!
:::

Here's a number of other representations:

In [None]:
def to_mathml(expr: sp.Expr) -> str:
    printer = MathMLPresentationPrinter()
    xml = printer._print(expr)
    return xml.toprettyxml().replace("\t", "  ")


Markdown(
    f"""
```python
# Python
{sp.pycode(expression)}
```
```cpp
// C++
{sp.cxxcode(expression, standard="c++17")}
```
```fortran
! Fortran
{sp.fcode(expression).strip()}
```
```matlab
% Matlab / Octave
{sp.octave_code(expression)}
```
```julia
# Julia
{sp.julia_code(expression)}
```
```rust
// Rust
{sp.rust_code(expression)}
```
```xml
<!-- MathML -->
{to_mathml(expression)}
```
"""
)

## Foldable expressions

The previous example is quite simple, but SymPy works just as well with huge expressions, as we will see in [Large expressions](#large-expressions). Before, though, let's have a look how to define these larger expressions in such a way that we can still read them. A nice solution is to define {class}`sp.Expr <sympy.core.expr.Expr>` classes with the `@unevaluated` decorator (see [ComPWA/ampform#364](https://github.com/ComPWA/ampform/issues/364)). Here, we define a Chew-Mandelstam function $\rho^\text{CM}$ for $S$-waves. This function requires the definition of a break-up momentum $q$.

In [None]:
@unevaluated(real=False)
class PhspFactorSWave(sp.Expr):
    s: sp.Symbol
    m1: sp.Symbol
    m2: sp.Symbol
    _latex_repr_ = R"\rho^\text{{CM}}\left({s}\right)"

    def evaluate(self) -> sp.Expr:
        s, m1, m2 = self.args
        q = BreakupMomentum(s, m1, m2)
        cm = (
            (2 * q / sp.sqrt(s))
            * sp.log((m1**2 + m2**2 - s + 2 * sp.sqrt(s) * q) / (2 * m1 * m2))
            - (m1**2 - m2**2) * (1 / s - 1 / (m1 + m2) ** 2) * sp.log(m1 / m2)
        ) / (16 * sp.pi**2)
        return 16 * sp.pi * sp.I * cm


@unevaluated(real=False)
class BreakupMomentum(sp.Expr):
    s: sp.Symbol
    m1: sp.Symbol
    m2: sp.Symbol
    _latex_repr_ = R"q\left({s}\right)"

    def evaluate(self) -> sp.Expr:
        s, m1, m2 = self.args
        return sp.sqrt((s - (m1 + m2) ** 2) * (s - (m1 - m2) ** 2) / (s * 4))

We now have a very clean mathematical representation of how the $\rho^\text{CM}$ function is defined in terms of $q$:

In [None]:
s, m1, m2 = sp.symbols("s m1 m2")
q_expr = BreakupMomentum(s, m1, m2)
ρ_expr = PhspFactorSWave(s, m1, m2)
Math(aslatex({e: e.evaluate() for e in [ρ_expr, q_expr]}))

Now, let's build up a more complicated expression that contains this phase space factor. Here, we use SymPy to derive a Breit-Wigner using a single-channel [$K$ matrix](https://doi.org/10.1002/andp.19955070504) {cite}`Chung:1995dx`:

In [None]:
I = sp.Identity(n=1)
K = sp.MatrixSymbol("K", m=1, n=1)
ρ = sp.MatrixSymbol("rho", m=1, n=1)
T = (I - sp.I * K * ρ).inv() * K
T

In [None]:
T.as_explicit()[0, 0]

Here we need to provide definitions for the matrix elements of $K$ and $\rho$. A suitable choice is our phase space factor for $S$ waves we defined above:

In [None]:
m0, Γ0, γ0 = sp.symbols("m0 Gamma0 gamma0")
K_expr = (γ0**2 * m0 * Γ0) / (s - m0**2)

In [None]:
substitutions = {
    K[0, 0]: K_expr,
    ρ[0, 0]: ρ_expr,
}
Math(aslatex(substitutions))

And there we have it! After some [algebraic simplifications](https://docs.sympy.org/latest/tutorials/intro-tutorial/simplification.html), we get a Breit-Wigner with Chew-Mandelstam phase space factor for $S$ waves:

In [None]:
T_expr = T.as_explicit().xreplace(substitutions)
BW_expr = T_expr[0, 0].simplify(doit=False)
BW_expr

The expression tree now has a node that is 'folded':

In [None]:
dot_style = [
    (sp.Basic, {"style": "filled", "fillcolor": "white"}),
    (sp.Atom, {"color": "gray", "style": "filled", "fillcolor": "white"}),
    (sp.Symbol, {"color": "dodgerblue1"}),
    (PhspFactorSWave, {"color": "indianred2"}),
]
dot = sp.dotprint(BW_expr, bgcolor=None, styles=dot_style)
graphviz.Source(dot)

After unfolding, we get the full expression tree of fundamental mathematical operations:

In [None]:
dot = sp.dotprint(BW_expr.doit(), bgcolor=None, styles=dot_style)
graphviz.Source(dot)

## Large expressions

Here, we import the large symbolic intensity expression that was used for [![10.1007/JHEP07(2023)228](<https://zenodo.org/badge/doi/10.1007/JHEP07(2023)228.svg>)](<https://doi.org/10.1007/JHEP07(2023)228>) and see how well SymPy serialization performs on a much more complicated model.

In [None]:
DATA_DIR = Path(polarimetry.__file__).parent / "lhcb"
PARTICLES = load_particles(DATA_DIR / "particle-definitions.yaml")
MODEL = load_model(DATA_DIR / "model-definitions.yaml", PARTICLES, model_id=0)
unfolded_intensity_expr = perform_cached_doit(MODEL.full_expression)

In [None]:
Markdown(
    f"""
The model contains **{sp.count_ops(unfolded_intensity_expr):,d}** mathematical operations.
See [ComPWA/polarimetry#319](https://github.com/ComPWA/polarimetry/pull/319) for the origin
of this investigation.
"""
)

## Serialization with `srepr`

SymPy expressions can directly be serialized to Python code as well, with the function [`srepr()`](https://docs.sympy.org/latest/modules/printing.html#sympy.printing.repr.srepr). For the full intensity expression, we can do so with:

In [None]:
%%time

eval_str = sp.srepr(unfolded_intensity_expr)

In [None]:
n_nodes = sp.count_ops(unfolded_intensity_expr)
byt = len(eval_str.encode("utf-8"))
mb = f"{1e-6 * byt:.2f}"
rendering = shorten(eval_str, placeholder=" ...", width=85)
src = f"""
This serializes the intensity expression of {n_nodes:,d} nodes
to a string of **{mb} MB**.

```python
{rendering} {")" * (rendering.count("(") - rendering.count(")"))}
```
"""
Markdown(src)

It is up to the user, however, to import the classes of each exported node before the string can be unparsed with [`eval()`](https://docs.python.org/3/library/functions.html#eval) (see [this comment](https://github.com/ComPWA/polarimetry/issues/20#issuecomment-1809840854)).

In [None]:
imported_intensity_expr = eval(eval_str)

In the case of this intensity expression, it is sufficient to import all definition from the main `sympy` module and the `Str` class. Optionally, the required `import` statements can be embedded into the string:

In [None]:
exec_str = f"""\
from sympy import *
from sympy.core.symbol import Str

def get_intensity_function() -> Expr:
    return {eval_str}
"""

In [None]:
exec_filename = Path("../_static/exported_intensity_model.py")
with open(exec_filename, "w") as f:
    f.write(exec_str)

In [None]:
Markdown(f"See [`{exec_filename.name}`]({exec_filename}) for the exported model.")

The parsing is then done with [`exec()`](https://docs.python.org/3/library/functions.html#exec) instead of the [`eval()`](https://docs.python.org/3/library/functions.html#eval) function:

In [None]:
%%time

exec(exec_str)
imported_intensity_expr = get_intensity_function()

Notice how the imported expression is **exactly the same** as the serialized one, including assumptions:

In [None]:
assert imported_intensity_expr == unfolded_intensity_expr
assert hash(imported_intensity_expr) == hash(unfolded_intensity_expr)

### Common sub-expressions

A problem is that the expression exported generated with [`srepr()`](https://docs.sympy.org/latest/modules/printing.html#sympy.printing.repr.srepr) is not human-readable in practice for large expressions. One way out may be to extract common components of the main expression with [Foldable expressions](#foldable-expressions). Another may be to use SymPy to [detect and collect common sub-expressions](https://docs.sympy.org/latest/modules/rewriting.html#common-subexpression-detection-and-collection).

In [None]:
sub_exprs, common_expr = sp.cse(unfolded_intensity_expr, order="none")

In [None]:
Math(sp.multiline_latex(sp.Symbol("I"), common_expr[0], environment="eqnarray"))

In [None]:
Math(aslatex(dict(sub_exprs[:10])))

This already works quite well with {func}`sp.lambdify <sympy.utilities.lambdify.lambdify>` (without `cse=True`, this would takes minutes):

In [None]:
%%time

args = sorted(unfolded_intensity_expr.free_symbols, key=str)
_ = sp.lambdify(args, unfolded_intensity_expr, cse=True, dummify=True)

Still, as can be seen above, there are many sub-expressions that have exactly the same form. It would be better to find those expressions that have a similar structure, so that we can serialize them to functions or custom sub-definitions.

In SymPy, the equivalence between the expressions can be determined by the [`match()`](https://docs.sympy.org/latest/modules/core.html#sympy.core.basic.Basic.match) method using [`Wild`](https://docs.sympy.org/latest/modules/core.html#sympy.core.symbol.Wild) symbols. We therefore first have to make all symbols in the common sub-expressions 'wild'. In addition, in the case of this intensity expression, some of symbols are [indexed](https://docs.sympy.org/latest/modules/tensor/indexed.html) and need to be replaced first.

In [None]:
pure_symbol_expr = unfolded_intensity_expr.replace(
    query=lambda z: isinstance(z, sp.Indexed),
    value=lambda z: sp.Symbol(sp.latex(z), **z.assumptions0),
)
sub_exprs, common_expr = sp.cse(pure_symbol_expr, order="none")

Note that for example the following two common sub-expressions are equivalent:

In [None]:
Math(aslatex({k: v for i, (k, v) in enumerate(sub_exprs) if i in {5, 8}}))

[`Wild`](https://docs.sympy.org/latest/modules/core.html#sympy.core.symbol.Wild) symbols now allow us to find how these expressions relate to each other.

In [None]:
is_symbol = lambda z: isinstance(z, sp.Symbol)
make_wild = lambda z: sp.Wild(z.name)
X = [x.replace(is_symbol, make_wild) for _, x in sub_exprs]
Math(aslatex(X[5].match(X[8])))

:::{hint}
This can be used to define functions for larger, common expression blocks.
:::