# Model serialization

```{autolink-concat}
```

This page demonstrates a strategy for exporting an amplitude model with its suggested parameter defaults to disk and loading it back into memory later on for computations with the computational backend.

In [None]:
from __future__ import annotations

import os
import pickle
from textwrap import shorten

import cloudpickle
import jax.numpy as jnp
import sympy as sp
from IPython.display import Markdown
from tensorwaves.function.sympy import create_function

from polarimetry.io import mute_jax_warnings, perform_cached_doit
from polarimetry.lhcb import load_model
from polarimetry.lhcb.particle import load_particles

mute_jax_warnings()

## Export model

In [None]:
data_dir = "../../data"
particles = load_particles(f"{data_dir}/particle-definitions.yaml")
model = load_model(f"{data_dir}/model-definitions.yaml", particles, model_id=0)
unfolded_intensity_expr = perform_cached_doit(model.full_expression)

In [None]:
free_parameters = {
    symbol: value
    for symbol, value in model.parameter_defaults.items()
    if isinstance(symbol, sp.Indexed)
    if "production" in str(symbol)
}
fixed_parameters = {
    symbol: value
    for symbol, value in model.parameter_defaults.items()
    if symbol not in free_parameters
}
subs_intensity_expr = unfolded_intensity_expr.xreplace(fixed_parameters)

In [None]:
def sigma3via12() -> dict[sp.Symbol, sp.Expr]:
    s1, s2, s3 = sp.symbols("sigma1:4", nonnegative=True)
    m0, m1, m2, m3 = sp.symbols("m:4", nonnegative=True)
    return {s3: m1**2 + m2**2 + m3**2 + m0**2 - s1 - s2}

In [None]:
dict_forms = {
    "intensity_expr": unfolded_intensity_expr,
    "variables": {k: v.doit() for (k, v) in model.variables.items()},
    "parameter_defaults": model.parameter_defaults,
    "sigma3": sigma3via12(),
}

In [None]:
filename = "exported-model.pkl"
with open(filename, "wb") as f:
    cloudpickle.dump(dict_forms, f)

## Import model

The model is saved in a Python {obj}`dict` and to a {mod}`pickle` file. The dictionary contains a SymPy expressions for the model and suggested parameter default values. These parameter and variable symbols are substituted using the `fully_substitute()` function.

In [None]:
def load_model(filename: str) -> dict:
    if not os.path.exists(filename):
        msg = f"The input file not found at ${filename}"
        raise ValueError(msg)
    with open(filename, "rb") as f:
        return pickle.load(f)

In [None]:
def fully_substitute(model_description: dict) -> sp.Expr:
    return (
        model_description["intensity_expr"]
        .xreplace(model_description["variables"])
        .xreplace(model_description["sigma3"])
        .xreplace(model_description["parameter_defaults"])
    )

In [None]:
imported_model = load_model(filename)
intensity_on_2vars = fully_substitute(imported_model)

### Compilation

The resulting symbolic expression depends on two variables:

- $\sigma_1 = m_{K\pi}^2$, mass of the $K^- \pi^+$ system, and
- $\sigma_2 = m_{pK}^2$, mass of the $p K^-$ system.

This expression is turned into a numerical function by either {func}`~sympy.utilities.lambdify.lambdify`, using {doc}`JAX<jax:index>` as a computational backend.

For {mod}`sympy` backend the position argument are used.

In [None]:
s12 = sp.symbols("sigma1:3", nonnegative=True)
assert intensity_on_2vars.free_symbols == set(s12)

func = sp.lambdify(s12, intensity_on_2vars)

In [None]:
func(1.0, 3.0), func(1.1, 3.2)

The compilation to JAX is facilitated by {mod}`tensorwaves`:

In [None]:
density = create_function(intensity_on_2vars, backend="jax")

In [None]:
density({"sigma1": jnp.array([1.0, 1.1]), "sigma2": jnp.array([3.0, 3.2])})

## Serialization with `srepr`

SymPy expressions can directly be serialized to Python code as well, with the function [`srepr()`](https://docs.sympy.org/latest/modules/printing.html#sympy.printing.repr.srepr). For the full intensity expression, we can do so with:

In [None]:
%%time
eval_str = sp.srepr(unfolded_intensity_expr)

In [None]:
n_nodes = sp.count_ops(unfolded_intensity_expr)
byt = len(eval_str.encode("utf-8"))
mb = f"{1e-6*byt:.2f}"
rendering = shorten(eval_str, placeholder=" ...", width=85)
src = f"""
This serializes the intensity expression of {n_nodes:,d} nodes
to a string of **{mb} MB**.

```
{rendering}
```
"""
Markdown(src)

It is up to the user, however, to import the classes of each exported node before the string can be unparsed with [`eval()`](https://docs.python.org/3/library/functions.html#eval) (see [this comment](https://github.com/ComPWA/polarimetry/issues/20#issuecomment-1809840854)).

In [None]:
eval(eval_str)

In the case of this intensity expression, it is sufficient to import all definition from the main `sympy` module and the `Str` class.

In [None]:
from sympy import *  # noqa: F403
from sympy.core.symbol import Str  # noqa: F401

In [None]:
%%time
eval_imported_intensity_expr = eval(eval_str)

Notice how the imported expression is **exactly the same** as the serialized one, including assumptions:

In [None]:
assert eval_imported_intensity_expr == unfolded_intensity_expr
assert hash(eval_imported_intensity_expr) == hash(unfolded_intensity_expr)

Optionally, the `import` statements can be embedded into the string. The parsing is then done with [`exec()`](https://docs.python.org/3/library/functions.html#exec) instead:

In [None]:
exec_str = f"""\
from sympy import *
from sympy.core.symbol import Str

def get_intensity_function() -> Expr:
    return {eval_str}
"""

In [None]:
exec_filename = "exported_intensity_model.py"
with open(exec_filename, "w") as f:
    f.write(exec_str)

In [None]:
src = f"""
See {{download}}`{exec_filename}` for the exported model.
"""
Markdown(src)

In [None]:
%%time
exec(exec_str)
exec_imported_intensity_expr = get_intensity_function()  # noqa: F405

In [None]:
assert exec_imported_intensity_expr == unfolded_intensity_expr
assert hash(exec_imported_intensity_expr) == hash(unfolded_intensity_expr)

:::{note}
The load time is faster due to caching within SymPy.
:::