Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: add control over complex square roots #72

Merged
merged 4 commits into from Jun 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .cspell.json
Expand Up @@ -180,6 +180,7 @@
"noqa",
"nrows",
"nsimplify",
"numpycode",
"pandoc",
"permalinks",
"phsp",
Expand Down
1 change: 1 addition & 0 deletions .flake8
Expand Up @@ -25,6 +25,7 @@ rst-roles =
class,
doc,
download,
eq,
file,
func,
meth,
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Expand Up @@ -155,6 +155,7 @@
# Intersphinx settings
intersphinx_mapping = {
"attrs": ("https://www.attrs.org/en/stable", None),
"compwa-org": ("https://compwa-org.readthedocs.io", None),
"expertsystem": ("https://expertsystem.readthedocs.io/en/stable", None),
"ipywidgets": ("https://ipywidgets.readthedocs.io/en/stable", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
Expand Down
13 changes: 13 additions & 0 deletions docs/extend_docstrings.py
Expand Up @@ -21,6 +21,7 @@
relativistic_breit_wigner,
relativistic_breit_wigner_with_ff,
)
from ampform.dynamics.math import ComplexSqrt


def update_docstring(
Expand Down Expand Up @@ -55,6 +56,18 @@ def render_breakup_momentum() -> None:
)


def render_complex_sqrt() -> None:
x = sp.Symbol("x", real=True)
complex_sqrt = ComplexSqrt(x)
update_docstring(
ComplexSqrt,
fR"""
.. math:: {sp.latex(complex_sqrt)} = {sp.latex(complex_sqrt.evaluate())}
:label: ComplexSqrt
""",
)


def render_coupled_width() -> None:
L = sp.Symbol("L", integer=True)
s, m0, w0, m_a, m_b, d = sp.symbols("s m0 Gamma0 m_a m_b d")
Expand Down
1 change: 0 additions & 1 deletion docs/usage/dynamics.ipynb
Expand Up @@ -628,7 +628,6 @@
" start=0,\n",
" stop=4,\n",
" num=500,\n",
" dtype=np.complex64, # analytic continuation\n",
")\n",
"sliders.set_ranges(\n",
" m0=(0, 5, 500),\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/usage/dynamics/analytic-continuation.ipynb
Expand Up @@ -253,14 +253,14 @@
},
"outputs": [],
"source": [
"plot_domain = np.linspace(0, 3, 1_000, dtype=np.complex64)\n",
"plot_domain = np.linspace(0, 3, 1_000)\n",
"sliders.set_ranges(\n",
" m_a=(0, 2, 200),\n",
" m_b=(0, 2, 200),\n",
")\n",
"sliders.set_values(\n",
" m_a=0.45,\n",
" m_b=1.4,\n",
" m_a=0.8,\n",
" m_b=1.25,\n",
")"
]
},
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Expand Up @@ -12,6 +12,7 @@ addopts =
--ignore=docs/conf.py
filterwarnings =
error
ignore:.*invalid value encountered in sqrt.*:RuntimeWarning
nb_diff_ignore =
/cells/*/execution_count
/cells/*/outputs
Expand Down
21 changes: 9 additions & 12 deletions src/ampform/dynamics/__init__.py
Expand Up @@ -16,6 +16,7 @@
implement_doit_method,
verify_signature,
)
from .math import ComplexSqrt

try:
from typing import Protocol
Expand Down Expand Up @@ -194,14 +195,12 @@ def phase_space_factor(
) -> sp.Expr:
"""Standard phase-space factor, using `breakup_momentum_squared`.

See :pdg-review:`2020; Resonances; p.4`, Equation (49.8).

.. warning:: This function uses a
:func:`~sympy.functions.elementary.miscellaneous.sqrt`. In order to
enable analytic continuation, input data needs to be complex valued.
See :pdg-review:`2020; Resonances; p.4`, Equation (49.8), with a slight
adaptation: instead of a normal square root, this phase space factor make
use of :eq:`ComplexSqrt` (`.ComplexSqrt`).
"""
q_squared = breakup_momentum_squared(s, m_a, m_b)
return sp.sqrt(q_squared) / (8 * sp.pi * sp.sqrt(s))
return ComplexSqrt(q_squared) / (8 * sp.pi * sp.sqrt(s))


