Skip to content

Commit

Permalink
fix!: add control over complex square roots (#72)
Browse files Browse the repository at this point in the history
Implement ComPWA-org TR-001 by defining a ComplexSqrt class

* refactor: split out ff_squared
* test: add test for ComplexSqrt
  • Loading branch information
redeboer committed Jun 10, 2021
1 parent 50d21be commit 7838ac0
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 20 deletions.
1 change: 1 addition & 0 deletions .cspell.json
Expand Up @@ -180,6 +180,7 @@
"noqa",
"nrows",
"nsimplify",
"numpycode",
"pandoc",
"permalinks",
"phsp",
Expand Down
1 change: 1 addition & 0 deletions .flake8
Expand Up @@ -25,6 +25,7 @@ rst-roles =
class,
doc,
download,
eq,
file,
func,
meth,
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Expand Up @@ -155,6 +155,7 @@
# Intersphinx settings
intersphinx_mapping = {
"attrs": ("https://www.attrs.org/en/stable", None),
"compwa-org": ("https://compwa-org.readthedocs.io", None),
"expertsystem": ("https://expertsystem.readthedocs.io/en/stable", None),
"ipywidgets": ("https://ipywidgets.readthedocs.io/en/stable", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
Expand Down
13 changes: 13 additions & 0 deletions docs/extend_docstrings.py
Expand Up @@ -21,6 +21,7 @@
relativistic_breit_wigner,
relativistic_breit_wigner_with_ff,
)
from ampform.dynamics.math import ComplexSqrt


def update_docstring(
Expand Down Expand Up @@ -55,6 +56,18 @@ def render_breakup_momentum() -> None:
)


def render_complex_sqrt() -> None:
x = sp.Symbol("x", real=True)
complex_sqrt = ComplexSqrt(x)
update_docstring(
ComplexSqrt,
fR"""
.. math:: {sp.latex(complex_sqrt)} = {sp.latex(complex_sqrt.evaluate())}
:label: ComplexSqrt
""",
)


def render_coupled_width() -> None:
L = sp.Symbol("L", integer=True)
s, m0, w0, m_a, m_b, d = sp.symbols("s m0 Gamma0 m_a m_b d")
Expand Down
1 change: 0 additions & 1 deletion docs/usage/dynamics.ipynb
Expand Up @@ -628,7 +628,6 @@
" start=0,\n",
" stop=4,\n",
" num=500,\n",
" dtype=np.complex64, # analytic continuation\n",
")\n",
"sliders.set_ranges(\n",
" m0=(0, 5, 500),\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/usage/dynamics/analytic-continuation.ipynb
Expand Up @@ -253,14 +253,14 @@
},
"outputs": [],
"source": [
"plot_domain = np.linspace(0, 3, 1_000, dtype=np.complex64)\n",
"plot_domain = np.linspace(0, 3, 1_000)\n",
"sliders.set_ranges(\n",
" m_a=(0, 2, 200),\n",
" m_b=(0, 2, 200),\n",
")\n",
"sliders.set_values(\n",
" m_a=0.45,\n",
" m_b=1.4,\n",
" m_a=0.8,\n",
" m_b=1.25,\n",
")"
]
},
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Expand Up @@ -12,6 +12,7 @@ addopts =
--ignore=docs/conf.py
filterwarnings =
error
ignore:.*invalid value encountered in sqrt.*:RuntimeWarning
nb_diff_ignore =
/cells/*/execution_count
/cells/*/outputs
Expand Down
21 changes: 9 additions & 12 deletions src/ampform/dynamics/__init__.py
Expand Up @@ -16,6 +16,7 @@
implement_doit_method,
verify_signature,
)
from .math import ComplexSqrt

