Porting `brentq` to jax. Probably not necessary given existence of [jaxopt](https://github.com/google/jaxopt)...

In [48]:
from functools import partial
from typing import Callable, Tuple, Union

from jax import jit
from jax.experimental.host_callback import id_print
import jax.numpy as jnp
from scipy.optimize import root_scalar

Array = jnp.ndarray
ArrOrF = Union[Array, float]

In [139]:
EPS = jnp.array(jnp.finfo(float).eps)

@partial(jit, static_argnums=(0,))
def brent(
    f: Callable[[ArrOrF], ArrOrF],
    x0: ArrOrF,
    x1: ArrOrF,
    xtol: ArrOrF = jnp.array(1e-7),
    ytol: ArrOrF = 2 * EPS,
    maxiter: int = jnp.array(50, dtype=int),
) -> ArrOrF:
    """
    Brent's method for root-finding. Ported from this julia implementation:
    https://mmas.github.io/brent-julia.
    """
    y0 = f(x0)
    y1 = f(x1)
    x0, y0, x1, y1 = jax.lax.cond(
        jnp.abs(y0) < jnp.abs(y1),
        lambda _: (x1, y1, x0, y0),
        lambda _: (x0, y0, x1, y1),
        None,
    )
    x2 = x0
    y2 = y0
    x3 = x2

    # i=1, x=nan, y=nan
    init_val = (1, jnp.nan, x0, x1, x2, x3, jnp.inf, y0, y1, y2, True)

    def cond_fun(val):
        i, x, x0, x1, _, _, y, _, _, _, _ = val
        # stop if i > maxiter or |x1-x0| < xtol or |y| < ytol
        return jnp.logical_and(
            i <= maxiter, jnp.logical_and(jnp.abs(x1 - x0) >= xtol, jnp.abs(y) >= ytol)
        )

    def body_fun(val):
        i, _, x0, x1, x2, x3, _, y0, y1, y2, bisection = val

        # Use inverse quadratic interpolation if f(x0) != f(x1) != f(x2)
        # and linear interpolation (secant method) otherwise
        x = jax.lax.cond(
            jnp.logical_and(jnp.abs(y0 - y2) > ytol, jnp.abs(y1 - y2) > ytol),
            lambda _: x0 * y1 * y2 / ((y0 - y1) * (y0 - y2))
            + x1 * y0 * y2 / ((y1 - y0) * (y1 - y2))
            + x2 * y0 * y1 / ((y2 - y0) * (y2 - y1)),
            lambda _: x1 - y1 * (x1 - x0) / (y1 - y0),
            None,
        )

        # Use bisection method if satisfies the conditions
        delta = jnp.abs(2 * EPS * jnp.abs(x1))
        min1 = jnp.abs(x - x1)
        min2 = jnp.abs(x1 - x2)
        min3 = jnp.abs(x2 - x3)
        bisection_cond = jnp.logical_or(
            jnp.logical_and(x < (3 * x0 + x1) / 4, x > x1),
            jnp.logical_or(
                jnp.logical_and(bisection, min1 >= min2 / 2),
                jnp.logical_or(
                    jnp.logical_and(jnp.logical_not(bisection), min1 >= min3 / 2),
                    jnp.logical_or(
                        jnp.logical_and(bisection, min2 < delta),
                        jnp.logical_and(jnp.logical_not(bisection), min3 < delta),
                    ),
                ),
            ),
        )
        x, bisection = jax.lax.cond(
            bisection_cond, lambda _: ((x0 + x1) / 2, True), lambda _: (x, False), None
        )

        y = f(x)

        # Update running variables
        x3 = x2
        x2 = x1
        x0, y0, x1, y1 = jax.lax.cond(
            jnp.sign(y0) != jnp.sign(y),
            lambda _: (x0, y0, x, y),
            lambda _: (x, y, x1, y1),
            None,
        )
        x0, y0, x1, y1 = jax.lax.cond(
            jnp.abs(y0) < jnp.abs(y1),
            lambda _: (x1, y1, x0, y0),
            lambda _: (x0, y0, x1, y1),
            None,
        )
        return (i + 1, x, x0, x1, x2, x3, y, y0, y1, y2, bisection)

    val = jax.lax.while_loop(cond_fun, body_fun, init_val)
    i, x = val[0], val[1]
    x = jax.lax.cond(i == maxiter, lambda _: jnp.nan, lambda _: x, None)
    return {"converged": jnp.logical_not(jnp.isnan(x)), "iterations": i, "root": x}

In [140]:
@jit
def f(x):
    return x**4 - 2 * x**2 + 1 / 4

In [141]:
jnp.sqrt(1 - jnp.sqrt(3) / 2)

DeviceArray(0.36602542, dtype=float32)

In [142]:
@partial(jit, static_argnums=(0,))
def apply(f, x):
    return f(x)

In [143]:
%timeit apply(f, jnp.array(0.5))

155 µs ± 3.22 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [144]:
%timeit brent(f, jnp.array(0.0), jnp.array(1.0))["root"].block_until_ready()

339 µs ± 3.89 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [128]:
%prun -s cumulative [brent(f, jnp.array(0.0), jnp.array(1.0))["root"].block_until_ready() for _ in range(100)]

 

In [111]:
%timeit root_scalar(f, bracket=(0.0, 1.0), maxiter=50, method="brentq").root

11.2 µs ± 105 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
