# [TR-009] Relativistic K-matrix

:::{note}

This report is a sequel to {doc}`/report/009`.

:::

The **covariant description of the $\boldsymbol{T}$-matrix** is:

$$
\boldsymbol{T} = \sqrt{\boldsymbol{\rho}} \; \boldsymbol{\hat{T}} \sqrt{\boldsymbol{\rho}}
$$ (covariant T-matrix)

with the phase space factor matrix $\boldsymbol{\rho}$ defined as:

$$
\sqrt{\boldsymbol{\rho}} = \begin{pmatrix}
\rho_0 & \cdots & 0      \\
\vdots & \ddots & \vdots \\
0      & \cdots & \rho_{n-1}
\end{pmatrix}
$$ (rho matrix)

and

$$
\rho_i = \frac{2q_i}{m} = \sqrt{
  \left[1-\left(\frac{m_{i,a}+m_{i,b}}{m}\right)^2\right]
  \left[1-\left(\frac{m_{i,a}-m_{i,b}}{m}\right)^2\right]
}
$$ (phase space factor)

This results in a similar transformation for the $\boldsymbol{K}$-matrix

$$
\boldsymbol{K} = \sqrt{\boldsymbol{\rho}} \; \boldsymbol{\hat{K}} \sqrt{\boldsymbol{\rho}}
$$ (covariant K-matrix)

with (compare Eq. {eq}`T-matrix` in {doc}`/report/005`):

$$
\boldsymbol{\hat{T}} = \boldsymbol{\hat{K}}(\boldsymbol{I} - i\boldsymbol{\rho}\boldsymbol{\hat{K}})^{-1}
$$ (covariant T-matrix as K)

It's common to integrate these phase space factors into the parametrization of $K_{ij}$ as well:

$$
K_{ij} = \sum_R \frac{g_{R,i}(m)g_{R,j}(m)}{\left(m_R^2-m^2\right)\sqrt{\rho_i\rho_j}}
$$ (covariant parametrization)

In addition, one often uses an "energy dependent" {func}`~ampform.dynamics.coupled_width` $\Gamma(m)$ instead of a fixed width $\Gamma_0$ as done in {doc}`/report/005`.

In [None]:
%matplotlib widget
import os
import re
import warnings
from typing import Any, Union

import matplotlib.pyplot as plt
import mpl_interactions.ipyplot as iplt
import numpy as np
import symplot
import sympy as sp
from ampform.dynamics import coupled_width, phase_space_factor
from ampform.dynamics.decorator import (
    UnevaluatedExpression,
    create_expression,
    implement_doit_method,
)
from IPython.display import Image, Math
from ipywidgets import widgets as ipywidgets
from matplotlib import cm
from mpl_interactions.controller import Controls
from sympy.printing.latex import LatexPrinter

warnings.filterwarnings("ignore")
STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

## Wrapping expressions

To keep a nice rendering, we wrap the expressions for {func}`~ampform.dynamics.phase_space_factor` and {func}`~ampform.dynamics.coupled_width` into a class that derives from {class}`~sympy.core.expr.Expr` (see e.g. the implementation of {class}`~ampform.dynamics.BlattWeisskopfSquared`). Note that we need to use {func}`~symplot.partial_doit` to keep these expression symbols after evaluating the {class}`~sympy.concrete.summations.Sum`.

In [None]:
@implement_doit_method()
class PhaseSpaceFactor(UnevaluatedExpression):
    is_commutative = True

    def __new__(
        cls,
        m: sp.Symbol,
        m_a: sp.Symbol,
        m_b: sp.Symbol,
        i: int,
        evaluate: bool = False,
        **hints: Any,
    ) -> "PhaseSpaceFactor":
        args = sp.sympify((m, m_a, m_b, i))
        return create_expression(cls, evaluate, *args, **hints)

    def evaluate(self) -> sp.Expr:
        m, m_a, m_b, *_ = self.args
        return phase_space_factor(m ** 2, m_a, m_b)

    def _latex(self, printer: LatexPrinter, *args: Any) -> str:
        m = printer._print(self.args[0])
        i = self.args[-1]
        return fR"\rho_{{{i}}}({m})"


