#   Python seems to have numerical problem dealing with large numbers and highly ill-conditioned linear systems
#   The following code is an example
#   Given $x$, define a function $f(B) = |x^TBx|^2 \sum_{i=1}^2\lambda^2_i$.
#   We use jax to compute the Hessian of $f$ w.r.t. $B$.
#   In the implementation, we view $B$ as a vector.
#   We try different implementation of $f$ and test the accuracy of jax.hessian

In [5]:
import jax.numpy as np
import numpy as onp
from jax import hessian, vmap, grad, jacrev
import scipy
from jax.config import config
config.update("jax_enable_x64", True)

x = np.array([-2.        ,  4.89132319])
B = np.zeros((2, 2))
lbds = np.array([1.43588479e+06, 5.27642883e+03])

def f0(params):
    s = (x[0] ** 2 * params[0] + x[0] * x[1] * (params[1] + params[2]) + x[1] ** 2 * params[3]) ** 2
    return np.dot(lbds ** (3 / 2), lbds ** (3 / 2)) * s
def f1(params):
    r =  (lbds ** (3/2)) * (x.dot(np.reshape(params, (2, 2)).dot(x)))
    return np.dot(r, r)

def f2(params):
    B = np.reshape(params, (2, 2))
    return np.dot(lbds ** (3 / 2), lbds ** (3 / 2)) * np.dot(x, np.dot(B, x)) ** 2

def mtx_mul_vec(mtx, vec):
    l = len(vec)
    vec_ext = np.kron(np.eye(l), np.reshape(x, (-1, 1)))
    return np.dot(mtx, vec_ext)

def vec_mtx_vec(mtx, vec):
    l = len(vec)
    vec_ext = np.kron(np.eye(l), np.reshape(x, (-1, 1)))
    return np.dot(np.dot(vec, vec_ext.T), mtx)

def f3(params):
    return np.dot(lbds ** (3/2), lbds ** (3/2)) * (np.dot(x, mtx_mul_vec(params, x))) ** 2

def f4(params):
    return np.dot(lbds ** (3/2), lbds ** (3/2)) * vec_mtx_vec(params, x) ** 2

def Hessian():
    rho = np.dot(lbds ** (3/2), lbds ** (3/2))
    l = len(x)
    vec = np.reshape(x, (-1, 1))
    vec_ext = np.kron(np.eye(l), np.reshape(x, (-1, 1)))
    return 2 * rho * vec_ext @ vec @ vec.T @ vec_ext.T

In [6]:
Bvec = np.array(onp.random.randn(4))

print("Examine the accuracy of the functions")
print(f0(Bvec) - f1(Bvec))
print(f1(Bvec) - f2(Bvec))
print(f3(Bvec) - f4(Bvec))
print(f4(Bvec) - f0(Bvec))

H  = Hessian()
H0 = hessian(f0)(Bvec)
H1 = hessian(f1)(Bvec)
H2 = hessian(f2)(Bvec)
H3 = hessian(f3)(Bvec)
H4 = hessian(f4)(Bvec)

print("")
print("Examine the symmetry of the Hessian")
print(np.linalg.norm(H0 - H0.T))
print(np.linalg.norm(H1 - H1.T))
print(np.linalg.norm(H2 - H2.T))
print(np.linalg.norm(H3 - H3.T))
print(np.linalg.norm(H4 - H4.T))

print("")
print("Examine the accuracy of the Hessian")
print(np.linalg.norm(H0 - H))
print(np.linalg.norm(H1 - H))
print(np.linalg.norm(H2 - H))
print(np.linalg.norm(H3 - H))
print(np.linalg.norm(H4 - H))

print("")
print("Check the eigenvalues of H")
e, v = scipy.linalg.eigh(H)
print(e)
e, v = scipy.linalg.eigh(H4)
print(e)

Examine the accuracy of the functions
0.0
0.0
-65536.0
65536.0

Examine the symmetry of the Hessian
0.0
532416.9808862223
532416.9808862223
532416.9808862223
0.0

Examine the accuracy of the Hessian
747225.366592971
645454.7448985095
645454.7448985095
645454.7448985095
747225.366592971

Check the eigenvalues of H
[-5.24288000e+05  6.55360000e+04  1.57286400e+06  4.61717664e+21]
[1.04857600e+05 4.19430400e+05 1.04857600e+06 4.61717664e+21]
