In [None]:
%config InlineBackend.figure_formats = ['svg']
import os

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

```{autolink-concat}
```

<!-- cspell:ignore argnums jacrev -->

````{margin}
```{spec} Gradient with autodiff
:id: TR-999
:status: WIP
:tags: tensorwaves

In this report, we investigate whether autodiff can be be used to analytically compute the gradient of an amplitude model. The suspicion is that autodiff cannot handle large expressions well, because the chain rule results in an excessive number of computational nodes for the gradient of the function.
```
````

# Gradient with autodiff

In [None]:
%pip install -q "tensorwaves[jax,pwa]@git+https://github.com/ComPWA/tensorwaves@order-function-args" ampform~=0.14 qrules~=0.9.8

In [None]:
import inspect
import os

import ampform
import jax
import matplotlib.pyplot as plt
import numpy as np
import qrules
from ampform.dynamics.builder import (
    create_non_dynamic_with_ff,
    create_relativistic_breit_wigner_with_ff,
)
from ampform.io import aslatex
from IPython.display import Latex
from jax.tree_util import Partial
from matplotlib import cm
from tensorwaves.data import (
    IntensityDistributionGenerator,
    SympyDataTransformer,
    TFPhaseSpaceGenerator,
    TFUniformRealNumberGenerator,
    TFWeightedPhaseSpaceGenerator,
)
from tensorwaves.function.sympy import create_function, create_parametrized_function

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

## Formulate model

In [None]:
reaction = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [-1, +1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["a(0)", "f(0)", "omega"],
    allowed_interaction_types=["strong", "EM"],
    formalism="helicity",
)

In [None]:
import graphviz

dot = qrules.io.asdot(reaction, collapse_graphs=True)
graphviz.Source(dot)

In [None]:
model_builder = ampform.get_builder(reaction)
model_builder.adapter.permutate_registered_topologies()
model_builder.set_dynamics("J/psi(1S)", create_non_dynamic_with_ff)
for name in reaction.get_intermediate_particles().names:
    model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = model_builder.formulate()

In [None]:
model.intensity

In [None]:
selection = {k: v for i, (k, v) in enumerate(model.amplitudes.items()) if i < 3}
src = aslatex(selection)
del selection
Latex(src)

## Generate data

In [None]:
rng = TFUniformRealNumberGenerator(seed=0)
phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
phsp_momenta = phsp_generator.generate(100_000, rng)

In [None]:
helicity_transformer = SympyDataTransformer.from_sympy(
    model.kinematic_variables, backend="jax"
)

In [None]:
unfolded_expression = model.expression.doit()
substituted_expression = unfolded_expression.xreplace(model.parameter_defaults)
fixed_intensity_func = create_function(substituted_expression, backend="jax")

In [None]:
weighted_phsp_generator = TFWeightedPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
data_generator = IntensityDistributionGenerator(
    domain_generator=weighted_phsp_generator,
    function=fixed_intensity_func,
    domain_transformer=helicity_transformer,
)
data_momenta = data_generator.generate(10_000, rng)

In [None]:
list(helicity_transformer.functions)

In [None]:
phsp = helicity_transformer(phsp_momenta)
data = helicity_transformer(data_momenta)

In [None]:
sorted(substituted_expression.free_symbols, key=str)

In [None]:
list(model.kinematic_variables)

In [None]:
resonances = sorted(reaction.get_intermediate_particles(), key=lambda p: p.mass)
evenly_spaced_interval = np.linspace(0, 1, len(resonances))
colors = [cm.rainbow(x) for x in evenly_spaced_interval]
fig, ax = plt.subplots(figsize=(9, 4))
ax.hist(
    np.real(data["m_12"]),
    bins=200,
    alpha=0.5,
    density=True,
)
ax.set_xlabel("$m$ [GeV]")
for p, color in zip(resonances, colors):
    ax.axvline(x=p.mass, linestyle="dotted", label=p.name, color=color)
ax.legend()
plt.show()

## Gradient creation with autodiff

In [None]:
free_symbols = {
    symbol: value
    for symbol, value in model.parameter_defaults.items()
    if not symbol.name.startswith("d")
}
some_coefficient = next(s for s in free_symbols if s.name.startswith("C"))
free_symbols.pop(some_coefficient)
fixed_symbols = {
    symbol: value
    for symbol, value in model.parameter_defaults.items()
    if symbol not in free_symbols
}