def phase_space_factor_ac(
Expand All @@ -215,8 +214,7 @@ def phase_space_factor_ac(
**Warning**: The PDG specifically derives this formula for a two-body decay
*with equal masses*.
"""
q_squared = breakup_momentum_squared(s, m_a, m_b)
rho = sp.sqrt(sp.Abs(q_squared)) / (8 * sp.pi * sp.sqrt(s))
rho = phase_space_factor(s, m_a, m_b)
s_threshold = (m_a + m_b) ** 2
return _analytic_continuation(rho, s, s_threshold)

Expand Down Expand Up @@ -296,11 +294,10 @@ def relativistic_breit_wigner_with_ff( # pylint: disable=too-many-arguments
:pdg-review:`2020; Resonances; p.6`.
"""
q_squared = breakup_momentum_squared(s, m_a, m_b)
form_factor = sp.sqrt(
BlattWeisskopfSquared(
angular_momentum, z=q_squared * meson_radius ** 2
)
ff_squared = BlattWeisskopfSquared(
angular_momentum, z=q_squared * meson_radius ** 2
)
form_factor = sp.sqrt(ff_squared)
mass_dependent_width = coupled_width(
s, mass0, gamma0, m_a, m_b, angular_momentum, meson_radius, phsp_factor
)
Expand Down
66 changes: 66 additions & 0 deletions src/ampform/dynamics/math.py
@@ -0,0 +1,66 @@
"""A collection of basic mathematical operations, used in `ampform.dynamics`."""
# cspell:ignore cmath compwa csqrt lambdifier
# pylint: disable=no-member, protected-access, unused-argument

from typing import Any

import sympy as sp
from sympy.plotting.experimental_lambdify import Lambdifier
from sympy.printing.printer import Printer


class ComplexSqrt(sp.Expr):
"""Square root that returns positive imaginary values for negative input.

A special version :func:`~sympy.functions.elementary.miscellaneous.sqrt`
that renders nicely as LaTeX and and can be used as a handle for lambdify
printers. See :doc:`compwa-org:report/000`, :doc:`compwa-org:report/001`,
and :doc:`sympy:modules/printing` for how to implement a custom
:func:`~sympy.utilities.lambdify.lambdify` printer.
"""

is_commutative = True

def __new__(cls, x: sp.Expr, *args: Any, **kwargs: Any) -> sp.Expr:
x = sp.sympify(x)
expr = sp.Expr.__new__(cls, x, *args, **kwargs)
if hasattr(x, "free_symbols") and not x.free_symbols:
return expr.evaluate()
return expr

def evaluate(self) -> sp.Expr:
x = self.args[0]
if not x.is_real:
return sp.sqrt(x)
return self._evaluate_complex(x)

@staticmethod
def _evaluate_complex(x: sp.Expr) -> sp.Expr:
return sp.Piecewise(
(sp.I * sp.sqrt(-x), x < 0),
(sp.sqrt(x), True),
)

def _latex(self, printer: Printer, *args: Any) -> str:
x = printer._print(self.args[0])
return fR"\sqrt[\mathrm{{c}}]{{{x}}}"

def _numpycode(self, printer: Printer, *args: Any) -> str:
return self.__print_complex(printer)

def _pythoncode(self, printer: Printer, *args: Any) -> str:
printer.module_imports["cmath"].add("sqrt as csqrt")
x = printer._print(self.args[0])
return (
f"(((1j*sqrt(-{x}))"
f" if isinstance({x}, (float, int)) and ({x} < 0)"
f" else (csqrt({x}))))"
)

def __print_complex(self, printer: Printer) -> str:
x = self.args[0]
expr = self._evaluate_complex(x)
return printer._print(expr)


Lambdifier.builtin_functions_different["ComplexSqrt"] = "sqrt"
53 changes: 49 additions & 4 deletions tests/test_dynamics.py
@@ -1,8 +1,11 @@
# pylint: disable=no-self-use
import numpy as np
import pytest
import qrules as q
import sympy as sp
from sympy import preorder_traversal

from ampform.dynamics import ComplexSqrt
from ampform.helicity import HelicityModel


Expand Down Expand Up @@ -77,12 +80,14 @@ def test_generate(
expression = sp.piecewise_fold(expression)
assert isinstance(expression, sp.Add)
a1, a2 = tuple(map(str, expression.args))
a1 = a1.replace("ComplexSqrt", "sqrt")
a2 = a2.replace("ComplexSqrt", "sqrt")
if formalism == "canonical":
assert a1 == "0.08/(-m**2 - 0.06*I*sqrt(m**2 - 0.07)/Abs(m) + 0.98)"
assert a2 == "0.23/(-m**2 - 0.17*I*sqrt(m**2 - 0.07)/Abs(m) + 2.27)"
assert a1 == "0.08/(-m**2 + 0.98 - 0.12*I*sqrt(m**2/4 - 0.02)/Abs(m))"
assert a2 == "0.23/(-m**2 + 2.27 - 0.34*I*sqrt(m**2/4 - 0.02)/Abs(m))"
elif formalism == "helicity":
assert a1 == "0.17/(-m**2 - 0.17*I*sqrt(m**2 - 0.07)/Abs(m) + 2.27)"
assert a2 == "0.06/(-m**2 - 0.06*I*sqrt(m**2 - 0.07)/Abs(m) + 0.98)"
assert a1 == "0.17/(-m**2 + 2.27 - 0.34*I*sqrt(m**2/4 - 0.02)/Abs(m))"
assert a2 == "0.06/(-m**2 + 0.98 - 0.12*I*sqrt(m**2/4 - 0.02)/Abs(m))"
else:
raise NotImplementedError

Expand All @@ -96,3 +101,43 @@ def round_nested(expression: sp.Expr, n_decimals: int) -> sp.Expr:
if isinstance(node, sp.Pow) and node.args[1] == 1 / 2:
expression = expression.subs(node, round(node.n(), n_decimals))
return expression


class TestComplexSqrt:
@pytest.mark.parametrize("real", [False, True])
def test_evaluate(self, real: bool):
x = sp.Symbol("x", real=real)
expr = ComplexSqrt(x).evaluate()
if real:
assert expr == sp.Piecewise(
(sp.I * sp.sqrt(-x), x < 0),
(sp.sqrt(x), True),
)
else:
assert expr == sp.sqrt(x)

def test_latex(self):
x = sp.Symbol("x")
expr = ComplexSqrt(x)
assert sp.latex(expr) == R"\sqrt[\mathrm{c}]{x}"

@pytest.mark.parametrize("real", [False, True])
@pytest.mark.parametrize("backend", ["math", "numpy"])
def test_lambdify(self, backend: str, real: bool):
x = sp.Symbol("x", real=real)
expression = ComplexSqrt(x)
lambdified = sp.lambdify(x, expression, backend)
assert lambdified(np.array(-1)) == 1j

@pytest.mark.parametrize(
("input_value", "expected"),
[
(sp.Symbol("x", real=True), "ComplexSqrt(x)"),
(sp.Symbol("x"), "ComplexSqrt(x)"),
(+4, "2"),
(-4, "2*I"),
],
)
def test_new(self, input_value, expected: str):
expr = ComplexSqrt(input_value)
assert str(expr) == expected