# Invensity distribution for $\Lambda_c^+ \to p K^- \pi^+$ default model

In [None]:
import os
import pickle

import jax.numpy as jnp
import sympy as sp

#
from tensorwaves.function.sympy import create_function

#### Build the model

The model is saved in a python dictionary dumped to a pickle file.
The dictionary contains a sympy expressions for the model and default values of the constant which are substituted in the `fully_substitute` function.

In [None]:
def load_model(filename):
    if not os.path.exists(filename):
        raise ValueError(f"The input file not found at ${filename}")
    with open(filename, "rb") as f:
        return pickle.load(f)

In [None]:
def fully_substitute(model_description):
    expr = (
        model_description["intensity"]
        .xreplace(model_description["variables"])
        .xreplace(model_description["Hproduction"])
        .xreplace(model_description["sigma3"])
        .xreplace(model_description["parameter_defaults"])
    )
    return expr

In [None]:
docs_dir = os.getcwd()
filename = f"{docs_dir}/.sympy-cache-default-model.pkl"
#
model_description = load_model(filename)
intensity_on_2vars = fully_substitute(model_description)

### Compilation

The symbolic expression depends on two variables:

- $\sigma_1 == m_{K\pi}^2$, mass of the $K^- \pi^+$ system, and
- $\sigma_2 == m_{pK}^2$, mass of the $p K^-$ system

The expression is turned into a numerical function by either `sp.lamdify`, or using the `jax` backend.

For `sympy` backend the position argument are used.

In [None]:
s12 = sp.symbols("sigma1:3", nonnegative=True)
assert intensity_on_2vars.free_symbols == set(s12)
#
lf = sp.lambdify(s12, intensity_on_2vars)

In [None]:
lf(1.0, 3.0), lf(1.1, 3.2)

The compilation with the `jax` backend is done by `tensorwaves` functionality

In [None]:
density = create_function(intensity_on_2vars, backend="jax")

In [None]:
density({"sigma1": jnp.array([1.0, 1.1]), "sigma2": jnp.array([3.0, 3.2])})