In [None]:
import os

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

```{autolink-concat}
```

::::{margin}
:::{card} Visualization of the second Riemann sheet with rotated branch cut 
TR-979
^^^
To reproduce the Riemann sheets shown in [Fig.&nbsp;50.1](https://pdg.lbl.gov/2023/reviews/rpp2023-rev-resonances.pdf#page=2) and [50.2](https://pdg.lbl.gov/2023/reviews/rpp2023-rev-resonances.pdf#page=4) of the PDG way to get from the first physical sheet to the second unphysical sheet is derived for the Amplitude calculated within the K matrix formalism.   
:::
::::

# Riemann sheets for one channel 

In [None]:
%pip install -q ampform==0.14.8 plotly==5.17.0 sympy==1.12

In [None]:
from __future__ import annotations

import warnings
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from ampform.io import aslatex
from ampform.kinematics.phasespace import Kallen
from ampform.sympy import unevaluated
from IPython.display import Math, display
from ipywidgets import FloatSlider, VBox, interactive_output

warnings.filterwarnings("ignore")

## Phase space factors

In [None]:
@unevaluated(real=False)
class PhaseSpaceFactorRotated(sp.Expr):
    s: Any
    m1: Any
    m2: Any
    phi: Any
    _latex_repr_ = R"\rho^{{{phi}}}_{{{m1}, {m2}}}\left({s}\right)"

    def evaluate(self) -> sp.Expr:
        s, m1, m2, phi = self.args
        return RotatedSqrt((s - ((m1 + m2) ** 2)) * (s - (m1 - m2) ** 2), phi) / s


@unevaluated
class RotatedSqrt(sp.Expr):
    z: Any
    phi: Any
    _latex_repr_ = R"\sqrt[{phi}]{{{z}}}"

    def evaluate(self) -> sp.Expr:
        z, phi = self.args
        return sp.exp(-phi * sp.I / 2) * sp.sqrt(z * sp.exp(phi * sp.I))


s, m1, m2, z, phi = sp.symbols("s m1 m2 z phi")
rho_expr_rot = PhaseSpaceFactorRotated(s, m1, m2, phi)
sqrt_expr = RotatedSqrt(z, phi)
Math(aslatex({e: e.doit(deep=False) for e in [rho_expr_rot, sqrt_expr]}))

In [None]:
@unevaluated(real=False)
class PhaseSpaceFactorCM(sp.Expr):
    s: Any
    m1: Any
    m2: Any
    _latex_repr_ = R"\rho^\mathrm{{CM}}_{{{m1},{m2}}}\left({s}\right)"

    def evaluate(self) -> sp.Expr:
        s, m1, m2 = self.args
        return -16 * sp.pi * sp.I * ChewMandelstam(s, m1, m2)


@unevaluated(real=False)
class ChewMandelstam(sp.Expr):
    s: Any
    m1: Any
    m2: Any
    _latex_repr_ = R"\Sigma\left({s}\right)"

    def evaluate(self) -> sp.Expr:
        s, m1, m2 = self.args
        q = BreakupMomentum(s, m1, m2)
        return (
            1
            / (16 * sp.pi**2)
            * (
                (2 * q / sp.sqrt(s))
                * sp.log((m1**2 + m2**2 - s + 2 * sp.sqrt(s) * q) / (2 * m1 * m2))
                - (m1**2 - m2**2) * (1 / s - 1 / (m1 + m2) ** 2) * sp.log(m1 / m2)
            )
        )


@unevaluated(real=False)
class BreakupMomentum(sp.Expr):
    s: Any
    m1: Any
    m2: Any
    _latex_repr_ = R"q\left({s}\right)"

    def evaluate(self) -> sp.Expr:
        s, m1, m2 = self.args
        return sp.sqrt(Kallen(s, m1**2, m2**2)) / (2 * sp.sqrt(s))


s, m1, m2 = sp.symbols("s m1 m2")
rho_cm_expr = PhaseSpaceFactorCM(s, m1, m2)
cm_expr = ChewMandelstam(s, m1, m2)
q_expr = BreakupMomentum(s, m1, m2)
kallen = Kallen(*sp.symbols("x:z"))
Math(aslatex({e: e.doit(deep=False) for e in [rho_cm_expr, cm_expr, q_expr, kallen]}))

## T matrix

### First sheet

In [None]:
class DiagonalMatrix(sp.DiagonalMatrix):
    def _latex(self, printer):
        return printer._print(self.args[0])


n_channels = 1
I = sp.Identity(n_channels)
K = sp.MatrixSymbol("K", n_channels, n_channels)
CM = DiagonalMatrix(sp.MatrixSymbol(R"\rho^\mathrm{CM}", n_channels, n_channels))
T1 = (I - sp.I * K * CM).inv() * K
T1

In [None]:
T1_explicit = T1.as_explicit()
T1_explicit[0, 0]

In [None]:
m0, w0, gamma, phi = sp.symbols("m0 Gamma gamma phi")
symbols = sp.Tuple(s, m1, m2, m0, w0, gamma, phi)

In [None]:
definitions = {
    K[0, 0]: (gamma * gamma * m0 * w0) / (s - m0**2),
    CM[0, 0]: -PhaseSpaceFactorCM(s, m1, m2),
}
T1_expr = T1_explicit.xreplace(definitions)
T1_expr[0, 0].simplify(doit=False)

### Second Riemann sheet

In [None]:
rho = DiagonalMatrix(sp.MatrixSymbol("rho", n_channels, n_channels))
T2 = (T1.inv() + 2 * sp.I * rho).inv()
T2

In [None]:
T2_explicit = T2.as_explicit()
T2_explicit[0, 0]

In [None]:
definitions = {
    **definitions,
    rho[0, 0]: PhaseSpaceFactorRotated(s, m1, m2, phi),
}

In [None]:
T2_expr = T2_explicit.xreplace(definitions)
T2_expr[0, 0].simplify(doit=False)

In [None]:
T1_expr[0, 0].simplify(doit=False)

## Visualization

In [None]:
T1_func = sp.lambdify(symbols, T1_expr[0, 0].doit())
T2_func = sp.lambdify(symbols, T2_expr[0, 0].doit())
parameter_defaults = {
    m1: 0.1,
    m2: 0.7,
    m0: 1.5,
    w0: 0.8,
    gamma: 1,
    phi: 0,
}
args = eval(str(symbols[1:].xreplace(parameter_defaults)))
s_thr = ((m1 + m2) ** 2).xreplace(parameter_defaults)

In [None]:
%matplotlib widget


def plot3d(phi, T_max) -> None:
    global DATA
    args_rot = [*args[:-1], phi]
    Tn = T2_func(Zn, *args_rot)
    Tp = T1_func(Zp, *args_rot)
    if DATA is None:
        DATA = {
            "neg": ax.pcolormesh(X, -Y, Tn.imag, **style),
            "pos": ax.pcolormesh(X, +Y, Tp.imag, **style),
        }
    else:
        DATA["neg"].set_array(Tn.imag)
        DATA["pos"].set_array(Tp.imag)
    DATA["neg"].set_clim(-T_max, +T_max)
    DATA["pos"].set_clim(-T_max, +T_max)
    fig.canvas.draw_idle()


X, Y = np.meshgrid(
    np.linspace(0, 5, num=200),
    np.linspace(0, 4, num=100),
)
Zn = X - Y * 1j
Zp = X + Y * 1j

T_max = 0.5
DATA = None
style = dict(
    vmin=-T_max,
    vmax=+T_max,
    cmap=plt.cm.coolwarm,
)

fig, ax = plt.subplots()
ax.axvline(
    ((m1 - m2) ** 2).xreplace(parameter_defaults),
    c="black",
    ls="dotted",
    zorder=99,
)
ax.axvline(
    ((m1 + m2) ** 2).xreplace(parameter_defaults),
    c="black",
    ls="dotted",
    zorder=99,
)

sliders = dict(
    phi=FloatSlider(
        description="phi",
        min=-2 * np.pi,
        max=+2 * np.pi,
        step=np.pi / 16,
    ),
    T_max=FloatSlider(
        description="T_max",
        min=0.01,
        max=5.0,
        step=0.01,
        value=1.0,
    ),
)

ui = VBox(tuple(sliders.values()))
output = interactive_output(plot3d, controls=sliders)
fig.tight_layout()
display(ui, output)