In [1]:
import numpy as onp
import jax
import jax.numpy as np
import jax.flatten_util
import scipy
import logging

from jax_fem import logger
from hessian.hessp import incremental_forward_and_adjoint
from hessian.utils import tree_l2_norm_error

logger.setLevel(logging.DEBUG)

       __       ___      ___   ___                _______  _______ .___  ___. 
      |  |     /   \     \  \ /  /               |   ____||   ____||   \/   | 
      |  |    /  ^  \     \  V  /      ______    |  |__   |  |__   |  \  /  | 
.--.  |  |   /  /_\  \     >   <      |______|   |   __|  |   __|  |  |\/|  | 
|  `--'  |  /  _____  \   /  .  \                |  |     |  |____ |  |  |  | 
 \______/  /__/     \__\ /__/ \__\               |__|     |_______||__|  |__| 
                                                                              



# Problem statement

This tutorial solves the following problem:

Given a parameter vector $\boldsymbol{\theta}\in\mathbb{R}^2$, solve the residual equations $\boldsymbol{F}(\boldsymbol{u}, \boldsymbol{\theta})=\boldsymbol{0}$ to get the solution vector $\boldsymbol{u}\in\mathbb{R}^2$.

Therefore, $\boldsymbol{u}$ is a function of $\boldsymbol{\theta}$ implicitly, i.e., $\boldsymbol{u}(\boldsymbol{\theta})$.

Define some objective function $J(\boldsymbol{u},\boldsymbol{\theta})$, the goal is to find the Hessian-vector product, i.e., $\frac{\textrm{d}^2 J}{\textrm{d}\boldsymbol{\theta}^2} \hat{\boldsymbol{\theta}}$ for any vector $\hat{\boldsymbol{\theta}}\in\mathbb{R}^2$.

We will define the specific forms of the residual function $\boldsymbol{F}$ and the objective function $J$ in the subsequent codes.
Three methods to find Hessian-vector product will be presented and compared:

- **AD method**: Implicit differentiation approach with customized AD rules
- **FD method**: Finite difference approximate approach
- **AN method**: Analytical approach (available only for this particular simple problem)   


## Residual equations and adjoint method

The relationship between solution vector $\boldsymbol{u} = [u_0, u_1]$ and parameter vector $\boldsymbol{\theta} = [\theta_0, \theta_1]$ is defined by the residual equations:

$$
\boldsymbol{F}(\boldsymbol{u}, \boldsymbol{\theta}) = \mathbf{0} \quad \Rightarrow \quad
\begin{cases}
F_0(u, \theta) = \theta_0^2 u_0 + \theta_1 - 1 = 0 \\
F_1(u, \theta) = \theta_1^2 u_0^2 + \theta_1 u_1^2 - 2 = 0
\end{cases}
$$

- Newton's method is used to solve $\boldsymbol{F}(\boldsymbol{u}, \boldsymbol{\theta}) = \mathbf{0}$ for $\boldsymbol{u}$ with given $\boldsymbol{\theta}$.
- Adjoint method is used to get the adjoint vector $\boldsymbol{\lambda}$, which will be useful.

In [2]:
def forward_step(θ, u_init):
    def F_fn(u, θ):
        return np.array([θ[0]**2 * u[0] + θ[1] - 1, θ[1]**2 * u[0]**2 + θ[1] * u[1]**2 - 2])

    _, unflatten = jax.flatten_util.ravel_pytree(u_init)

    def u_fn(θ):
        # Newton solve
        tol = 1e-8
        max_iter = 1000
        u_flat, _ = jax.flatten_util.ravel_pytree(u_init)

        def flat_F_fn(u_flat):
            u = unflatten(u_flat)
            F = F_fn(u, θ)
            return jax.flatten_util.ravel_pytree(F)[0]

        # Main Newton loop
        for _ in range(max_iter):
            # Compute current residual
            F_flat = flat_F_fn(u_flat)
            residual_norm = np.linalg.norm(F_flat)

            print(f"res = {residual_norm}")
            if residual_norm < tol:
                break

            # Compute dense Jacobian
            J = jax.jacfwd(flat_F_fn)(u_flat)
            
            # Solve linear system (with small regularization for stability)
            Δu_flat = np.linalg.solve(J, -F_flat)
            
            # Update solution
            u_flat += Δu_flat

        return unflatten(u_flat)

    # Solve forward problem
    print(f"\n################## Solve forward problem...")
    u = u_fn(θ)
    print(f"################## End of forward problem\n")
   
    return u, F_fn


def adjoint_step(u, θ, J_fn, F_fn):
    # Solve adjoint problem
    _, unflatten = jax.flatten_util.ravel_pytree(u)

    def flat_F_fn(u_flat):
        u = unflatten(u_flat)
        F = F_fn(u, θ)
        return jax.flatten_util.ravel_pytree(F)[0]

    print(f"\n################## Solve adjoint problem...")
    u_flat, _ = jax.flatten_util.ravel_pytree(u)

    A = jax.jacfwd(flat_F_fn)(u_flat)

    λ_rhs = jax.grad(J_fn)(u, θ)
    λ_rhs_vec = jax.flatten_util.ravel_pytree(λ_rhs)[0]
    λ_vec = np.linalg.solve(A.transpose(), -λ_rhs_vec)
    λ = unflatten(λ_vec)
    print(f"################## End of adjoint problem\n")
    
    return λ, A


def forward_and_adjoint(θ, J_fn, u_init):
    u, F_fn = forward_step(θ, u_init)
    λ, A = adjoint_step(u, θ, J_fn, F_fn)
    return u, λ, F_fn, A

## AD method

We build a `HessVecProduct` class as a manager, which can be handy for optimization (e.g., interface with the `scipy` package). Optimization itself is not covered in this tutorial, since we only focus on Hessian-vector products here.

Note that the function `incremental_forward_and_adjoint` is called from the `hessian.hessp`.

In [3]:
class HessVecProduct:
    def __init__(self, u_init, J_fn):
        self.internal_vars = {'θ': None, 'u': None, 'λ': None, 'F_fn': None, 'A': None}
        self.J_fn = J_fn
        self.state_linear_solver = lambda A, b: np.linalg.solve(A, b)
        self.adjoint_linear_solver = lambda A, b: np.linalg.solve(A, b)
        
    def J(self, θ):
        u, F_fn = forward_step(θ, u_init)
        return self.J_fn(u, θ)

    def grad(self, θ):
        u, λ, F_fn, A = forward_and_adjoint(θ, self.J_fn, u_init)
        primals_out, f_vjp = jax.vjp(lambda θ: F_fn(u, θ), θ) # λ_i * (∂/∂θ_k)F_i
        vjp_θ, = f_vjp(λ)
        dJ_dθ = jax.grad(self.J_fn, argnums=1)(u, θ) # ∂J/∂θ_k
        vjp_result = jax.tree_util.tree_map(lambda x, y: x + y, vjp_θ, dJ_dθ) # dJ/dθ_k
        return vjp_result

    def hessp(self, θ, θ_hat):
        tol = 1e-8
        if (self.internal_vars['θ'] is None) or tree_l2_norm_error(self.internal_vars['θ'], θ) > tol:
            print(f"hessp needs to solve forward and adjoint problem...")
            u, λ, F_fn, A = forward_and_adjoint(θ, self.J_fn, u_init)
            self.internal_vars['θ'] = θ
            self.internal_vars['u'] = u
            self.internal_vars['λ'] = λ
            self.internal_vars['F_fn'] = F_fn
            self.internal_vars['A'] = A
        else:
            print(f"hessp does NOT need to solve forward and adjoint problem...")
            θ = self.internal_vars['θ']
            u = self.internal_vars['u']
            λ = self.internal_vars['λ']
            F_fn = self.internal_vars['F_fn']
            A = self.internal_vars['A']

        print(f"\n################## Solve incremental forward and adjoint problem...")
        dθ_dθ_J_θ_hat, profile_info = incremental_forward_and_adjoint(u, θ, λ, θ_hat, self.J_fn, F_fn, A, 
                                                                      self.state_linear_solver, self.adjoint_linear_solver, 
                                                                      option='rev_rev') # change 'option' for other modes
        print(f"################## End of incremental forward and adjoint problem\n")
        return dθ_dθ_J_θ_hat


## FD method

Finite difference gives approximate Hessian-vector products.

In [4]:
def finite_difference_hessp(hess_vec_prod, θ, θ_hat):
    h = 1e-6
    θ_minus = jax.tree_util.tree_map(lambda x, y: x - h*y, θ, θ_hat)
    θ_plus  = jax.tree_util.tree_map(lambda x, y: x + h*y, θ, θ_hat)
    value_plus = hess_vec_prod.grad(θ_plus)
    value_minus = hess_vec_prod.grad(θ_minus)
    dθ_dθ_J_θ_hat = jax.tree_util.tree_map(lambda x, y: (x - y)/(2*h), value_plus, value_minus)
    return dθ_dθ_J_θ_hat

## AN method

For this particular problem, it is easy to analytically obtain

$$
\boldsymbol{u}(\boldsymbol{\theta)} = 
\begin{bmatrix}
\frac{1 - \theta_1}{\theta_0^2} \\
\sqrt{\frac{2 - \left(\frac{1 - \theta_1}{\theta_0^2}\right)^2 \theta_1^2}{\theta_1}}
\end{bmatrix}
$$

Then, we can let JAX compute the Hessian-vector products for us.

In [5]:
def analytical_hessp(θ, θ_hat, J_fn):
    def u_fn(θ):
        return np.array([(1. - θ[1])/θ[0]**2, np.sqrt( (2. -  ((1. - θ[1])/θ[0]**2)**2 * θ[1]**2) / θ[1] )])

    def J(θ):
        return J_fn(u_fn(θ), θ)

    dθ_dθ_J_θ_hat = jax.jacrev(jax.grad(J))(θ) @ θ_hat
    return dθ_dθ_J_θ_hat

## Objective function

Let us define the objective function as

$$
J(\boldsymbol{u},\boldsymbol{\theta}) = u_0^3 + u_1^3 + \theta_0^3 + \theta_1^3 + (u_0^2 + u_1^2)(e^{\theta_0} + e^{\theta_1})
$$

In [6]:
def J_fn(u, θ):
    u_vec = jax.flatten_util.ravel_pytree(u)[0]
    θ_vec = jax.flatten_util.ravel_pytree(θ)[0]
    return np.sum(u_vec**3) + np.sum(θ_vec**3) + np.sum(u_vec**2) * np.sum(np.exp(θ_vec))

## Results

We show the three approaches and the corresponding Hessian-vector products.

In [10]:
θ = np.array([3., .2])
θ_hat = np.array([0.2, 0.3])
u_init = np.array([0.1, 0.1])

hess_vec_prod = HessVecProduct(u_init, J_fn)

# Automatic differentiation
print(f"\n\n************************************************************************")
print(f"Running AD-based approach to find Hessian-vector product...")
dθ_dθ_J_θ_hat_ad = hess_vec_prod.hessp(θ, θ_hat)
print(f"\nEnd of AD-based approach to find Hessian-vector product")
print(f"************************************************************************\n\n")

# Finite difference
print(f"\n\n************************************************************************")
print(f"Running FD-based approach to find Hessian-vector product...")
dθ_dθ_J_θ_hat_fd = finite_difference_hessp(hess_vec_prod, θ, θ_hat)
print(f"\nEnd of FD-based approach to find Hessian-vector product")
print(f"************************************************************************\n\n")

# Analytical
print(f"\n\n************************************************************************")
print(f"Running AN-based approach to find Hessian-vector product...")
dθ_dθ_J_θ_hat_an = analytical_hessp(θ, θ_hat, J_fn)
print(f"\nEnd of AN-based approach to find Hessian-vector product")
print(f"************************************************************************\n\n")

print(f"\nResults:")
print(f"AD = {dθ_dθ_J_θ_hat_ad} (automatic differentiation)")
print(f"FD = {dθ_dθ_J_θ_hat_fd} (finite difference)")
print(f"AN = {dθ_dθ_J_θ_hat_an} (analytical)")

[05-18 14:07:43][DEBUG] jax_fem: ################## Solve incremental forward problem...
[05-18 14:07:43][DEBUG] jax_fem: ################## Solve incremental adjoint problem...
[05-18 14:07:43][DEBUG] jax_fem: rev_rev: time elapsed for J-related evaluation is 0.0391595340000066




************************************************************************
Running AD-based approach to find Hessian-vector product...
hessp needs to solve forward and adjoint problem...

################## Solve forward problem...
res = 2.0001014374276123
res = 498.8451170370369
res = 124.21335426711808
res = 30.561338186930122
res = 7.171115420075269
res = 1.401865156171163
res = 0.14443609767126464
res = 0.0024324414958898366
res = 7.388146427977915e-07
res = 6.88338275267597e-14
################## End of forward problem


################## Solve adjoint problem...
################## End of adjoint problem


################## Solve incremental forward and adjoint problem...


[05-18 14:07:43][DEBUG] jax_fem: rev_rev: time elapsed for F-related evaluation is 0.06895638000000304
[05-18 14:07:43][DEBUG] jax_fem: ################## Find hessian-vector product...
[05-18 14:07:43][DEBUG] jax_fem: ################## Finshed using AD to find HVP.



################## End of incremental forward and adjoint problem


End of AD-based approach to find Hessian-vector product
************************************************************************




************************************************************************
Running FD-based approach to find Hessian-vector product...

################## Solve forward problem...
res = 2.000101454231843
res = 498.84436698668696
res = 124.21316675767527
res = 30.56129132148854
res = 7.1711037480425155
res = 1.4018623769967844
res = 0.1444356430226601
res = 0.0024324266989759202
res = 7.388056602053439e-07
res = 6.794564910705958e-14
################## End of forward problem


################## Solve adjoint problem...
################## End of adjoint problem


################## Solve forward problem...
res = 2.0001014206234675
res = 498.8458670896375
res = 124.2135417771237
res = 30.561385052512392
res = 7.171127092143186
res = 1.4018679353542702
res = 0.1444365523216904
res = 0.00243245