# [TR-005] K-matrix

<!-- cspell:ignore Cayley -->

This report investigates how to implement $K$-matrix dynamics with {doc}`SymPy <sympy:index>`. The challenge is to generate a correct parametrization for an arbitrary **number of coupled channels $n$** and an arbitrary **number of resonances $n_R$**.

We followed the physics as described by {pdg-review}`Resonances` and {cite}`chungPartialWaveAnalysis1995,petersPartialWaveAnalysis2004,meyerMatrixTutorial2008`.

We simply construct an $n \times n$ {class}`sympy.Matrix <sympy.matrices.dense.MutableDenseMatrix>` with {class}`~sympy.core.symbol.Symbol`s as its elements. We then use substitute these {class}`~sympy.core.symbol.Symbol`s with certain parametrizations using {meth}`~sympy.core.basic.Basic.subs`. In order to generate symbols for $n_R$ resonances and $n$ channels, we use {doc}`indexed symbols <modules/tensor/indexed>`.

This approach is less elegant and (theoretically) slower than using {class}`~sympy.matrices.expressions.MatrixSymbol`s. That approach is explored in {doc}`/report/007`.

In [None]:
%matplotlib widget
import os
import warnings
from typing import Tuple

import matplotlib.pyplot as plt
import mpl_interactions.ipyplot as iplt
import numpy as np
import symplot
import sympy as sp
from IPython.display import Image, Math
from ipywidgets import widgets as ipywidgets
from matplotlib import cm
from mpl_interactions.controller import Controls

warnings.filterwarnings("ignore")
STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

## Non-relativistic

The non-relativistic case is simplest and allows us to check whether the case $m=1, n=1$ (single resonance, single channel) reduces to a non-relativistic Breit-Wigner function.[^1]

[^1]: Of course, there is no need to work with matrices in this $1 \times 1$ case. To keeps things general, however, we keep using matrices.

### Procedure

Now, it would be nice to use a {class}`~sympy.core.symbol.Symbol` to represent $n$ and $n_R$, but this does not work well in the {class}`~sympy.concrete.summations.Sum` class.

In [None]:
n, n_resonances = sp.symbols("n n_R", integer=True, positive=True)

We therefore set these variables to a specific {obj}`int` value and define some other {class}`~sympy.core.symbol.Symbol`s for the rest of the implementation.[^2]

[^2]: We use {class}`~sympy.core.symbol.Symbol`s as indices, because that renders more nicely.

In [None]:
n = 1
n_resonances = 1
i, j, R = sp.symbols("i j R", integer=True, negative=False)
m = sp.Symbol("m", real=True)
M = sp.IndexedBase("M", shape=(n_resonances,))
Gamma = sp.IndexedBase("Gamma", shape=(n_resonances,))
g = sp.IndexedBase("g", shape=(n_resonances, n))
gamma = sp.IndexedBase("gamma", shape=(n_resonances, n))

The $\boldsymbol{K}$-matrix constitutes the $\boldsymbol{T}$-matrix (which expresses the transition amplitudes in which we are interested) as follows:

$$
\boldsymbol{T} = \boldsymbol{K}(\boldsymbol{I} - i\boldsymbol{K})^{-1}
$$ (T-matrix)

