In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# Brent's Method: A Better Rootfind Algorithm

Our workhorse routine for finding a root is *bisection*.  Given a function $f(x)$ and an interval $x\in[a,b]$ such that $f(a) < 0, f(b) > 0$ or $f(a) > 0, f(b) < 0$, bisection is guaranteed to converge to the root.  As you know, this method works by finding the midpoint $m = (a+b)/2$; if the root lies in $[a,m]$ then the routine sets $b=m$, otherwise if the root lies in $[m,b]$ the routine sets $a=m$. The process then repeats until $|b-a|$ is smaller than some specified tolerance.

Although bisection is guaranteed to converge, it isn't particularly smart. In this notebook, we'll demonstrate a smarter algorithm, known as *Brent's method*.  Let's begin by defining a suitable trial function, $f(x) = \sin(x)$, on an interval $x\in[1.5,4]$.  We'll use this $f(x)$ to explore different rootfind strategies.

In [None]:
def f(x):
    return np.sin(x)

# we're making a lot of variations of this plot, so we'll collect everything into a single
# function
def base_plot(x):
    fig = plt.figure(figsize=(8,5))
    ax = fig.add_subplot(1,1,1)
    xmin = x.min()
    xmax = x.max()
    y = f(x)
    ymin = y.min()
    ymax = y.max()
    pad = 0.05*(ymax-ymin)
    ymin -= pad
    ymax += pad
    ax.set_xlim(xmin,xmax)
    ax.set_ylim(ymin,ymax)
    plt.plot(x,y,'k-')
    plt.vlines(np.pi,ymin,ymax,linestyle='-',linewidth=0.5,color='0.7')
    plt.hlines(0,xmin,xmax,linestyle='-',linewidth=0.5,color='0.7')
    plt.xlabel(r'$x$')
    plt.ylabel(r'$y=f(x)$')
    return ax

xl = 1.5
xr = 4.0
x = np.linspace(xl,xr,100)
base_plot(x)

## Interpolating to predict the root

As a first step, we could *predict* the location of the root by drawing a line between the endpoints, known as a *secant*, and finding where that secant crosses zero. We've marked on the plot (*red dot*) this guess for the root, and also indicated where the bisection guess (*black dot*) would fall.

In [None]:
def secant(f,a,b):
    fa = f(a)
    fb = f(b)
    # chose a to be closest to root, i.e., fa < fb
    if fb < fa:
        a,b = b,a
        fa,fb = fb,fa
    return (fb*a - fa*b)/(fb-fa)

ax = base_plot(x)
ax.plot([xl,xr],[f(xl),f(xr)],color='r',linewidth=0.5,linestyle='-')
r = secant(f,xl,xr)
ax.plot(r,f(r),'ro')

xm = 0.5*(xl+xr)
ax.plot(xm,f(xm),'ko')

This guess is a bit closer than bisection. We could use this guess and the left endpoint to make a new secant and a new, better guess for the root. By repeating this process, we would gradually home in on the root.

Of course, we can make a better approximation than just a secant. As you know, through any two points you can construct a curve; through any three points you can construct a parabola. As it turns out, we do have three points: left, right, and middle.
Just as we constructed a straight line through the endpoints, we could construct a parabola ($y(x) = ax^2+bx+c$) through the endpoints of the interval and the midpoint, and then find $r$ such that $y(x=r) = 0$.

The formula for $a$, $b$, and $c$ can be found with just a bit of algebra:
\begin{eqnarray}
    a &=& \left(\frac{y_2-y_0}{x_2-x_0} - \frac{y_1-y_0}{x_1-x_0}\right)\frac{1}{x_2-x_1}\\
    b &=& \frac{y_1-y_0}{x_1-x_0} - a\left(x_1+x_0\right)\\
    c &=& y_0 - ax_0^2 -bx_0.
\end{eqnarray}
Here $x_0$ means the left endpoint, $x_1$ is the midpoint, and $x_2$ is the right endpoint of the interval. This formula is implemented in the function `quadfit` in the next cell.

