# Quasi-Newton methods: BFGS

In [1]:
import numpy as np
import jax
import jax.numpy as jnp
import scipy.optimize
import scipy as sp
import matplotlib.pyplot as plt

# We enable double precision in JAX
from jax.config import config
config.update("jax_enable_x64", True)

Consider the [Rosenbrock function](https://en.wikipedia.org/wiki/Rosenbrock_function), that is minimized in $\mathbf{x} = (1,1,\dots,1)^T$:

$$\mathcal{L}(\mathbf{x}) = \sum_{i=1}^{N-1} [100 (x_{i+1} - x_i^2 )^2 + (1-x_i)^2]$$

In [2]:
def loss(x):
    return sum(100.0 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2)

Use `jax` to compute and compile the Rosenbrock function and its gradient.

In [3]:
loss_jit = jax.jit(loss)
grad_jit = jax.jit(jax.grad(loss, argnums=0))

Implement the BFGS method (with line search) for the minimization of the Rosenbrock function.
Set a maximum of 1000 epochs and a stopping tolerance on the gradient eucledian norm of $10^{-8}$. Employ an initial guess for $\mathbf{x}$ with random numbers in the interval $[0,2]$.

In [6]:
N = 100
max_epochs = 1000
tol = 1e-8

np.random.seed(0)
x = np.random.uniform(0, 2, N)
dx = grad_jit(x)
I = jnp.eye(N)
B_inv = I.copy()
history = list([loss_jit(x)])

epoch = 0
while epoch < max_epochs and np.linalg.norm(dx) > tol:
    epoch += 1
    # search direction
    dir = -(B_inv @ dx)
    # line search
    lr = sp.optimize.line_search(loss_jit, grad_jit, x, dir, maxiter=1000)[0] # lr respect automatically wolf condition, thanks scipy
    x_new = x + lr * dir # use + sign because you inverted the sign of dir (before)
    dx_new = grad_jit(x_new)
    # Sherman-Morrison update
    s = x_new - x
    y = dx_new - dx
    E = I - np.outer(y, s) / np.inner(y, s)
    B_inv = E.T @ B_inv @ E + np.outer(s, s) / np.inner(y, s)
    l = loss_jit(x_new)
    history.append(l)
    x = x_new
    dx = dx_new
    # print updates
    print(f"epoch {epoch}")
    print(f"loss: {l}")
    print(f"gradient norm: {np.linalg.norm(dx)}")

print(f"\n\nNorm of the difference between real solution and found solution: {np.linalg.norm(x - np.ones(x.shape))}")

epoch 1
loss: 6280.107001306367
gradient norm: 4530.527872576332
epoch 2
loss: 3546.6756039477614
gradient norm: 2669.8769399107296
epoch 3
loss: 3139.207524653895
gradient norm: 2314.7958658662646
epoch 4
loss: 2539.008813292769
gradient norm: 1821.56594629532
epoch 5
loss: 2202.563142547062
gradient norm: 1471.8908975532315
epoch 6
loss: 2010.1138228553934
gradient norm: 1367.0924362488195
epoch 7
loss: 1922.158432616825
gradient norm: 1258.1102578730665
epoch 8
loss: 1872.907222708664
gradient norm: 1252.859661396601
epoch 9
loss: 1833.9544431531815
gradient norm: 1201.1761340820083
epoch 10
loss: 1821.7746362871774
gradient norm: 1175.8313284017452
epoch 11
loss: 1801.7180391071145
gradient norm: 1153.8701641850214
epoch 12
loss: 1759.0696873448665
gradient norm: 1094.1213272292946
epoch 13
loss: 1720.2874311914754
gradient norm: 1058.6321362372305
epoch 14
loss: 1679.5276586090877
gradient norm: 1092.0510684027167
epoch 15
loss: 1625.7924188156742
gradient norm: 1048.6642791424238