#### Newton's Method
Use Newton's method with Jax autodifferentiation to solve for the zeros of a nonlinear function

#### Preliminaries

In [1]:
# load some packages
import matplotlib.pyplot as plt
import jax.numpy as jnp
import warnings
import jax
jax.config.update('jax_enable_x64', True)  # jax uses 32 bit numbers for some reason

#### 1D Implementation
The OG Newton's method solves 1D problems.

We start with an initial guess $x_0$ and iterate such that
$$
x_n = x_{n-1} - \frac{f(x_{n-1})}{f'(x_{n-1})}
$$
until two successive guesses are within some tolerance of each other.

In [70]:
def newtons_method_1d(f, x0=0., tol=1e-16, maxiter=1000):
    '''
    Uses Newton's method to find zeros of a 1D function
    '''
    # get derivative function
    df = jax.grad(f)

    # initial conditions
    x = x_l = x0
    for _ in range(maxiter):
        # update rule
        x, x_l = x - f(x) / df(x), x

        # exit condition
        if jnp.abs(x - x_l) < tol:
            break
    else:
        warnings.warn('Max iteration count exceeded')
    
    return x

## test cases
# sin
f = jnp.sin
%timeit x = newtons_method_1d(f)
print(x, f(x))

# cos
f = jnp.cos
%timeit x = newtons_method_1d(f, 1.)  # needs to be 1 since df(0) = 0
print(x, f(x))

# exp
f = lambda x: jnp.exp(x) - 10
%timeit x = newtons_method_1d(f)
print(x, f(x))

# polynomial
f = lambda x: x**2 + 2*x + 1
%timeit x = newtons_method_1d(f)
print(x, f(x))

407 μs ± 3.54 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
-0.9999999925494194 -0.8414709807823306
1.94 ms ± 18.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
-0.9999999925494194 0.5403023121375871
8.27 ms ± 199 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
-0.9999999925494194 -9.632120556087642
32 ms ± 663 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
-0.9999999925494194 0.0


#### Higher Dimensional Implementation
We can also implement the same method in higher dimensions using the Jacobian $J_F$. The update rule then becomes
$$
\vec{x_n} = \vec{x_{n-1}} - J_F(\vec{x_{n-1}})^{-1} F(\vec{x_{n-1}})
$$

In [3]:
def newtons_method_inv(F, x0, tol=1e-10, maxiter=1000):
    '''
    Uses Newton's method to find zeros of a higher dimensional function
    '''
    # get derivative function
    jac = jax.jacfwd(F)

    # initial conditions
    X = X_l = x0  # x0 has to be defined so we know the sixe of input 
    for _ in range(maxiter):
        # update rule
        X, X_l = X - jnp.linalg.inv(jac(X)) @ F(X), X

        # exit condition
        if jnp.abs(X - X_l).max() < tol:
            break
    else:
        warnings.warn('Max iteration count exceeded')
    
    return X

Since $J_F(\vec{x_{n-1}})^{-1} F(\vec{x_{n-1}})$ is the solution to $J_F(\vec{x_{n-1}}) \vec{x} =  F(\vec{x_{n-1}})$ we can also write this using the `solve` or `lstsq` method.

In [4]:
def newtons_method_solve(F, X0, tol=1e-10, maxiter=1000):
    '''
    Uses Newton's method to find zeros of a higher dimensional function
    '''
    # get derivative function
    jac = jax.jacfwd(F)

    # initial conditions
    X = X_l = X0  # x0 has to be defined so we know the sixe of inputt
    for _ in range(maxiter):
        # update rule
        X, X_l = X - jnp.linalg.lstsq(jac(X), F(X))[0], X

        # exit condition
        if jnp.abs(X - X_l).max() < tol:
            break
    else:
        warnings.warn('Max iteration count exceeded')
    
    return X

Finally, both of these methods can be updated to use the `@jax.jit` tag to make them faster for larger systems.

In [5]:
def newtons_method_inv_jit(F, x0, tol=1e-10, maxiter=1000):
    '''
    Uses Newton's method to find zeros of a higher dimensional function
    '''
    # get derivative function
    jac = jax.jacfwd(F)

    # update rule
    @jax.jit
    def q(X):
        return X - jnp.linalg.inv(jac(X)) @ F(X)

    # initial conditions
    X = X_l = x0  # x0 has to be defined so we know the sixe of inputt
    for _ in range(maxiter):
        # update rule
        X, X_l = q(X), X

        # exit condition
        if jnp.abs(X - X_l).max() < tol:
            break
    else:
        warnings.warn('Max iteration count exceeded')
    
    return X


def newtons_method_solve_jit(F, x0, tol=1e-10, maxiter=1000):
    '''
    Uses Newton's method to find zeros of a higher dimensional function
    '''
    # get derivative function
    jac = jax.jacfwd(F)

    # update rule
    @jax.jit
    def q(X):
        return X - jnp.linalg.lstsq(jac(X), F(X))[0]

    # initial conditions
    X = X_l = x0  # x0 has to be defined so we know the sixe of inputt
    for _ in range(maxiter):
        # update rule
        X, X_l = q(X), X

        # exit condition
        if jnp.abs(X - X_l).max() < tol:
            break
    else:
        warnings.warn('Max iteration count exceeded')
    
    return X

In [6]:
def F(X):
    return jnp.array([
            5*X[0]**2 + X[0]*X[1]**2 * jnp.sin(2*X[1])**2 - 2,
            jnp.exp(X[0] - X[1]) + 4 * X[1] - 3
        ])
%timeit newtons_method_inv(F, jnp.ones(2))
%timeit newtons_method_solve(F, jnp.ones(2))
%timeit newtons_method_inv_jit(F, jnp.ones(2))
%timeit newtons_method_solve_jit(F, jnp.ones(2))

28.8 ms ± 914 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
28.5 ms ± 961 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
53.8 ms ± 184 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
38.4 ms ± 253 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [7]:
def F(X):
    return (jnp.outer(X, jnp.cos(X[::-1])+jnp.linspace(1, 2, 12))**(-1)).sum(axis=0) - jnp.prod(X)
%timeit newtons_method_inv(F, jnp.ones(12))
%timeit newtons_method_solve(F, jnp.ones(12))
%timeit newtons_method_inv_jit(F, jnp.ones(12))
%timeit newtons_method_solve_jit(F, jnp.ones(12))

12.7 ms ± 371 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
12.1 ms ± 184 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
71.4 ms ± 311 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
54.3 ms ± 698 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
def F(X):
    return jnp.exp(X * jnp.hstack((X[1:], X[0]))) - 10 * jnp.hstack((jnp.cos(X[:10]), X[10:]**2))
%timeit newtons_method_inv(F, jnp.ones(28))
%timeit newtons_method_solve(F, jnp.ones(28))
%timeit newtons_method_inv_jit(F, jnp.ones(28))
%timeit newtons_method_solve_jit(F, jnp.ones(28))

18.7 ms ± 162 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
19 ms ± 22.5 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
64.2 ms ± 285 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
44.2 ms ± 383 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [68]:
def F(X):
    resid_1 = 0.5 * X[:100]**2. - X[100:]**2.
    resid_2 = jnp.linspace(0, 20**(1/4), 100)**4 + 2 * X[:100] - X[100:]
    return jnp.hstack((resid_1, resid_2))
%timeit newtons_method_inv(F, jnp.ones(200))
%timeit newtons_method_solve(F, jnp.ones(200))
%timeit newtons_method_inv_jit(F, jnp.ones(200))
%timeit newtons_method_solve_jit(F, jnp.ones(200))

162 ms ± 9.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
765 ms ± 32.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
101 ms ± 946 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
657 ms ± 22.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
