```{autolink-concat}

```

# Numerical integration

In [None]:
%matplotlib widget

In [None]:
import os
import time
from collections.abc import Callable
from dataclasses import dataclass, field
from functools import cache
from typing import Literal

import ipywidgets as w
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import quadax
from ipympl.backend_nbagg import Canvas
from IPython.display import SVG, display
from matplotlib.axes import Axes
from matplotlib.collections import LineCollection, QuadMesh
from matplotlib.lines import Line2D

# cspell:disable-next-line
Algorithm = Literal["quadcc", "quadgk", "quadts", "romberg", "rombergts"]
jax.config.update("jax_enable_x64", True)


def hide_toolbars(canvas: Canvas) -> None:
    canvas.header_visible = False
    canvas.footer_visible = False
    canvas.toolbar_visible = False

In [None]:
def integrate_numerically(
    s: npt.NDArray[np.float64],
    m1: float,
    m2: float,
    epsilon: float = 0.05,
    start_offset: float = 0,
    algorithm: Callable = quadax.quadcc,
    **kwargs,
):
    if algorithm in {quadax.romberg, quadax.rombergts}:
        kwargs.pop("order")
    else:
        kwargs.pop("divmax")
    s_thr = (m1 + m2) ** 2
    integral, _ = algorithm(
        jax.tree_util.Partial(integrand, s=s, m1=m1, m2=m2, epsilon=epsilon),
        interval=[s_thr + start_offset, jnp.inf],
        **kwargs,
    )
    return (s - s_thr) * integral / jnp.pi


@jax.jit
def integrand(sp, s, m1, m2, epsilon):
    s_thr = (m1 + m2) ** 2
    return rho(sp, m1, m2) / ((sp - s_thr) * (sp - s - 1j * epsilon))


@jax.jit
def rho(s, m1, m2):
    return jnp.sqrt((s - (m1 - m2) ** 2) * (s - (m1 + m2) ** 2)) / s

In [None]:
@jax.jit
def sigma0(s, m1, m2):
    return (1 / jnp.pi) * (
        (2 * q(s, m1, m2) / jnp.sqrt(s))
        * jnp.log((m1**2 + m2**2 - s + 2 * q(s, m1, m2) * jnp.sqrt(s)) / (2 * m1 * m2))
        - (m1**2 - m2**2) * (1 / s - 1 / (m1 + m2) ** 2) * jnp.log(m1 / m2)
    )


@jax.jit
def q(s, m1, m2):
    return jnp.sqrt((s - (m1 - m2) ** 2) * (s - (m1 + m2) ** 2)) / (2 * jnp.sqrt(s))

In [None]:
@dataclass(kw_only=True)
class PlotContent:
    mesh: QuadMesh
    real: tuple[Line2D, Line2D, Line2D]
    imag: tuple[Line2D, Line2D, Line2D]
    pseudothreshold: tuple[Line2D, Line2D, Line2D]
    threshold: tuple[Line2D, Line2D, Line2D]
    integrated_interval: LineCollection


