In [None]:
import os

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

```{autolink-concat}
```

::::{margin}
:::{card} Square root Riemann sheets
TR-025
:::
::::

# Rotating square root cuts

In [None]:
%pip install -q ampform==0.14.8 ipywidgets==8.1.1 sympy==1.12

In [None]:
%matplotlib widget
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from ampform.io import aslatex
from ampform.sympy import unevaluated
from IPython.display import Math, display
from ipywidgets import FloatSlider, VBox, interactive_output

In [None]:
@unevaluated
class RotatedSqrt(sp.Expr):
    z: Any
    phi: Any = 0
    _latex_repr_ = R"\sqrt[{phi}]{{{z}}}"

    def evaluate(self) -> sp.Expr:
        z, phi = self.args
        return sp.exp(-phi * sp.I / 2) * sp.sqrt(z * sp.exp(phi * sp.I))


z, phi = sp.symbols("z phi")
expr = RotatedSqrt(z, phi)
Math(aslatex({expr: expr.doit(deep=False)}))

In [None]:
symbols = (z, phi)
func = sp.lambdify(symbols, expr.doit())

In [None]:
fig, axes = plt.subplots(
    figsize=(6, 8.5),
    gridspec_kw=dict(
        height_ratios=[1, 2],
    ),
    nrows=2,
)
ax1, ax2 = axes
ax1.set_ylabel(f"${sp.latex(expr)}$")
ax2.set_ylabel("$\mathrm{Im}\,z$")
ax2.axhline(0, c="black", ls="dotted", zorder=99)
for ax in axes:
    ax.set_xlabel("$\mathrm{Re}\,z$")
    ax.set_xticks([-1, 0, +1])
ax2.set_yticks([-1, 0, +1])

data = None
x = np.linspace(-1, +1, num=400)
X, Y = np.meshgrid(x, x)
Z = X + Y * 1j

sliders = dict(
    phi=FloatSlider(
        min=-3 * np.pi,
        max=+3 * np.pi,
        step=np.pi / 8,
        description="phi",
    ),
)


def plot(phi):
    global data
    ax1.set_title(Rf"${sp.latex(expr)} \qquad \phi={phi / np.pi:.4g}\pi$")
    t = func(x, phi)
    T = func(Z, phi)
    if data is None:
        data = {
            "real": ax1.plot(x, t.real, label="real")[0],
            "imag": ax1.plot(x, t.imag, label="imag")[0],
            "2D": ax2.pcolormesh(X, Y, T.imag, cmap=plt.cm.coolwarm),
        }
    else:
        data["real"].set_ydata(t.real)
        data["imag"].set_ydata(t.imag)
        data["2D"].set_array(T.imag)
    data["2D"].set_clim(vmin=-1, vmax=+1)
    ax1.set_ylim(-1.2, +1.2)
    fig.canvas.draw_idle()


ui = VBox(tuple(sliders.values()))
output = interactive_output(plot, controls=sliders)
ax1.legend(loc="lower left")
cbar = plt.colorbar(data["2D"], ax=ax2, pad=0.02)
cbar.ax.set_yticks([-1, 0, +1])
cbar.ax.set_ylabel(f"${sp.latex(expr)}$")
fig.tight_layout()
display(ui, output)