```{autolink-concat}
```

::::{margin}
:::{card} Square root over arrays with negative values
TR-000
^^^
This notebook investigates how to write a square root function in {mod}`sympy` that computes the positive square root for negative values. The lambdified version of this 'complex square root' should have the same behavior for each computational backend.
+++
✅&nbsp;[tensorwaves#284](https://github.com/ComPWA/tensorwaves/pull/284)
:::
::::

# Complex square roots

<!-- cspell:disable -->

In [None]:
%pip install -q black==21.5b2 jax==0.4.28 jaxlib==0.4.28 numpy==1.23 sympy==1.8

In [None]:
import inspect

import jax
import jax.numpy as jnp
import numpy as np
import sympy as sp
from black import FileMode, format_str
from IPython.display import display

## Negative input values

When using {mod}`numpy` as back-end, {mod}`sympy` lambdifies a {func}`~sympy.functions.elementary.miscellaneous.sqrt` to a {obj}`numpy.sqrt`:

In [None]:
x = sp.Symbol("x")
sqrt_expr = sp.sqrt(x)
sqrt_expr

sqrt(x)

In [None]:
np_sqrt = sp.lambdify(x, sqrt_expr, "numpy")
source = inspect.getsource(np_sqrt)
print(source)

def _lambdifygenerated(x):
    return (sqrt(x))



As expected, if input values for the {obj}`numpy.sqrt` are negative, {mod}`numpy` raises a {class}`RuntimeWarning` and returns `NaN`:

In [None]:
sample = np.linspace(-1, 1, 5)
np_sqrt(sample)

  return (sqrt(x))


array([       nan,        nan, 0.        , 0.70710678, 1.        ])

If we want {mod}`numpy` to return imaginary numbers for negative input values, one can use {class}`complex` input data instead (e.g. {doc}`numpy.complex64 <numpy:reference/arrays.scalars>`). Negative values are then treated as lying just above the real axis, so that their square root is a positive imaginary number:

In [None]:
complex_sample = sample.astype(np.complex64)
np_sqrt(complex_sample)

array([0.        +1.j        , 0.        +0.70710677j,
       0.        +0.j        , 0.70710677+0.j        ,
       1.        +0.j        ], dtype=complex64)

A {func}`sympy.sqrt <sympy.functions.elementary.miscellaneous.sqrt>` lambdified to [JAX](https://jax.rtfd.io) exhibits the same behavior:

In [None]:
jax_sqrt = jax.jit(sp.lambdify(x, sqrt_expr, jnp))
source = inspect.getsource(jax_sqrt)
print(source)

def _lambdifygenerated(x):
    return (sqrt(x))



In [None]:
jax_sqrt(sample)



DeviceArray([       nan,        nan, 0.        , 0.70710677, 1.        ],            dtype=float32)

In [None]:
jax_sqrt(complex_sample)

DeviceArray([-4.3711388e-08+1.j        , -3.0908620e-08+0.70710677j,
              0.0000000e+00+0.j        ,  7.0710677e-01+0.j        ,
              1.0000000e+00+0.j        ], dtype=complex64)

**There is a problem with this approach though**: once input data is complex, _all_ square roots in a larger expression (some amplitude model) compute imaginary solutions for negative values, while this is not always the desired behavior.

Take for instance the two square roots appearing in {class}`~ampform.dynamics.phasespace.PhaseSpaceFactor` --- does the $\sqrt{s}$ also have to be evaluatable for negative $s$?

## Complex square root

Numpy also offers a special function that evaluates negative values even if the input values are real: {func}`numpy.emath.sqrt`:

In [None]:
np.emath.sqrt(-1)

1j

Unfortunately, the {mod}`jax.numpy` API does not interface to {mod}`numpy.emath`. It is possible to decorate {func}`numpy.emath.sqrt` be decorated with {func}`jax.jit`, but that **only works with static, hashable arguments**:

In [None]:
jax_csqrt_error = jax.jit(np.emath.sqrt, backend="cpu")
jax_csqrt_error(-1)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError)

In [None]:
jax_csqrt = jax.jit(np.emath.sqrt, backend="cpu", static_argnums=0)
jax_csqrt(-1)

DeviceArray(0.+1.j, dtype=complex64)

In [None]:
jax_csqrt(sample)

ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'numpy.ndarray'>, [-1.  -0.5  0.   0.5  1. ]. The error was:
TypeError: unhashable type: 'numpy.ndarray'


## Conditional square root

To be able to control which square roots in the complete expression should be evaluatable for negative values, one could use {class}`~sympy.functions.elementary.piecewise.Piecewise`:

In [None]:
def complex_sqrt(x: sp.Symbol) -> sp.Expr:
    return sp.Piecewise(
        (sp.sqrt(-x) * sp.I, x < 0),
        (sp.sqrt(x), True),
    )


complex_sqrt(x)

Piecewise((I*sqrt(-x), x < 0), (sqrt(x), True))

In [None]:
display(
    complex_sqrt(-4),
    complex_sqrt(+4),
)

2*I

2

Be careful though when lambdifying this expression: do not use the `__dict__` of the {mod}`numpy` module as backend, but use the module itself instead. When using `__dict__`, {func}`~sympy.utilities.lambdify.lambdify` will return an `if-else` statement, which is inefficient and, worse, will result in problems with {doc}`JAX <jax:index>`:

:::{warning}

Do not use the module `__dict__` for the `modules` argument of {func}`~sympy.utilities.lambdify.lambdify`.

:::

In [None]:
np_complex_sqrt_no_select = sp.lambdify(x, complex_sqrt(x), np.__dict__)
source = inspect.getsource(np_complex_sqrt_no_select)
print(source)

def _lambdifygenerated(x):
    return (((1j*sqrt(-x)) if (x < 0) else (sqrt(x))))



In [None]:
np_complex_sqrt_no_select(-1)

1j

In [None]:
jax_complex_sqrt_no_select = jax.jit(np_complex_sqrt_no_select)
jax_complex_sqrt_no_select(-1)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function _lambdifygenerated at <lambdifygenerated-3>:1, transformed by jit., this concrete value was not available in Python because it depends on the value of the arguments to _lambdifygenerated at <lambdifygenerated-3>:1, transformed by jit. at flattened positions [0], and the computation of these values is being staged out (that is, delayed rather than executed eagerly).
 (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError)

When instead using the {mod}`numpy` module (or `"numpy"`), {func}`~sympy.utilities.lambdify.lambdify` correctly lambdifies to {func}`numpy.select` to represent the cases.

In [None]:
np_complex_sqrt = sp.lambdify(x, complex_sqrt(x), np)
source = inspect.getsource(np_complex_sqrt)

In [None]:
print(format_str(source.replace("nan)", "nan,)"), mode=FileMode()))

def _lambdifygenerated(x):
    return select(
        [less(x, 0), True],
        [1j * sqrt(-x), sqrt(x)],
        default=nan,
    )



Still, JAX does not handle this correctly. First, lambdifying JAX again results in this `if-else` syntax:

In [None]:
jnp_complex_sqrt = sp.lambdify(x, complex_sqrt(x), jnp)
source = inspect.getsource(jnp_complex_sqrt)
print(source)

def _lambdifygenerated(x):
    return (((1j*sqrt(-x)) if (x < 0) else (sqrt(x))))



But even if we lambdify to {mod}`numpy` and decorate the result with a {func}`jax.jit` decorator, the resulting function does not work properly:

In [None]:
jax_complex_sqrt_error = jax.jit(np_complex_sqrt)
source = inspect.getsource(jax_complex_sqrt_error)

In [None]:
print(format_str(source.replace("nan)", "nan,)"), mode=FileMode()))

def _lambdifygenerated(x):
    return select(
        [less(x, 0), True],
        [1j * sqrt(-x), sqrt(x)],
        default=nan,
    )



In [None]:
jax_complex_sqrt_error(-1)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError)

The very same function in created purely with {mod}`jax.numpy` does work without problems, so it seems this is a SymPy problem:

In [None]:
@jax.jit
def jax_complex_sqrt(x):
    return jnp.select(
        [jnp.less(x, 0), True],
        [1j * jnp.sqrt(-x), jnp.sqrt(x)],
        default=jnp.nan,
    )

In [None]:
jax_complex_sqrt(sample)

DeviceArray([0.        +1.j        , 0.        +0.70710677j,
             0.        +0.j        , 0.70710677+0.j        ,
             1.        +0.j        ], dtype=complex64)

A solution to this is presented in {ref}`001/index:Handle for JAX`.