# Custom lambdification

As noted in {doc}`/report/000`, it's hard to lambdify a {func}`sympy.sqrt <sympy.functions.elementary.miscellaneous.sqrt>` to {doc}`JAX <jax:index>`. One possible way out is to define a custom class that derives from {class}`sympy.Expr <sympy.core.expr.Expr>` and {doc}`overwrite its printer methods <sympy:modules/printing>`.

In [None]:
from typing import Any

import sympy as sp
from sympy.printing.printer import Printer


class ComplexSqrt(sp.Expr):
    def __new__(cls, x, *args: Any, **kwargs: Any):
        x = sp.sympify(x)
        expr = sp.Expr.__new__(cls, x, *args, **kwargs)
        if hasattr(x, "free_symbols") and not x.free_symbols:
            return expr.evaluate()
        return expr

    def evaluate(self):
        x = self.args[0]
        if not x.is_real:
            return sp.sqrt(x)
        return sp.Piecewise(
            (sp.I * sp.sqrt(-x), x < 0),
            (sp.sqrt(x), True),
        )

    def _latex(self, printer: Printer, *args: Any) -> str:
        x = printer._print(self.args[0])
        return fR"\sqrt[\mathrm{{c}}]{{{x}}}"

    def _numpycode(self, printer: Printer, *args: Any) -> str:
        printer.module_imports["numpy.lib"].add("scimath")
        x = printer._print(self.args[0])
        return f"scimath.sqrt({x})"

    def _pythoncode(self, printer: Printer, *args: Any) -> str:
        printer.module_imports["cmath"].add("sqrt as csqrt")
        x = printer._print(self.args[0])
        return f"csqrt({x})"

In addition, one may overwrite this `Lambdifier` class, so that {func}`sympy.plot() <sympy.plotting.plot.plot>` also works:

In [None]:
from sympy.plotting.experimental_lambdify import Lambdifier

Lambdifier.builtin_functions_different["ComplexSqrt"] = "sqrt"

The `_latex()` method ensures that `ComplexSqrt` renders nicely in notebooks:

In [None]:
x = sp.Symbol("x")
ComplexSqrt(x)

As opposed to the {doc}`derivation of a sympy.Expr </adr/002/expr>`, however, this class evaluates directly:

In [None]:
ComplexSqrt(-4)

And the important part, lambdifying to {mod}`numpy` works well as well:

In [None]:
import inspect

lambdified = sp.lambdify(x, ComplexSqrt(x), "numpy")
source = inspect.getsource(lambdified)
print(source)

Just as noted in {ref}`report/000:Complex square root` though, {mod}`numpy.lib.scimath` cannot be used with {doc}`JAX <jax:index>`:

In [None]:
import jax
import numpy as np

jax_lambdified = jax.jit(lambdified)
sample = np.linspace(-1, +1, 5)
jax_lambdified(sample)