In [None]:
%%bash 
pip install torchdiffeq

In [None]:
import torch
from torch import nn
import numpy as np
from torch.nn import Module
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt
import torch.optim as optim

adjoint = True

if adjoint == True:
  from torchdiffeq import odeint_adjoint as odeint
if adjoint == False:
  from torchdiffeq import odeint

In [None]:
device = torch.device('cuda:' + str(0) if torch.cuda.is_available() else 'cpu')

true_y0 = torch.tensor([[2., 0.]]).to(device)
t = torch.linspace(0., 25., 1000).to(device)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]).to(device)

In [None]:
class ODEFunc(nn.Module):

    def __init__(self):
        super(ODEFunc, self).__init__()

        #Define a very simple neural network architecture with 1 hidden layer.
        self.net = nn.Sequential(
            nn.Linear(2, 50),
            nn.Tanh(),
            nn.Linear(50, 2),
        )

        #Initialise the weights and biases for the linear layers.
        #The isinstance functions checks that the first input is an instance or subclass of the second argument.
        
        for m in self.net.modules():
          if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0, std=0.1)
            nn.init.constant_(m.bias, val=0)
        
    #The forward function defines how the data is passed through the neural net. 
    #In particular, it is called when you apply the neural net to an input variable.
    #We act the net on y**3 such that it is only learning to represent the matrix (see class Lambda)

    def forward(self, y):
        return self.net(y**3)

func = ODEFunc()

In [None]:
"""
------------------------------------------------------------------------------------------------------------------------
This is a "manual" approach to calculating the Hessian, which uses nested for loops. It produces a symmetric square 
matrix that is the same as that obtained with the approach given below (that uses library functions).
------------------------------------------------------------------------------------------------------------------------
"""

def get_manual_hessian(grads, parameters):
  """
  Calculation of the Hessian using nested for loops.
  """
  start = time.time()                       #Begin timer.
  grads2 = torch.zeros(252,252)             #Create an matrix of zeros thas has the same shape as the Hessian.

  y_counter = 0                             #y_direction refers to row number in the Hessian.

  for grad in grads:
      grad = torch.reshape(grad, [-1])                                  #Rearrange the gradient information into a vector.        

      for j, g in enumerate(grad):
        x_counter = 0                                                   #x_direction refers to column number in the Hessian.

        for l, param in enumerate(parameters):
          g2 = torch.autograd.grad(g, param, retain_graph=True)[0]      #Calculate the gradient of an element of the gradient wrt one layer's parameters.
          g2 = torch.reshape(g2, [-1])                                  #Reshape this into a vector.
          len = g2.shape[0]                       
          grads2[j+y_counter, x_counter:x_counter+len] = g2             #Indexing ensures that the second order derivatives are placed in the correct positions.
          x_counter += len

      grads2 = grads2.to(device)
      y_counter += grad.shape[0]
  
  print('Time used is ', time.time() - start)

  return grads2

input = torch.tensor([1.0,1.0])
target = torch.tensor([2.0,0.0])
optimizer = optim.RMSprop(func.parameters(), lr=1e-3) #func.parameters are the parameters to optimise.
optimizer.zero_grad()
parameters = optimizer.param_groups[0]['params']
loss = torch.linalg.norm(func(input) - target)

grads = torch.autograd.grad(loss, parameters, create_graph=True)  # first order gradients

hessian = get_manual_hessian(grads, parameters)  # calculate hessian





In [None]:
"""
------------------------------------------------------------------------------------------------------------------------
This approach works, and produces a 252 x 252 matrix. It is built using library functions and, whilst
there is no guarantee that it gives the correct answer, can be used to check the more "manual" approach used above.
------------------------------------------------------------------------------------------------------------------------
"""

def get_loss_square(params_vector):

  a = params_vector[:100].reshape([50, 2])
  b = params_vector[100:150].reshape([50])
  c = params_vector[150:250].reshape([2, 50])
  d = params_vector[250:252].reshape([2])

  input = torch.tensor([1.0,1.0])
  y = torch.tensor([2.0,0.0])

  x = F.linear(input**3, a, b)
  m = nn.Tanh()
  x = m(x)
  x = F.linear(x, c, d)

  loss = torch.linalg.norm(y-x)
  return loss

def get_hessian(net):
  
  start = time.time()
  param_tensors = net.parameters()
  params_vector = torch.tensor([])
  for param in param_tensors:
    vec = torch.reshape(param, (-1,))
    params_vector = torch.cat((params_vector, vec))

  hessian = torch.autograd.functional.hessian(get_loss_square, params_vector)
  print('Time used is ', time.time() - start)
  return hessian

hessian = get_hessian(func)


#print(hessian[:5,:5])

eigenvalues, v = torch.symeig(hessian)

n, bins, patches = plt.hist(eigenvalues, bins=150)