In [None]:
%config InlineBackend.figure_formats = ['svg']
import os

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

```{autolink-concat}
```

<!-- cspell:ignore argnums -->

::::{margin}
:::{card} Gradient of an amplitude model with autodiff
TR-999
^^^
In this report, we investigate whether autodiff can be be used to analytically compute the gradient of an amplitude model. The suspicion is that autodiff cannot handle large expressions well, because the chain rule results in an excessive number of computational nodes for the gradient of the function.
+++
WIP
:::
::::

# Gradient with autodiff

In [None]:
%pip install -q "tensorwaves[jax,pwa]@git+https://github.com/ComPWA/tensorwaves@order-function-args" ampform~=0.14 psutil==5.9.6 qrules~=0.9.8

In [None]:
from __future__ import annotations

import inspect
import os
from textwrap import dedent

import ampform
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import psutil
import qrules
import sympy as sp
from ampform.dynamics.builder import (
    create_non_dynamic_with_ff,
    create_relativistic_breit_wigner_with_ff,
)
from ampform.io import aslatex
from IPython.display import Latex, Markdown
from matplotlib import cm
from tensorwaves.data import (
    IntensityDistributionGenerator,
    SympyDataTransformer,
    TFPhaseSpaceGenerator,
    TFUniformRealNumberGenerator,
    TFWeightedPhaseSpaceGenerator,
)
from tensorwaves.function.sympy import create_parametrized_function


def display_memory_usage() -> None:
    process = psutil.Process(os.getpid())
    memory = process.memory_info().rss
    if memory < 1024**2:
        memory_str = f"{memory / 1024**1:.2f} kB"
    elif memory < 1024**3:
        memory_str = f"{memory / 1024**2:.2f} MB"
    else:
        memory_str = f"{memory / 1024**3:.2f} GB"
    msg = dedent(f"""
        :::{{hint}}
        Memory Usage: **{memory_str}**
        :::
        """).strip()
    display(Markdown(msg))


os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
display_memory_usage()

## Formulate model

In [None]:
REACTION = qrules.generate_transitions(
    initial_state="B0",
    final_state=["K+", "pi-", "pi0"],
    allowed_intermediate_particles=["K*(892)", "rho"],
    formalism="helicity",
    mass_conservation_factor=0,
)

In [None]:
import graphviz

dot = qrules.io.asdot(REACTION, collapse_graphs=True)
graphviz.Source(dot)

In [None]:
INITIAL_STATE, *_ = REACTION.initial_state.values()
BUILDER = ampform.get_builder(REACTION)
BUILDER.adapter.permutate_registered_topologies()
BUILDER.set_dynamics(INITIAL_STATE.name, create_non_dynamic_with_ff)
for name in REACTION.get_intermediate_particles().names:
    BUILDER.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
    del name
MODEL = BUILDER.formulate()

In [None]:
MODEL.intensity

In [None]:
selection = {k: v for i, (k, v) in enumerate(MODEL.amplitudes.items()) if i < 3}
src = aslatex(selection)
del selection
Latex(src)

In [None]:
display_memory_usage()

## Generate data

In [None]:
rng = TFUniformRealNumberGenerator(seed=0)
phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=REACTION.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in REACTION.final_state.items()},
)
phsp_momenta = phsp_generator.generate(100_000, rng)

In [None]:
helicity_transformer = SympyDataTransformer.from_sympy(
    MODEL.kinematic_variables, backend="jax"
)

In [None]:
unfolded_expression = MODEL.expression.doit()
intensity_func = create_parametrized_function(
    unfolded_expression,
    parameters=MODEL.parameter_defaults,
    backend="jax",
)

In [None]:
weighted_phsp_generator = TFWeightedPhaseSpaceGenerator(
    initial_state_mass=REACTION.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in REACTION.final_state.items()},
)
data_generator = IntensityDistributionGenerator(
    domain_generator=weighted_phsp_generator,
    function=intensity_func,
    domain_transformer=helicity_transformer,
)
data_momenta = data_generator.generate(10_000, rng)

In [None]:
phsp = helicity_transformer(phsp_momenta)
data = helicity_transformer(data_momenta)

## Gradient creation with autodiff

In [None]:
free_symbols = {
    symbol: value
    for symbol, value in MODEL.parameter_defaults.items()
    if symbol.name[0] in {"C", "d"}
}
fixed_symbols = {
    symbol: value
    for symbol, value in MODEL.parameter_defaults.items()
    if symbol not in free_symbols
}

In [None]:
intensity_func = create_parametrized_function(
    unfolded_expression.xreplace(fixed_symbols),
    parameters=free_symbols,
    backend="jax",
)

In [None]:
display_memory_usage()

In [None]:
%config InlineBackend.figure_formats = ['png']

