In [None]:
import os

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

```{autolink-concat}
```

# [R20] 2-dim K matrix <br>
## Plot of $T$ matrix elements, Agrand plot, Riemann sheet plot

In [None]:
%pip install -q ampform==0.14.8 plotly==5.17.0 sympy==1.12

## Generate transitions

In [None]:
from collections import defaultdict
from typing import Dict, List, Tuple

import ampform
import graphviz
import sympy as sp
from ampform import ReactionInfo
from ampform.dynamics.builder import TwoBodyKinematicVariableSet
from ampform.io import aslatex
from IPython.display import Latex, Markdown, Math
from qrules.particle import Particle
from qrules.transition import ReactionInfo

In [None]:
from __future__ import annotations

import qrules
from qrules.particle import ParticleCollection
from qrules.transition import ReactionInfo

FINAL_STATES: list[tuple[str, ...]] = [
    ["pi0", "eta", "gamma"],
    ["K0", "K~0", "gamma"],
]
REACTIONS: list[ReactionInfo] = [
    qrules.generate_transitions(
        initial_state="J/psi(1S)",
        final_state=final_state,
        # Different quantum numbers
        allowed_intermediate_particles=["a(0)(980)"],
        # Same quantum numbers
        # allowed_intermediate_particles=["N(1895)+", "N(1650)+"],
        allowed_interaction_types=["strong", "em"],
        formalism="canonical-helicity",
    )
    for final_state in FINAL_STATES
]

## Model formulation
Reactions $J\psi \rightarrow \pi^0 \eta \gamma$ and $J \psi \rightarrow K^0 \bar{K}^0 \gamma$

In [None]:
from collections import defaultdict

import ampform
import sympy as sp
from ampform.helicity import TwoBodyKinematicVariableSet
from qrules.particle import Particle

COLLECTED_X_SYMBOLS: dict[sp.Symbol, list[tuple[Particle, int]]] = defaultdict(set)


def create_dynamics_symbol(
    resonance: Particle, variable_pool: TwoBodyKinematicVariableSet
) -> Tuple[sp.Expr, Dict[sp.Symbol, float]]:
    L = sp.Rational(variable_pool.angular_momentum)
    J = sp.Rational(resonance.spin)
    Q = resonance.charge
    P = sp.Rational(resonance.parity)
    superscript = f"L={L},S={J}"
    subscript = f"Q={Q:+d},P={int(P):+d}"
    X = sp.Symbol(f"X^{{{superscript}}}_{{{subscript}}}")
    COLLECTED_X_SYMBOLS[X].add((resonance, variable_pool.angular_momentum))
    parameter_defaults = {}
    return X, parameter_defaults


MODELS = []
for reaction in REACTIONS:
    builder = ampform.get_builder(reaction)
    builder.adapter.permutate_registered_topologies()
    builder.scalar_initial_state_mass = True
    builder.stable_final_state_ids = [0, 1, 2]
    for resonance in reaction.get_intermediate_particles():
        builder.set_dynamics(resonance.name, create_dynamics_symbol)
    MODELS.append(builder.formulate())

In [None]:
from ampform.io import aslatex
from IPython.display import Math

selected_amplitudes = {
    k: v for i, (k, v) in enumerate(MODELS[0].amplitudes.items()) if i < 3
}

## K-matrix dynamics

In [None]:
for X, resonance_info in COLLECTED_X_SYMBOLS.items():
    for res, L in sorted(resonance_info):
        display(X)
        print(f"  {res.name:<20s} {res.mass:>8g} GeV  {res.width:>8g} GeV")

In [None]:
from dataclasses import dataclass


@dataclass
class TwoBodyDecay:  # specific to the channel
    child1: Particle
    child2: Particle


DECAYS = tuple(
    TwoBodyDecay(
        child1=reaction.final_state[0],
        child2=reaction.final_state[1],
    )
    for reaction in REACTIONS
)
s = sp.Symbol("m_01", real=True) ** 2

In [None]:
PARAMETERS_DEFAULTS = {}
for model in MODELS:
    PARAMETERS_DEFAULTS.update(model.parameter_defaults)

