# [TR-000] Lambdified square roots

In [None]:
%%sh
pip install jax==0.2.13 jaxlib==0.1.67 numpy==1.19.5 sympy==1.8 > /dev/null

## Negative input values

A {func}`~sympy.functions.elementary.miscellaneous.sqrt` in {mod}`sympy` is lambdified to a {doc}`numpy.sqrt() <reference/generated/numpy.sqrt>` when using {mod}`numpy` as back-end:

In [None]:
import sympy as sp

x = sp.Symbol("x")
sqrt_expr = sp.sqrt(x)
sqrt_expr

In [None]:
import inspect

np_sqrt = sp.lambdify(x, sqrt_expr, "numpy")
source = inspect.getsource(np_sqrt)
print(source)

When input values for the {doc}`numpy.sqrt() <reference/generated/numpy.sqrt>` are negative, it raises a {class}`RuntimeWarning` and returns `NaN`:

In [None]:
import numpy as np

sample = np.linspace(-1, 1, 5)
np_sqrt(sample)

As a work-around, one can use {class}`complex` input data instead (e.g. {doc}`numpy.complex64 <numpy:reference/arrays.scalars>`): negative values are treated as lying just above the real axis, so that their square root is a positive imaginary number:

In [None]:
complex_sample = sample.astype(np.complex64)
np_sqrt(complex_sample)

Lambdifying a {func}`sympy.sqrt <sympy.functions.elementary.miscellaneous.sqrt>` to [JAX](https://jax.rtfd.io) also goes without problems:

In [None]:
import jax
import jax.numpy as jnp

jax_sqrt = jax.jit(sp.lambdify(x, sqrt_expr, jnp))
source = inspect.getsource(jax_sqrt)
print(source)

In [None]:
jax_sqrt(sample)

In [None]:
jax_sqrt(complex_sample)

There is a problem with this approach though: once input data is cast to complex values, _all_ square roots in a larger expression (some amplitude model) compute imaginary solutions for negative values, while this is not always the desired behavior. Take for instance the two square roots appearing in {func}`~ampform.dynamics.phase_space_factor` --- does the $\sqrt{s}$ also have to be evaluatable for negative $s$?

## Complex square root

Numpy also offers a special function that does evaluate negative values: {func}`numpy.lib.scimath.sqrt`:

In [None]:
from numpy.lib.scimath import sqrt as csqrt

csqrt(-1)

However, the {mod}`jax.numpy` API does not interface to {mod}`numpy.lib.scimath`, nor can {func}`numpy.lib.scimath.sqrt` be decorated with {func}`jax.jit`:

In [None]:
jax_csqrt = jax.jit(csqrt, backend="cpu")
jax_csqrt(1j)

## Conditional square root

To be able to control which square roots in the complete expression should be evaluatable for negative values, one could use {class}`~sympy.functions.elementary.piecewise.Piecewise`:

In [None]:
def complex_sqrt(x: sp.Symbol) -> sp.Expr:
    return sp.Piecewise(
        (sp.sqrt(-x) * sp.I, x < 0),
        (sp.sqrt(x), True),
    )


complex_sqrt(x)

In [None]:
display(
    complex_sqrt(-4),
    complex_sqrt(+4),
)

Be careful though when lambdifying this expression: _do not use the `__dict__` of the {mod}`numpy` module as backend_, but use the module itself instead. When using `__dict__`, {func}`~sympy.utilities.lambdify.lambdify` will return an `if-else` statement, which is inefficient and, worse, will result in problems with {doc}`JAX <jax:index>`:

In [None]:
np_complex_sqrt_no_select = sp.lambdify(x, complex_sqrt(x), np.__dict__)
source = inspect.getsource(np_complex_sqrt_no_select)
print(source)

In [None]:
np_complex_sqrt_no_select(-1)

In [None]:
import jax

jax_complex_sqrt_no_select = jax.jit(np_complex_sqrt_no_select)
jax_complex_sqrt_no_select(-1)

When instead using the {mod}`numpy` module (or `"numpy"`), {func}`~sympy.utilities.lambdify.lambdify` correctly lambdifies to {func}`numpy.select` to represent the cases.

In [None]:
np_complex_sqrt = sp.lambdify(x, complex_sqrt(x), np)
source = inspect.getsource(np_complex_sqrt)
print(source)

Still, JAX does not handle this correctly. The printed source code seems ok, but this lambdified function crashes once used:

In [None]:
import jax.numpy as jnp

jax_complex_sqrt_error = jax.jit(sp.lambdify(x, complex_sqrt(x), jnp))
source = inspect.getsource(np_complex_sqrt)
print(source)

In [None]:
jax_complex_sqrt_error(-1)

The very same function in created purely with {mod}`jax.numpy` does work without problems, so it seems this is a SymPy problem:

In [None]:
@jax.jit
def jax_complex_sqrt(x):
    return jnp.select(
        [jnp.less(x, 0), True],
        [1j * jnp.sqrt(-x), jnp.sqrt(x)],
        default=jnp.nan,
    )

In [None]:
jax_complex_sqrt(sample)