In [None]:
%config InlineBackend.figure_formats = ['svg']
import os

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

```{autolink-concat}
```

````{margin}
```{spec} Intensity distribution generator with importance sampling
:id: TR-018
:status: WIP
:tags: physics;tensorwaves

This reports sets out how data generation with TensorWaves works and what would be the best approach to tackle [ComPWA/tensorwaves#402](https://github.com/ComPWA/tensorwaves/issues/402).
```
````

# Importance sampling

In [None]:
%pip install -q ampform==0.14.1 git+https://github.com/zfit/phasespace@7131fbd qrules[viz]==0.9.7 scipy==1.9.0 sympy==1.10.1 tensorwaves[jax,pwa]==0.4.6

## Model definition

In [None]:
import logging
import os
import warnings

import jax.numpy as jnp
import numpy as np

logging.getLogger("absl").setLevel(logging.ERROR)  # no JAX warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # no TF warnings
warnings.filterwarnings("ignore")  # sqrt negative argument

In [None]:
import qrules

reaction = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [-1, +1]),
    final_state=["K0", "Sigma+", "p~"],
    allowed_intermediate_particles=["Sigma(1660)~-", "N(1650)+"],
    allowed_interaction_types="strong",
    formalism="canonical-helicity",
)

In [None]:
import graphviz

src = qrules.io.asdot(reaction, collapse_graphs=True)
graphviz.Source(src)

In [None]:
import ampform
from ampform.dynamics.builder import (
    create_non_dynamic_with_ff,
    create_relativistic_breit_wigner_with_ff,
)

builder = ampform.get_builder(reaction)
builder.adapter.permutate_registered_topologies()
builder.scalar_initial_state_mass = True
builder.stable_final_state_ids = [0, 1, 2]
builder.set_dynamics("J/psi(1S)", create_non_dynamic_with_ff)
for name in reaction.get_intermediate_particles().names:
    builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = builder.formulate()

## Phase space distribution

In [None]:
from tensorwaves.data import SympyDataTransformer

transformer = SympyDataTransformer.from_sympy(
    model.kinematic_variables, backend="jax"
)

An evenly distributed phase space sample can be generated with a {class}`~tensorwaves.data.TFPhaseSpaceGenerator`:

In [None]:
from tensorwaves.data import TFPhaseSpaceGenerator, TFUniformRealNumberGenerator

rng = TFUniformRealNumberGenerator(seed=0)
phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
phsp = phsp_generator.generate(1_000_000, rng)
phsp = transformer(phsp)

In [None]:
%config InlineBackend.figure_formats = ['png']

In [None]:
import matplotlib.pyplot as plt


def convert_zero_to_nan(array):
    array = np.array(array).astype("float")
    array[array == 0] = np.nan
    return jnp.array(array)


Z, x_edges, y_edges = jnp.histogram2d(
    phsp["m_01"].real ** 2,
    phsp["m_12"].real ** 2,
    bins=100,
)
X, Y = jnp.meshgrid(x_edges, y_edges)
Z = convert_zero_to_nan(Z)

bin_width_x = X[0, 1] - X[0, 0]
bin_width_y = Y[1, 0] - Y[0, 0]
bar_title = (
    Rf"events per ${1e3*bin_width_x:.0f} \times {1e3*bin_width_y:.0f}$ MeV$^2/c^4$"
)
xlabel = R"$M^2\left(K^0\Sigma^+\right)$"
ylabel = R"$M^2\left(\Sigma^+\bar{p}\right)$"

fig, ax = plt.subplots(dpi=200, figsize=(4.5, 4))
ax.set_title("TFPhaseSpaceGenerator sample")
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
mesh = ax.pcolormesh(X, Y, Z)
c_bar = plt.colorbar(mesh, ax=ax)
c_bar.ax.set_ylabel(bar_title)
plt.show()

This {class}`~tensorwaves.data.TFPhaseSpaceGenerator` actually uses a **hit-and-miss** strategy on a distribution and its weights generated by a {class}`~tensorwaves.data.TFWeightedPhaseSpaceGenerator`. That generator interfaces to the [`phasespace`](https://phasespace.readthedocs.io) package.

::::{margin}
:::{seealso}
[ComPWA/tensorwaves#16](https://github.com/ComPWA/tensorwaves/issues/16) on a Python interface for [`EvtGen`](https://gitlab.cern.ch/evtgen/evtgen).
:::
::::

In [None]:
from tensorwaves.data import TFWeightedPhaseSpaceGenerator

weighted_phsp_generator = TFWeightedPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
unweighted_phsp, phsp_weights = weighted_phsp_generator.generate(1_000_000, rng)
unweighted_phsp = transformer(unweighted_phsp)

In [None]:
from scipy.interpolate import griddata

n_bins = 100
x = unweighted_phsp["m_01"].real ** 2
y = unweighted_phsp["m_12"].real ** 2
X, Y = jnp.meshgrid(
    jnp.linspace(x.min(), x.max(), num=n_bins),
    jnp.linspace(y.min(), y.max(), num=n_bins),
)

Z_unweighted, x_edges, y_edges = jnp.histogram2d(x, y, bins=n_bins)
Z_weighted, x_edges, y_edges = jnp.histogram2d(
    x, y, bins=n_bins, weights=phsp_weights
)
Z_weights = griddata(np.transpose([x, y]), phsp_weights, (X, Y), method="linear")

X_edges, Y_edges = np.meshgrid(x_edges, y_edges)
Z_unweighted = convert_zero_to_nan(Z_unweighted)
Z_weighted = convert_zero_to_nan(Z_weighted)

fig, axes = plt.subplots(
    dpi=200,
    figsize=(16, 5),
    ncols=3,
    tight_layout=True,
)
for ax in axes:
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
ax1, ax2, ax3 = axes
ax1.set_title("Unweighted distribution")
ax2.set_title("Weights")
ax3.set_title("Weighted phase space distribution")

mesh = ax1.pcolormesh(X_edges, Y_edges, Z_unweighted)
c_bar = plt.colorbar(mesh, ax=ax1)
c_bar.ax.set_ylabel(bar_title)

mesh = ax2.pcolormesh(X, Y, Z_weights)
c_bar = plt.colorbar(mesh, ax=ax2)
c_bar.ax.set_ylabel("phase space weight")

mesh = ax3.pcolormesh(X_edges, Y_edges, Z_weighted)
c_bar = plt.colorbar(mesh, ax=ax3)
c_bar.ax.set_ylabel(bar_title)

plt.show()

## Intensity distribution

In [None]:
from tensorwaves.function.sympy import create_function

intensity_expr = model.expression.doit()
intensity_func = create_function(
    expression=intensity_expr.xreplace(model.parameter_defaults),
    backend="jax",
)

In [None]:
from tensorwaves.data import IntensityDistributionGenerator

data_generator = IntensityDistributionGenerator(
    domain_generator=weighted_phsp_generator,
    function=intensity_func,
    domain_transformer=transformer,
)
data = data_generator.generate(100_000, rng)
data = transformer(data)

In [None]:
Z, x_edges, y_edges = jnp.histogram2d(
    data["m_01"].real ** 2,
    data["m_12"].real ** 2,
    bins=100,
)
X, Y = jnp.meshgrid(x_edges, y_edges)
Z = convert_zero_to_nan(Z)

fig, ax = plt.subplots(dpi=200, figsize=(4.5, 4))
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
mesh = ax.pcolormesh(X, Y, Z)
c_bar = plt.colorbar(mesh, ax=ax)
c_bar.ax.set_ylabel("intensity")
plt.show()