@dataclass
class DispersionIntegralWidget:
    _c: PlotContent | None = field(init=False, default=None)
    s: npt.NDArray[np.float64] = field(init=False)
    S: npt.NDArray[np.complex128] = field(init=False)
    ax_rho: Axes
    ax_real: Axes
    ax_imag: Axes
    real_lim: tuple[float, float]
    imag_max: float
    grid: tuple[int, int]

    def __post_init__(self):
        X, Y = np.meshgrid(
            np.linspace(*self.real_lim, self.grid[0]),
            np.linspace(-self.imag_max, +self.imag_max, self.grid[1]),
        )
        self.S = X + 1j * Y

    def __call__(
        self,
        *,
        projection: Literal["real", "imag", "abs"],
        m1: float,
        m2: float,
        epsilon: float,
        epsilon_prime: float,
        start_offset: float,
        z_max: float,
        algorithm_name: Algorithm,
        epsabs: float,
        epsrel: float,
        divmax: int,
        order: int,
        resolution: int,
        y_lim: tuple[float, float],
    ) -> None:
        S = self.S
        x = generate_domain(*self.real_lim, resolution)
        s = x + epsilon * 1j
        match algorithm_name:
            case "quadcc":
                algorithm = quadax.quadcc
                divmax_slider.disabled = True
                order_slider.disabled = False
                order = update_order_slider(order, [8, 16, 32, 64, 128, 256])
            case "quadgk":
                algorithm = quadax.quadgk
                divmax_slider.disabled = True
                order = update_order_slider(order, [15, 21, 31, 41, 51, 61])
                order_slider.disabled = False
            case "quadts":
                algorithm = quadax.quadts
                divmax_slider.disabled = True
                order = update_order_slider(order, [41, 61, 81, 101])
            case "romberg":
                algorithm = quadax.romberg
                divmax_slider.disabled = False
                order_slider.disabled = True
            case "rombergts":
                algorithm = quadax.rombergts
                divmax_slider.disabled = False
                order_slider.disabled = True
            case _:
                msg = f"Unknown algorithm: {algorithm_name}"
                raise ValueError(msg)
        y_exact = sigma0(x + 1e-10j, m1, m2)
        y_ana = sigma0(s, m1, m2)
        start_time = time.perf_counter()
        y_num = integrate_numerically(
            s,
            m1,
            m2,
            epsilon_prime,
            start_offset,
            algorithm,
            epsabs=epsabs,
            epsrel=epsrel,
            divmax=divmax,
            order=order,
        ).block_until_ready()
        end_time = time.perf_counter()
        y = (y_exact, y_ana, y_num)
        duration = end_time - start_time
        timer_box.value = f"Computation time: <b>{format_time(duration)}</b> for {resolution:,} points"
        Z = rho(S, m1, m2)
        if projection == "abs":
            Z = jnp.abs(Z)
        else:
            Z = getattr(Z, projection)
        s_neg = (m1 - m2) ** 2
        s_pos = (m1 + m2) ** 2
        if self._c is None:
            self._c = PlotContent(
                mesh=self.ax_rho.pcolormesh(
                    S.real,
                    S.imag,
                    Z,
                    cmap="RdBu_r",
                    rasterized=True,
                    vmin=-z_max,
                    vmax=+z_max,
                ),
                imag=(
                    self.ax_imag.plot(x, y_exact.imag, color="black", lw=0.2)[0],
                    self.ax_imag.plot(x, y_ana.imag, alpha=0.5, color="C0")[0],
                    self.ax_imag.plot(x, y_num.imag, color="C1", lw=0.3)[0],
                ),
                real=(
                    self.ax_real.plot(x, y_exact.real, color="black", lw=0.2)[0],
                    self.ax_real.plot(x, y_ana.real, alpha=0.5, color="C0")[0],
                    self.ax_real.plot(x, y_num.real, color="C1", lw=0.3)[0],
                ),
                pseudothreshold=tuple(
                    ax.axvline(s_neg, color="C2", linestyle="dotted")
                    for ax in (self.ax_rho, self.ax_real, self.ax_imag)
                ),
                threshold=tuple(
                    ax.axvline(s_pos, color="C3", linestyle="dotted")
                    for ax in (self.ax_rho, self.ax_real, self.ax_imag)
                ),
                integrated_interval=self.ax_rho.hlines(
                    y=epsilon_prime,
                    xmin=s_pos + start_offset,
                    xmax=self.ax_rho.get_xlim()[1],
                    color="black",
                    linewidth=0.5,
                ),
            )
        else:
            self._c.mesh.set_array(Z)
            self._c.mesh.set_clim(-z_max, +z_max)
            for i in (0, 1, 2):
                self._c.imag[i].set_data(x, y[i].imag)
                self._c.real[i].set_data(x, y[i].real)
            for line in self._c.pseudothreshold:
                line.set_xdata([s_neg])
            for line in self._c.threshold:
                line.set_xdata([s_pos])
            self._c.integrated_interval.set_segments([
                [
                    [s_pos + start_offset, epsilon_prime],
                    [self.ax_rho.get_xlim()[1], epsilon_prime],
                ]
            ])
        for ax in (self.ax_real, self.ax_imag):
            ax.set_ylim(*y_lim)


@cache
def generate_domain(start, stop, resolution: int) -> npt.NDArray[np.float64]:
    return np.linspace(start, stop, resolution, dtype=np.float64)