@implement_doit_method()
class CoupledWidth(UnevaluatedExpression):
    is_commutative = True

    def __new__(
        cls,
        m: sp.Symbol,
        mass0: sp.IndexedBase,
        gamma0: sp.IndexedBase,
        m_a: sp.IndexedBase,
        m_b: sp.IndexedBase,
        angular_momentum: int,
        R: Union[int, sp.Symbol],
        i: int,
        evaluate: bool = False,
        **hints: Any,
    ) -> "CoupledWidth":
        args = sp.sympify((m, mass0, gamma0, m_a, m_b, angular_momentum, R, i))
        return create_expression(cls, evaluate, *args, **hints)

    def evaluate(self) -> sp.Expr:
        m, mass0, gamma0, m_a, m_b, angular_momentum, R, i = self.args

        def phsp_factor(s, m_a, m_b):
            return PhaseSpaceFactor(s, m_a, m_b, i)

        return coupled_width(
            m ** 2,
            mass0[R],
            gamma0[R, i],
            m_a[i],
            m_b[i],
            angular_momentum=angular_momentum,
            meson_radius=1,
            phsp_factor=phsp_factor,
        )

    def _latex(self, printer: LatexPrinter, *args: Any) -> str:
        m = printer._print(self.args[0])
        R = self.args[-2]
        i = self.args[-1]
        return fR"{{\Gamma_{{{R},{i}}}}}({m})"

And here is what the equations look like:

In [None]:
n_channels = 2
n_resonances, i, R, L = sp.symbols(
    "n_R, i, R, L", integer=True, negative=False
)
m = sp.Symbol("m", real=True)
M = sp.IndexedBase("m_a", shape=(n_resonances,))
Gamma = sp.IndexedBase("Gamma", shape=(n_resonances, n_channels))
gamma = sp.IndexedBase("gamma", shape=(n_resonances, n_channels))
m_a = sp.IndexedBase("m_a", shape=(n_channels,))
m_b = sp.IndexedBase("m_b", shape=(n_channels,))

In [None]:
width_expr = CoupledWidth(m, M, Gamma, m_a, m_b, 0, R, i)
phsp_expr = PhaseSpaceFactor(m, m_a[i], m_b[i], i)

In [None]:
Math(
    sp.multiline_latex(
        lhs=width_expr,
        rhs=width_expr.evaluate(),
    )
)

In [None]:
Math(
    sp.multiline_latex(
        lhs=phsp_expr,
        rhs=phsp_expr.doit().simplify().subs(sp.Abs(m), m),
    )
)

## Implementation

The implementation is quite similar to {ref}`that of TR-005 <report/005:Generalization>`, with the only difference being additional $\boldsymbol{\rho}$-matrix and the insertion of coupled width.

In [None]:
def Kij_relativistic(
    m: sp.Symbol,
    M: sp.IndexedBase,
    Gamma: sp.IndexedBase,
    gamma: sp.IndexedBase,
    i: int,
    j: int,
    n_resonances: Union[int, sp.Symbol],
    angular_momentum: Union[int, sp.Symbol] = 0,
) -> sp.Expr:
    g_i = gamma[R, i] * CoupledWidth(
        m, M, Gamma, m_a, m_b, angular_momentum, R, i
    )
    g_j = gamma[R, j] * CoupledWidth(
        m, M, Gamma, m_a, m_b, angular_momentum, R, j
    )
    parametrization = (g_i * g_j) / (M[R] ** 2 - m ** 2)
    return sp.Sum(parametrization, (R, 0, n_resonances - 1))