In [None]:
def quadfit(x,y):
    """
    Given 3 points, find a parabola y = ax^2 + bx + c going through 
    the points and return the coefficients a,b,c.
    
    Arguments
        x, y (arrays)
            x and y values of the points. Should be in order x[0] < x[1] < x[2].
    Returns
        a, b, c
            coefficients of parabola y = ax^2 + bx +c
    """
    try:
        a = ((y[2]-y[0])/(x[2]-x[0]) - (y[1]-y[0])/(x[1]-x[0]))/(x[2]-x[1])
        b = (y[1]-y[0])/(x[1]-x[0]) - (x[1]+x[0])*a
        c = y[0] - a*x[0]**2 - b*x[0]
    except:
        raise ValueError('unable to compute a, b, c')
    return (a,b,c)

ax = base_plot(x)

# construct the quadratic
xm = 0.5*(xl+xr)
xp = [xl,xm,xr]
yp = [f(xl),f(xm),f(xr)]
a,b,c = quadfit(xp,yp)

# plot the parabola
y_quad = a*x**2 + b*x + c
ax.plot(x,y_quad,color='b')

# solution to quadratic equation that lies in the interval [1.5,4]
rq = 0.5*b/a*(-np.sqrt(1-4*a*c/b**2)-1)
ax.plot(rq,f(rq),'bo')

# show for comparison the guess from the secant method
ax.plot([xl,xr],[f(xl),f(xr)],color='r',linewidth=0.5,linestyle='-')
r = secant(f,xl,xr)
ax.plot(r,f(r),'ro')

# show for comparison the bisection guess at midpoint
ax.plot(xm,f(xm),'ko')

In the plot, the blue curve is our parabola that goes through the midpoint (black dot) and the two endpoints. As you can see, this approximation gives us a closer guess to the root.

It is actually simpler, however, to invert the parabola: instead of making $y$ a quadratic function of $x$, instead treat $x$ as a quadratic function of $y$ ($x=ay^2+by+c$). To predict the root, just set $y=0$ to get the root $x=c$.  This is known as *inverse quadratic interpolation*. Given three points $(a,f(a))$, $(b, f(b))$, and $(c, f(c))$, with $a < r < b$, a prediction for $x$ such that $f(x) = 0$ is
$$
    x = \frac{f(b)f(r)a}{[f(b)-f(a)][f(r)-f(a)]} - \frac{f(a)f(r)b}{[f(b)-f(a)][f(r)-f(b)]} 
        + \frac{f(a)f(b)r}{[f(r)-f(a)][f(r)-f(b)]}
$$.
This is implemented in the function `inverse_quad` in the next cell.

In [None]:
def inverse_quad(f,r,a,b):
    """
    Performs inverse quadratic interpolation to predict x s.t. f(x) = 0 in the interval
    a < x < b.
    
    on the 
    three points (r, f(r)), (a, f(a)), and (b, f(b)) to predict x s.t. f(x)=0.
    
    Arguments
        f
            function for which we seek the root f(x) = 0
        r
            guess for the root: a < r < b
        a, b
            endpoints of interval
    Returns
        x
            prediction for the root
    """
    fa = f(a)
    fb = f(b)
    fr = f(r)
    fba = fb-fa
    fra = fr-fa
    frb = fr-fb
    
    return fb*fr*a/fba/fra - fa*fr*b/fba/frb + fa*fb*r/fra/frb

xl = x.min()
xr = x.max()
xm = 0.5*(xl+xr)
ax = base_plot(x)

# construct the quadratic
xm = 0.5*(xl+xr)
xp = [xl,xm,xr]
yp = [f(xl),f(xm),f(xr)]

# notice how we switch the arguments, so that x = a*y**2 + b*y + c now.
a,b,c = quadfit(yp,xp)

# plot the parabola
y_quad = np.linspace(-0.75,1.0)
x_quad = a*y_quad**2 + b*y_quad + c
ax.plot(x_quad,y_quad,color='b')

# solution to quadratic equation that lies in the interval [1.5,4]
rq = inverse_quad(f,xm,xl,xr)
ax.plot(rq,f(rq),'bo')