In [None]:
resonances, *_ = COLLECTED_X_SYMBOLS.values()

## Formulate $K$ matrix

In [None]:
n_channels = len(REACTIONS)
I = sp.Identity(n_channels)
K = sp.MatrixSymbol("K", n_channels, n_channels)
rho = sp.MatrixSymbol("rho", n_channels, n_channels)

In [None]:
from ampform.dynamics import BlattWeisskopfSquared, BreakupMomentumSquared
from ampform.dynamics.builder import TwoBodyKinematicVariableSet
from sympy.matrices.expressions.matexpr import MatrixElement
import numpy as np


def formulate_K_matrix(
    resonances: list[tuple[Resonances, int]], n_channels: int
) -> dict[MatrixElement, sp.Expr]:
    Kmatrix_expressions = {}
    for i in range(n_channels):
        for j in range(n_channels):
            resonance_contributions = []
            for res, L_R in resonances:
                s = sp.Symbol("m_01", real=True) ** 2
                m_a_i = sp.Symbol(Rf"m_{{0,{i}}}")
                m_b_i = sp.Symbol(Rf"m_{{1,{i}}}")
                m_a_j = sp.Symbol(Rf"m_{{0,{j}}}")
                m_b_j = sp.Symbol(Rf"m_{{1,{j}}}")
                w_R = sp.Symbol(Rf"\Gamma_{{{res.latex}}}")
                g_Ri = sp.Symbol(Rf"\gamma_{{{res.latex},{i}}}")
                g_Rj = sp.Symbol(Rf"\gamma_{{{res.latex},{j}}}")
                phi_Ri = sp.Symbol(Rf"\phi_{{{res.latex},{i}}}^\gamma")
                phi_Rj = sp.Symbol(Rf"\phi_{{{res.latex},{j}}}^\gamma")
                gamma_Ri = g_Ri * sp.exp(sp.I * phi_Ri)
                gamma_Rj = g_Rj * sp.exp(sp.I * phi_Rj)
                d_R = sp.Symbol(Rf"R_{{{res.latex}}}")
                q_i = BreakupMomentumSquared(s, m_a_i, m_b_i)
                q_j = BreakupMomentumSquared(s, m_a_j, m_b_j)
                ff_Ri = sp.sqrt(BlattWeisskopfSquared(L_R, q_i * d_R**2))
                ff_Rj = sp.sqrt(BlattWeisskopfSquared(L_R, q_j * d_R**2))
                m_R = sp.Symbol(Rf"m_{{{res.latex}}}")

                # Default parameter values
                parameter_defaults = {
                    m_a_i: DECAYS[i].child1.mass,
                    m_b_i: DECAYS[i].child2.mass,
                    m_a_j: DECAYS[j].child1.mass,
                    m_b_j: DECAYS[j].child2.mass,
                    w_R: res.width,
                    d_R: 1,
                    m_R: res.mass,
                    g_Ri: 1,
                    g_Rj: 1,
                    phi_Ri: np.pi / 2,
                    phi_Rj: np.pi / 2,
                }
                PARAMETERS_DEFAULTS.update(parameter_defaults)
                expr = (
                    gamma_Ri
                    * gamma_Rj
                    * ff_Ri
                    * ff_Rj
                    * m_R
                    * w_R
                    * gamma_Ri
                    * gamma_Rj
                ) / (s - m_R**2)
                resonance_contributions.append(expr)
            Kmatrix_expressions[K[i, j]] = sum(resonance_contributions)

    return Kmatrix_expressions


K_expressions = formulate_K_matrix(resonances, n_channels=len(REACTIONS))
Math(aslatex(K_expressions))
K_matrix = K.as_explicit()

In [None]:
K.as_explicit().xreplace(K_expressions)

In [None]:
import sympy as sp
from ampform.kinematics.phasespace import Kallen
from __future__ import annotations

import warnings
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import sympy as sp
from ampform.io import aslatex
from ampform.sympy import unevaluated
from IPython.display import Math
from ipywidgets import FloatSlider, VBox, interactive_output

