In [None]:
"""
---------------------------------------------------------------------------------------------------------------------------------------------------------
This document contains code for calculating the Hessian using both a "manual" library-based approach. These were compared using the network ODEFunc, and 
appear to give identical results. The class Simple_Net is designed to allow comparison of these methods to analytical approaches.
---------------------------------------------------------------------------------------------------------------------------------------------------------
"""

In [None]:
%%bash
cd drive/MyDrive/colab_notebooks/calculating_hessians/testing_on_normal_nets/network_1

In [None]:
import torch
from torch import nn
import numpy as np
from torch.nn import Module
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt
import torch.optim as optim

device = torch.device('cuda:' + str(0) if torch.cuda.is_available() else 'cpu')

In [None]:
class ODEFunc(nn.Module):

    def __init__(self):
        super(ODEFunc, self).__init__()

        #Define a very simple neural network architecture with 1 hidden layer.
        self.net = nn.Sequential(
            nn.Linear(2, 50),
            nn.Tanh(),
            nn.Linear(50, 2),
        )

        #Initialise the weights and biases for the linear layers.
        #The isinstance functions checks that the first input is an instance or subclass of the second argument.
        
        for m in self.net.modules():
          if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0, std=0.1)
            nn.init.constant_(m.bias, val=0)
        
    #The forward function defines how the data is passed through the neural net. 
    #In particular, it is called when you apply the neural net to an input variable.
    #We act the net on y**3 such that it is only learning to represent the matrix (see class Lambda)

    def forward(self, y):
        return self.net(y**3)

func = ODEFunc()

In [None]:
"""
-----------------------------------------------------------------------------------------------------------------------------------------
This produces a symmetric square matrix that is the same as that obtained with the approach given below (that uses library functions).
-----------------------------------------------------------------------------------------------------------------------------------------
"""

def get_manual_hessian(grads, parameters):
  """
  Calculation of the Hessian using nested for loops.
  Inputs: 
    grads:      tuple of gradient tensors. Created using something like grads = torch.autograd.grad(loss, parameters, create_graph=True).
    parameters: list of parameter objects. Created using something like parameters = optimizer.param_groups[0]['params'].
  """
  start = time.time()                       #Begin timer.

  n_params = 0
  for param in parameters:
    n_params += torch.numel(param)
  grads2 = torch.zeros(n_params,n_params)             #Create an matrix of zeros thas has the same shape as the Hessian.

  y_counter = 0                                       #y_direction refers to row number in the Hessian.

  for grad in grads:
      grad = torch.reshape(grad, [-1])                                  #Rearrange the gradient information into a vector.        

      for j, g in enumerate(grad):
        x_counter = 0                                                   #x_direction refers to column number in the Hessian.

        for l, param in enumerate(parameters):
          g2 = torch.autograd.grad(g, param, retain_graph=True)[0]      #Calculate the gradient of an element of the gradient wrt one layer's parameters.
          g2 = torch.reshape(g2, [-1])                                  #Reshape this into a vector.
          len = g2.shape[0]                       
          grads2[j+y_counter, x_counter:x_counter+len] = g2             #Indexing ensures that the second order derivatives are placed in the correct positions.
          x_counter += len

      grads2 = grads2.to(device)
      y_counter += grad.shape[0]
  
  print('Time used is ', time.time() - start)

  return grads2

#Prepare optimizer and network.
input = torch.tensor([1.0,1.0])
target = torch.tensor([2.0,0.0])
optimizer = optim.RMSprop(func.parameters(), lr=1e-3) #func.parameters are the parameters to optimise.
optimizer.zero_grad()
parameters = optimizer.param_groups[0]['params']
pred_y = func(input)
loss = torch.linalg.norm(pred_y - target)
loss.backward(create_graph=True)

#Obtaining the gradients in this way is preferable (over torch.autograd.grad()) since it can be used after loss.backward().
grads = [0,0,0,0] 
for counter, param in enumerate(func.parameters()):
  grads[counter] = param.grad
grads = tuple(grads)

get_manual_hessian(grads, parameters)

In [None]:
"""
-------------------------------------------------------------------------------------------------------------------------------------
This approach works, and produces a 252 x 252 matrix. It is built using library functions and a nn.Module, which means 
it can more easily be used in the context of NODEs. This is the method that I have implemented during the training of 
simple NODEs.
Note: It requires tailoring to each neural network
-------------------------------------------------------------------------------------------------------------------------------------
"""

