# Newton method

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import time
import jax.numpy as jnp
import jax

# We enable double precision in JAX
from jax.config import config
config.update("jax_enable_x64", True)

We consider a random matrix $A \in \mathbb{R}^{n\times n}$, with $n = 100$ and a random vector $\mathbf{x}_{\text{ex}} \in \mathbb{R}^n$.
We define then $\mathbf{b} = A \, \mathbf{x}_{\text{ex}}$.

In [16]:
n = 100

np.random.seed(0)
A = np.random.randn(n,n)
x_ex = np.random.randn(n)
b = A @ x_ex

Define the loss function

$$
\mathcal{L}(\mathbf{x}) = \| \mathbf{b} - A \, \mathbf{x} \|_2^2
$$

In [17]:
def loss(x):
  return jnp.sum(jnp.square(A@x - b))

By using the `jax` library, implement and compile functins returning the gradient ($\nabla \mathcal{J}(\mathbf{x})$) and the hessian ($\nabla^2 \mathcal{J}(\mathbf{x})$) of the loss function (*Hint*: use the `jacrev` or the `jacfwd`) function.

In [18]:
grad = jax.grad(loss)
hess = jax.jacfwd(jax.jacrev(loss))

loss_jit = jax.jit(loss)
grad_jit = jax.jit(grad)
hess_jit = jax.jit(hess)

Check that the results are correct (up to machine precision).

In [19]:
np.random.seed(0)
x_guess = np.random.randn(n)

G_ad = grad_jit(x_guess)
G_ex = 2 * A.T @ (A @ x_guess - b)
print(np.linalg.norm(G_ad - G_ex))

H_ad = hess_jit(x_guess)
H_ex = 2 * A.T @ A
print(np.linalg.norm(H_ad - H_ex))

2.1758506481219884e-12
5.159031549839615e-13


Exploit the formula
$$
\nabla^2 \mathcal{J}(\mathbf{x}) \mathbf{v} = \nabla_{\mathbf{x}} \phi(\mathbf{x}, \mathbf{v})
$$
where 
$$
\phi(\mathbf{x}, \mathbf{v}) := \nabla \mathcal{J}(\mathbf{x}) \cdot \mathbf{v}
$$
to write an optimized function returning the hessian-vector-product
$$
(\mathbf{x}, \mathbf{v}) \mapsto \nabla^2 \mathcal{J}(\mathbf{x}) \mathbf{v}.
$$
Compare the computational performance w.r.t. the full hessian computation.

In [20]:
np.random.seed(1)
v = np.random.randn(n)

hvp_basic = lambda x, v: hess(x) @ v
hvp = jax.grad(lambda x, v: jnp.dot(grad(x), v), argnums = 0)

hvp_basic_jit = jax.jit(hvp_basic)
hvp_jit = jax.jit(hvp)

Hv_ad = hvp_jit(x_guess, v)
Hv_ex = H_ex @ v
print(np.linalg.norm(Hv_ad - Hv_ex))

1.2669597472924609e-12


In [21]:
%timeit hvp_basic_jit(x_guess, v)
%timeit hvp_jit(x_guess, v)

258 µs ± 14.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
5.29 µs ± 120 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


Implement the Newton method for the minimization of the loss function $\mathcal{L}$. Set a maximim number of 100 iterations and a tolerance on the increment norm of $\epsilon = 10^{-8}$.

In [23]:
num_epochs = 100
tolerance = 1e-8

x = x_guess.copy()
for k in range(num_epochs):
  G = grad_jit(x)
  H = hess_jit(x)
  l = loss_jit(x)
  incr = np.linalg.solve(H, -G)
  x += incr
  print('=== epoch %d' % k)
  print('loss: %1.3e' % l)
  print('incr: %1.3e' % np.linalg.norm(incr))
  if np.linalg.norm(incr) < tolerance:
    break

=== epoch 0
loss: 3.739e+04
incr: 1.548e+01
=== epoch 1
loss: 6.545e-22
incr: 2.677e-09


Repeat the optimization loop for the loss function

$$
\mathcal{L}(\mathbf{x}) = \| \mathbf{b} - A \, \mathbf{x} \|_4^4
$$

In [24]:
def loss(x):
  return jnp.sum((A@x - b)**4)
grad = jax.grad(loss)
hess = jax.jacfwd(jax.jacrev(loss))

loss_jit = jax.jit(loss)
grad_jit = jax.jit(grad)
hess_jit = jax.jit(hess)

In [25]:
num_epochs = 100
tolerance = 1e-8

x = x_guess.copy()
for k in range(num_epochs):
  G = grad_jit(x)
  H = hess_jit(x)
  l = loss_jit(x)
  incr = np.linalg.solve(H, -G)
  x += incr
  print('=== epoch %d' % k)
  print('loss: %1.3e' % l)
  print('incr: %1.3e' % np.linalg.norm(incr))
  if np.linalg.norm(incr) < tolerance:
    break

=== epoch 0
loss: 2.433e+08
incr: 5.158e+00
=== epoch 1
loss: 4.806e+07
incr: 3.439e+00
=== epoch 2
loss: 9.493e+06
incr: 2.293e+00
=== epoch 3
loss: 1.875e+06
incr: 1.528e+00
=== epoch 4
loss: 3.704e+05
incr: 1.019e+00
=== epoch 5
loss: 7.317e+04
incr: 6.793e-01
=== epoch 6
loss: 1.445e+04
incr: 4.529e-01
=== epoch 7
loss: 2.855e+03
incr: 3.019e-01
=== epoch 8
loss: 5.639e+02
incr: 2.013e-01
=== epoch 9
loss: 1.114e+02
incr: 1.342e-01
=== epoch 10
loss: 2.200e+01
incr: 8.945e-02
=== epoch 11
loss: 4.346e+00
incr: 5.964e-02
=== epoch 12
loss: 8.585e-01
incr: 3.976e-02
=== epoch 13
loss: 1.696e-01
incr: 2.651e-02
=== epoch 14
loss: 3.350e-02
incr: 1.767e-02
=== epoch 15
loss: 6.617e-03
incr: 1.178e-02
=== epoch 16
loss: 1.307e-03
incr: 7.853e-03
=== epoch 17
loss: 2.582e-04
incr: 5.236e-03
=== epoch 18
loss: 5.100e-05
incr: 3.490e-03
=== epoch 19
loss: 1.007e-05
incr: 2.327e-03
=== epoch 20
loss: 1.990e-06
incr: 1.551e-03
=== epoch 21
loss: 3.931e-07
incr: 1.034e-03
=== epoch 22
loss: 7