# show for comparison the guess from the secant method
ax.plot([xl,xr],[f(xl),f(xr)],color='r',linewidth=0.5,linestyle='-')
r = secant(f,xl,xr)
ax.plot(r,f(r),'ro')

# show for comparison the bisection guess at midpoint
ax.plot(xm,f(xm),'ko')

So now we can construct a (hopefully) better estimate (blue dot) for the root. Now that we have a guess, on the next step (labeled '1' in the plot below) we could use the midpoint (*black dot*), our guess for the root (*blue dot*), and the right endpoint to construct the next guess.

In [None]:
ax = base_plot(x)

xl = x.min()
xr = x.max()
xm = 0.5*(xl+xr)

# solution to quadratic equation that lies in the interval [1.5,4]
rq = inverse_quad(f,xm,xl,xr)
ax.plot(rq,f(rq),'bo')
ax.annotate(s='0',xy=(rq,f(rq)),
            va='bottom',ha='left',xytext=(4,4),textcoords='offset points',size='x-small')
rq = inverse_quad(f,rq,xm,xr)
ax.plot(rq,f(rq),'bo')
ax.annotate(s='1',xy=(rq,f(rq)),
            va='bottom',ha='left',xytext=(4,4),textcoords='offset points',size='x-small')

# show for comparison the bisection guess at midpoint
ax.plot(xm,f(xm),'ko')
ax.annotate(s='0',xy=(xm,f(xm)),
            va='bottom',ha='left',xytext=(4,4),textcoords='offset points',size='x-small')
xm = 0.5*(xm+xr)
ax.plot(xm,f(xm),'ko')
ax.annotate(s='1',xy=(xm,f(xm)),
            va='bottom',ha='left',xytext=(4,4),textcoords='offset points',size='x-small')

## Trade-offs: speed vs robustness

As you'll notice in the plot above, after one iteration with the inverse quadratic interpolation we are basically on top of the root (blue dot marked '1').  In contrast, with bisection we overshoot the root (black dot marked 1).

Inverse quadratic interpolation doesn't always converge faster than bisection, however. Generally speaking, it works well when the interval in question is small enough that a parabola is a good approximation to the function.  For example, let's widen our interval to $x\in[1.5,5.5]$ and again step through two iterations. We'll see that the upswing in $f(x)$ at $x > 3\pi/2$ forces the guess from inverse quadratic interpolation wide of the mark.

In [None]:
xl = 1.5
xr = 5.5
x = np.linspace(xl,xr,100)
ax = base_plot(x)

xm = 0.5*(xl+xr)

# solution to quadratic equation that lies in the interval [1.5,4]
rq = inverse_quad(f,xm,xl,xr)
ax.plot(rq,f(rq),'bo')
ax.annotate(s='0',xy=(rq,f(rq)),
            va='bottom',ha='left',xytext=(4,4),textcoords='offset points',size='x-small')
rq = inverse_quad(f,rq,xm,xr)
ax.plot(rq,f(rq),'bo')
ax.annotate(s='1',xy=(rq,f(rq)),
            va='bottom',ha='left',xytext=(4,4),textcoords='offset points',size='x-small')

# show for comparison the bisection guess at midpoint
ax.plot(xm,f(xm),'ko')
ax.annotate(s='0',xy=(xm,f(xm)),
            va='bottom',ha='left',xytext=(4,4),textcoords='offset points',size='x-small')
xm = 0.5*(xl+xm)
ax.plot(xm,f(xm),'ko')
ax.annotate(s='1',xy=(xm,f(xm)),
            va='bottom',ha='left',xytext=(4,4),textcoords='offset points',size='x-small')

In fact, there isn't even a guarantee that the guess provided by inverse quadratic interpolation will remain in the interval $[a,b]$. We therefore can't simply use that method without some safeguards. Basically, if the guess is worse than the bisection guess, or if it goes out of bounds, we take a bisection step; otherwise we use inverse quadratic interpolation.  This recipe is known as *Brent's method*.  The function `brent_lite` in the next cell performs one iteration of this method.

For comparison, we'll also define a function `bisect` that performs one iteration of bisection.