This expression comes from a [Cayley transformation](https://en.wikipedia.org/wiki/Cayley_transform) on the $\boldsymbol{S}$-matrix, which ensures that the $\boldsymbol{K}$-matrix is real. The challenge is now to choose a correct parametrization. There are several choices, but a common one is the following summation over the resonances $R$:

In [None]:
def Kij(i, j, n_resonances) -> sp.Expr:
    parametrization = (g[R, i] * g[R, j]) / (M[R] ** 2 - m ** 2)
    return sp.Sum(parametrization, (R, 0, n_resonances - 1))

In [None]:
n_R = sp.Symbol("n_R")
kij = Kij(i, j, n_R)
Math(
    "K_{ij} = "
    + f"{sp.latex(kij)} = {sp.latex(kij.subs(n_R, 1).doit())}"
)

where the residue constants are often further parametrized by (see {cite}`chungPartialWaveAnalysis1995`, Eq. (82)):

$$
g_{R,i}=\gamma_{R,i}\sqrt{m_R\Gamma_R}
$$ (residue-constant-g)

In {mod}`sympy`, we now define the $\boldsymbol{K}$-matrix in terms of a {class}`Matrix <sympy.matrices.dense.MutableDenseMatrix>` with {class}`~sympy.tensor.indexed.IndexedBase` instances as elements that can serve as {class}`~sympy.core.symbol.Symbol`s.

In [None]:
K_symbol = sp.IndexedBase("K", shape=(n, n))
K = sp.Matrix([[K_symbol[i, j] for j in range(n)] for i in range(n)])
display(K_symbol[i, j], K)

The $\boldsymbol{T}$-matrix can now be computed from Eq. {eq}`T-matrix`:

In [None]:
T = K * (sp.eye(n) - sp.I * K).inv()
T

Next, we need to substitute the elements $K_{i,j}$ with the parametrization we defined above. Since we chose $n=1$, let's focus only on $i=0$ and $j=0$ and investigate what the $T$-matrix looks like after substituting $K_{i,j}$:

In [None]:
i_eval, j_eval = 0, 0
T_channel = T[i_eval, j_eval]
T_channel_subs = T_channel.subs(
    {
        K[i, j]: Kij(i, j, n_resonances)
        for i in range(n)
        for j in range(n)
    }
).subs({i: i_eval, j: j_eval})
T_channel_subs

In [None]:
T_channel_subs = T_channel_subs.doit()

Next, we evaluate the {class}`~sympy.concrete.summations.Sum` and substitute the residue constants with Eq. {eq}`residue-constant-g`

In [None]:
T_channel_subs = T_channel_subs.doit()
T_channel_subs = T_channel_subs.subs(
    {
        g[R, i]: gamma[R, i] * sp.sqrt(M[R] * Gamma[R])
        for R in range(n_resonances)
        for i in range(n)
    }
)

And indeed, the resulting element from the $\boldsymbol{T}$-matrix looks like a non-relativistic Breit-Wigner function!

In [None]:
if n_resonances == 1 or n == 2:
    T_channel_subs = T_channel_subs.simplify()
T_channel_subs

### Generalization

The above procedure has been condensed into a function that can handle an arbitrary number of resonances and an arbitrary number of channels.

In [None]:
def k_matrix(
    n_resonances: int,
    n_channels: int,
    *,
    channel: Tuple[int, int] = (0, 0),
) -> sp.Expr:
    # Define symbols
    m = sp.Symbol("m", real=True)
    M = sp.IndexedBase("m", shape=(n_resonances,))
    Gamma = sp.IndexedBase("Gamma", shape=(n_resonances,))
    gamma = sp.IndexedBase("gamma", shape=(n_resonances, n_channels))
    # Define K-matrix and T-matrix
    K_symbol = sp.IndexedBase("K", shape=(n_channels, n_channels))
    K = sp.Matrix(
        [
            [K_symbol[i, j] for j in range(n_channels)]
            for i in range(n_channels)
        ]
    )
    T = K * (sp.eye(n_channels) - sp.I * K).inv()

    # Define parametrization
    def Kij(i, j) -> sp.Expr:
        g_i = gamma[R, i] * sp.sqrt(M[R] * Gamma[R])
        g_j = gamma[R, j] * sp.sqrt(M[R] * Gamma[R])
        parametrization = (g_i * g_j) / (M[R] ** 2 - m ** 2)
        return sp.Sum(parametrization, (R, 0, n_resonances - 1))

    # Substitute elements
    T_channel = T[channel[0], channel[1]]
    T_channel = T_channel.subs(
        {
            K[i, j]: Kij(i, j)
            for i in range(n_channels)
            for j in range(n_channels)
        }
    )
    # Evaluate summation
    T_channel = T_channel.doit()
    # Replace IndexedBase with Symbols
    T_channel = (
        T_channel.subs(
            {M[R]: sp.Symbol(f"m{R}") for R in range(n_resonances)}
        )
        .subs(
            {
                Gamma[R]: sp.Symbol(f"Gamma{R}")
                for R in range(n_resonances)
            }
        )
        .subs(
            {
                gamma[R, i]: sp.Symbol(fR"\gamma_{{{R},{i}}}")
                for R in range(n_resonances)
                for i in range(n_channels)
            }
        )
    )
    return T_channel

In [None]:
k_matrix(n_resonances=1, n_channels=1).simplify()

In [None]:
k_matrix(n_resonances=2, n_channels=1)

In [None]:
k_matrix(n_resonances=1, n_channels=2).simplify()

### Visualization

Now, let's use {mod}`symplot` to visualize the single channel $\boldsymbol{K}$-matrix for arbitrary $n_R$.

In [None]:
def plot_k_matrix(
    n_channels: int,
    n_resonances: int,
    channel: Tuple[int, int] = (0, 0),
    render_math: bool = False,
    simplify: bool = False,
    title: str = "",
) -> None:
    # Convert to Symbol: symplot cannot handle IndexedBase
    expr = k_matrix(
        n_channels=n_channels,
        n_resonances=n_resonances,
        channel=channel,
    )
    np_expr, sliders = symplot.prepare_sliders(expr, plot_symbol=m)

    # Set plot domain
    x_min, x_max = 1e-3, 3
    y_min, y_max = -0.5, +0.5
    z_min, z_max = -2, +2

    plot_domain = np.linspace(x_min, x_max, num=500)
    x_values = np.linspace(x_min, x_max, num=160)
    y_values = np.linspace(y_min, y_max, num=80)
    X, Y = np.meshgrid(x_values, y_values)
    plot_domain_complex = X + Y * 1j

    z_cut_min = 0.75 * z_min
    z_cut_max = 0.75 * z_max
    cut_off_min = np.vectorize(
        lambda z: z if z > z_cut_min else z_cut_min
    )
    cut_off_max = np.vectorize(
        lambda z: z if z < z_cut_max else z_cut_max
    )

    # Set slider values and ranges
    m0_values = np.linspace(x_min, x_max, num=n_resonances + 2)
    m0_values = m0_values[1:-1]

    def set_default_values():
        for R in range(n_resonances):
            # ranges
            sliders.set_ranges({f"m{R}": (0, 3, 100)})
            sliders.set_ranges({f"Gamma{R}": (-1, 1, 100)})
            for i in range(n_channels):
                sliders.set_ranges(
                    {fR"\gamma_{{{R},{i}}}": (0, 2, 100)}
                )
            # values
            sliders.set_values({f"m{R}": m0_values[R]})
            sliders.set_values({f"Gamma{R}": (R + 1) * 0.1})
            for i in range(n_channels):
                sliders.set_values({fR"\gamma_{{{R},{i}}}": 1})

    set_default_values()

    # Create interactive plots
    controls = Controls(**sliders)
    nrows = 2  # set to 3 for imag+real
    fig, axes = plt.subplots(
        nrows=nrows,
        figsize=(8, nrows * 3.0),
        sharex=True,
        tight_layout=True,
    )
    if not title:
        title = (
            fR"${n_channels} \times {n_channels}$ $K$-matrix"
            f" with {n_resonances} resonances"
            f" ― channel {channel}"
        )
    fig.suptitle(title)

    # 2D plot
    axes[0].set_ylabel("$|T|^{2}$")
    iplt.plot(
        plot_domain,
        lambda *args, **kwargs: np.abs(np_expr(*args, **kwargs)) ** 2,
        ax=axes[0],
        controls=controls,
    )
    mass_line_style = dict(
        c="red",
        alpha=0.3,
    )
    for name in controls.params:
        if not name.startswith("m"):
            continue
        iplt.axvline(controls[name], ax=axes[0], **mass_line_style)

    # 3D plot
    def plot3(**kwargs):
        Z = np_expr(plot_domain_complex, **kwargs)
        Z_imag = cut_off_min(cut_off_max(Z.imag))
        for ax in axes[1:]:
            ax.clear()
        axes[-1].pcolormesh(X, Y, Z_imag, cmap=cm.coolwarm)
        axes[-1].set_title("Im $T$")
        if len(axes) == 3:
            Z_real = cut_off_min(cut_off_max(Z.real))
            axes[-2].pcolormesh(X, Y, Z_real, cmap=cm.coolwarm)
            axes[-2].set_title("Re $T$")
        for ax in axes[1:]:
            ax.axhline(0, linewidth=0.5, c="black", linestyle="dotted")
            for R in range(n_resonances):
                mass = kwargs[f"m{R}"]
                ax.axvline(mass, **mass_line_style)
            ax.set_ylabel("Im $m$")
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_facecolor("white")
        for R in range(n_resonances):
            mass = kwargs[f"m{R}"]
            axes[-1].text(
                x=mass + (x_max - x_min) * 0.008,
                y=0.95 * y_min,
                s=f"$m_{R}$",
                c="red",
            )
        axes[-1].set_xlabel("Re $m$")
        fig.canvas.draw_idle()

    # Create GUI
    sliders_copy = dict(sliders)
    h_boxes = []
    for R in range(n_resonances):
        buttons = [
            sliders_copy.pop(f"m{R}"),
            sliders_copy.pop(f"Gamma{R}"),
        ]
        if n_channels == 1:
            symbol_to_arg = {
                symbol: arg
                for arg, symbol in sliders._SliderKwargs__arg_to_symbol.items()
            }
            dummy_name = symbol_to_arg[fR"\gamma_{{{R},0}}"]
            buttons.append(sliders_copy.pop(dummy_name))
        h_box = ipywidgets.HBox(buttons)
        h_boxes.append(h_box)
    remaining_sliders = sorted(
        sliders_copy.values(), key=lambda s: s.description
    )
    ui = ipywidgets.VBox(h_boxes + remaining_sliders)
    output = ipywidgets.interactive_output(plot3, controls=sliders)
    display(ui, output)
    if render_math:
        if simplify:
            expr = expr.simplify()
        display(expr)

In [None]:
plot_k_matrix(
    n_resonances=3, n_channels=1, channel=(0, 0), render_math=True
)

In [None]:
if STATIC_WEB_PAGE:
    output_file = "005-K-matrix-n1-r3.png"
    plt.savefig(output_file, dpi=150)
    display(Image(output_file))

In [None]:
plot_k_matrix(n_resonances=2, n_channels=2, channel=(0, 0))

In [None]:
if STATIC_WEB_PAGE:
    output_file = "005-K-matrix-n2-r2-00.png"
    plt.savefig(output_file, dpi=150)
    display(Image(output_file))