# B-matrix extension of polarimeter

In [None]:
from __future__ import annotations

import logging
from warnings import filterwarnings

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from ampform.sympy import PoolSum
from IPython.display import display
from matplotlib import cm
from sympy.physics.matrices import msigma
from tqdm.auto import tqdm

from polarimetry import _to_index
from polarimetry.data import create_data_transformer, generate_meshgrid_sample
from polarimetry.io import (
    mute_jax_warnings,
    perform_cached_doit,
    perform_cached_lambdify,
)
from polarimetry.lhcb import load_model_builder, load_model_parameters
from polarimetry.lhcb.particle import load_particles
from polarimetry.plot import use_mpl_latex_fonts

filterwarnings("ignore")
logging.getLogger("polarimetry.function").setLevel(logging.INFO)
mute_jax_warnings()

## Formulate expressions

Reference subsystem 1 is defined as:

![](../docs/_images/orientation-K.svg)

In [None]:
model_choice = 0
model_file = "../data/model-definitions.yaml"
particles = load_particles("../data/particle-definitions.yaml")
builder = load_model_builder(model_file, particles, model_choice)
imported_parameter_values = load_model_parameters(
    model_file, builder.decay, model_choice, particles
)
reference_subsystem = 1
model = builder.formulate(reference_subsystem, cleanup_summations=True)
model.parameter_defaults.update(imported_parameter_values)