def relativistic_k_matrix(
    n_resonances: int,
    n_channels: int,
    angular_momentum: Union[int, sp.Symbol] = 0,
) -> sp.Matrix:
    # Define symbols
    m = sp.Symbol("m", real=True)
    M = sp.IndexedBase("m", shape=(n_resonances,))
    Gamma = sp.IndexedBase("Gamma", shape=(n_resonances, n_channels))
    gamma = sp.IndexedBase("gamma", shape=(n_resonances, n_channels))
    m_a = sp.IndexedBase("m_a", shape=(n_channels,))
    m_b = sp.IndexedBase("m_b", shape=(n_channels,))
    # Define phase space matrix
    rho = sp.zeros(n_channels, n_channels)
    for i in range(n_channels):
        rho[i, i] = PhaseSpaceFactor(m ** 2, m_a[i], m_b[i], i)
    # Define K-matrix and T-matrix
    K = create_symbol_matrix("K", n_channels)
    T = K * (sp.eye(n_channels) - sp.I * rho * K).inv()
    # Substitute elements
    return T.subs(
        {
            K[i, j]: Kij_relativistic(
                m=m,
                M=M,
                Gamma=Gamma,
                gamma=gamma,
                i=i,
                j=j,
                n_resonances=n_resonances,
                angular_momentum=angular_momentum,
            )
            for i in range(n_channels)
            for j in range(n_channels)
        }
    )


def create_symbol_matrix(name: str, n: int) -> sp.Matrix:
    symbol = sp.IndexedBase("K", shape=(n, n))
    return sp.Matrix([[symbol[i, j] for j in range(n)] for i in range(n)])

Single channel, one resonance:

In [None]:
expr = relativistic_k_matrix(n_resonances=1, n_channels=1)[0, 0]
Math(
    sp.multiline_latex(
        lhs=expr,
        rhs=symplot.partial_doit(expr, sp.Sum).simplify(doit=False),
    )
)

Two channels, one resonances:

In [None]:
expr = relativistic_k_matrix(n_resonances=1, n_channels=2)[0, 0]
symplot.partial_doit(expr, sp.Sum).simplify(doit=False)

Single channel, $n_R$ resonances:

In [None]:
relativistic_k_matrix(n_resonances, n_channels=1)[0, 0]

Two channels, $n_R$ resonances:

In [None]:
expr = relativistic_k_matrix(n_resonances, n_channels=2)[0, 0]
Math(sp.multiline_latex("", expr))

## Visualization

Again, let's use {mod}`symplot` to visualize the relativistic $\boldsymbol{K}$-matrix for arbitrary $n_R$.

:::{tip}

{doc}`/report/008` explains the need for {func}`symplot.substitute_indexed_symbols`.

:::

