# Cross-check with LHCb data

```{autolink-concat}
```

In [None]:
from __future__ import annotations

import json
import logging
import os

import numpy as np
import sympy as sp
from IPython.display import Markdown, Math, display
from tqdm.auto import tqdm

from polarimetry.amplitude import simplify_latex_rendering
from polarimetry.data import create_data_transformer
from polarimetry.io import (
    as_latex,
    display_latex,
    mute_jax_warnings,
    perform_cached_doit,
    perform_cached_lambdify,
)
from polarimetry.lhcb import (
    get_conversion_factor,
    load_model,
    load_model_builder,
    parameter_key_to_symbol,
)
from polarimetry.lhcb.particle import load_particles

mute_jax_warnings()

model_file = "../data/model-definitions.yaml"
particles = load_particles("../data/particle-definitions.yaml")
model = load_model(model_file, particles, model_id=0)
simplify_latex_rendering()

NO_TQDM = "EXECUTE_NB" in os.environ
if NO_TQDM:
    logging.getLogger().setLevel(logging.ERROR)

In [None]:
with open("../data/crosscheck.json") as stream:
    crosscheck_data = json.load(stream)

## Lineshape comparison

In [None]:
σ1, σ2, σ3 = sp.symbols("sigma1:4", nonnegative=True)
lineshape_vars = {k: v for k, v in crosscheck_data["mainvars"].items()}
lineshape_subs = {
    σ1: lineshape_vars["m2kpi"],
    σ2: lineshape_vars["m2pk"],
    **model.parameter_defaults,
}

In [None]:
K892_chain = model.decay.find_chain("K(892)")
L1405_chain = model.decay.find_chain("L(1405)")
L1690_chain = model.decay.find_chain("L(1690)")
Math(as_latex([K892_chain, L1405_chain, L1690_chain]))

In [None]:
crosscheck_data["lineshapes"]

In [None]:
amplitude_builder = load_model_builder(model_file, particles, model_id=0)
build_dynamics = lambda c: amplitude_builder.dynamics_choices.get_builder(c)(c)[
    0
].doit()
K892_bw_val = build_dynamics(K892_chain).xreplace(lineshape_subs).n()
L1405_bw_val = build_dynamics(L1405_chain).xreplace(lineshape_subs).n()
L1690_bw_val = build_dynamics(L1690_chain).xreplace(lineshape_subs).n()
display_latex([K892_bw_val, L1405_bw_val, L1690_bw_val])

In [None]:
lineshape_decimals = 13
np.testing.assert_array_almost_equal(
    np.array(list(map(complex, crosscheck_data["lineshapes"].values()))),
    np.array(list(map(complex, [K892_bw_val, L1405_bw_val, L1690_bw_val]))),
    decimal=lineshape_decimals,
)
src = f"""
:::{{tip}}
These values are **equal up to {lineshape_decimals} decimals**.
:::
"""
Markdown(src)

## Amplitude comparison

### SymPy expressions

In [None]:
half = sp.Rational(1, 2)
amplitude_exprs = {
    (λ_Λc, λ_p): amplitude_builder.formulate_aligned_amplitude(λ_Λc, λ_p, 0, 0)[0]
    for λ_Λc in [-half, +half]
    for λ_p in [-half, +half]
}

In [None]:
unfolded_amplitude_exprs = {
    k: perform_cached_doit(expr.doit().xreplace(model.amplitudes))
    for k, expr in tqdm(amplitude_exprs.items(), disable=NO_TQDM)
}

### Numerical functions

In [None]:
%%time
production_couplings = {
    symbol: value
    for symbol, value in model.parameter_defaults.items()
    if isinstance(symbol, sp.Indexed)
    if "production" in str(symbol)
}
fixed_parameters = {
    s: v
    for s, v in model.parameter_defaults.items()
    if s not in production_couplings
}
amplitude_funcs = {
    k: perform_cached_lambdify(
        expr.xreplace(fixed_parameters),
        parameters=production_couplings,
        backend="numpy",
    )
    for k, expr in unfolded_amplitude_exprs.items()
}

### Input data

In [None]:
amplitude_vars = {k: v for k, v in crosscheck_data["chainvars"].items()}
display(amplitude_vars)

In [None]:
transformer = create_data_transformer(model)
input_data = {
    str(σ1): amplitude_vars["m2kpi"],
    str(σ2): amplitude_vars["m2pk"],
    str(σ3): amplitude_vars["m2ppi"],
}
input_data = {k: float(v) for k, v in transformer(input_data).items()}

In [None]:
display_latex({sp.Symbol(k): v for k, v in input_data.items()})

### Comparison table

In [None]:
def plusminus_to_helicity(plusminus: str) -> sp.Rational:
    if plusminus == "+":
        return +half
    if plusminus == "-":
        return -half
    raise NotImplementedError(plusminus)


amplitude_decimals = 13
real_amp_crosscheck = {
    k: v for k, v in crosscheck_data["chains"].items() if k.startswith("Ar")
}
couplings_to_zero = {str(symbol): 0 for symbol in production_couplings}

src = f"""
:::{{tip}}
Computed amplitudes are equal to LHCb amplitudes up to **{amplitude_decimals} decimals**.
:::

|     | Computed | Expected | Difference |
| ---:| --------:| --------:| ----------:|
"""
for i, (amp_identifier, entry) in enumerate(real_amp_crosscheck.items()):
    resonance_name = amp_identifier[2:-1]
    subsystem_identifier = resonance_name[0]
    coupling = parameter_key_to_symbol(amp_identifier.replace("Ar", "A"))
    src += f"| **`{amp_identifier}`** | ${sp.latex(coupling)}$ |\n"
    for matrix_key, expected in entry.items():
        matrix_suffix = matrix_key[1:]  # ++, +-, -+, --
        λ_Λc, λ_p = map(plusminus_to_helicity, matrix_suffix)
        func = amplitude_funcs[(λ_Λc, -λ_p)]
        func.update_parameters(couplings_to_zero)
        func.update_parameters({str(coupling): 1})
        computed = complex(func(input_data))
        resonance = model.decay.find_chain(amp_identifier[2:-1]).resonance
        computed *= get_conversion_factor(resonance, -λ_p)
        expected = complex(expected)
        if abs(expected) != 0.0:
            diff = abs(computed - expected) / abs(expected)
            if diff < 1e-6:
                diff = f"{diff:.2e}"
            else:
                diff = f'<span style="color:red;">{diff:.2e}</span>'
        else:
            diff = ""
        np.testing.assert_array_almost_equal(
            computed,
            expected,
            decimal=amplitude_decimals,
            err_msg=f"  {amp_identifier} {matrix_key}",
        )
        src += f"| `{matrix_key}` | {computed:>.6f} | {expected:>.6f} | {diff} |\n"
Markdown(src)