warnings.filterwarnings("ignore")


@unevaluated
class RotatedSqrt(sp.Expr):
    z: Any
    phi: Any
    _latex_repr_ = R"\sqrt[{phi}]{{{z}}}"

    def evaluate(self) -> sp.Expr:
        z, phi = self.args
        return sp.exp(-phi * sp.I / 2) * sp.sqrt(z * sp.exp(phi * sp.I))


@unevaluated(real=False)
class PhaseSpaceFactor_Rotated(sp.Expr):
    s: Any
    m1: Any
    m2: Any
    phi: 0
    _latex_repr_ = R"\rho^\phi_{{{m1}, {m2}}}\left({s}\right)"

    def evaluate(self) -> sp.Expr:
        s, m1, m2, phi = self.args
        return RotatedSqrt((s - ((m1 + m2) ** 2)) * (s - (m1 - m2) ** 2) / s**2, phi)


phi = sp.Symbol("phi")


@unevaluated(real=False)
class PhaseSpaceFactor_Kaellen(sp.Expr):
    s: Any
    m1: Any
    m2: Any
    _latex_repr_ = R"\rho_{{{m1}, {m2}}}\left({s}\right)"

    def evaluate(self) -> sp.Expr:
        s, m1, m2 = self.args
        return 2 * BreakupMomentum(s, m1, m2) / sp.sqrt(s)


@unevaluated(real=False)
class PhaseSpaceCM(sp.Expr):
    s: Any
    m1: Any
    m2: Any
    _latex_repr_ = R"\rho^\mathrm{{CM}}_{{{m1},{m2}}}\left({s}\right)"

    def evaluate(self) -> sp.Expr:
        s, m1, m2 = self.args
        return -16 * sp.pi * sp.I * ChewMandelstam(s, m1, m2)


@unevaluated(real=False)
class ChewMandelstam(sp.Expr):
    s: Any
    m1: Any
    m2: Any
    _latex_repr_ = R"\Sigma\left({s}\right)"

    def evaluate(self) -> sp.Expr:
        s, m1, m2 = self.args
        q = BreakupMomentum(s, m1, m2)
        return (
            1
            / (16 * sp.pi**2)
            * (
                (2 * q / sp.sqrt(s))
                * sp.log((m1**2 + m2**2 - s + 2 * sp.sqrt(s) * q) / (2 * m1 * m2))
                - (m1**2 - m2**2) * (1 / s - 1 / (m1 + m2) ** 2) * sp.log(m1 / m2)
            )
        )


@unevaluated(real=False)
class BreakupMomentum(sp.Expr):
    s: Any
    m1: Any
    m2: Any
    _latex_repr_ = R"q\left({s}\right)"

    def evaluate(self) -> sp.Expr:
        s, m1, m2 = self.args
        return sp.sqrt(Kallen(s, m1**2, m2**2)) / (2 * sp.sqrt(s))


m1, m2 = sp.symbols("m1 m2")
s = sp.Symbol("m_01", real=True) ** 2

rho_expr_Kaellen = PhaseSpaceFactor_Kaellen(s, m1, m2)
rho_cm_expr = PhaseSpaceCM(s, m1, m2)
cm_expr = ChewMandelstam(s, m1, m2)
q_expr = BreakupMomentum(s, m1, m2)
kallen = Kallen(*sp.symbols("x:z"))
Math(
    aslatex({
        e: e.doit(deep=False)
        for e in [rho_expr_Kaellen, rho_cm_expr, cm_expr, q_expr, kallen]
    })
)

In [None]:
import numpy as np
import sympy as sp
from sympy.matrices import MatrixSymbol


