Skip to content

Commit

Permalink
Merge pull request #417 from chrishyland/root-finding
Browse files Browse the repository at this point in the history
ENH: Root finding
  • Loading branch information
mmcky committed Jul 25, 2018
2 parents 2a612cc + 876c797 commit 1155267
Show file tree
Hide file tree
Showing 3 changed files with 381 additions and 1 deletion.
2 changes: 1 addition & 1 deletion quantecon/optimize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
"""

from .scalar_maximization import brent_max

from .root_finding import newton, newton_halley, newton_secant
257 changes: 257 additions & 0 deletions quantecon/optimize/root_finding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
import numpy as np
from numba import jit, njit
from collections import namedtuple

__all__ = ['newton', 'newton_halley', 'newton_secant']

_ECONVERGED = 0
_ECONVERR = -1

results = namedtuple('results',
('root function_calls iterations converged'))

@njit
def _results(r):
r"""Select from a tuple of(root, funccalls, iterations, flag)"""
x, funcalls, iterations, flag = r
return results(x, funcalls, iterations, flag == 0)

@njit
def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50,
disp=True):
"""
Find a zero from the Newton-Raphson method using the jitted version of
Scipy's newton for scalars. Note that this does not provide an alternative
method such as secant. Thus, it is important that `fprime` can be provided.
Note that `func` and `fprime` must be jitted via Numba.
They are recommended to be `njit` for performance.
Parameters
----------
func : callable and jitted
The function whose zero is wanted. It must be a function of a
single variable of the form f(x,a,b,c...), where a,b,c... are extra
arguments that can be passed in the `args` parameter.
x0 : float
An initial estimate of the zero that should be somewhere near the
actual zero.
fprime : callable and jitted
The derivative of the function (when available and convenient).
args : tuple, optional
Extra arguments to be used in the function call.
tol : float, optional
The allowable error of the zero value.
maxiter : int, optional
Maximum number of iterations.
disp : bool, optional
If True, raise a RuntimeError if the algorithm didn't converge
Returns
-------
results : namedtuple
root - Estimated location where function is zero.
function_calls - Number of times the function was called.
iterations - Number of iterations needed to find the root.
converged - True if the routine converged
"""

if tol <= 0:
raise ValueError("tol is too small <= 0")
if maxiter < 1:
raise ValueError("maxiter must be greater than 0")

# Convert to float (don't use float(x0); this works also for complex x0)
p0 = 1.0 * x0
funcalls = 0
status = _ECONVERR

# Newton-Raphson method
for itr in range(maxiter):
# first evaluate fval
fval = func(p0, *args)
funcalls += 1
# If fval is 0, a root has been found, then terminate
if fval == 0:
status = _ECONVERGED
p = p0
itr -= 1
break
fder = fprime(p0, *args)
funcalls += 1
# derivative is zero, not converged
if fder == 0:
p = p0
break
newton_step = fval / fder
# Newton step
p = p0 - newton_step
if abs(p - p0) < tol:
status = _ECONVERGED
break
p0 = p

if disp and status == _ECONVERR:
msg = "Failed to converge"
raise RuntimeError(msg)

return _results((p, funcalls, itr + 1, status))

@njit
def newton_halley(func, x0, fprime, fprime2, args=(), tol=1.48e-8,
maxiter=50, disp=True):
"""
Find a zero from Halley's method using the jitted version of
Scipy's.
`func`, `fprime`, `fprime2` must be jitted via Numba.
Parameters
----------
func : callable and jitted
The function whose zero is wanted. It must be a function of a
single variable of the form f(x,a,b,c...), where a,b,c... are extra
arguments that can be passed in the `args` parameter.
x0 : float
An initial estimate of the zero that should be somewhere near the
actual zero.
fprime : callable and jitted
The derivative of the function (when available and convenient).
fprime2 : callable and jitted
The second order derivative of the function
args : tuple, optional
Extra arguments to be used in the function call.
tol : float, optional
The allowable error of the zero value.
maxiter : int, optional
Maximum number of iterations.
disp : bool, optional
If True, raise a RuntimeError if the algorithm didn't converge
Returns
-------
results : namedtuple
root - Estimated location where function is zero.
function_calls - Number of times the function was called.
iterations - Number of iterations needed to find the root.
converged - True if the routine converged
"""

if tol <= 0:
raise ValueError("tol is too small <= 0")
if maxiter < 1:
raise ValueError("maxiter must be greater than 0")

# Convert to float (don't use float(x0); this works also for complex x0)
p0 = 1.0 * x0
funcalls = 0
status = _ECONVERR

# Halley Method
for itr in range(maxiter):
# first evaluate fval
fval = func(p0, *args)
funcalls += 1
# If fval is 0, a root has been found, then terminate
if fval == 0:
status = _ECONVERGED
p = p0
itr -= 1
break
fder = fprime(p0, *args)
funcalls += 1
# derivative is zero, not converged
if fder == 0:
p = p0
break
newton_step = fval / fder
# Halley's variant
fder2 = fprime2(p0, *args)
p = p0 - newton_step / (1.0 - 0.5 * newton_step * fder2 / fder)
if abs(p - p0) < tol:
status = _ECONVERGED
break
p0 = p

if disp and status == _ECONVERR:
msg = "Failed to converge"
raise RuntimeError(msg)

return _results((p, funcalls, itr + 1, status))

@njit
def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50,
disp=True):
"""
Find a zero from the secant method using the jitted version of
Scipy's secant method.
Note that `func` must be jitted via Numba.
Parameters
----------
func : callable and jitted
The function whose zero is wanted. It must be a function of a
single variable of the form f(x,a,b,c...), where a,b,c... are extra
arguments that can be passed in the `args` parameter.
x0 : float
An initial estimate of the zero that should be somewhere near the
actual zero.
args : tuple, optional
Extra arguments to be used in the function call.
tol : float, optional
The allowable error of the zero value.
maxiter : int, optional
Maximum number of iterations.
disp : bool, optional
If True, raise a RuntimeError if the algorithm didn't converge.
Returns
-------
results : namedtuple
root - Estimated location where function is zero.
function_calls - Number of times the function was called.
iterations - Number of iterations needed to find the root.
converged - True if the routine converged
"""

if tol <= 0:
raise ValueError("tol is too small <= 0")
if maxiter < 1:
raise ValueError("maxiter must be greater than 0")

# Convert to float (don't use float(x0); this works also for complex x0)
p0 = 1.0 * x0
funcalls = 0
status = _ECONVERR

# Secant method
if x0 >= 0:
p1 = x0 * (1 + 1e-4) + 1e-4
else:
p1 = x0 * (1 + 1e-4) - 1e-4
q0 = func(p0, *args)
funcalls += 1
q1 = func(p1, *args)
funcalls += 1
for itr in range(maxiter):
if q1 == q0:
p = (p1 + p0) / 2.0
status = _ECONVERGED
break
else:
p = p1 - q1 * (p1 - p0) / (q1 - q0)
if np.abs(p - p1) < tol:
status = _ECONVERGED
break
p0 = p1
q0 = q1
p1 = p
q1 = func(p1, *args)
funcalls += 1

if disp and status == _ECONVERR:
msg = "Failed to converge"
raise RuntimeError(msg)

return _results((p, funcalls, itr + 1, status))
123 changes: 123 additions & 0 deletions quantecon/optimize/tests/test_root_finding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import numpy as np
from numpy.testing import assert_almost_equal, assert_allclose
from numba import njit

from quantecon.optimize import newton, newton_halley, newton_secant

@njit
def func(x):
"""
Function for testing on.
"""
return (x**3 - 1)


@njit
def func_prime(x):
"""
Derivative for func.
"""
return (3*x**2)

@njit
def func_prime2(x):
"""
Second order derivative for func.
"""
return 6*x

@njit
def func_two(x):
"""
Harder function for testing on.
"""
return np.sin(4 * (x - 1/4)) + x + x**20 - 1


@njit
def func_two_prime(x):
"""
Derivative for func_two.
"""
return 4*np.cos(4*(x - 1/4)) + 20*x**19 + 1

@njit
def func_two_prime2(x):
"""
Second order derivative for func_two
"""
return 380*x**18 - 16*np.sin(4*(x - 1/4))


def test_newton_basic():
"""
Uses the function f defined above to test the scalar maximization
routine.
"""
true_fval = 1.0
fval = newton(func, 5, func_prime)
assert_almost_equal(true_fval, fval.root, decimal=4)


def test_newton_basic_two():
"""
Uses the function f defined above to test the scalar maximization
routine.
"""
true_fval = 1.0
fval = newton(func, 5, func_prime)
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0)


def test_newton_hard():
"""
Harder test for convergence.
"""
true_fval = 0.408
fval = newton(func_two, 0.4, func_two_prime)
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01)

def test_halley_basic():
"""
Basic test for halley method
"""
true_fval = 1.0
fval = newton_halley(func, 5, func_prime, func_prime2)
assert_almost_equal(true_fval, fval.root, decimal=4)

def test_halley_hard():
"""
Harder test for halley method
"""
true_fval = 0.408
fval = newton_halley(func_two, 0.4, func_two_prime, func_two_prime2)
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01)

def test_secant_basic():
"""
Basic test for secant option.
"""
true_fval = 1.0
fval = newton_secant(func, 5)
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.001)


def test_secant_hard():
"""
Harder test for convergence for secant function.
"""
true_fval = 0.408
fval = newton_secant(func_two, 0.4)
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01)


# executing testcases.

if __name__ == '__main__':
import sys
import nose

argv = sys.argv[:]
argv.append('--verbose')
argv.append('--nocapture')
nose.main(argv=argv, defaultTest=__file__)

0 comments on commit 1155267

Please sign in to comment.