try:
from typing import Protocol
Expand Down Expand Up @@ -194,14 +195,12 @@ def phase_space_factor(
) -> sp.Expr:
"""Standard phase-space factor, using `breakup_momentum_squared`.
See :pdg-review:`2020; Resonances; p.4`, Equation (49.8).
.. warning:: This function uses a
:func:`~sympy.functions.elementary.miscellaneous.sqrt`. In order to
enable analytic continuation, input data needs to be complex valued.
See :pdg-review:`2020; Resonances; p.4`, Equation (49.8), with a slight
adaptation: instead of a normal square root, this phase space factor make
use of :eq:`ComplexSqrt` (`.ComplexSqrt`).
"""
q_squared = breakup_momentum_squared(s, m_a, m_b)
return sp.sqrt(q_squared) / (8 * sp.pi * sp.sqrt(s))
return ComplexSqrt(q_squared) / (8 * sp.pi * sp.sqrt(s))


def phase_space_factor_ac(
Expand All @@ -215,8 +214,7 @@ def phase_space_factor_ac(
**Warning**: The PDG specifically derives this formula for a two-body decay
*with equal masses*.
"""
q_squared = breakup_momentum_squared(s, m_a, m_b)
rho = sp.sqrt(sp.Abs(q_squared)) / (8 * sp.pi * sp.sqrt(s))
rho = phase_space_factor(s, m_a, m_b)
s_threshold = (m_a + m_b) ** 2
return _analytic_continuation(rho, s, s_threshold)

Expand Down Expand Up @@ -296,11 +294,10 @@ def relativistic_breit_wigner_with_ff( # pylint: disable=too-many-arguments
:pdg-review:`2020; Resonances; p.6`.
"""
q_squared = breakup_momentum_squared(s, m_a, m_b)
form_factor = sp.sqrt(
BlattWeisskopfSquared(
angular_momentum, z=q_squared * meson_radius ** 2
)
ff_squared = BlattWeisskopfSquared(
angular_momentum, z=q_squared * meson_radius ** 2
)
form_factor = sp.sqrt(ff_squared)
mass_dependent_width = coupled_width(
s, mass0, gamma0, m_a, m_b, angular_momentum, meson_radius, phsp_factor
)
Expand Down
66 changes: 66 additions & 0 deletions src/ampform/dynamics/math.py
@@ -0,0 +1,66 @@
"""A collection of basic mathematical operations, used in `ampform.dynamics`."""
# cspell:ignore cmath compwa csqrt lambdifier
# pylint: disable=no-member, protected-access, unused-argument

from typing import Any

import sympy as sp
from sympy.plotting.experimental_lambdify import Lambdifier
from sympy.printing.printer import Printer


class ComplexSqrt(sp.Expr):
"""Square root that returns positive imaginary values for negative input.
A special version :func:`~sympy.functions.elementary.miscellaneous.sqrt`
that renders nicely as LaTeX and and can be used as a handle for lambdify
printers. See :doc:`compwa-org:report/000`, :doc:`compwa-org:report/001`,
and :doc:`sympy:modules/printing` for how to implement a custom
:func:`~sympy.utilities.lambdify.lambdify` printer.
"""

is_commutative = True

def __new__(cls, x: sp.Expr, *args: Any, **kwargs: Any) -> sp.Expr:
x = sp.sympify(x)
expr = sp.Expr.__new__(cls, x, *args, **kwargs)
if hasattr(x, "free_symbols") and not x.free_symbols:
return expr.evaluate()
return expr

def evaluate(self) -> sp.Expr:
x = self.args[0]
if not x.is_real:
return sp.sqrt(x)
return self._evaluate_complex(x)

@staticmethod
def _evaluate_complex(x: sp.Expr) -> sp.Expr:
return sp.Piecewise(
(sp.I * sp.sqrt(-x), x < 0),
(sp.sqrt(x), True),
)

def _latex(self, printer: Printer, *args: Any) -> str:
x = printer._print(self.args[0])
return fR"\sqrt[\mathrm{{c}}]{{{x}}}"

def _numpycode(self, printer: Printer, *args: Any) -> str:
return self.__print_complex(printer)

def _pythoncode(self, printer: Printer, *args: Any) -> str:
printer.module_imports["cmath"].add("sqrt as csqrt")
x = printer._print(self.args[0])
return (
f"(((1j*sqrt(-{x}))"
f" if isinstance({x}, (float, int)) and ({x} < 0)"
f" else (csqrt({x}))))"
)

def __print_complex(self, printer: Printer) -> str:
x = self.args[0]
expr = self._evaluate_complex(x)
return printer._print(expr)


Lambdifier.builtin_functions_different["ComplexSqrt"] = "sqrt"
53 changes: 49 additions & 4 deletions tests/test_dynamics.py
@@ -1,8 +1,11 @@
# pylint: disable=no-self-use
import numpy as np
import pytest
import qrules as q
import sympy as sp
from sympy import preorder_traversal

from ampform.dynamics import ComplexSqrt
from ampform.helicity import HelicityModel


Expand Down Expand Up @@ -77,12 +80,14 @@ def test_generate(
expression = sp.piecewise_fold(expression)
assert isinstance(expression, sp.Add)
a1, a2 = tuple(map(str, expression.args))
a1 = a1.replace("ComplexSqrt", "sqrt")
a2 = a2.replace("ComplexSqrt", "sqrt")
if formalism == "canonical":
assert a1 == "0.08/(-m**2 - 0.06*I*sqrt(m**2 - 0.07)/Abs(m) + 0.98)"
assert a2 == "0.23/(-m**2 - 0.17*I*sqrt(m**2 - 0.07)/Abs(m) + 2.27)"
assert a1 == "0.08/(-m**2 + 0.98 - 0.12*I*sqrt(m**2/4 - 0.02)/Abs(m))"
assert a2 == "0.23/(-m**2 + 2.27 - 0.34*I*sqrt(m**2/4 - 0.02)/Abs(m))"
elif formalism == "helicity":
assert a1 == "0.17/(-m**2 - 0.17*I*sqrt(m**2 - 0.07)/Abs(m) + 2.27)"
assert a2 == "0.06/(-m**2 - 0.06*I*sqrt(m**2 - 0.07)/Abs(m) + 0.98)"
assert a1 == "0.17/(-m**2 + 2.27 - 0.34*I*sqrt(m**2/4 - 0.02)/Abs(m))"
assert a2 == "0.06/(-m**2 + 0.98 - 0.12*I*sqrt(m**2/4 - 0.02)/Abs(m))"
else:
raise NotImplementedError

Expand All @@ -96,3 +101,43 @@ def round_nested(expression: sp.Expr, n_decimals: int) -> sp.Expr:
if isinstance(node, sp.Pow) and node.args[1] == 1 / 2:
expression = expression.subs(node, round(node.n(), n_decimals))
return expression


class TestComplexSqrt:
@pytest.mark.parametrize("real", [False, True])
def test_evaluate(self, real: bool):
x = sp.Symbol("x", real=real)
expr = ComplexSqrt(x).evaluate()
if real:
assert expr == sp.Piecewise(
(sp.I * sp.sqrt(-x), x < 0),
(sp.sqrt(x), True),
)
else:
assert expr == sp.sqrt(x)

def test_latex(self):
x = sp.Symbol("x")
expr = ComplexSqrt(x)
assert sp.latex(expr) == R"\sqrt[\mathrm{c}]{x}"

@pytest.mark.parametrize("real", [False, True])
@pytest.mark.parametrize("backend", ["math", "numpy"])
def test_lambdify(self, backend: str, real: bool):
x = sp.Symbol("x", real=real)
expression = ComplexSqrt(x)
lambdified = sp.lambdify(x, expression, backend)
assert lambdified(np.array(-1)) == 1j

@pytest.mark.parametrize(
("input_value", "expected"),
[
(sp.Symbol("x", real=True), "ComplexSqrt(x)"),
(sp.Symbol("x"), "ComplexSqrt(x)"),
(+4, "2"),
(-4, "2*I"),
],
)
def test_new(self, input_value, expected: str):
expr = ComplexSqrt(input_value)
assert str(expr) == expected

0 comments on commit 7838ac0

Please sign in to comment.