def format_time(seconds: float) -> str:
    if seconds < 1e-3:
        return f"{1e6 * seconds:,.1f} µs"
    if seconds < 1:
        return f"{1e3 * seconds:,.1f} ms"
    if seconds < 60:
        return f"{seconds:,.1f} s"
    mm = int(seconds // 60)
    ss = seconds % 60
    return f"{mm} min {ss:,.1f} s"


def update_order_slider(order: int, options: list[int]) -> int:
    order_slider.disabled = False
    order_slider.options = options
    if order in options:
        return order
    return min(options, key=lambda o: abs(o - order))

In [None]:
cont = dict(continuous_update=False)
physics_sliders = dict(
    projection=w.RadioButtons(
        description="Projection",
        options=["real", "imag", "abs"],
        value="imag",
        layout=w.Layout(width="max-content"),
    ),
    m1=w.FloatSlider(value=0.13, min=0.0, max=2.0, step=0.01, description="m1", **cont),
    m2=w.FloatSlider(value=0.98, min=0.0, max=2.0, step=0.01, description="m2", **cont),
    y_lim=w.FloatRangeSlider(
        description="y range",
        min=-5,
        max=10,
        value=(-1, +1),
        readout_format=".1f",
        **cont,
    ),
    z_max=w.FloatLogSlider(
        value=1.0,
        min=-3,
        max=3,
        description="Color scale",
        step=0.25,
        readout_format=".3g",
        **cont,
    ),
    resolution=w.IntSlider(
        value=200,
        min=100,
        max=5000,
        description="Resolution",
        step=100,
        **cont,
    ),
    epsilon=w.FloatLogSlider(
        value=1e-4,
        min=-12,
        max=2,
        description="s + iϵ",
        step=0.5,
        readout_format=".0e",
        **cont,
    ),
    epsilon_prime=w.FloatLogSlider(
        value=1e-8,
        min=-12,
        max=2,
        description="s' + iϵ",
        step=0.5,
        readout_format=".0e",
        **cont,
    ),
    start_offset=w.FloatLogSlider(
        value=1e-20,
        min=-20,
        max=2,
        description="thr + ϵ",
        step=0.5,
        readout_format=".0e",
        **cont,
    ),
)
order_slider = w.RadioButtons(
    description="Integration order",
    layout=w.Layout(width="100px"),
    options=[8, 16, 32, 64, 128, 256],
    value=256,
)
divmax_slider = w.IntSlider(
    value=20,
    min=1,
    max=30,
    description="divmax",
    **cont,
)
algorithm_sliders = dict(
    algorithm_name=w.RadioButtons(
        description="Integration algorithm",
        value="romberg",
        layout=w.Layout(width="150px"),
        options=Algorithm.__args__,
    ),
    order=order_slider,
    epsabs=w.FloatLogSlider(
        value=1e-5,
        min=-12,
        max=0,
        description="epsabs",
        step=0.5,
        readout_format=".0e",
        **cont,
    ),
    epsrel=w.FloatLogSlider(
        value=1e-5,
        min=-12,
        max=0,
        description="epsrel",
        step=0.5,
        readout_format=".0e",
        **cont,
    ),
    divmax=divmax_slider,
)
sliders = dict(
    **physics_sliders,
    **algorithm_sliders,
)
ui = w.VBox([
    tabs := w.Tab([
        w.HBox([
            physics_sliders["projection"],
            w.VBox(list(physics_sliders.values())[1:5]),
            w.VBox(list(physics_sliders.values())[5:]),
        ]),
        w.HBox([
            algorithm_sliders["algorithm_name"],
            order_slider,
            w.VBox(list(algorithm_sliders.values())[2:]),
        ]),
    ]),
    timer_box := w.HTML(),
])
tabs.titles = ["Physics", "Integration"]

In [None]:
plt.rc("font", size=12)
fig, axes = plt.subplots(figsize=(10, 8), nrows=3, sharex=True)
hide_toolbars(fig.canvas)
fig.subplots_adjust(bottom=0.1, hspace=0.1, left=0.15, right=0.95, top=0.95)
ax1, ax2, ax3 = axes
for ax in axes.ravel():
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
ax1.spines["bottom"].set_visible(False)
ax1.set_title(R"$\rho(s)$")
ax1.set_ylabel("Im $s$")
ax2.set_ylabel(R"Im $\Sigma_0(s)$")
ax3.set_ylabel(R"Re $\Sigma_0(s)$")
ax3.set_xlabel("Re $s$")
for ax in (ax2, ax3):
    ax.axhline(0, color="gray", lw=0.5)
plot_widget = DispersionIntegralWidget(
    *axes,
    real_lim=(0, 6),
    imag_max=2,
    grid=(500, 300),
)
out = w.interactive_output(plot_widget, sliders)
plt.show()
display(out, ui)

In [None]:
if "EXECUTE_NB" in os.environ:
    output_path = "numerical-integration-widget.svg"
    fig.savefig(output_path, bbox_inches="tight")
    display(SVG(output_path), ui)