# Newton method with line search

In [None]:
import jax, jax.numpy as jnp
import numpy as np

In [None]:
def func(x):
    return (x[0]-1)**2 + (x[1]+2)**2

In [None]:
grad = jax.jit(jax.grad(func))
hess = jax.jit(jax.hessian(func))

In [None]:
def bisection_line_search(f, df, a=0.0, b=1.0, tol=1e-6, max_iter=20):
    fa, fb = df(a), df(b)
    if fa == 0: return a
    if fb == 0: return b
    for _ in range(max_iter):
        mid = 0.5*(a+b); fm = df(mid)
        if abs(fm) < tol: return mid
        if fa*fm < 0: b, fb = mid, fm
        else: a, fa = mid, fm
    return 0.5*(a+b)

In [None]:
def newton(x0, max_iter=5, tol=1e-6):
    x = np.array(x0, dtype=float)
    for i in range(max_iter):
        g = np.array(grad(x))
        print(f'iter {i}, x={x}, grad_norm={np.linalg.norm(g):.2e}')
        if np.linalg.norm(g) < tol: break
        H = np.array(hess(x))
        step = np.linalg.solve(H, g)
        line_obj = lambda a: func(x - a*step)
        line_grad = lambda a: float(jax.grad(line_obj)(a))
        alpha = bisection_line_search(line_obj, line_grad)
        x = x - alpha*step
    return x

In [None]:
res = newton(np.array([0.0, 0.0]))
print('optimum:', res)