def formulate_Phsp_matrix(n_channels: int) -> dict[sp.MatrixElement, sp.Expr]:
    matrix_expressions = {}

    for i in range(n_channels):
        for j in range(n_channels):
            if i == j:
                m_a_i = sp.Symbol(Rf"m_{{0,{i}}}")
                m_b_i = sp.Symbol(Rf"m_{{1,{i}}}")
                rho_i = PhaseSpaceFactor_Rotated(s, m_a_i, m_b_i, phi=0T_I_func = sp.lambdify(symbols, T_I_cm_expr_00.doit())
T_II_func = sp.lambdify(symbols, T_II_rho_expr_00.doit())

CM_I_func = sp.lambdify(symbols, CM_I_cm_expr_00.doit())
Rho_II_func = sp.lambdify(symbols, Rho_II_rho_expr_00.doit())

values = {
    ma1: 0.1,
    mb1: 0.2,
    ma2: 0.4,
    mb2: 0.5,
    m0: 1.0,
    w0: 0.5,
    gamma1: 1,
    gamma2: 1,
}

args = eval(str(symbols[1:].xreplace(values)))

epsilon = 1e-5
x = np.linspace(0, 2.5, num=200)
y = np.linspace(epsilon, 1, num=100)
X, Y = np.meshgrid(x, y)
Zn = X - Y * 1j
Zp = X + Y * 1j

T1n = T_I_func(Zn**2, *args)
T1p = T_I_func(Zp**2, *args)

T2n = T_II_func(Zn**2, *args)
T2p = T_II_func(Zp**2, *args)


fig, axes = plt.subplots(figsize=(10, 6), ncols=1, sharey=True)
ax1 = axes

ax1.set_xlabel(R"Re s")
ax1.set_ylabel(R"Im T(s)")

ax1.plot(x, T1n[0].imag, label=R"$T_\mathrm{I}(s-0i)$")
fig.tight_layout()
plt.show()
def sty(sheet_name: str) -> dict:
    sheet_color = sheet_colors[sheet_name]
    n_lines = 16
    return dict(
        cmin=-vmax,
        cmax=+vmax,
        colorscale=[[0, "rgb(0, 0, 0)"], [1, sheet_color]],
        contours=dict(
            x=dict(
                show=True,
                start=x.min(),
                end=x.max(),
                size=(x.max() - x.min()) / n_lines,
                color="black",
            ),
            y=dict(
                show=True,
                start=-y.max(),
                end=+y.max(),
                size=(y.max() - y.min()) / (n_lines // 2),
                color="black",
            ),
        ),
        name=sheet_name,
        opacity=0.4,
        showscale=False,
    )


vmax = 3
project = np.imag
sheet_colors = {
    "Physical (T1)": "blue",
    "Unphysical (T2)": "red",
    "Unphysical (T3)": "green",
    "Unphysical (T4)": "yellow",
}
Selection deleted
T_I_func = sp.lambdify(symbols, T_I_cm_expr_00.doit())
T_II_func = sp.lambdify(symbols, T_II_rho_expr_00.doit())


Tn = T_II_func(Zp**2, *args)
Tp = T_II_func(Zn**2, *args)

CMp = CM_I_func(Zp**2, *args)
CMn = CM_I_func(Zn**2, *args)
Rhop = Rho_II_func(Zp**2, *args)


S1 = go.Surface(x=X, y=Y, z=16*np.pi*CMp.imag, **sty("Physical (T1)"))
S2 = go.Surface(x=X, y=-Y, z=16*np.pi*CMn.imag, **sty("Unphysical (T2)"))
S3 = go.Surface(x=X, y=-Y, z=Rhop.real, **sty("Unphysical (T3)"))
zeros = np.zeros(x.shape)

lineshape1 = go.Scatter3d(
    x=x,
    y=zeros,
    z=project(CMn),
    line=dict(color="yellow", width=10),
    mode="lines",
    name="Lineshape",
)

lineshape2 = go.Scatter3d(
    x=x,
    y=zeros,
    z=project(S1),
    line=dict(color="yellow", width=10),
    mode="lines",
    name="Lineshape",
)

fig = go.Figure(data=[S1,S2,lineshape1,lineshape2])
fig.update_layout(
    height=550,
    margin=dict(l=0, r=0, t=30, b=0),
    showlegend=False,
)
fig.update_scenes(
    camera_center=dict(z=-0.2),
    xaxis_title_text="Re s",
    yaxis_title_text="Im s",
    zaxis_range=[-vmax, +vmax],
    zaxis_title_text=R"Im $\rho$ (s)",
)

fig.show()




0
8
riemann-sheets
Python 3 (ipykernel) | Idle
1
Riemann_for_1channel.ipynb
Ln 1, Col 1
Mode: Command
)
                matrix_expressions[rho[i, j]] = rho_i
                parameter_defaults = {
                    m_a_i: DECAYS[i].child1.mass,
                    m_b_i: DECAYS[i].child2.mass,
                }
                PARAMETERS_DEFAULTS.update(parameter_defaults)
            else:
                matrix_expressions[rho[i, j]] = 0

    return matrix_expressions


Phsp_expressions = formulate_Phsp_matrix(n_channels=len(REACTIONS))
rho.as_explicit().xreplace(Phsp_expressions).doit()

## $T$ matrix 
In this section the $T$ matrix element $T_{00}$ for the channel $p\bar{p}\eta$ is visualized to investigate the lineshape.

In [None]:
T = (I - sp.I * K * rho).inv() * K
T_matrix = T.as_explicit()
T_matrix

## Merge dictionarys

In [None]:
combined_expressions = {**K_expressions, **Phsp_expressions}

In [None]:
from tensorwaves.function.sympy import create_parametrized_function
from tensorwaves.interface import ParametrizedFunction

%pip install tensorwaves[jax]


def create_function(expr: sp.Expr) -> ParametrizedFunction:
    parameter_symbols = set(expr.free_symbols) - s.free_symbols
    parameter_defaults = {
        k: v for k, v in PARAMETERS_DEFAULTS.items() if k in parameter_symbols
    }
    return create_parametrized_function(expr, parameter_defaults, backend="jax")


T_funcs = np.array([
    [
        create_function(T_matrix[i, j].xreplace(combined_expressions).doit())
        for i in range(n_channels)
    ]
    for j in range(n_channels)
])
K_funcs = np.array([
    [
        create_function(K_matrix[i, j].xreplace(combined_expressions).doit())
        for i in range(n_channels)
    ]
    for j in range(n_channels)
])
T_funcs[0, 0].parameters

## Create widgets

In [None]:
import ipywidgets as w

matrix_index = [i for i in range(n_channels)]
sliders = {
    "i": w.RadioButtons(
        description="T matrix K matrix index 0", options=matrix_index
    ),
    "j": w.RadioButtons(
        description="T matrix K matrix index 1", options=matrix_index
    ),
    R"\pm": w.RadioButtons(description=R"Sign", options=[-1, +1]),
    "complex_rendering": w.RadioButtons(
        description="Complex Rendering:", options=["imag", "real"]
    ),
    "z_cutoff": w.FloatSlider(
        description=R"z_cutoff", max=1, value=0.5, continuous_update=False
    ),
}


for funcs_matrix in T_funcs:
    for func in funcs_matrix.flatten():
        for par, value in func.parameters.items():
            value = complex(value).real
            step_ = 0.01
            if par.startswith("m_{"):
                min_ = 0.01
                max_ = 4
            elif par.startswith(R"phi"):
                min_ = 0
                max_ = np.pi
                step_ = np.pi / 8
            else:
                min_ = 0
                max_ = 2 * value
            sliders[par] = w.FloatSlider(
                continuous_update=False,
                description=Rf"${par}$",
                max=max_,
                min=min_,
                step=step_,
                value=value,
            )

slider_group = w.VBox(list(sliders.values()))

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
from ipywidgets import FloatSlider, VBox, interactive_output
from matplotlib import cm
from matplotlib.lines import Line2D
from mpl_toolkits.mplot3d import Axes3D
from scipy.interpolate import griddata
import jax.numpy as jnp

fig, axs = plt.subplots(3, 1, figsize=(7, 10), tight_layout=True, sharex=False)
((ax_1, ax_2, ax_3)) = axs

x_2D = np.linspace(-2.0, 2.0, num=8000)
data = {"m_01": x_2D}

LINES: list[Line2D] | None = None


def set_parameters_KT(func, parameters):
    filtered_parameters = {
        k: v for k, v in parameters.items() if k in func.parameters
    }
    func.update_parameters(filtered_parameters)


selected_threshs = []
for i in range(n_channels):
    channel_count = +1
    selected_threshs.append(
        T_funcs[channel_count, i].parameters[f"m_{{0,{i}}}"]
        + T_funcs[channel_count, i].parameters[f"m_{{1,{i}}}"]
    )


def plot(*, i, j, z_cutoff, complex_rendering, **parameters):
    global LINES
    data[R"\pm"] = parameters[R"\pm"]
    ax_1.set_title(Rf"Agrand plot")
    ax_1.set_xlabel(Rf"$Re(T)$")
    ax_1.set_ylabel(Rf"$Im(T)$")
    ax_2.set_title(Rf"T matrix element for channel ${{{i},{j}}}$")
    ax_2.set_xlabel(Rf"$m$ [GeV]")
    ax_2.set_ylabel(Rf"$T_{{{i},{j}}}$")
    ax_3.set_title(Rf"K matrix element for channel ${{{i},{j}}}$")
    ax_3.set_xlabel(Rf"$m$ [GeV]")
    ax_3.set_ylabel(Rf"$K_{{{i},{j}}}$")
    set_parameters_KT(T_funcs[i, j], parameters)
    set_parameters_KT(K_funcs[i, j], parameters)
    y_T = T_funcs[i, j](data)
    y_real_T = y_T.real
    y_imag_T = y_T.imag
    y_abs_T = np.abs(y_T) ** 2

    y_K = K_funcs[i, j](data)
    y_real_K = y_K.real
    y_imag_K = y_K.imag
    y_abs_K = np.abs(y_K) ** 2

    resonance_masses = {}
    for par, value in T_funcs[i, j].parameters.items():
        if par.startswith("m_{N"):
            resonance_masses[par] = value
    if LINES is None:
        LINES = [
            ax_1.plot(
                y_real_T,
                y_imag_T,
                c="black",
                lw=2,
                label=R"Real",
            )[0],
            ax_2.plot(
                x_2D,
                y_abs_T,
                c="red",
                lw=2,
                label=R"Abs",
            )[0],
            ax_2.plot(
                x_2D,
                y_imag_T,
                c="blue",
                lw=2,
                label=R"Imag",
            )[0],
            ax_2.plot(
                x_2D,
                y_real_T,
                c="Magenta",
                lw=2,
                label=R"Real",
            )[0],
            ax_3.plot(
                x_2D,
                y_abs_K,
                c="red",
                lw=2,
                label=R"Abs",
            )[0],
            ax_3.plot(
                x_2D,
                y_imag_K,
                c="blue",
                lw=2,
                label=R"Imag",
            )[0],
            ax_3.plot(
                x_2D,
                y_real_K,
                c="Magenta",
                lw=2,
                label=R"Real",
            )[0],
        ]
        pos_count = -1
        for res, value in resonance_masses.items():
            pos_count += 1
            LINES.append(
                ax_1.axvline(
                    value,
                    linestyle="dashed",
                    color=f"C{pos_count}",
                    label=Rf"${res}$",
                ),
            )
        ax_2.legend()
    else:
        LINES[0].set_data(y_real_T, y_imag_T)
        LINES[1].set_data(x_2D, y_real_T)
        LINES[2].set_data(x_2D, y_abs_T)
        LINES[3].set_data(x_2D, y_imag_T)
        LINES[4].set_data(x_2D, y_abs_K)
        LINES[5].set_data(x_2D, y_imag_K)
        LINES[6].set_data(x_2D, y_real_K)
        for k, (res, value) in enumerate(resonance_masses.items(), 7):
            LINES[k].set_data([value, value], [0, 1])
    ax_1.relim()
    ax_1.autoscale_view()
    ax_2.relim()
    ax_2.autoscale_view()
    ax_3.relim()
    ax_3.autoscale_view()
    fig.canvas.draw()


UI = VBox([slider_group])
output = interactive_output(plot, controls=sliders)
display(VBox([UI, output]))