# [TR-001] Custom lambdification

:::{seealso}

{doc}`SymPy's tutorial page on the printing modules <sympy:modules/printing>`

:::

<!-- cspell:disable -->

In [None]:
%%sh
pip install jax==0.2.13 jaxlib==0.1.67 matplotlib==3.4.2 numpy==1.19.5 sympy==1.8 > /dev/null

In [None]:
import inspect
from typing import Any

import jax
import numpy as np
import sympy as sp

## Overwriting printer methods

As noted in {doc}`/report/000`, it's hard to lambdify a {func}`sympy.sqrt <sympy.functions.elementary.miscellaneous.sqrt>` to {doc}`JAX <jax:index>`. One possible way out is to define a custom class that derives from {class}`sympy.Expr <sympy.core.expr.Expr>` and {doc}`overwrite its printer methods <sympy:modules/printing>`.

In [None]:
from sympy.printing.printer import Printer


class ComplexSqrt(sp.Expr):
    def __new__(cls, x, *args: Any, **kwargs: Any):
        x = sp.sympify(x)
        expr = sp.Expr.__new__(cls, x, *args, **kwargs)
        if hasattr(x, "free_symbols") and not x.free_symbols:
            return expr.evaluate()
        return expr

    def evaluate(self):
        x = self.args[0]
        if not x.is_real:
            return sp.sqrt(x)
        return sp.Piecewise(
            (sp.I * sp.sqrt(-x), x < 0),
            (sp.sqrt(x), True),
        )

    def _latex(self, printer: Printer, *args: Any) -> str:
        x = printer._print(self.args[0])
        return fR"\sqrt[\mathrm{{c}}]{{{x}}}"

    def _numpycode(self, printer: Printer, *args: Any) -> str:
        printer.module_imports["numpy.lib"].add("scimath")
        x = printer._print(self.args[0])
        return f"scimath.sqrt({x})"

    def _pythoncode(self, printer: Printer, *args: Any) -> str:
        printer.module_imports["cmath"].add("sqrt as csqrt")
        x = printer._print(self.args[0])
        return f"csqrt({x})"

As opposed to the {doc}`derivation of a sympy.Expr </adr/002/expr>`, however, this class evaluates directly:

In [None]:
ComplexSqrt(-4)

The `_latex()` method ensures that `ComplexSqrt` renders nicely in notebooks:

In [None]:
x = sp.Symbol("x")
ComplexSqrt(x)

## Plot custom class

In addition, one may modify this `Lambdifier` class, so that {func}`sympy.plot() <sympy.plotting.plot.plot>` also works on this custom class:

In [None]:
from sympy.plotting.experimental_lambdify import Lambdifier

Lambdifier.builtin_functions_different["ComplexSqrt"] = "sqrt"

In [None]:
x = sp.Symbol("x")
expr = ComplexSqrt(x)
p1 = sp.plot(sp.re(expr), (x, -1, 2), show=False, line_color="red")
p2 = sp.plot(sp.im(expr), (x, -1, 2), show=False)
p1.append(p2[0])
p1.show()

## Lambdifying

The important part, lambdifying to {mod}`numpy` or {mod}`math` works well as well now:

In [None]:
lambdified_py = sp.lambdify(x, ComplexSqrt(x), "math")
source = inspect.getsource(lambdified_py)
print(source)

In [None]:
lambdified_np = sp.lambdify(x, ComplexSqrt(x), "numpy")
source = inspect.getsource(lambdified_np)
print(source)

Just as noted in {ref}`report/000:Complex square root` though, {mod}`numpy.lib.scimath` is not provided by the NumPy API of {doc}`JAX <jax:index>`. As discussed there, we can at most decorate the {mod}`numpy.lib.scimath` version with {func}`jax.jit` and work with static arguments only:

In [None]:
jax_lambdified = jax.jit(lambdified_np, backend="cpu", static_argnums=0)
jax_lambdified(-1)

Unhashable (non-static) input samples are not accepted:

In [None]:
sample = np.linspace(-1, +1, 5)
jax_lambdified(sample)

## Handle for JAX