```{autolink-concat}
```

::::{margin}
:::{card} Coupled-channel fit with $P$-vector dynamics for one single pole
TR-032
^^^
Illustration of how to formulate an amplitude model for two channels with P-vector dynamics. A combined fit is performed over the sum of the likelihood over both distributions. The example uses a single pole, but can easily be extended to multiple poles.
+++
🚧&nbsp;[compwa.github.io#278](https://github.com/ComPWA/compwa.github.io/pull/278)
:::
::::

# P-vector fit comparison

In [None]:
%pip install -q 'qrules[viz]==0.10.2' 'tensorwaves[jax,phsp]==0.4.12' ampform==0.15.4 pandas==2.2.2 sympy==1.12

In [None]:
from __future__ import annotations

import re
from collections import defaultdict
from functools import lru_cache
from itertools import product
from typing import Any, Iterable, Mapping

import ampform
import attrs
import graphviz
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import qrules
import sympy as sp
from ampform.dynamics.builder import TwoBodyKinematicVariableSet
from ampform.helicity import HelicityModel, ParameterValues
from ampform.io import aslatex
from ampform.kinematics.phasespace import Kallen
from ampform.sympy import perform_cached_doit, unevaluated
from attrs import define, field, frozen
from IPython.display import Math, display
from matplotlib import cm
from qrules.particle import Particle, ParticleCollection
from qrules.transition import ReactionInfo
from sympy import Abs
from sympy.matrices.expressions.matexpr import MatrixElement
from tensorwaves.data import (
    IntensityDistributionGenerator,
    SympyDataTransformer,
    TFPhaseSpaceGenerator,
    TFUniformRealNumberGenerator,
    TFWeightedPhaseSpaceGenerator,
)
from tensorwaves.estimator import UnbinnedNLL
from tensorwaves.function.sympy import create_parametrized_function
from tensorwaves.interface import DataSample, Estimator, Function, ParameterValue
from tensorwaves.optimizer import Minuit2
from tensorwaves.optimizer.callbacks import CSVSummary

_ = np.seterr(invalid="ignore")

In [None]:
%config InlineBackend.figure_formats = ['svg']

## Studied decay

In [None]:
@lru_cache(maxsize=1)
def create_particle_database() -> ParticleCollection:
    particles = qrules.load_default_particles()
    for nstar in particles.filter(lambda p: p.name.startswith("N")):
        particles.remove(nstar)
    particles += create_nstar(mass=1.82, width=0.6, parity=+1, spin=1.5, idx=1)
    return particles


def create_nstar(
    mass: float, width: float, parity: int, spin: float, idx: int
) -> Particle:
    spin = sp.Rational(spin)
    parity_symbol = "⁺" if parity > 0 else "⁻"
    unicode_subscripts = list("₀₁₂₃₄₅₆₇₈₉")
    return Particle(
        name=f"N{unicode_subscripts[idx]}({spin}{parity_symbol})",
        latex=Rf"N_{idx}({spin.numerator}/{spin.denominator}^-)",
        pid=2024_05_00_00 + 100 * bool(parity + 1) + idx,
        mass=mass,
        width=width,
        baryon_number=1,
        charge=+1,
        isospin=(0.5, +0.5),
        parity=parity,
        spin=1.5,
    )

In [None]:
FINAL_STATES: list[tuple[str, ...]] = [
    ["K0", "Sigma+", "p~"],
    ["eta", "p", "p~"],
]
REACTIONS: list[ReactionInfo] = [
    qrules.generate_transitions(
        initial_state="J/psi(1S)",
        final_state=final_state,
        allowed_intermediate_particles=["N"],
        allowed_interaction_types=["strong"],
        formalism="helicity",
        particle_db=create_particle_database(),
    )
    for final_state in FINAL_STATES
]

In [None]:
for reaction in REACTIONS:
    src = qrules.io.asdot(reaction, collapse_graphs=True)
    graph = graphviz.Source(src)
    display(graph)
    del reaction, src, graph

## Amplitude builder

In [None]:
@define
class DynamicsSymbolBuilder:
    collected_symbols: set[sp.Symbol, tuple[Particle, TwoBodyKinematicVariableSet]] = (
        field(factory=lambda: defaultdict(set))
    )

    def __call__(
        self, resonance: Particle, variable_pool: TwoBodyKinematicVariableSet
    ) -> tuple[sp.Expr, dict[sp.Symbol, float]]:
        jp = render_jp(resonance)
        charge = resonance.charge
        if variable_pool.angular_momentum is not None:
            L = sp.Rational(variable_pool.angular_momentum)
            X = sp.Symbol(Rf"X_{{{jp}, Q={charge:+d}}}^{{l={L}}}")
        else:
            X = sp.Symbol(Rf"X_{{{jp}, Q={charge:+d}}}")
        self.collected_symbols[X].add((resonance, variable_pool))
        parameter_defaults = {}
        return X, parameter_defaults


def render_jp(particle: Particle) -> str:
    spin = sp.Rational(particle.spin)
    j = (
        str(spin)
        if spin.denominator == 1
        else Rf"\frac{{{spin.numerator}}}{{{spin.denominator}}}"
    )
    if particle.parity is None:
        return f"J={j}"
    p = "-" if particle.parity < 0 else "+"
    return f"J^P={{{j}}}^{{{p}}}"

In [None]:
MODELS: list[HelicityModel] = []
for reaction in REACTIONS:
    builder = ampform.get_builder(reaction)
    builder.adapter.permutate_registered_topologies()
    builder.config.scalar_initial_state_mass = True
    builder.config.stable_final_state_ids = [0, 1, 2]
    create_dynamics_symbol = DynamicsSymbolBuilder()
    for resonance in reaction.get_intermediate_particles():
        builder.set_dynamics(resonance.name, create_dynamics_symbol)
    MODELS.append(builder.formulate())
    del builder, reaction, resonance

In [None]:
selected_amplitudes = {
    k: v for i, (k, v) in enumerate(MODELS[0].amplitudes.items()) if i == 0
}
Math(aslatex(selected_amplitudes, terms_per_line=1))

In [None]:
src = R"\begin{array}{cll}" "\n"
for symbol, resonances in create_dynamics_symbol.collected_symbols.items():
    src += Rf"  {symbol} \\" "\n"
    for p, _ in resonances:
        src += Rf"  {p.latex} & m={p.mass:g}\text{{ GeV}} & \Gamma={p.width:g}\text{{ GeV}} \\"
        src += "\n"
src += R"\end{array}"
Math(src)

## Dynamics parametrization

### Phasespace factor

In [None]:
@unevaluated(real=False)
class PhaseSpaceCM(sp.Expr):
    s: Any
    m1: Any
    m2: Any
    _latex_repr_ = R"\rho^\mathrm{{CM}}_{{{m1},{m2}}}\left({s}\right)"

    def evaluate(self) -> sp.Expr:
        s, m1, m2 = self.args
        return -16 * sp.pi * sp.I * ChewMandelstam(s, m1, m2)


@unevaluated(real=False)
class ChewMandelstam(sp.Expr):
    s: Any
    m1: Any
    m2: Any
    _latex_repr_ = R"\Sigma\left({s}\right)"

    def evaluate(self) -> sp.Expr:
        s, m1, m2 = self.args
        q = BreakupMomentum(s, m1, m2)
        return (
            (2 * q / sp.sqrt(s))
            * sp.log(Abs((m1**2 + m2**2 - s + 2 * sp.sqrt(s) * q) / (2 * m1 * m2)))
            - (m1**2 - m2**2) * (1 / s - 1 / (m1 + m2) ** 2) * sp.log(m1 / m2)
        ) / (16 * sp.pi**2)


@unevaluated(real=False)
class BreakupMomentum(sp.Expr):
    s: Any
    m1: Any
    m2: Any
    _latex_repr_ = R"q\left({s}\right)"

    def evaluate(self) -> sp.Expr:
        s, m1, m2 = self.args
        return sp.sqrt(Kallen(s, m1**2, m2**2)) / (2 * sp.sqrt(s))

In [None]:
s, m1, m2 = sp.symbols("s m1 m2", nonnegative=True)
exprs = [
    PhaseSpaceCM(s, m1, m2),
    ChewMandelstam(s, m1, m2),
    BreakupMomentum(s, m1, m2),
]
Math(aslatex({e: e.doit(deep=False) for e in exprs}))

### $K$-matrix formalism

In [None]:
n_channels = len(REACTIONS)
I = sp.Identity(n_channels)
K = sp.MatrixSymbol("K", n_channels, n_channels)
P = sp.MatrixSymbol("P", n_channels, 1)
F = sp.MatrixSymbol("F", n_channels, 1)
rho = sp.MatrixSymbol("rho", n_channels, n_channels)

In [None]:
def get_decay_products(reaction: ReactionInfo) -> DecayProducts:
    some_transition, *_ = reaction.transitions
    decay_product_ids = some_transition.topology.get_edge_ids_outgoing_from_node(1)
    for transition in reaction.transitions:
        if decay_product_ids != transition.topology.get_edge_ids_outgoing_from_node(1):
            msg = "Reaction contains multiple sub-systems"
            raise ValueError(msg)
    child1_id, child2_id = sorted(decay_product_ids)
    return DecayProducts(
        child1=reaction.final_state[child1_id],
        child2=reaction.final_state[child2_id],
    )


@frozen
class DecayProducts:
    child1: Particle
    child2: Particle


DECAYS = tuple(get_decay_products(m.reaction_info) for m in MODELS)

In [None]:
PARAMETERS_DEFAULTS = {}
for model in MODELS:
    PARAMETERS_DEFAULTS.update(model.parameter_defaults)
    del model

#### $K$-matrix parametrization

In [None]:
def formulate_k_matrix(
    resonances: list[tuple[Particle, int]], n_channels: int
) -> dict[MatrixElement, sp.Expr]:
    expressions = {}
    for i, j in product(range(n_channels), range(n_channels)):
        resonance_contributions = []
        for res, _ in resonances:
            s = sp.Symbol("m_01", real=True) ** 2
            g_Ri = sp.Symbol(Rf"g_{{{res.latex},{i}}}")
            g_Rj = sp.Symbol(Rf"g_{{{res.latex},{j}}}")
            m_R = sp.Symbol(Rf"m_{{{res.latex}}}")
            parameter_defaults = {
                m_R: res.mass,
                g_Ri: 1,
                g_Rj: 0.1,
            }
            PARAMETERS_DEFAULTS.update(parameter_defaults)
            expr = (g_Ri * g_Rj) / (m_R**2 - s)
            resonance_contributions.append(expr)
        expressions[K[i, j]] = sum(resonance_contributions)
    return expressions


K_expressions = formulate_k_matrix(resonances, n_channels=len(REACTIONS))
K_matrix = K.as_explicit()
K.as_explicit().xreplace(K_expressions)

#### $P$-vector parametrization

In [None]:
def formulate_p_vector(
    resonances: list[tuple[Particle, int]], n_channels: int
) -> dict[MatrixElement, sp.Expr]:
    expressions = {}
    for i in range(n_channels):
        resonance_contributions = []
        for res, _ in resonances:
            s = sp.Symbol("m_01", real=True) ** 2
            g_Ri = sp.Symbol(Rf"g_{{{res.latex},{i}}}")
            beta_R = sp.Symbol(Rf"\beta_{{{res.latex}}}")
            m_R = sp.Symbol(Rf"m_{{{res.latex}}}")
            parameter_defaults = {
                m_R: res.mass,
                beta_R: 1 + 0j,
                g_Ri: 1,
            }
            PARAMETERS_DEFAULTS.update(parameter_defaults)
            expr = (beta_R * g_Ri) / (m_R**2 - s)
            resonance_contributions.append(expr)
        expressions[P[i, 0]] = sum(resonance_contributions)
    return expressions


P_expressions = formulate_p_vector(resonances, n_channels=len(REACTIONS))
P_vector = P.as_explicit()
P.as_explicit().xreplace(P_expressions)

#### Phase space factor parametrization

In [None]:
def formulate_phsp_factor_matrix(n_channels: int) -> dict[sp.MatrixElement, sp.Expr]:
    expressions = {}
    for i in range(n_channels):
        for j in range(n_channels):
            if i == j:
                m_a_i = sp.Symbol(Rf"m_{{0,{i}}}")
                m_b_i = sp.Symbol(Rf"m_{{1,{i}}}")
                s = sp.Symbol("m_01", real=True) ** 2
                rho_i = PhaseSpaceCM(s, m_a_i, m_b_i)
                expressions[rho[i, j]] = rho_i
                parameter_defaults = {
                    m_a_i: DECAYS[i].child1.mass,
                    m_b_i: DECAYS[i].child2.mass,
                }
                PARAMETERS_DEFAULTS.update(parameter_defaults)
            else:
                expressions[rho[i, j]] = 0
    return expressions


rho_expressions = formulate_phsp_factor_matrix(n_channels=len(REACTIONS))
rho.as_explicit().xreplace(rho_expressions)

### $F$-vector construction

:::{note}
For some reason one has to leave out the multiplication of $\rho$ by $i$ within the calculation of the $F$ vector
:::

In [None]:
F = (I - sp.I * K * rho).inv() * P
F

In [None]:
F_vector = F.as_explicit()

In [None]:
parametrizations = {**K_expressions, **rho_expressions, **P_expressions}
F_exprs = F_vector.xreplace(parametrizations)
F_exprs[0].simplify(doit=False)

In [None]:
F_unfolded_exprs = np.array([perform_cached_doit(expr) for expr in F_exprs])

In [None]:
DYNAMICS_EXPRESSIONS_FVECTOR = []
for i in range(n_channels):
    exprs = {
        symbol: F_unfolded_exprs[i]
        for symbol, resonances in create_dynamics_symbol.collected_symbols.items()
    }
    DYNAMICS_EXPRESSIONS_FVECTOR.append(exprs)

MODELS_FVECTOR = []
for i in range(n_channels):
    MODELS_FVECTOR.append(
        attrs.evolve(
            MODELS[i],
            parameter_defaults=ParameterValues({
                **MODELS[i].parameter_defaults,
                **PARAMETERS_DEFAULTS,
            }),
        )
    )

In [None]:
FULL_EXPRESSIONS_FVECTOR = []
for i in range(n_channels):
    FULL_EXPRESSIONS_FVECTOR.append(
        perform_cached_doit(MODELS_FVECTOR[i].expression).xreplace(
            DYNAMICS_EXPRESSIONS_FVECTOR[i]
        )
    )

### Create numerical functions

In [None]:
UNFOLDED_EXPRESSIONS_FVECTOR = []
INTENSITY_FUNCS_FVECTOR = []
for i in range(n_channels):
    UNFOLDED_EXPRESSIONS_FVECTOR.append(
        perform_cached_doit(FULL_EXPRESSIONS_FVECTOR[i])
    )
    INTENSITY_FUNCS_FVECTOR.append(
        create_parametrized_function(
            expression=UNFOLDED_EXPRESSIONS_FVECTOR[i],
            backend="jax",
            parameters=MODELS_FVECTOR[i].parameter_defaults,
        )
    )

## Update parameters

In [None]:
m_res = 1.82
g_res_ch0 = 1.8
g_res_ch1 = 2.5

new_parameters_fvector = {
    R"m_{N_1(3/2^-)}": 1.71,
    R"\beta_{N_1(3/2^-)}": 1 + 0j,
    R"g_{N_1(3/2^-),0}": g_res_ch0,
    R"g_{N_1(3/2^-),1}": g_res_ch1,
}

In [None]:
for i in range(n_channels):
    INTENSITY_FUNCS_FVECTOR[i].update_parameters(new_parameters_fvector)

##  Generate data with $F$ vector

### Generate phase space sample

In [None]:
HELICITY_TRANSFORMERS = []
for i in range(n_channels):
    HELICITY_TRANSFORMERS.append(
        SympyDataTransformer.from_sympy(
            MODELS_FVECTOR[i].kinematic_variables, backend="jax"
        )
    )

In [None]:
PHSP = []
epsilon = 1e-8
for i in range(n_channels):
    rng = TFUniformRealNumberGenerator(seed=0)
    phsp_generator = TFPhaseSpaceGenerator(
        initial_state_mass=REACTIONS[i].initial_state[-1].mass,
        final_state_masses={it: p.mass for it, p in REACTIONS[i].final_state.items()},
    )
    phsp_momenta = phsp_generator.generate(100_000, rng)
    phsp = HELICITY_TRANSFORMERS[i](phsp_momenta)
    phsp = {k: v.real for k, v in phsp.items()}
    phsp = {
        k: v + epsilon * 1j if re.match(r"^m_\d\d$", k) else v for k, v in phsp.items()
    }
    PHSP.append(phsp)

### Dynamics expressions

In [None]:
DYNAMICS_EXPR_FVECTOR = []
for i in range(n_channels):
    values, *_ = DYNAMICS_EXPRESSIONS_FVECTOR[i].values()
    DYNAMICS_EXPR_FVECTOR.append(values)

In [None]:
DYNAMICS_FUNCS_FVECTOR = []
for i in range(n_channels):
    func = create_parametrized_function(
        expression=DYNAMICS_EXPR_FVECTOR[i].doit(),
        backend="jax",
        parameters=MODELS_FVECTOR[i].parameter_defaults,
    )
    DYNAMICS_FUNCS_FVECTOR.append(func)

### Weighted data with $F$ vector 

In [None]:
for i in range(n_channels):
    fig, ax = plt.subplots(figsize=(6, 5))
    intensity = np.real(INTENSITY_FUNCS_FVECTOR[i](PHSP[i]))
    c = ax.hist(
        np.real(PHSP[i]["m_01"]) ** 2,
        bins=100,
        weights=intensity,
    )
    ax.set_xlabel(R"$M^2\left(\eta p\right)\, \mathrm{[(GeV/c)^2]}$")
    ax.set_ylabel(R"Intensity [a.u.]")
    fig.tight_layout()
    plt.show()

In [None]:
DATA = []
for i in range(n_channels):
    weighted_phsp_generator = TFWeightedPhaseSpaceGenerator(
        initial_state_mass=MODELS[i].reaction_info.initial_state[-1].mass,
        final_state_masses={
            i: p.mass for i, p in MODELS[i].reaction_info.final_state.items()
        },
    )
    data_generator = IntensityDistributionGenerator(
        domain_generator=weighted_phsp_generator,
        function=INTENSITY_FUNCS_FVECTOR[i],
        domain_transformer=HELICITY_TRANSFORMERS[i],
    )
    data_momenta = data_generator.generate(50_000, rng)
    pd.DataFrame({
        (k, label): np.transpose(v)[i]
        for k, v in data_momenta.items()
        for i, label in enumerate(["E", "px", "py", "pz"])
    })
    phsp = HELICITY_TRANSFORMERS[i](phsp_momenta)
    data = HELICITY_TRANSFORMERS[i](data_momenta)
    DATA.append(data)

In [None]:
for i in range(n_channels):
    resonances = sorted(
        MODELS[i].reaction_info.get_intermediate_particles(),
        key=lambda p: p.mass,
    )
    evenly_spaced_interval = np.linspace(
        0, 1, len(INTENSITY_FUNCS_FVECTOR[i].parameters.items())
    )
    colors = [cm.rainbow(x) for x in evenly_spaced_interval]
    fig, ax = plt.subplots(figsize=(9, 4))
    ax.hist(
        np.real(DATA[i]["m_01"]),
        bins=200,
        alpha=0.5,
        density=True,
    )
    ax.set_xlabel("$m$ [GeV]")
    for (k, v), color in zip(new_parameters_fvector.items(), colors):
        if k.startswith("m_{"):
            ax.axvline(
                x=v,
                linestyle="dotted",
                label=r"$" + k + "$",
                color=color,
            )
    ax.legend()
    plt.show()

## Perform fit

### Set initial parameters 

In [None]:
initial_parameters = {
    R"m_{N_1(3/2^-)}": 1.9,
    R"\beta_{N_1(3/2^-)}": 1 + 0j,
    R"g_{N_1(3/2^-),0}": 2.8,
    R"g_{N_1(3/2^-),1}": 1.6,
}

In [None]:
def indicate_masses(ax, function):
    ax.set_xlabel("$m$ [GeV]")
    for (k, v), color_F in zip(function.parameters.items(), colors_F):
        if k.startswith("m_{N"):
            ax.axvline(
                x=v,
                linestyle="dotted",
                label=r"$" + k + "$" "(F vector)",
                color=color_F,
            )


def compare_model(
    variable_name: str,
    data: DataSample,
    phsp: DataSample,
    function: Function[DataSample, np.ndarray],
    bins: int = 100,
):
    fig, ax = plt.subplots(figsize=(9, 4))
    ax.hist(
        data[variable_name].real,
        bins=bins,
        alpha=0.5,
        label="data",
        density=True,
    )
    intensities = function(phsp)
    ax.hist(
        phsp[variable_name].real,
        weights=intensities,
        bins=bins,
        histtype="step",
        color="red",
        label="Fit model with $F$ vector",
        density=True,
    )
    indicate_masses(ax, function)
    ax.axvline(
        DECAYS[0].child1.mass + DECAYS[0].child2.mass,
        color="grey",
        linestyle="dotted",
        label=rf"${DECAYS[0].child1.latex} \, {DECAYS[0].child2.latex}$ threshold",
    )
    ax.axvline(
        DECAYS[1].child1.mass + DECAYS[1].child2.mass,
        color="grey",
        linestyle="dotted",
        label=rf"${DECAYS[1].child1.latex} \, {DECAYS[1].child2.latex}$ threshold",
    )
    ax.legend()
    fig.show()

In [None]:
ORIGINAL_PARAMETERS_F = []
for i in range(n_channels):
    resonances = sorted(
        MODELS[i].reaction_info.get_intermediate_particles(),
        key=lambda p: p.mass,
    )
    evenly_spaced_interval = np.linspace(
        0, 1, len(INTENSITY_FUNCS_FVECTOR[i].parameters.items())
    )
    colors_F = [cm.rainbow(x) for x in evenly_spaced_interval]
    original_parameters = INTENSITY_FUNCS_FVECTOR[i].parameters
    ORIGINAL_PARAMETERS_F.append(original_parameters)
    INTENSITY_FUNCS_FVECTOR[i].update_parameters(initial_parameters)
    compare_model("m_01", DATA[i], PHSP[i], INTENSITY_FUNCS_FVECTOR[i])

### Define estimator

In [None]:
class EstimatorSum(Estimator):
    def __init__(self, estimators: Iterable[Estimator]) -> None:
        self.__estimators = tuple(estimators)

    def __call__(self, parameters: Mapping[str, ParameterValue]) -> float:
        return sum(estimator(parameters) for estimator in self.__estimators)

    def gradient(
        self, parameters: Mapping[str, ParameterValue]
    ) -> dict[str, ParameterValue]:
        raise NotImplementedError

In [None]:
combined_estimators = EstimatorSum(
    UnbinnedNLL(
        INTENSITY_FUNCS_FVECTOR[i],
        data=DATA[i],
        phsp=PHSP[i],
        backend="jax",
    )
    for i in range(n_channels)
)

## Optimized fit

In [None]:
minuit2 = Minuit2(
    callback=CSVSummary("fit_traceback.csv"),
    use_analytic_gradient=False,
)
fit_result = minuit2.optimize(combined_estimators, initial_parameters)
fit_result

In [None]:
for i in range(n_channels):
    INTENSITY_FUNCS_FVECTOR[i].update_parameters(fit_result.parameter_values)
    compare_model("m_01", DATA[i], PHSP[i], INTENSITY_FUNCS_FVECTOR[i])

In [None]:
original_parameters = {
    **ORIGINAL_PARAMETERS_F[0],
    **ORIGINAL_PARAMETERS_F[1],
}
df = pd.DataFrame({
    f"${p}$": (
        initial_parameters[p],
        fit_result.parameter_values[p],
        original_parameters[p],
    )
    for p in fit_result.parameter_values
}).T
df.columns = ("initial", "fit result", "original")
df.round(decimals=3)

In [None]:
n_real_par = fit_result.count_number_of_parameters(complex_twice=True)
n_events = len(next(iter(data.values())))
log_likelihood = -fit_result.estimator_value
log_likelihood

In [None]:
aic = 2 * n_real_par - 2 * log_likelihood
aic

In [None]:
bic = n_real_par * np.log(n_events) - 2 * log_likelihood
bic