In [None]:
def plot_mass_projection(ax, decay_ids: set[int]):
    decay_products = get_decay_products(decay_ids)
    decay_products_str = "".join(p.latex for p in decay_products)
    resonances = get_resonances(decay_ids)
    evenly_spaced_interval = np.linspace(0, 1, len(resonances))
    colors = [cm.rainbow(x) for x in evenly_spaced_interval]
    ax.hist(
        phsp[f"m_{''.join(map(str, sorted(decay_ids)))}"].real,
        bins=200,
        alpha=0.5,
        density=True,
        weights=intensity_func(phsp),
    )
    ax.set_xlabel(f"$m_{{{decay_products_str}}}$ [GeV]")
    for p, color in zip(resonances, colors):
        ax.axvline(x=p.mass, linestyle="dotted", label=f"${p.latex}$", color=color)
    ax.legend()


def get_decay_products(decay_ids: set[int]) -> tuple[Particle, Particle]:
    return tuple(REACTION.final_state[i] for i in sorted(decay_ids))


def get_resonances(decay_ids: set[int]) -> list[Particle]:
    resonances = {
        t.states[3].particle
        for t in REACTION.transitions
        if t.topology.get_edge_ids_outgoing_from_node(1) == decay_ids
    }
    return sorted(resonances, key=lambda p: (p.name[0], p.mass))


fig, axes = plt.subplots(figsize=(9, 12), nrows=3)
plot_mass_projection(axes[0], decay_ids={0, 1})
plot_mass_projection(axes[1], decay_ids={1, 2})
plot_mass_projection(axes[2], decay_ids={2, 0})
plt.show()

In [None]:
def indicate_resonances(ax_func, decay_ids) -> None:
    resonances = get_resonances(decay_ids)
    evenly_spaced_interval = np.linspace(0, 1, len(resonances))
    colors = [cm.rainbow(x) for x in evenly_spaced_interval]
    for p, color in zip(resonances, colors):
        ax_func(p.mass**2, linestyle="dotted", label=f"${p.latex}$", color=color)


x_subsystem = {0, 1}
y_subsystem = {1, 2}
x_products = get_decay_products(x_subsystem)
y_products = get_decay_products(y_subsystem)
fig, ax = plt.subplots(figsize=(8, 6))
ax.hist2d(
    phsp[f"m_{''.join(map(str, sorted(x_subsystem)))}"].real ** 2,
    phsp[f"m_{''.join(map(str, sorted(y_subsystem)))}"].real ** 2,
    bins=100,
    cmin=1,
    weights=intensity_func(phsp),
)
ax.set_xlabel(f"$m_{{{' '.join(p.latex for p in x_products)}}}$")
ax.set_ylabel(f"$m_{{{' '.join(p.latex for p in y_products)}}}$")
indicate_resonances(ax.axvline, x_subsystem)
indicate_resonances(ax.axhline, y_subsystem)
plt.show()

In [None]:
display_memory_usage()

## Optimize parameters

### Numerical gradient descent

In [None]:
sig = inspect.signature(intensity_func.function)
arg_names = tuple(sig.parameters)
arg_to_par = {
    arg: par
    for arg, par in zip(arg_names, intensity_func.argument_order)
    if par in intensity_func.parameters
}
idx_to_par = dict(enumerate(arg_to_par.values()))
parameter_values = {
    arg: complex(intensity_func.parameters[par]).real
    for arg, par in arg_to_par.items()
}

In [None]:
def estimator(args):
    parameters = {idx_to_par[i]: v for i, v in enumerate(args)}
    intensity_func.update_parameters(parameters)
    data_intensities = intensity_func(data)
    phsp_intensities = intensity_func(phsp)
    likelihoods = data_intensities / jnp.mean(phsp_intensities)
    return -jnp.sum(jnp.log(likelihoods))


estimator(parameter_values.values())

In [None]:
arr = jnp.array([1.0])
msg = f"""
:::{{hint}}
JAX is using this precision: **{arr.dtype}**. For the model, we have:
- {len(REACTION.get_intermediate_particles())} resonances in {len(REACTION.transition_groups)} subsystems
- {len(parameter_values)} of {len(MODEL.parameter_defaults)} free parameters
- {sp.count_ops(unfolded_expression):,d} computational nodes
:::
"""
msg = dedent(msg).strip()
display(Markdown(msg))

In [None]:
import iminuit
from tqdm.auto import tqdm

PROGRESS_BAR = tqdm()


def estimator_with_progress_bar(*args, **kwargs):
    estimator_value = estimator(*args, **kwargs)
    PROGRESS_BAR.update()
    PROGRESS_BAR.set_postfix({"estimator": f"{estimator_value:,10g}"})
    return estimator_value


RNG = np.random.default_rng(seed=0)
δ = 0.01
starting_values = tuple(
    p * RNG.uniform(1 - δ, 1 + δ) for p in parameter_values.values()
)
optimizer = iminuit.Minuit(
    estimator_with_progress_bar,
    starting_values,
    name=tuple(parameter_values),
)
optimizer.errors = tuple(
    0.1 * abs(x) if abs(x) != 0.0 else 0.1 for x in starting_values
)
optimizer.errordef = iminuit.Minuit.LIKELIHOOD
optimizer.migrad()

### With analytic gradient

In [None]:
display_memory_usage()