# Cross-check with LHCb data

In [None]:
%%capture
%run ./amplitude-model.ipynb

```{autolink-concat}
```

In [None]:
# pyright: reportUndefinedVariable=false
from __future__ import annotations

import json

import numpy as np
import sympy as sp
from IPython.display import Markdown, Math
from tqdm.notebook import tqdm

from polarization.io import as_latex, display_latex

In [None]:
with open("../data/crosscheck.json") as stream:
    crosscheck_data = json.load(stream)

## Lineshape comparison

In [None]:
lineshape_vars = {k: v for k, v in crosscheck_data["mainvars"].items()}
lineshape_subs = {
    σ1: lineshape_vars["m2kpi"],
    σ2: lineshape_vars["m2pk"],
    **parameter_defaults,
}

In [None]:
K892_decay = next(filter(lambda d: d.resonance.name == "K(892)", decays))
L1405_decay = next(filter(lambda d: d.resonance.name == "L(1405)", decays))
L1690_decay = next(filter(lambda d: d.resonance.name == "L(1690)", decays))
Math(as_latex([K892_decay, L1405_decay, L1690_decay]))

In [None]:
crosscheck_data["lineshapes"]

In [None]:
K892_bw_val = formulate_dynamics(K892_decay).doit().xreplace(lineshape_subs).n()
L1405_bw_val = formulate_dynamics(L1405_decay).doit().xreplace(lineshape_subs).n()
L1690_bw_val = formulate_dynamics(L1690_decay).doit().xreplace(lineshape_subs).n()
display_latex([K892_bw_val, L1405_bw_val, L1690_bw_val])

In [None]:
np.testing.assert_array_almost_equal(
    np.array(list(map(complex, crosscheck_data["lineshapes"].values()))),
    np.array(list(map(complex, [K892_bw_val, L1405_bw_val, L1690_bw_val]))),
    decimal=13,
)

## Amplitude comparison

In [None]:
amplitude_vars = {k: v for k, v in crosscheck_data["chainvars"].items()}
amplitude_subs = {
    σ1: amplitude_vars["m2kpi"],
    σ2: amplitude_vars["m2pk"],
    σ3: amplitude_vars["m2ppi"],
    **masses,
    **parameter_defaults,
}
amplitude_subs.update({s: 0 for s in prod_couplings})

In [None]:
def plusminus_to_helicity(plusminus: str) -> sp.Rational:
    if plusminus == "+":
        return +half
    if plusminus == "-":
        return -half
    raise NotImplementedError(plusminus)


def render_difference(computed: complex, expected: complex) -> str:
    if abs(expected) == 0.0:
        if abs(computed) == 0.0:
            return "✅"
        return "❌"
    diff = 100 * (abs(computed) / abs(expected)) - 100
    if abs(diff) >= 5:
        return "❌"
    return f"{diff:.1f}%"


id_to_amp_builder = {
    "D": formulate_Δ_amplitude,
    "K": formulate_K_amplitude,
    "L": formulate_Λ_amplitude,
}
real_amp_crosscheck = {
    k: v for k, v in crosscheck_data["chains"].items() if k.startswith("Ar")
}
progress_bar = tqdm(
    desc="Computing amplitude for cross-check",
    total=4 * len(real_amp_crosscheck),
)
src = """
|     | Computed | Expected |     |
| ---:| --------:| --------:|:---:|
"""
for amp_identifier, entry in real_amp_crosscheck.items():
    resonance_name = amp_identifier[2:-1]
    decay = next(filter(lambda d: d.resonance.name == resonance_name, decays))
    subsystem_identifier = resonance_name[0]
    amp_builder = id_to_amp_builder[subsystem_identifier]
    src += f"| **`{amp_identifier}`** |\n"
    for matrix_key, expected in entry.items():
        progress_bar.postfix = f"{amp_identifier}, {matrix_key}"
        matrix_suffix = matrix_key[1:]  # ++, +-, -+, --
        λ_Λc, λ_p = map(plusminus_to_helicity, matrix_suffix)
        amp_expr = amp_builder(λ_Λc, λ_p, [decay]).doit()
        amp_expr = amp_expr.xreplace(angles).doit()
        coupling_subs = {
            s: 1
            for s in amp_expr.free_symbols
            if isinstance(s, sp.Indexed) and s.base == H_prod
        }
        if len(coupling_subs) > 1:
            raise ValueError(
                f"Multiple production couplings for {amp_identifier}, {matrix_key}"
            )
        amp_expr = amp_expr.xreplace(coupling_subs)
        amp_expr = amp_expr.xreplace(amplitude_subs)
        computed = complex(amp_expr.n())
        expected = complex(expected)
        diff = render_difference(computed, expected)
        src += f"| `{matrix_key}` | {computed:>16.4f} | {expected:>16.4f} | {diff} |\n"
        progress_bar.update()
progress_bar.postfix = ""
progress_bar.close()
Markdown(src)