# Example of JAX's Automatic Differentiation for Computation of Gradient and Hessian Matrix

This document demonstrates how to use the JAX library to compute the gradient and Hessian matrix of a multivariable function and verify the correctness of the results. 

We use a simple quadratic objective function $f(x) = \|Ax - b\|^2$ as an example.

## Main Contents

1. **Objective Function Definition**: Define a multivariable quadratic function $f(x) = \|Ax - b\|^2$, where $A$ is a matrix, and $x$ and $b$ are vectors.
2. **Random Data Generation**: Randomly generate the matrix $A$, vectors $b$, and $x$, allowing the user to set their dimensions.
3. **Gradient Computation**: Use JAX's `grad` function to compute the gradient of the objective function.
4. **Hessian Matrix Computation**: Use JAX's `hessian` function to compute the Hessian matrix of the objective function.
5. **Result Verification**: Calculate the analytical form of the gradient and Hessian matrix and compare them with the results computed by JAX to verify their consistency.

The following code implementation and detailed comments will help readers understand each step's operations and principles.

In [1]:
import jax.numpy as jnp
from jax import grad, hessian
import numpy as np

# Define the dimensions of the matrix and vectors
m, n = 4, 3  # For example, A is a 4x3 matrix, b is a vector of length 4, and x is a vector of length 3

# Randomly generate the matrix A and vector b
A = jnp.array(np.random.randn(m, n))
b = jnp.array(np.random.randn(m))

# Randomly generate the test point x
x = jnp.array(np.random.randn(n))

# Define the objective function f(x) = ||Ax - b||^2
def objective_function(x):
    return jnp.sum((jnp.dot(A, x) - b) ** 2)

# Compute the gradient of the objective function using JAX's grad function
grad_function = grad(objective_function)

# Compute the Hessian matrix of the objective function using JAX's hessian function
hess_function = hessian(objective_function)

# Compute the analytical gradient
# The analytical gradient formula is 2 * A^T * (A * x - b)
analytical_grad = 2 * jnp.dot(A.T, jnp.dot(A, x) - b)

# Compute the gradient using JAX
jax_grad = grad_function(x)

# Compute the analytical Hessian matrix
# The analytical Hessian matrix formula is 2 * A^T * A
analytical_hess = 2 * jnp.dot(A.T, A)

# Compute the Hessian matrix using JAX
jax_hess = hess_function(x)

# Print the analytical gradient and the gradient computed by JAX
print(f"Analytical Gradient: {analytical_grad}")
print(f"JAX Gradient: {jax_grad}")

# Print the analytical Hessian and the Hessian computed by JAX
print(f"Analytical Hessian: \n{analytical_hess}")
print(f"JAX Hessian: \n{jax_hess}")

# Verify if the results are consistent
print('='*50)
assert jnp.allclose(analytical_grad, jax_grad, rtol=1e-05, atol=1e-06), "The gradients do not match!"
assert jnp.allclose(analytical_hess, jax_hess, rtol=1e-05, atol=1e-06), "The Hessians do not match!"
print("The gradients and Hessians match!")
print('='*50)


Analytical Gradient: [8.0490675 5.0572586 8.928235 ]
JAX Gradient: [8.0490675 5.0572586 8.928235 ]
Analytical Hessian: 
[[4.31184    0.53901273 1.7090865 ]
 [0.53901273 7.4374714  3.4608386 ]
 [1.7090865  3.4608386  5.548139  ]]
JAX Hessian: 
[[4.31184    0.53901273 1.7090865 ]
 [0.53901273 7.4374714  3.4608386 ]
 [1.7090865  3.4608386  5.548139  ]]
The gradients and Hessians match!


## Results Do Not Match?

This is likely due to numerical errors in floating-point calculations. 

Increasing the tolerance parameters in `allclose` can resolve this issue. These errors tend to increase as the scale of variables grows. However, theoretically, the computed results should be the same.

## Summary

The key point in using JAX is to use `jax.numpy` instead of `numpy` to describe the objective function. Other than that, there are no specific requirements.

## Another Example

Consider a more complex objective function, $\operatorname{tr}\left(A \cdot X^{\top} \cdot B \cdot X \cdot C\right)$, where A is a symmetric matrix. According to [Matrix Calculus](https://www.matrixcalculus.org/), the analytical expression for the gradient is:

$$
\frac{\partial}{\partial X}\left(\operatorname{tr}\left(A \cdot X^{\top} \cdot B \cdot X \cdot C\right)\right)=B \cdot X \cdot C \cdot A+B^{\top} \cdot X \cdot A \cdot C^{\top}
$$

In [2]:
import jax.numpy as jnp
from jax import grad
import numpy as np

# Define the dimensions of the matrices
n = 5

# Randomly generate matrices A, B, C, and ensure A is symmetric
# np.random.seed(0)  # To ensure reproducibility, set a random seed
A = np.random.randn(n, n)
A = (A + A.T) / 2  # Ensure A is symmetric
A = jnp.array(A)
B = jnp.array(np.random.randn(n, n))
C = jnp.array(np.random.randn(n, n))

# Randomly generate the test point X
X = jnp.array(np.random.randn(n, n))

# Define the objective function f(X) = tr(A * X' * B * X * C)
def objective_function(X):
    return jnp.trace(jnp.dot(A, jnp.dot(X.T, jnp.dot(B, jnp.dot(X, C)))))

# Compute the gradient of the objective function using JAX's grad function
grad_function = grad(objective_function)

# Compute the analytical gradient
# The analytical gradient formula is B * X * C * A + B.T * X * A * C.T
analytical_grad = jnp.dot(B, jnp.dot(X, jnp.dot(C, A))) + jnp.dot(B.T, jnp.dot(X, jnp.dot(A, C.T)))

# Compute the gradient using JAX
jax_grad = grad_function(X)

# Print the analytical gradient and the gradient computed by JAX
print(f"Analytical Gradient: \n{analytical_grad}")
print(f"JAX Gradient: \n{jax_grad}")

# Verify if the results are consistent
print('='*50)
assert jnp.allclose(analytical_grad, jax_grad, rtol=1e-05, atol=1e-05), "The gradients do not match!"
print("The gradients match!")
print('='*50)


Analytical Gradient: 
[[  3.7105083 -19.493462    8.29455   -15.67075    -1.8323956]
 [ -5.695397   13.323085   -3.8245463   3.06843     8.027785 ]
 [  3.8335967  -8.7685      7.812145  -22.5878     -9.239664 ]
 [  6.060854  -17.89624     1.8304543  15.615094   -5.914414 ]
 [ -6.886673   26.397444   -7.059568  -13.699003    5.1156425]]
JAX Gradient: 
[[  3.7105088 -19.493464    8.294551  -15.670748   -1.8323936]
 [ -5.695398   13.323086   -3.8245468   3.0684295   8.027785 ]
 [  3.8335967  -8.768501    7.8121448 -22.587801   -9.239666 ]
 [  6.060856  -17.89624     1.8304546  15.615095   -5.9144144]
 [ -6.886676   26.397442   -7.0595675 -13.699005    5.1156445]]
The gradients match!