In [None]:
def plot_relativistic_k_matrix(
    n_channels: int,
    n_resonances: int,
    angular_momentum: Union[int, sp.Symbol] = 0,
    title: str = "",
) -> None:
    # Convert to Symbol: symplot cannot handle IndexedBase
    epsilon = sp.Symbol("epsilon")
    i, j = sp.symbols("i, j", integer=True, negative=False)
    j = i
    expr = relativistic_k_matrix(
        n_resonances, n_channels, angular_momentum=angular_momentum
    ).doit()[i, j]
    expr = symplot.substitute_indexed_symbols(expr)
    expr = expr.subs(m, m + epsilon * sp.I)
    np_expr, sliders = symplot.prepare_sliders(expr, m)
    symbol_to_arg = {
        symbol: arg for arg, symbol in sliders._arg_to_symbol.items()
    }

    # Set plot domain
    x_min, x_max = 1e-3, 3
    y_min, y_max = -0.5, +0.5
    z_min, z_max = -2, +2

    plot_domain = np.linspace(x_min, x_max, num=500)
    x_values = np.linspace(x_min, x_max, num=160)
    y_values = np.linspace(y_min, y_max, num=80)
    X, Y = np.meshgrid(x_values, y_values)
    plot_domain_complex = X + Y * 1j

    z_cut_min = 0.75 * z_min
    z_cut_max = 0.75 * z_max
    cut_off_min = np.vectorize(lambda z: z if z > z_cut_min else z_cut_min)
    cut_off_max = np.vectorize(lambda z: z if z < z_cut_max else z_cut_max)

    # Set slider values and ranges
    m0_values = np.linspace(x_min, x_max, num=n_resonances + 2)
    m0_values = m0_values[1:-1]

    def set_default_values():
        if "L" in sliders:
            sliders.set_ranges(L=(0, 8))
        sliders.set_ranges(
            {
                "i": (0, n_channels - 1),
                "epsilon": (y_min * 0.2, y_max * 0.2, 0.01),
            }
        )
        for R in range(n_resonances):
            # ranges
            sliders.set_ranges({f"m{R}": (0, 3, 100)})
            for i in range(n_channels):
                sliders.set_ranges(
                    {
                        fR"\Gamma_{{{R},{i}}}": (0, 2, 100),
                        fR"\gamma_{{{R},{i}}}": (0, 5, 100),
                        f"m_a{i}": (0, 1, 0.01),
                        f"m_b{i}": (0, 1, 0.01),
                    }
                )
            # values
            sliders.set_values({f"m{R}": m0_values[R]})
            for i in range(n_channels):
                sliders.set_values(
                    {
                        fR"\Gamma_{{{R},{i}}}": 0.4 + R * 0.2 - i * 0.3,
                        fR"\gamma_{{{R},{i}}}": 2.5 - 0.4 * R + 0.3 * i,
                        f"m_a{i}": (i + 1) * 0.25,
                        f"m_b{i}": (i + 1) * 0.25,
                    }
                )

    set_default_values()

    # Create interactive plots
    controls = Controls(**sliders)
    nrows = 2  # set to 3 for imag+real
    fig, axes = plt.subplots(
        nrows=nrows,
        figsize=(8, nrows * 3.0),
        sharex=True,
        tight_layout=True,
    )
    for ax in axes:
        ax.set_xlim(x_min, x_max)
    if not title:
        title = (
            fR"${n_channels} \times {n_channels}$ $K$-matrix"
            f" with {n_resonances} resonances"
        )
    fig.suptitle(title)

    # 2D plot
    axes[0].set_ylabel("$|T|^{2}$")
    axes[0].set_yticks([])

    def plot(channel: int):
        def wrapped(*args, **kwargs) -> sp.Expr:
            kwargs["i"] = channel
            return np.abs(np_expr(*args, **kwargs)) ** 2

        return wrapped

    for i in range(n_channels):
        iplt.plot(
            plot_domain,
            plot(i),
            ax=axes[0],
            controls=controls,
            ylim="auto",
            label=f"channel {i}",
        )
    if n_channels > 1:
        axes[0].legend(loc="upper right")
    mass_line_style = dict(
        c="red",
        alpha=0.3,
    )
    for name in controls.params:
        if not re.match(r"^m[0-9]+$", name):
            continue
        iplt.axvline(controls[name], ax=axes[0], **mass_line_style)

    # 3D plot
    def plot3(**kwargs):
        epsilon = kwargs["epsilon"]
        kwargs["epsilon"] = 0
        Z = np_expr(plot_domain_complex, **kwargs)
        Z_imag = cut_off_min(cut_off_max(Z.imag))
        for ax in axes[1:]:
            ax.clear()
        axes[-1].pcolormesh(X, Y, Z_imag, cmap=cm.coolwarm)
        i = kwargs["i"]
        if n_channels == 1:
            axes[-1].set_title("Im $T$")
        else:
            axes[-1].set_title(f"Im $T$, channel {i}")
        if len(axes) == 3:
            Z_real = cut_off_min(cut_off_max(Z.real))
            axes[-2].pcolormesh(X, Y, Z_real, cmap=cm.coolwarm)
            axes[-2].set_title("Re $T$")
        for ax in axes[1:]:
            ax.axhline(0, linewidth=0.5, c="black", linestyle="dotted")
            if epsilon != 0.0:
                ax.axhline(
                    epsilon,
                    linewidth=0.5,
                    c="blue",
                    linestyle="dotted",
                    label=R"$\epsilon$",
                )
                axes[-1].text(
                    x=x_min + 0.008,
                    y=epsilon + 0.01,
                    s=R"$\epsilon$",
                    c="blue",
                )
            for R in range(n_resonances):
                mass = kwargs[f"m{R}"]
                ax.axvline(mass, **mass_line_style)
            if "m_a0" in kwargs:
                colors = cm.plasma(np.linspace(0, 1, n_channels))
                for i, color in enumerate(colors):
                    m_a = kwargs[f"m_a{i}"]
                    m_b = kwargs[f"m_b{i}"]
                    s_thr = m_a + m_b
                    ax.axvline(s_thr, c=color, linestyle="dotted")
                    ax.text(
                        x=s_thr,
                        y=0.95 * y_min,
                        s=f"$m_{{a{i}}}+m_{{b{i}}}$",
                        c=color,
                        rotation=-90,
                    )
                    m_diff = m_a - m_b
                    x_offset = (x_max - x_min) * 0.015
                    if (
                        m_diff > x_offset + 0.01
                        and s_thr - abs(m_diff) > x_offset
                    ):
                        ax.axvline(
                            m_diff,
                            c=color,
                            linestyle="dashed",
                            alpha=0.5,
                        )
                        ax.text(
                            x=m_diff - x_offset,
                            y=0.95 * y_min,
                            s=f"$m_{{a{i}}}-m_{{b{i}}}$",
                            c=color,
                            rotation=+90,
                        )
            ax.set_ylabel("Im $m$")
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_facecolor("white")
        for R in range(n_resonances):
            mass = kwargs[f"m{R}"]
            axes[-1].text(
                x=mass + (x_max - x_min) * 0.008,
                y=0.95 * y_min,
                s=f"$m_{R}$",
                c="red",
            )
        axes[-1].set_xlabel("Re $m$")
        fig.canvas.draw_idle()

    # Create GUI
    sliders_copy = dict(sliders)
    h_boxes = []
    for R in range(n_resonances):
        buttons = [sliders_copy.pop(f"m{R}")]
        if n_channels == 1:
            buttons.append(
                sliders_copy.pop(symbol_to_arg[fR"\Gamma_{{{R},0}}"])
            )
            buttons.append(
                sliders_copy.pop(symbol_to_arg[fR"\gamma_{{{R},0}}"])
            )
        h_box = ipywidgets.HBox(buttons)
        h_boxes.append(h_box)
    remaining_sliders = sorted(
        sliders_copy.values(), key=lambda s: s.description
    )
    ui = ipywidgets.VBox(h_boxes + remaining_sliders)
    output = ipywidgets.interactive_output(plot3, controls=sliders)
    display(ui, output)


def to_symbol(idx: sp.Indexed) -> sp.Symbol:
    base_name, _, _ = str(idx).rpartition("[")
    subscript = ",".join(map(str, idx.indices))
    if len(idx.indices) > 1:
        base_name = translate(base_name)
        subscript = "_{" + subscript + "}"
    return sp.Symbol(f"{base_name}{subscript}")


def replace_indexed_symbols(expression: sp.Expr) -> sp.Expr:
    return expression.subs(
        {
            s: to_symbol(s)
            for s in expression.free_symbols
            if isinstance(s, sp.Indexed)
        }
    )

In [None]:
plot_relativistic_k_matrix(
    n_resonances=2,
    n_channels=1,
    angular_momentum=L,
    title="Relativistic $K$-matrix, single channel",
)

{{ run_interactive }}

In [None]:
if STATIC_WEB_PAGE:
    output_file = "009-relativistic-K-matrix-n1-r2.png"
    plt.savefig(output_file, dpi=150)
    display(Image(output_file))