In [None]:
"""
---------------------------------------------------------------------------------------------------------------------------------------------------------
This file is used to assess performance of "manual" hessian calculations with a loss containing the odeint() function from torchdiffeq.
---------------------------------------------------------------------------------------------------------------------------------------------------------
"""

In [None]:
%%bash 
pip install torchdiffeq

In [None]:
#Libraries
import torch
from torch import nn
import numpy as np
from torch.nn import Module
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt
import torch.optim as optim

device = torch.device('cuda:' + str(0) if torch.cuda.is_available() else 'cpu')

In [None]:
#Problem set up

true_y0 = torch.tensor([[2., 0.]]).to(device)
t = torch.linspace(0., 25., 1000).to(device)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]).to(device)

adjoint = False

if adjoint == True:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint

class Lambda(nn.Module):

    def forward(self, t, y):
        return torch.mm(y**3, true_A)

#It's not obvious, but this true_y solution defines a spiral in the x-y plane (I verified this computationally).
with torch.no_grad():
    true_y = odeint(Lambda(), true_y0, t, method='dopri5')    #Produces 1000 2D vectors, i.e. the evolution of y at 1000 points through time.
                                                              #This comes in the form of an array with shape [1000,1,2]            

In [None]:
batch_time = 10
batch_size = 20
data_size = 1000

def get_batch():

    #Generates a random list of integers in the range (data_size - batch_time), of length (batch_size).
    s = torch.from_numpy(np.random.choice(np.arange(data_size - batch_time, dtype=np.int64), batch_size, replace=False)) 

    #Creates the random batch. batch_y will be our ground truth when optimising the neural net.
    batch_y0 = true_y[s]  # (M, D)
    batch_t = t[:batch_time]  # (T)
    batch_y = torch.stack([true_y[s + i] for i in range(batch_time)], dim=0)  # (T, M, D)
    return batch_y0.to(device), batch_t.to(device), batch_y.to(device)

In [None]:
class ODEFunc(nn.Module):

    def __init__(self):
        super(ODEFunc, self).__init__()

        #Define a very simple neural network architecture with 1 hidden layer.
        self.net = nn.Sequential(
            nn.Linear(2, 50),
            nn.Tanh(),
            nn.Linear(50, 2),
        )

        #Initialise the weights and biases for the linear layers.
        #The isinstance functions checks that the first input is an instance or subclass of the second argument.
        
        for m in self.net.modules():
          if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0, std=0.1)
            nn.init.constant_(m.bias, val=0)
        
    #The forward function defines how the data is passed through the neural net. 
    #In particular, it is called when you apply the neural net to an input variable.
    #We act the net on y**3 such that it is only learning to represent the matrix (see class Lambda)

    def forward(self, t, y):
        return self.net(y**3)

In [None]:
def get_manual_hessian(grads, parameters):
  """
  Calculation of the Hessian using nested for loops.
  Inputs: 
    grads:      tuple of gradient tensors. Created using something like grads = torch.autograd.grad(loss, parameters, create_graph=True).
    parameters: list of parameter objects. Created using something like parameters = optimizer.param_groups[0]['params'].
  """
  start = time.time()                       #Begin timer.

  n_params = 0
  for param in parameters:
    n_params += torch.numel(param)
  grads2 = torch.zeros(n_params,n_params)             #Create an matrix of zeros thas has the same shape as the Hessian.

  y_counter = 0                                       #y_direction refers to row number in the Hessian.

  for grad in grads:
      grad = torch.reshape(grad, [-1])                                  #Rearrange the gradient information into a vector.        

      for j, g in enumerate(grad):
        x_counter = 0                                                   #x_direction refers to column number in the Hessian.

        for l, param in enumerate(parameters):
          g2 = torch.autograd.grad(g, param, retain_graph=True)[0]      #Calculate the gradient of an element of the gradient wrt one layer's parameters.
          g2 = torch.reshape(g2, [-1])                                  #Reshape this into a vector.
          len = g2.shape[0]                       
          grads2[j+y_counter, x_counter:x_counter+len] = g2             #Indexing ensures that the second order derivatives are placed in the correct positions.
          x_counter += len

      grads2 = grads2.to(device)
      y_counter += grad.shape[0]
      print("Gradients calculated for row number " + str(y_counter) + ".")
  
  print('Time used is ', time.time() - start)

  return grads2

In [None]:
func = ODEFunc().to(device)
batch_y0, batch_t, batch_y = get_batch()

optimizer = optim.RMSprop(func.parameters(), lr=1e-3) #func.parameters are the parameters to optimise.
optimizer.zero_grad()
pred_y = odeint(func, batch_y0, batch_t).to(device)

parameters = optimizer.param_groups[0]['params']
loss = torch.mean(torch.abs(pred_y-batch_y))
loss.backward(create_graph=True)

#Calculating the gradients like this allows you to do so after calling loss.backward(). This is not possible if they were calculated with autograd.grad().
grads = [0,0,0,0] 
for counter, param in enumerate(func.parameters()):
  grads[counter] = param.grad
grads = tuple(grads)

hessian = get_manual_hessian(grads, parameters)
optimizer.step()