class Network(nn.Module):

  def __init__(self, a, b, c, d):
    super(Network, self).__init__()
    self.a = a
    self.b = b
    self.c = c
    self.d = d

  def forward(self, y):
    x = F.linear(y, self.a, self.b)
    m = nn.Tanh()
    x = m(x)
    x = F.linear(x, self.c, self.d)
    return x


def get_loss_square_2(params_vector):

  a = params_vector[0:100].reshape([50, 2])
  b = params_vector[100:150].reshape([50])
  c = params_vector[150:250].reshape([2, 50])
  d = params_vector[250:252].reshape([2])
  
  neural_net = Network(a, b, c, d).to(device)
  input = torch.tensor([1.0,1.0]).to(device)
  target = torch.tensor([2.0,0.0]).to(device)
  pred_y = neural_net(input)
  
  loss = torch.linalg.norm(pred_y - target) 
  return loss

def get_library_hessian(net):

  param_tensors = net.parameters()
  params_vector = torch.tensor([]).to(device)   
  for param in param_tensors:
    vec = torch.reshape(param, (-1,)).to(device)
    params_vector = torch.cat((params_vector, vec))
    

  hessian = torch.autograd.functional.hessian(get_loss_square_2, params_vector)
  return hessian

In [None]:
"""
--------------------------------------------------------------------------------------------------------------------------------------------
The following code is used to test some very simple examples that can be compared to analytical methods.
--------------------------------------------------------------------------------------------------------------------------------------------
"""

In [None]:
class Simple_Net(nn.Module):
    """
    Defines a very simple neural network that can be evaluated analytically.
    """

    def __init__(self):
        super(Simple_Net, self).__init__()

        #Define a very simple neural network architecture with no hidden layers.
        self.net = nn.Sequential(
            nn.Linear(2, 2),
            nn.Sigmoid()
        )

        #Initialise the weights and biases for the linear layers.
        #The isinstance functions checks that the first input is an instance or subclass of the second argument.
        
        for m in self.net.modules():
          if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0, std=0.1)
            nn.init.constant_(m.bias, val=0)
        
    #The forward function defines how the data is passed through the neural net. 
    #In particular, it is called when you apply the neural net to an input variable.
    #We act the net on y**3 such that it is only learning to represent the matrix (see class Lambda)

    def forward(self, y):
        return self.net(y)

simple_func = Simple_Net()

In [None]:
#Show the parameters of the network.
for param in simple_func.parameters():
  print(param)

In [None]:
#Data input and model set up.
input = torch.tensor([1.0,1.0])
target = torch.tensor([2.0,0.0])
optimizer = optim.RMSprop(simple_func.parameters(), lr=1e-3) #net.parameters are the parameters to optimise.
optimizer.zero_grad()
parameters = optimizer.param_groups[0]['params']
loss = torch.linalg.norm(simple_func(input) - target)
print("Loss is " + str(loss.item()))
print('Output is ' + str(simple_func(input)[0].item()) + ", " + str(simple_func(input)[1].item()))

grads = torch.autograd.grad(loss, parameters, create_graph=True)  # first order gradients
print("Gradients are: " + str(grads[0]))
hessian = get_manual_hessian(grads, parameters)  # calculate hessian

In [None]:
"""
-------------------------------------------------------------------------------------------------------------------------------------
This approach works, and produces a 252 x 252 matrix. It is built using library functions and a nn.Module, which means 
it can more easily be used in the context of NODEs. This is the method that I have implemented during the training of 
simple NODEs.
It is not quite "automated", since it still requires that you define the function with the correct number of parameters 'groups',
which is individual to each network.
-------------------------------------------------------------------------------------------------------------------------------------
"""

class Network(nn.Module):

  def __init__(self, a, b):
    super(Network, self).__init__()
    self.a = a
    self.b = b

  def forward(self, y):
    x = F.linear(y, self.a, self.b)
    m = nn.Tanh()
    x = m(x)
    return x


def get_simple_loss_square(params_vector):

  a = params_vector[0:4].reshape([2, 2])
  b = params_vector[4:6].reshape([2])
  
  neural_net = Network(a, b).to(device)
  input = torch.tensor([1.0,1.0]).to(device)
  target = torch.tensor([2.0,0.0]).to(device)
  pred_y = neural_net(input)
  
  loss = torch.linalg.norm(pred_y - target)**2
  return loss

def get_simple_library_hessian(net):

  param_tensors = net.parameters()
  params_vector = torch.tensor([]).to(device)   
  for param in param_tensors:
    vec = torch.reshape(param, (-1,)).to(device)
    params_vector = torch.cat((params_vector, vec))
    

  hessian = torch.autograd.functional.hessian(get_simple_loss_square, params_vector)
  return hessian