In [None]:
def brent_lite(f,r,a,b):
    """
    f := function defined on a < x < b
    r := best guess for root
    
    returns
    r, a, b, method := new best guess for root r, new endpoints of interval a,b, and whether
        the guess was obtained by bisection or inverse quadratic interpolation.
    """
    fa = f(a)
    fb = f(b)
    # orient so that f(a) < 0
    if fb < 0:
        a,b = b,a
    m = (a+b)/2
    r = inverse_quad(f,r,a,b)

    # move the endpoints of the interval 
    if (f(m) < 0):
        a = m
    else:
        b = m

    # use the bisection guess if r is a worse guess than m; otherwise keep r
    if np.abs(f(r)) > np.abs(f(m)):
        r = (a+b)/2
        method = 'bisection'
    else:
        method = 'inverse quadratic interpolation'

    return (r,a,b,method)

In [None]:
def bisect(f,a,b):
    fa = f(a) # left side of bracket
    fb = f(b) # right side of bracket
    # orient so that f(a) < 0
    if fb < fa:
        a, b = b, a
    m = (a+b)/2
    if (f(m) < 0):
        a = m
    else:
        b = m
    return a, b

## Evaluation

Let's compare our two methods.  We'll first use bisection.  We'll run seven interations. For each step, we'll mark the endpoint by a black dot.

In [None]:
xl = 1.5; xr = 5.5
x = np.linspace(xl,xr,100)

ax = base_plot(x)
add_pt = True

a = xl; b = xr
rlast = a
print('{0:>9} {1:>9} {2:>11} {3:>11}'.format('iteration','root r','|r-pi|','f(r)'))
for i in range(7):
    a,b = bisect(f,a,b)
    r = (a+b)/2
    print('{0:9d} {1:9.6f} {2:11.4e} {3:11.4e}'.format(i,r,f(r),np.abs(r-np.pi)))
    if np.abs(r-rlast) > 0.05:
        plt.plot(r,f(r),'ko')
        plt.annotate(s='{0:d}'.format(i),xy=(r,f(r)),va='bottom',ha='left',xytext=(4,4),textcoords='offset points',size='x-small')
    rlast = r

Now let's use our rudimentary implementation of Brent's method. As we saw before, on the first iterations bisection is a better choice; after that the inverse quadratic interpolation rapidly converges to the root, so that we are at machine precision after only a few iterations.

In [None]:
xl = 1.5; xr = 5.5
xm = 0.5*(xl+xr)

x = np.linspace(xl,xr,100)
add_pt = True
a = xl; b = xr; r = xm

ax = base_plot(x)

print('{0:>9} {1:>9} {2:>11} {3:>11} {4}'.format('iteration','root r','|r-pi|','f(r)','method'))
for i in range(7):
    d = r
    r,a,b,method = brent_lite(f,r,a,b)
    print('{0:9d} {1:9.6f} {2:11.4e} {3:11.4e} {4}'.format(i,r,np.abs(r-np.pi),f(r),method))
    if method == 'bisection':
        clr = 'k'
    else:
        clr = 'b'
    if np.abs(d-r) > 0.01:
        plt.plot(r,f(r),marker='o',color=clr)
        plt.annotate(s='{0:d}'.format(i),xy=(r,f(r)),va='bottom',ha='left',xytext=(4,4),textcoords='offset points',size='x-small')

## Summary

Brent's method combines a rapidly converging sequence of inverse quadratic interpolations with a fallback bisection method. The method is therefore a popular choice for finding the root of a function on a bracketed interval. The algorithm is available in `scipy.optimize.brentq`.

In [None]:
from scipy.optimize import brentq
xl = 1.5
xr = 5.5
r, info = brentq(f,xl,xr,full_output=True)
print('root of sin(x) on interval [1.5,5.5] is {0:9.6f}'.format(r))
print('number of iterations = {0}'.format(info.iterations))

## Exercise

Using the cell above as a guide, solve for the root of $f(x)$ on $x\in[1.5,5.5]$ using `scipy.optimize.bisect`. How many iterations are required for `bisect` compared with `brentq`?