Suppose we want to compute $H^{-1}\vec{v}  =\vec{x}$, where $H$ is a hessian of some function $f$.

*   One can perform calculation of Hessian first and multiply after.

*   Both Hessian and inversion are expensive tasks. 

*   Efficient methods like Conjugate Gradient (CG) allows to approximate inverse in relatively few iterations.

    This is the same as solving problem

$$H \vec{x} = \vec{v}$$

CG algorithm (for a problem $A \vec{x} = \vec{b}$) contains the following elements
*   $$\vec{d}_0 = \vec{r}_0 = \vec{b}_0 - A \vec{x}_0 \ \text{ (Initialization)}$$
*   $$\alpha_i = \frac{\vec{r}_i^T\vec{r}_i}{\vec{d}_i^T A \vec{d}_i} \ \text{ (Line search step size)}$$
*   $$\vec{x}_{i+1}  = \vec{x}_i  + \alpha_i \vec{d}_i \ \text{ (Line search)}$$
*   $$\vec{r}_{i+1} = \vec{r}_i - \alpha_i A\vec{d}_i \ \text{ (Beginning of this notebook)}$$
*   $$\beta_{i+1} = \frac{\vec{r}_{i+1}^T\vec{r}_{i+1}}{\vec{r}_{i}^T \vec{r}_{i}}\ \text{ (Orthogonalization coef)} $$
*   $$\vec{d}_{i+1} = \vec{r}_{i+1} + \beta_{i+1} \vec{d}_i \ \text{ (Orthogonalization)} $$

It contains multiple steps where Hessian $(A)$ is multiplied by a 'search direction' $\vec{d}$

Recall that Hessian can be viewed as a Jacobian of a Gradient (1 dim Jacobian)
$$H(\cdot) = J(\nabla(\cdot))$$
Thus we see that
$$H \vec{v} = J\big(\nabla f\big) \vec{v} = J\big(\nabla f \cdot \vec{v}\big) $$
So we can avoid forming NxN Hessian matrix but use Nx1 gradient vector and get a vector $H \vec{v}$ directly as second gradient

***
Example of a model $f(x,\theta)$ parametrized by 
$$\theta = \{w_1,w_2,\dots\}$$
$$f(x,\theta) = x_1 w_1^2 + x_1 w_2^2 + \dots = \sum_i^N x_i w_i^2$$
Hessian entry is
$$H_{i,j} = \partial_{w_j}\partial_{w_i} \sum_k^N x_k w_k^2 = \partial_{w_i} \sum_k^N x_k \frac{\partial }{\partial {w_j}}\bigg(w_k^2\bigg) =  \partial{w_i} \sum_k^N x_k  2 w_k \frac{\partial w_k}{\partial {w_j}} = 2 \ \partial{w_i} \sum_k^N w_k x_k \delta_{k,j} =  2 \ x_j  \ \frac{\partial w_j}{\partial {w_i}}= 2 x_j \delta_{i,j}$$
So Hessian is diagonal matrix with entires 2
$$H_{i,i} = 2 x_i$$


Of coures this is the same as computing jacobian of a gradient:
$$[\nabla f]_i = 2 x_i w_i$$
$$J([\nabla f]_i)_j = \partial_{w_j} (2 x_i w_i)$$

In [90]:
import torch
_ = torch.manual_seed(0)

params =  torch.arange(5, requires_grad=True, dtype=float)

def model(inp, weights):
    return torch.dot(inp, weights**2)

x = torch.ones_like(params)
f = model(x,params)
grad = torch.autograd.grad(f, params, create_graph=True)[0]
H = torch.stack([torch.autograd.grad(s, params, create_graph=True)[0] for s in grad])
H.data

tensor([[2., 0., 0., 0., 0.],
        [0., 2., 0., 0., 0.],
        [0., 0., 2., 0., 0.],
        [0., 0., 0., 2., 0.],
        [0., 0., 0., 0., 2.]], dtype=torch.float64)

### Check if trick works
Using this approach we calculate gradient once and for second gradient we save computational graph, which allows to reuse it after, for different $\vec{v}$
$$H \vec{v} = J\big(\nabla f\big) \vec{v} = J\big(\nabla f \cdot \vec{v}\big) $$

In [91]:
grad = torch.autograd.grad(f, params, create_graph=True)[0]
v = torch.randn(5, dtype=float)
print('Hv = J(Jf.v):',torch.equal(H @ v, torch.autograd.grad(grad @ v, params, create_graph=True)[0]))

Hv = J(Jf.v): True


In [93]:
def conjugate_gradient_hess(b, x0=None, max_iters = None, tol = 1e-10):

    grad = torch.autograd.grad(f, params, create_graph=True)[0]

    def Av(x):
        return torch.autograd.grad(grad @ x, params, create_graph=True)[0]
    
    x = torch.zeros_like(b) if x0 is None else x0  #init guess

    r           = b - Av(x)    # residual
    d           = r.clone()  # direction
    rr          = torch.dot(r,r)
    num_iters   = len(b) if max_iters is None else max_iters
    for _ in range(num_iters):
        Ad          = Av(d)
        step_size   = rr/ (d @ Ad)
        x           += step_size * d
        r           -= step_size * Ad
        rr_new      = torch.dot(r,r)
        if rr_new < tol: break
        d           = r + (rr_new/rr)*d
        rr          = rr_new.clone()
    return x.detach()

if __name__ == "__main__":
    torch.manual_seed(0)
    A = H.detach()
    b = torch.ones_like(params)
    
    def ff(x):
        return (0.5*x @ A @ x - torch.dot(b,x))
    
    x0 = torch.rand(size = b.shape, dtype=float)
    x = conjugate_gradient_hess(b,x0)
    with torch.no_grad():
        print(f'{x = }; {ff(x).item() = }; {A @ x - b}')

x = tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000], dtype=torch.float64); ff(x).item() = -1.25; tensor([0., 0., 0., 0., 0.], dtype=torch.float64)
