# Newton method

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import time
import jax.numpy as jnp
import jax

# We enable double precision in JAX
from jax.config import config
config.update("jax_enable_x64", True)

We consider a random matrix $A \in \mathbb{R}^{n\times n}$, with $n = 100$ and a random vector $\mathbf{x}_{\text{ex}} \in \mathbb{R}^n$.
We define then $\mathbf{b} = A \, \mathbf{x}_{\text{ex}}$.

In [3]:
n = 100

np.random.seed(0)
A = np.random.randn(n,n)
x_ex = np.random.randn(n)
b = A @ x_ex

Define the loss function

$$
\mathcal{L}(\mathbf{x}) = \| \mathbf{b} - A \, \mathbf{x} \|_2^2
$$

In [4]:
def loss(x):
    return jnp.sum(jnp.square(A @ x - b))

By using the `jax` library, implement and compile functins returning the gradient ($\nabla \mathcal{J}(\mathbf{x})$) and the hessian ($\nabla^2 \mathcal{J}(\mathbf{x})$) of the loss function (*Hint*: use the `jacrev` or the `jacfwd`) function.

In [5]:
grad = jax.grad(loss, argnums=0)
hess = jax.jacrev(jax.jacrev(loss))

loss_jit = jax.jit(loss)
grad_jit = jax.jit(grad)
hess_jit = jax.jit(hess)

Check that the results are correct (up to machine precision).

In [6]:
np.random.seed(0)
x_guess = np.random.randn(n)

G_ad = grad_jit(x_guess)
G_ex = 2 * A.T @ (A @ x_guess - b)
print(np.linalg.norm(G_ad - G_ex))

H_ad = hess_jit(x_guess)
H_ex = 2 * A.T @ A
print(np.linalg.norm(H_ad - H_ex))

2.191240070771156e-12
5.490834437635708e-13


Exploit the formula
$$
\nabla^2 \mathcal{J}(\mathbf{x}) \mathbf{v} = \nabla_{\mathbf{x}} \phi(\mathbf{x}, \mathbf{v})
$$
where 
$$
\phi(\mathbf{x}, \mathbf{v}) := \nabla \mathcal{J}(\mathbf{x}) \cdot \mathbf{v}
$$
to write an optimized function returning the hessian-vector-product
$$
(\mathbf{x}, \mathbf{v}) \mapsto \nabla^2 \mathcal{J}(\mathbf{x}) \mathbf{v}.
$$
Compare the computational performance w.r.t. the full hessian computation.

In [8]:
np.random.seed(1)
v = np.random.randn(n)

hvp_basic = lambda x, v: hess(x) @ v
phi = lambda x, v: grad(x) @ v
hvp = jax.grad(phi, argnums=0)

hvp_basic_jit = jax.jit(hvp_basic)
hvp_jit = jax.jit(hvp)

Hv_ad = hvp_jit(x_guess, v)
Hv_ex = H_ex @ v
print(np.linalg.norm(Hv_ad - Hv_ex))

1.255129127373075e-12


In [9]:
%timeit hvp_basic_jit(x_guess, v)
%timeit hvp_jit(x_guess, v)

177 µs ± 14.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
17.5 µs ± 106 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


Implement the Newton method for the minimization of the loss function $\mathcal{L}$. Set a maximim number of 100 iterations and a tolerance on the increment norm of $\epsilon = 10^{-8}$.

In [25]:
def loss(x):
    return jnp.sum(jnp.power(A @ x - b, 2))

grad = jax.grad(loss, argnums=0)
hess = jax.jacrev(jax.jacrev(loss))

loss_jit = jax.jit(loss)
grad_jit = jax.jit(grad)
hess_jit = jax.jit(hess)

epochs = range(100)
eps = 1e-7
w = np.random.normal(0, 1e6, x_ex.shape)
for epoch in epochs:
    H = hess_jit(w)
    g = grad_jit(w)
    if np.linalg.norm(g) <= eps:
        print(f"Early Stopping at epoch {epoch}")
        break
    w -= np.linalg.inv(H) @ g

np.linalg.norm(x_ex - w) # convergenza velocissima --> tipica dei metodi del secondo ordine

Early Stopping at epoch 2


5.811046573607706e-14

Repeat the optimization loop for the loss function

$$
\mathcal{L}(\mathbf{x}) = \| \mathbf{b} - A \, \mathbf{x} \|_4^4
$$

In [22]:
def loss(x):
    return jnp.sum(jnp.power(A @ x - b, 4))

grad = jax.grad(loss, argnums=0)
hess = jax.jacrev(jax.jacrev(loss))

loss_jit = jax.jit(loss)
grad_jit = jax.jit(grad)
hess_jit = jax.jit(hess)

epochs = range(100)
eps = 1e-7
w = np.random.normal(0, 1e5, x_ex.shape)
for epoch in epochs:
    H = hess_jit(w)
    g = grad_jit(w)
    if np.linalg.norm(g) <= eps:
        print(f"Early Stopping at epoch {epoch}")
        break
    w -= np.linalg.inv(H) @ g

np.linalg.norm(x_ex - w)

Early Stopping at epoch 54


0.0002935209989837215

In [29]:
def loss(x):
    return jnp.sum(jnp.absolute(A @ x - b)) # la norma 1 non va bene --> matrice hessiana non invertibile

grad = jax.grad(loss, argnums=0)
hess = jax.jacrev(jax.jacrev(loss))

loss_jit = jax.jit(loss)
grad_jit = jax.jit(grad)
hess_jit = jax.jit(hess)

epochs = range(100)
eps = 1e-7
w = np.random.normal(0, 1e5, x_ex.shape)
for epoch in epochs:
    H = hess_jit(w)
    g = grad_jit(w)
    if np.linalg.norm(g) <= eps:
        print(f"Early Stopping at epoch {epoch}")
        break
    w -= np.linalg.inv(H) @ g

np.linalg.norm(x_ex - w)

LinAlgError: Singular matrix