<!-- cspell:ignore jquery ndigits nowrap -->

# AmpForm physics mini-demo

This demo focuses on AmpForm's main features

_Initialize all cells first by clicking the cell below and hitting `Shift+Enter`:_

In [None]:
from IPython.display import Javascript

Javascript("IPython.notebook.execute_cells_below()")

In [None]:
%config Completer.use_jedi = False
%config InlineBackend.figure_formats = ['svg']

import logging
import warnings

import ampform
import graphviz
import matplotlib.pyplot as plt
import numpy as np
import qrules
import sympy as sp
import tensorflow as tf
from ampform.dynamics.builder import (
    create_analytic_breit_wigner,
    create_relativistic_breit_wigner_with_ff,
)
from IPython.display import HTML, Math
from matplotlib import cm
from tensorwaves.data import generate_data, generate_phsp
from tensorwaves.data.transform import HelicityTransformer
from tensorwaves.model import LambdifiedFunction, SympyModel

tf.get_logger().setLevel("INFO")
logging.getLogger().setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

HTML("<style>div.prompt {display:none}</style>")

In [None]:
def indicate_masses(reaction, ax=None):
    if ax is None:
        ax = plt.gca()
    intermediate_states = sorted(
        (
            p
            for p in model.particles
            if p not in reaction.final_state.values()
            and p not in reaction.initial_state.values()
        ),
        key=lambda p: p.mass,
    )
    evenly_spaced_interval = np.linspace(0, 1, len(intermediate_states))
    colors = [cm.rainbow(x) for x in evenly_spaced_interval]
    for i, p in enumerate(intermediate_states):
        ax.axvline(x=p.mass, linestyle="dotted", label=p.name, color=colors[i])


def render_side_by_side(reaction, render_final_state_id=False):
    for i, transition_group in enumerate(reaction.transition_groups):
        dot = qrules.io.asdot(
            transition_group,
            collapse_graphs=True,
            render_final_state_id=render_final_state_id,
            render_node=False,
        )
        graphviz.Source(dot).render(f"reaction_{i}", format="svg")

    image_paths = [
        f"reaction_{i}.svg" for i in range(len(reaction.transition_groups))
    ]
    html_source = '<div style="white-space: nowrap">\n'
    for path in image_paths:
        with open(path) as stream:
            svg = stream.read()
        html_source += svg
    html_source += "\n</div>"
    display(HTML(html_source))


def round_floats(expression: sp.Expr, ndigits=2) -> sp.Expr:
    new_expression = expression
    for node in sp.preorder_traversal(expression):
        if isinstance(node, sp.Float):
            new_expression = new_expression.subs(node, round(node, ndigits))
    return new_expression

In [None]:
PDG = qrules.load_pdg()
particle_selection = qrules.ParticleCollection()
particle_selection += PDG["D0"]
particle_selection += PDG.filter(
    lambda p: p.name.startswith("K") and len(p.name) == 2
)
particle_selection += PDG.filter(lambda p: p.name.startswith("a("))
particle_selection += PDG.filter(lambda p: p.name.startswith("rho"))
particle_selection += PDG.filter(lambda p: p.name.startswith("phi"))

## Common Partial Wave Analysis

ComPWA provides an _ecosystem_ of three Python packages:
1. `QRules`: quantum number conservation rules ('resonance finder')
2. `AmpForm`: spin formalisms and dynamics
3. `TensorWaves`: toy MC and model optimization with multiple computational back-ends

