# Cross-check with LHCb data

In [None]:
%%capture
%run ./phase-space.ipynb

```{autolink-concat}
```

In [None]:
# pyright: reportUndefinedVariable=false
from __future__ import annotations

import json

import numpy as np
import sympy as sp
from IPython.display import Markdown, Math, display
from tensorwaves.function.sympy import create_parametrized_function
from tqdm.notebook import tqdm

from polarization.io import as_latex, display_latex

In [None]:
with open("../data/crosscheck.json") as stream:
    crosscheck_data = json.load(stream)

## Lineshape comparison

In [None]:
lineshape_vars = {k: v for k, v in crosscheck_data["mainvars"].items()}
lineshape_subs = {
    σ1: lineshape_vars["m2kpi"],
    σ2: lineshape_vars["m2pk"],
    **parameter_defaults,
}

In [None]:
K892_decay = next(filter(lambda d: d.resonance.name == "K(892)", decays))
L1405_decay = next(filter(lambda d: d.resonance.name == "L(1405)", decays))
L1690_decay = next(filter(lambda d: d.resonance.name == "L(1690)", decays))
Math(as_latex([K892_decay, L1405_decay, L1690_decay]))

In [None]:
crosscheck_data["lineshapes"]

In [None]:
K892_bw_val = formulate_dynamics(K892_decay).doit().xreplace(lineshape_subs).n()
L1405_bw_val = formulate_dynamics(L1405_decay).doit().xreplace(lineshape_subs).n()
L1690_bw_val = formulate_dynamics(L1690_decay).doit().xreplace(lineshape_subs).n()
display_latex([K892_bw_val, L1405_bw_val, L1690_bw_val])

In [None]:
np.testing.assert_array_almost_equal(
    np.array(list(map(complex, crosscheck_data["lineshapes"].values()))),
    np.array(list(map(complex, [K892_bw_val, L1405_bw_val, L1690_bw_val]))),
    decimal=13,
)

## Amplitude comparison

### SymPy expressions

In [None]:
amplitude_exprs = {
    (ν, λ): formulate_aligned_amplitude(ν, λ)
    for λ in [-half, +half]
    for ν in [-half, +half]
}

In [None]:
unfolded_amplitude_exprs = {
    k: expr.doit().xreplace(amp_definitions).doit()
    for k, expr in tqdm(amplitude_exprs.items())
}

### Numerical functions

In [None]:
%%time
fixed_parameters = {
    symbol: value
    for symbol, value in parameter_defaults.items()
    if symbol not in prod_couplings
}
amplitude_funcs = {
    k: create_parametrized_function(
        expr.xreplace(fixed_parameters),
        parameters=prod_couplings,
        backend="numpy",
    )
    for k, expr in unfolded_amplitude_exprs.items()
}

### Input data

In [None]:
amplitude_vars = {k: v for k, v in crosscheck_data["chainvars"].items()}
display(amplitude_vars)

In [None]:
input_data = {
    str(σ1): amplitude_vars["m2kpi"],
    str(σ2): amplitude_vars["m2pk"],
    str(σ3): amplitude_vars["m2ppi"],
}
input_data = {k: float(v) for k, v in transformer(input_data).items()}

In [None]:
display_latex({sp.Symbol(k): v for k, v in input_data.items()})

### Comparison table

In [None]:
def plusminus_to_helicity(plusminus: str) -> sp.Rational:
    if plusminus == "+":
        return +half
    if plusminus == "-":
        return -half
    raise NotImplementedError(plusminus)


real_amp_crosscheck = {
    k: v for k, v in crosscheck_data["chains"].items() if k.startswith("Ar")
}
couplings_to_zero = {str(symbol): 0 for symbol in prod_couplings}

src = """
|     | Computed | Expected | Difference |
| ---:| --------:| --------:| ----------:|
"""
for i, (amp_identifier, entry) in enumerate(real_amp_crosscheck.items()):
    resonance_name = amp_identifier[2:-1]
    decay = next(filter(lambda d: d.resonance.name == resonance_name, decays))
    subsystem_identifier = resonance_name[0]
    coupling = to_symbol(amp_identifier.replace("Ar", "A"))
    src += f"| **`{amp_identifier}`** | ${sp.latex(coupling)}$ |\n"
    for matrix_key, expected in entry.items():
        matrix_suffix = matrix_key[1:]  # ++, +-, -+, --
        λ_Λc, λ_p = map(plusminus_to_helicity, matrix_suffix)
        func = amplitude_funcs[(λ_Λc, -λ_p)]
        func.update_parameters(couplings_to_zero)
        func.update_parameters({str(coupling): 1})
        computed = complex(func(input_data))
        computed /= float((-1) ** (half + λ_p))
        expected = complex(expected)
        if abs(expected) != 0.0:
            diff = abs(computed - expected) / abs(expected)
            if diff < 1e-6:
                diff = f"{diff:.2e}"
            else:
                diff = f'<span style="color:red;">{diff:.2e}</span>'
        else:
            diff = ""
        np.testing.assert_array_almost_equal(computed, expected, decimal=4)
        src += f"| `{matrix_key}` | {computed:>.6f} | {expected:>.6f} | {diff} |\n"
Markdown(src)