$$
\vec\alpha = \sum_{\nu',\nu,\{\lambda\}} A^*_{\nu',\{\lambda\}}\vec\sigma_{\nu',\nu} A_{\nu,\{\lambda\}} \,\big / I_0 \\
\vec\beta = \sum_{\nu,\lambda',\lambda} A^*_{\nu,\lambda'} \vec\sigma_{\lambda',\lambda} A^*_{\nu,\lambda} \,\big/ I_0 \\
B_{\tau,\rho} = \sum_{\nu,\nu',\lambda',\lambda} A^*_{\nu',\lambda'} \sigma_{\nu',\nu}^\tau A_{\nu,\lambda} \sigma_{\lambda',\lambda}^\rho
$$

In [None]:
half = sp.Rational(1, 2)
λ, λp = sp.symbols(R"lambda \lambda^{\prime}", rational=True)
v, vp = sp.symbols(R"nu \nu^{\prime}", rational=True)
σ = [sp.Matrix([[1, 0], [0, 1]])]
σ.extend(msigma(i) for i in (1, 2, 3))

In [None]:
ref = reference_subsystem
B = tuple(
    tuple(
        PoolSum(
            builder.formulate_aligned_amplitude(vp, λp, 0, 0, ref)[0].conjugate()
            * σ[τ][_to_index(vp), _to_index(v)]
            * builder.formulate_aligned_amplitude(v, λ, 0, 0, ref)[0]
            * σ[ρ][_to_index(λp), _to_index(λ)],
            (v, [-half, +half]),
            (vp, [-half, +half]),
            (λ, [-half, +half]),
            (λp, [-half, +half]),
        ).cleanup()
        for ρ in range(4)
    )
    for τ in range(4)
)
B = sp.Matrix(B)

In [None]:
B_test = tuple(
    tuple(sp.Symbol(Rf"\tau={τ}, \rho={ρ}") for ρ in range(4)) for τ in range(4)
)
B_test = sp.Matrix(B_test)
B_test

In [None]:
B_test[0, 2]

## Functions and data

In [None]:
progress_bar = tqdm(desc="Unfolding expressions", total=16)
B_exprs = []
for τ in range(4):
    row = []
    for ρ in range(4):
        expr = perform_cached_doit(B[τ, ρ].doit().xreplace(model.amplitudes))
        progress_bar.update()
        row.append(expr)
    B_exprs.append(row)
progress_bar.close()
B_exprs = np.array(B_exprs)
B_exprs.shape

In [None]:
progress_bar = tqdm(desc="Lambdifying", total=16)
B_funcs = []
for τ in range(4):
    row = []
    for ρ in range(4):
        func = perform_cached_lambdify(
            B_exprs[τ, ρ].xreplace(model.parameter_defaults),
            backend="jax",
        )
        progress_bar.update()
        row.append(func)
    B_funcs.append(row)
progress_bar.close()
B_funcs = np.array(B_funcs)

In [None]:
data_sample = generate_meshgrid_sample(model.decay, resolution=400)
X = data_sample["sigma1"]
Y = data_sample["sigma2"]
transformer = create_data_transformer(model)
data_sample.update(transformer(data_sample))

In [None]:
B_arrays = jnp.array(
    [[B_funcs[τ, ρ](data_sample) for τ in range(4)] for ρ in range(4)]
)
B_norm = B_arrays / B_arrays[0, 0]
B_arrays.shape

In [None]:
decimals = 6
display(
    jnp.round(jnp.nanmax(B_norm.imag, axis=(2, 3)), decimals),
    jnp.round(jnp.nanmax(B_norm.real, axis=(2, 3)), decimals),
    jnp.round(jnp.nanmean(B_norm.real, axis=(2, 3)), decimals),
    jnp.round(jnp.nanstd(B_norm.real, axis=(2, 3)), decimals),
)

## Plots

In [None]:
%config InlineBackend.figure_formats = ['png']
plt.rcdefaults()
use_mpl_latex_fonts()
plt.rc("font", size=18)
s1_label = R"$m^2\left(K^-\pi^+\right)$ [GeV$^2$]"
s2_label = R"$m^2\left(pK^-\right)$ [GeV$^2$]"
fig, ax = plt.subplots(figsize=(8, 6.8))
ax.set_title("$I_0 = B_{0, 0}$")
ax.set_xlabel(s1_label)
ax.set_ylabel(s2_label)
ax.set_box_aspect(1)
ax.pcolormesh(X, Y, B_arrays[0, 0].real)
plt.show()

In [None]:
%config InlineBackend.figure_formats = ['png']
plt.rcdefaults()
use_mpl_latex_fonts()
plt.rc("font", size=10)
fig, axes = plt.subplots(
    dpi=200,
    figsize=(11, 10),
    ncols=4,
    nrows=4,
    sharex=True,
    sharey=True,
)
fig.suptitle(
    R"$B_{\tau,\rho} = \sum_{\nu,\nu',\lambda',\lambda} A^*_{\nu',\lambda'}"
    R" \sigma_{\nu',\nu}^\tau A_{\nu,\lambda} \sigma_{\lambda',\lambda}^\rho$"
)
progress_bar = tqdm(total=16)
for ρ in range(4):
    for τ in range(4):
        ax = axes[τ, ρ]
        ax.set_box_aspect(1)
        if τ == 0 and ρ == 0:
            Z = B_arrays[τ, ρ].real
            ax.set_title(f"$B_{{{τ}{ρ}}}$")
            cmap = cmap = cm.viridis
        else:
            Z = B_norm[τ, ρ].real
            ax.set_title(f"$B_{{{τ}{ρ}}} / B_{{00}}$")
            cmap = cmap = cm.coolwarm
        mesh = ax.pcolormesh(X, Y, Z, cmap=cmap)
        cbar = fig.colorbar(mesh, ax=ax, fraction=0.047, pad=0.01)
        if τ != 0 or ρ != 0:
            mesh.set_clim(vmin=-1, vmax=+1)
            cbar.set_ticks([-1, 0, +1])
            cbar.set_ticklabels(["-1", "0", "+1"])
        if τ == 3:
            ax.set_xlabel(s1_label)
        if ρ == 0:
            ax.set_ylabel(s2_label)
        progress_bar.update()
progress_bar.close()
fig.tight_layout()
plt.show()

**Hypothesis**:

$$
B_{0,\rho} = \vec\beta B_{00} \\
B_{\tau,0} = \vec\alpha B_{00} \\
B_{00} = I_0
$$

In [None]:
%config InlineBackend.figure_formats = ['svg']


def plot_field(vx, vy, v_abs, ax, strides=12, cmap=cm.viridis_r):
    mesh = ax.quiver(
        X[::strides, ::strides],
        Y[::strides, ::strides],
        vx[::strides, ::strides].real,
        vy[::strides, ::strides].real,
        v_abs[::strides, ::strides],
        cmap=cmap,
    )
    mesh.set_clim(vmin=0, vmax=+1)
    return mesh


def plot(x, y, z, strides=14):
    plt.rcdefaults()
    use_mpl_latex_fonts()
    plt.rc("font", size=18)
    fig, ax = plt.subplots(figsize=(8, 6.8), tight_layout=True)
    ax.set_box_aspect(1)
    v_abs = jnp.sqrt(x.real**2 + y.real**2 + z.real**2)
    mesh = plot_field(x, y, v_abs, ax, strides)
    color_bar = fig.colorbar(mesh, ax=ax, pad=0.01)
    return ax, color_bar

In [None]:
ax, cbar = plot(
    x=B_norm[3, 0],
    y=B_norm[1, 0],
    z=B_norm[2, 0],
    strides=10,
)
ax.set_title(
    R"$B_{\tau, 0} / B_{00} = \sum_{\nu',\nu,\{\lambda\}}"
    R" A^*_{\nu',\{\lambda\}}\vec\sigma_{\nu',\nu} A_{\nu,\{\lambda\}} \,\big / I_0$"
)
ax.set_xlabel(Rf"{s1_label}, $\quad\alpha_z$")
ax.set_ylabel(Rf"{s2_label}, $\quad\alpha_x$")
cbar.set_label(R"$\left|\vec{\alpha}\right|$")

In [None]:
ax, cbar = plot(
    y=B_norm[0, 3],
    x=B_norm[0, 1],
    z=B_norm[0, 2],
    strides=10,
)
ax.set_title(
    R"$B_{0,\rho} / B_{00} = \sum_{\nu,\lambda',\lambda} A^*_{\nu,\lambda'}"
    R" \vec\sigma_{\lambda',\lambda} A^*_{\nu,\lambda} \,\big/ I_0$"
)
ax.set_xlabel(Rf"{s1_label}, $\quad \beta_z = B_{{03}}$")
ax.set_ylabel(Rf"{s2_label}, $\quad \beta_x = B_{{01}}$")
cbar.set_label(R"$\left|\vec{\beta}\right|$")