In [None]:
%config InlineBackend.figure_formats = ['svg']
import os

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

```{autolink-concat}
```

````{margin}
```{spec} Square root over arrays with negative values
:id: TR-999
:status: Implemented
:tags: lambdification;sympy

This report has been implemented in {pr}`ComPWA/tensorwaves#284`.

```
````

# Riemann sheets

<!-- cspell:disable -->

In [None]:
%pip install -q ampform==0.14.6 plotly==5.17.0 sympy==1.12

Follow-up to {need}`TR-004`, where we investigate and reproduce the Riemann sheets shown in [Fig.&nbsp;50.1](https://pdg.lbl.gov/2023/reviews/rpp2022-rev-resonances.pdf#page=2) and [50.2](https://pdg.lbl.gov/2023/reviews/rpp2022-rev-resonances.pdf#page=4) of the PDG.

First, we formulate the $T$&nbsp;matrix in terms of a $K$&nbsp;matrix. There are two ways to do this and we associate the one with the $+$ with **Sheet I** and the one with $-$ with **Sheet II**.

In [None]:
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

import numpy as np
import plotly.graph_objects as go
import sympy as sp
from ampform.io import aslatex
from ampform.sympy import UnevaluatedExpression, implement_doit_method
from ampform.sympy.math import NumPyPrintable, create_expression
from IPython.display import Math

if TYPE_CHECKING:
    from sympy.printing.numpy import NumPyPrinter
    from sympy.printing.printer import Printer
    from sympy.printing.pycode import PythonCodePrinter

warnings.filterwarnings("ignore")

In [None]:
n = 1
K = sp.MatrixSymbol("K", n, n)
ρ = sp.MatrixSymbol("rho", n, n)
I = sp.Identity(1)
T1 = (K + sp.I * ρ).inv()
T2 = (K - sp.I * ρ).inv()

In [None]:
T1_symbol = sp.Symbol("T^{I}")
T2_symbol = sp.Symbol("T^{II}")
src = aslatex({T1_symbol: T1, T2_symbol: T2})
Math(src)

Note that this the above inversion is equivalent to [Eq.&nbsp;(50.31)](https://pdg.lbl.gov/2023/reviews/rpp2022-rev-resonances.pdf#page=4) (omitting the form factors&nsbp;$n$):

$$
T = (1 \pm iK\rho)^{-1}K.
$$

As an aside, we also define a special expression class for a square root where you can choose the sign for negative arguments. This can be used later for the phase space factor&nbsp;$\rho$.

In [None]:
class SignedSqrt(NumPyPrintable):
    is_commutative = True
    is_real = False

    def __new__(cls, x, sign, *args, **kwargs):
        x = sp.sympify(x)
        expr = create_expression(cls, x, sign, *args, **kwargs)
        if isinstance(x, sp.Number):
            return expr.get_definition()
        return expr

    def _latex(self, printer: Printer, *args) -> str:
        x = printer._print(self.args[0])
        sign = _render_sign(self.args[1], printer)
        return Rf"\sqrt[{sign}]{{{x}}}"

    def _numpycode(self, printer: NumPyPrinter, *args) -> str:
        return self.__print_complex(printer)

    def _pythoncode(self, printer: PythonCodePrinter, *args) -> str:
        printer.module_imports["cmath"].add("sqrt as csqrt")
        x = printer._print(self.args[0])
        sign = printer._print(self.args[1])
        return (
            f"((({sign}*1j*sqrt(-{x}))"
            f" if isinstance({x}, (float, int)) and ({x} < 0)"
            f" else (csqrt({x}))))"
        )

    def __print_complex(self, printer: Printer) -> str:
        expr = self.get_definition()
        return printer._print(expr)

    def get_definition(self) -> sp.Piecewise:
        x: sp.Expr = self.args[0]
        sign: sp.Expr = self.args[1]
        return sp.Piecewise(
            (sign * sp.I * sp.sqrt(-x), x < 0),
            (sp.sqrt(x), True),
        )


def _render_sign(sign, printer: Printer) -> str:
    if sign == +1:
        return "+"
    if sign == -1:
        return "-"
    return printer._print(sign)

In [None]:
x = sp.Symbol("x")
SignedSqrt(x, -1) + SignedSqrt(x, +1)

This gives us all the ingredients to formulate expressions for the parametrization of the matrix elements.

In [None]:
@implement_doit_method
class EqualMassPhspFactor(UnevaluatedExpression):
    is_commutative = True
    is_real = False

    def __new__(cls, s, m, sign, **hints) -> EqualMassPhspFactor:
        return create_expression(cls, s, m, sign, **hints)

    def evaluate(self) -> sp.Expr:
        s, m, sign = self.args
        return SignedSqrt(1 - 4 * m**2 / s, sign)

    def _latex(self, printer: LatexPrinter, *args) -> str:
        s = printer._print(self.args[0])
        sign = _render_sign(self.args[2], printer)
        return Rf"\rho^{{{sign}}}\left({s}\right)"

In [None]:
s = sp.Symbol("s", complex=True)
m0, m1, Γ = sp.symbols("m0 m1 Gamma", real=True, nonnegative=True)
sign = sp.symbols(R"\pm", integer=True)
K_expr = (m0**2 - s) / (m0 * Γ)
ρ_expr = EqualMassPhspFactor(s, m1, sign)

In [None]:
definitions = {K[0, 0]: K_expr, ρ[0, 0]: ρ_expr}
Math(aslatex({**definitions, ρ_expr: ρ_expr.doit()}))

In [None]:
T1_expr = T1.as_explicit().xreplace(definitions)
T2_expr = T2.as_explicit().xreplace(definitions)

In [None]:
T1_symbol = sp.Symbol("T^{I}")
T2_symbol = sp.Symbol("T^{II}")
src = aslatex(
    {
        T1_symbol: T1_expr[0].simplify(doit=False),
        T2_symbol: T2_expr[0].simplify(doit=False),
    }
)
Math(src)

Finally, we convert these expressions to numerical functions and use these to plot the Riemann sheets over a complex plane for the Mandelstam variable&nbsp;$s$.

In [None]:
args = (s, m0, m1, Γ, sign)
T1_func = sp.lambdify(args, T1_expr[0].doit(), modules="numpy")
T2_func = sp.lambdify(args, T2_expr[0].doit(), modules="numpy")

In [None]:
x = np.linspace(0, 4, num=200)
y = np.linspace(1e-5, 1.5, num=100)
X, Yn = np.meshgrid(x, -y)
X, Yp = np.meshgrid(x, +y)
Zn = X + Yn * 1j
Zp = X + Yp * 1j
Tn = T1_func(Zn**2, 1.5, 0.5, 0.4, +1)
Tp = T2_func(Zp**2, 1.5, 0.5, 0.4, +1)

vmax = 1


def sty(t):
    return {
        "cmin": -vmax,
        "cmax": +vmax,
        "colorscale": "RdBu_r",
        "surfacecolor": t.imag,
    }


Sn = go.Surface(x=X, y=Yn, z=Tn.real, **sty(Tn), name="Unphysical")
Sp = go.Surface(
    x=X, y=Yp, z=Tp.real, **sty(Tp), name="Physical", colorbar_title="Re T"
)
y = Yp[0]
z = x + y * 1j
line = go.Scatter3d(
    x=x,
    y=y,
    z=T1_func(z**2, 1.5, 0.5, 0.4, +1).real,
    marker={"size": 0},
    line={"color": "darkgreen", "width": 1},
)
fig = go.Figure(data=[Sn, Sp, line])
fig.update_layout(height=550, width=600)
fig.update_scenes(
    xaxis_title_text="Re s",
    yaxis_title_text="Im s",
    zaxis_range=[-vmax, +vmax],
)
fig.show()

Note that we matched the colors [of the PDG](https://pdg.lbl.gov/2023/reviews/rpp2022-rev-resonances.pdf#page=4), but that we had to remap the sheet associations. As can be seen in the plotting code, we have the following associations:

|                      | Sheet I  | Sheet II |
|----------------------|----------|----------|
| $\mathrm{Im}(s) < 0$ | $T^{I}$  | $T^{II}$ |
| $\mathrm{Im}(s) > 0$ | $T^{II}$ | $T^{I}$  |

So the sheet numbers 'flip' for $\mathrm{Im}(s) > 0$ and what we see in the third figure is just $T^{II}$.