In [None]:
"""
---------------------------------------------------------------------------------------------------------------------------------------------------------
This document contains a "manual" method for calculating the Hessian, which works using nested for loops.
---------------------------------------------------------------------------------------------------------------------------------------------------------
"""

In [None]:
import torch
from torch import nn
import numpy as np
from torch.nn import Module
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt
import torch.optim as optim

device = torch.device('cuda:' + str(0) if torch.cuda.is_available() else 'cpu')

In [None]:
class ODEFunc(nn.Module):

    def __init__(self):
        super(ODEFunc, self).__init__()

        #Define a very simple neural network architecture with 1 hidden layer.
        self.net = nn.Sequential(
            nn.Linear(2, 50),
            nn.Tanh(),
            nn.Linear(50, 2),
        )

        #Initialise the weights and biases for the linear layers.
        #The isinstance functions checks that the first input is an instance or subclass of the second argument.
        
        for m in self.net.modules():
          if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0, std=0.1)
            nn.init.constant_(m.bias, val=0)
        
    #The forward function defines how the data is passed through the neural net. 
    #In particular, it is called when you apply the neural net to an input variable.
    #We act the net on y**3 such that it is only learning to represent the matrix (see class Lambda)

    def forward(self, y):
        return self.net(y**3)

func = ODEFunc()

In [None]:
"""
-----------------------------------------------------------------------------------------------------------------------------------------
This produces a symmetric square matrix that is the same as that obtained with the approach given below (that uses library functions).
-----------------------------------------------------------------------------------------------------------------------------------------
"""

def get_manual_hessian(grads, parameters):
  """
  Calculation of the Hessian using nested for loops.
  Inputs: 
    grads:      tuple of gradient tensors. Created using something like grads = torch.autograd.grad(loss, parameters, create_graph=True).
    parameters: list of parameter objects. Created using something like parameters = optimizer.param_groups[0]['params'].
  """
  start = time.time()                       #Begin timer.

  n_params = 0
  for param in parameters:
    n_params += torch.numel(param)
  grads2 = torch.zeros(n_params,n_params)             #Create an matrix of zeros thas has the same shape as the Hessian.

  y_counter = 0                             #y_direction refers to row number in the Hessian.

  for grad in grads:
      grad = torch.reshape(grad, [-1])                                  #Rearrange the gradient information into a vector.        

      for j, g in enumerate(grad):
        x_counter = 0                                                   #x_direction refers to column number in the Hessian.

        for l, param in enumerate(parameters):
          g2 = torch.autograd.grad(g, param, retain_graph=True)[0]      #Calculate the gradient of an element of the gradient wrt one layer's parameters.
          g2 = torch.reshape(g2, [-1])                                  #Reshape this into a vector.
          len = g2.shape[0]                       
          grads2[j+y_counter, x_counter:x_counter+len] = g2             #Indexing ensures that the second order derivatives are placed in the correct positions.
          x_counter += len

      grads2 = grads2.to(device)
      y_counter += grad.shape[0]
  
  print('Time used is ', time.time() - start)

  return grads2

#Prepare optimizer and network.
input = torch.tensor([1.0,1.0])
target = torch.tensor([2.0,0.0])
optimizer = optim.RMSprop(func.parameters(), lr=1e-3) #func.parameters are the parameters to optimise.
optimizer.zero_grad()
parameters = optimizer.param_groups[0]['params']
loss = torch.linalg.norm(func(input) - target)
loss.backward(create_graph=True)

#Obtaining the gradients in this way is preferable (over torch.autograd.grad()) since it can be used after loss.backward().
grads = [0,0,0,0] 
for counter, param in enumerate(func.parameters()):
  grads[counter] = param.grad
grads = tuple(grads)

#grads = torch.autograd.grad(loss, parameters, create_graph=True)  # first order gradients

hessian = get_manual_hessian(grads, parameters)