Main targets are:
- to make PWA intelligible with **modern and interactive documentation** ([example](https://ampform.rtfd.io/en/stable/usage/dynamics/k-matrix.html))
- long-term, open source development that can be picked up by successors (**academic continuity**)
- to perform PWA with the latest tools from the Python data science community

## 1. [QRules](https://qrules.rtfd.io): generate allowed state transitions

- **Quantum number conservation rules** are encoded in `QRules`
- This allows us to investigate potential resonances

In [None]:
reaction = qrules.generate_transitions(
    initial_state="D0",
    final_state=["K-", "K+", "K0"],
    particle_db=particle_selection,
)
render_side_by_side(reaction)

`QRules` can also:
- check **conversation rule violations** in an invalid reaction
- serve as a **quantum number database** (PDG listings are default input)

## 2. [AmpForm](https://ampform.rtfd.io): express transitions as amplitude models

- `AmpForm` encodes **spin formalisms and dynamics**
- It can formulate QRules's state transitions as an **amplitude model**

To simplify things, we focus on these two resonances only:

In [None]:
reaction = qrules.generate_transitions(
    initial_state="D0",
    final_state=["K-", "K+", "K0"],
    allowed_intermediate_particles=["a(0)(980)0", "a(2)(1320)+"],
    formalism="helicity",
)
render_side_by_side(reaction, render_final_state_id=True)

In our example, **helicity formalism** with the following dynamics:
- 𝑎₀(1320) lies well within phase space: relativistic Breit-Wigner with form factor
- 𝑎₂(980) lies below threshold: **analytic continuation** for phase space factor

In [None]:
builder = ampform.get_builder(reaction)
builder.set_dynamics("a(2)(1320)+", create_relativistic_breit_wigner_with_ff)
builder.set_dynamics("a(0)(980)0", create_analytic_breit_wigner)
model = builder.formulate()

The model is formulated with SymPy (Computer Algebra System) so that the math can be inspected:

In [None]:
assert len(model.components) == 3
amplitude_to_symbol = {
    expression: sp.Symbol(name)
    for name, expression in model.components.items()
    if not name.startswith("I")
}
top_expression = model.expression.subs(amplitude_to_symbol)
sym_a980, sym_a1320, sym_intensity = tuple(
    map(sp.Symbol, sorted(model.components))
)
Math(sp.latex(sym_intensity) + "=" + sp.latex(top_expression))

In [None]:
amplitudes = tuple(amplitude_to_symbol)
amp_a980, amp_a1320 = amplitudes
amp_a1320_case1 = sp.piecewise_fold(amp_a1320).args[1][0]
Math(sp.latex(sym_a1320) + "=" + sp.latex(amp_a1320_case1))

...and after substituting the suggested parameter values suggested by `AmpForm`:

In [None]:
extended_parameter_defaults = dict(model.parameter_defaults)
for state_id, particle in reaction.final_state.items():
    var_name = f"m_{state_id}"
    symbol = sp.Symbol(var_name, real=True)
    extended_parameter_defaults[symbol] = particle.mass
amp_a1320_case1_evaluated = round_floats(
    amp_a1320_case1.doit().subs(extended_parameter_defaults),
    ndigits=3,
)
Math("=" + sp.latex(sp.piecewise_fold(amp_a1320_case1_evaluated)))

The math is fancy, but SymPy offers another major advantage:<br>
_the amplitude model is an **expression tree** that can be converted to **several computational back-ends**!_

In [None]:
dot = sp.dotprint(amp_a1320_case1_evaluated)
filename = "expression_tree"
graphviz.Source(dot).render(filename, format="svg")
HTML(f'<img src="{filename}.svg" style="width:1300px">')

## 3. [TensorWaves](https://tensorwaves.rtfd.io): generate toy MC and optimize model

As an example, here is a data sample generated from the amplitude model with **TensorFlow**:

In [None]:
sympy_model = SympyModel(
    expression=model.expression,
    parameters=model.parameter_defaults,
)
intensity = LambdifiedFunction(sympy_model, backend="jax")
data_converter = HelicityTransformer(model.adapter)
initial_state_mass = reaction.initial_state[-1].mass
final_state_masses = {i: p.mass for i, p in reaction.final_state.items()}

phsp_sample = generate_phsp(100_000, initial_state_mass, final_state_masses)
data_sample = generate_data(
    2_000, initial_state_mass, final_state_masses, data_converter, intensity
)

phsp_set = data_converter.transform(phsp_sample)
data_set = data_converter.transform(data_sample)

In [None]:
def show_distribution(var, ax, bin_width=0.01):
    sample = data_set[var]
    n_bins = int((sample.max() - sample.min()) / 0.01)
    ax.hist(sample, bins=n_bins, alpha=0.5, density=True)
    ax.set_ylabel(f"events / {round(bin_width * 1e3)} MeV")
    indicate_masses(reaction, ax=ax)


fig, axes = plt.subplots(
    nrows=2, figsize=(12, 8), sharex=True, tight_layout=True
)
show_distribution(var="m_02", ax=axes[0])
show_distribution(var="m_12", ax=axes[1])
final_state = reaction.final_state
axes[0].set_xlabel(
    f"$m_{{{final_state[0].latex}{final_state[2].latex}}}$ [GeV]"
)
axes[1].set_xlabel(
    f"$m_{{{final_state[1].latex}{final_state[2].latex}}}$ [GeV]"
)
axes[0].legend();