In [None]:
src = aslatex({k: free_symbols[k] for k in sorted(free_symbols, key=str)})
Latex(src)

In [None]:
expression = unfolded_expression
expression = expression.xreplace(fixed_symbols)
intensity_func = create_parametrized_function(
    expression, parameters=free_symbols, backend="jax"
)

In [None]:
from IPython.display import Markdown

Markdown(f"Function has **{len(intensity_func.parameters)} free parameters**.")

In [None]:
sig = inspect.signature(intensity_func.function)
arg_names = tuple(sig.parameters)
data_columns = {
    arg: data[key]
    for arg, key in zip(arg_names, intensity_func.argument_order)
    if key in data
}
data_columns

In [None]:
parameter_values = {
    arg: complex(intensity_func.parameters[key]).real
    for arg, key in zip(arg_names, intensity_func.argument_order)
    if key in intensity_func.parameters
}
parameter_values

In [None]:
func_with_data_inserted = Partial(intensity_func.function, *data_columns.values())
gradient_func = jax.jacrev(
    func_with_data_inserted,
    argnums=range(len(parameter_values)),
)
gradient_func

In [None]:
%%time  # compilation
_ = tuple(v.block_until_ready() for v in gradient_func(*parameter_values.values()))

In [None]:
gradient_values = gradient_func(*parameter_values.values())
gradient_values[0].shape

## Optimize parameters

### Numerical gradient descent

In [None]:
phsp_columns = {
    arg: phsp[key]
    for arg, key in zip(arg_names, intensity_func.argument_order)
    if key in data
}
func_with_phsp_inserted = Partial(intensity_func.function, *phsp_columns.values())

In [None]:
func_with_phsp_inserted(*parameter_values.values())

In [None]:
import jax.numpy as jnp


# @jax.jit  # Do not JIT here, otherwise jax.jacrev crashes!
def estimator(args):
    data_intensities = func_with_data_inserted(*args)
    phsp_intensities = func_with_phsp_inserted(*args)
    likelihoods = data_intensities / jnp.mean(phsp_intensities)
    return -jnp.sum(jnp.log(likelihoods))


estimator(parameter_values.values())

In [None]:
import iminuit
from tqdm.auto import tqdm

PROGRESS_BAR = tqdm()


def estimator_with_progress_bar(*args, **kwargs):
    estimator_value = estimator(*args, **kwargs)
    PROGRESS_BAR.update()
    PROGRESS_BAR.set_postfix({"estimator": estimator_value})
    return estimator_value


starting_values = tuple(parameter_values.values())
optimizer = iminuit.Minuit(
    estimator_with_progress_bar,
    starting_values,
    name=tuple(parameter_values),
)
optimizer.errors = tuple(
    0.1 * abs(x) if abs(x) != 0.0 else 0.1 for x in starting_values
)
optimizer.errordef = iminuit.Minuit.LIKELIHOOD
optimizer.migrad()

### With analytic gradient

In [None]:
estimator_gradient = jax.jacrev(estimator)

In [None]:
%%time
estimator_gradient(tuple(parameter_values.values()))

In [None]:
PROGRESS_BAR = tqdm()  # reset
autodiff_optimizer = iminuit.Minuit(
    estimator_with_progress_bar,
    starting_values,
    grad=estimator_gradient,  # analytic!
    name=tuple(parameter_values),
)
autodiff_optimizer.errors = tuple(
    0.1 * abs(x) if abs(x) != 0.0 else 0.1 for x in starting_values
)
autodiff_optimizer.errordef = iminuit.Minuit.LIKELIHOOD
autodiff_optimizer.migrad()

## Conclusion

In [None]:
def compute_diff(minuit):
    original_pars = np.array(starting_values)
    optimized_pars = np.array([p.value for p in minuit.params])
    diff = original_pars - optimized_pars
    return np.sqrt(np.sum(np.abs(diff) ** 2)) / len(minuit.params)


src = f"""
|  | numerical | autodiff |
|--|-----------|----------|
| time (s) | {optimizer.fmin.time:.1f} | {autodiff_optimizer.fmin.time:.1f} |
| average parameter offset | {compute_diff(optimizer):.4f} | {compute_diff(autodiff_optimizer):.4f} |
"""
Markdown(src)