In [None]:
%%capture
"""
-----------------------------------------------------------------------------------------------------------------------------------------------------------
This code can be used to calculate a numerical approximation to the Hessian using the method of finite differences (MOFD). This allows comparison with other 
techniques developed thus far. It does so using 64-bit floating point arithmetic.
In contrast to the document used to draw detailed comparisons between the MOFD and the library-Hessian approach, this file computes the library-Hessian 
using 32-bit arithmetic. This allows it to be easily combined with the optimization process.
The code used in the function get_mofd_hessian() automatically converts all floats to 64-bits. This means that the code is slightly altered in comparison
that used int the document mofd_hessian_float_64.ipynb.
-----------------------------------------------------------------------------------------------------------------------------------------------------------
"""

In [None]:
%%capture
%%bash 
pip install torchdiffeq

In [None]:
"""
-----------------------------------------------------------------------------------------------------------------------------------------------------------
System set up.
-----------------------------------------------------------------------------------------------------------------------------------------------------------
"""
import os
import argparse
import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

parser = argparse.ArgumentParser()
parser.add_argument('--method', type=str, choices=['dopri5', 'adams'], default='dopri5')
parser.add_argument('--data_size', type=int, default=1000)
parser.add_argument('--batch_time', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=20)
parser.add_argument('--niters', type=int, default=100)
parser.add_argument('--test_freq', type=int, default=20)
parser.add_argument('--viz', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--adjoint', action='store_true')
parser.add_argument('--manual_hessian', action='store_true')
parser.add_argument('--library_hessian', action='store_true')
parser.add_argument('--mofd_hessian', action='store_true')
parser.add_argument('--hessian_freq', type=int, default=20)
args = parser.parse_args(args=[])

args.batch_size = 100
args.batch_time = 20
args.niters=60
args.test_freq=100
args.library_hessian = True
args.viz = True
args.hessian_freq = 100
args.method = 'dopri5'

adjoint = False

if adjoint == True:
    from torchdiffeq import odeint_adjoint as odeint
if adjoint == False:
    from torchdiffeq import odeint

device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

#Note that these are 32-bit floats.
true_y0 = torch.tensor([2.]).to(device)
t_0, t_1 = 0., 2.
t = torch.linspace(t_0, t_1, args.data_size).to(device)

In [None]:
"""
-----------------------------------------------------------------------------------------------------------------------------------------------------------
Obtain information about the true solution to the equation of motion.
-----------------------------------------------------------------------------------------------------------------------------------------------------------
"""
class Lambda(nn.Module):

    def forward(self, t, y):
        return torch.exp(t)

#True solution defines an exponential.
with torch.no_grad():
    #32-bit float.
    true_y = odeint(Lambda(), true_y0, t, method=args.method)

def get_batch():

    s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time, dtype=np.int64), args.batch_size, replace=False)) 
    batch_y0 = true_y[s]  # (M, D) 
    batch_t = t[:args.batch_time]  # (T)      
    batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0)  # (T, M, D)
    return batch_y0.to(device), batch_t.to(device), batch_y.to(device)

In [None]:
class ODEFunc(nn.Module):
    """
    Defines a neural net to parametrize the derivative of the hidden state in the NODE.
    """

    def __init__(self):
        super(ODEFunc, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(1, 1),
            nn.Tanh(),
            nn.Linear(1, 1),
        )

        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)

    def forward(self, t, y):
        return self.net(y)

#Parameters are 32-bit floats.
test_net = ODEFunc().to(device)

In [None]:
class Network(nn.Module):
  """
  Neural network that is used for Hessian calculation with library-function and MOFD approaches.
  The parameter groups are chosen to match those of ODEFunc().

  NB: The way that Network() is defined means that an instance of this class does not possess any parameter attributes.
  """

  def __init__(self, a, b, c, d):
    super(Network, self).__init__()
    self.a = a
    self.b = b
    self.c = c
    self.d = d

  def forward(self, t, y):
    x = F.linear(y, self.a, self.b)
    m = nn.Tanh()
    x = m(x)
    x = F.linear(x, self.c, self.d)
    return x

In [None]:
def get_loss(params_vector):
  """
  Obtains the loss according to the parameters of the NODE.
  NB: Each individual NODE architecture must be specified here, along with a loss function.
  """

  a = params_vector[:1].reshape([1, 1])
  b = params_vector[1:2].reshape([1])
  c = params_vector[2:3].reshape([1, 1])
  d = params_vector[3:4].reshape([1])
  e
  neural_net = Network(a, b, c, d).to(device)
  pred_y = odeint(neural_net, true_y0, t, method= args.method)
  loss = torch.mean(torch.abs(pred_y - true_y))
  return loss

def get_library_hessian(net):
  """
  Obtains the Hessian of the NODE using the autograd.functional.hessian() function.
  Inputs: 
        - net: the network for which the Hessian is to be calculated.
  NB: Each individual NODE architecture must be specified in the function get_loss(), such that
  the Hessian is calculated correctly.
  """

  param_tensors = net.parameters()
  params_vector = torch.tensor([]).to(device)
  for param in param_tensors:
    vec = torch.reshape(param, (-1,)).to(device)
    params_vector = torch.cat((params_vector, vec))

  hessian = torch.autograd.functional.hessian(get_loss, params_vector)
  return hessian

In [None]:
def get_mofd_hessian(p_vec, shapes, base_loss, h=1e-4, show_iters=False):
  """
  Calculates the full Hessian using the MOFD.
  Inputs: - p_vec: the parameters of the network organized into a vector.
          - shapes: a list of torch.Size() objects describing the shapes of each parameter group.
          - base_loss: loss of the unperturbed system. Used in calculating diagonal Hessian elements.
          - h: the size of the pertubation applied to each parameter.
          - show_iters: True or False according to whether the iteration number is to be displayed during calculation.

  NB: This function adapts to network architecture automatically.

  This code is designed to convert all floats to 64-bit automatically.
  """
  #List of integers detailing the number of elements in each parameter group.
  nels = [int(torch.prod(torch.tensor(shape))) for shape in shapes]
  nels = torch.tensor(nels)
  nels = torch.cumsum(nels, dim=0)
  nels = nels.tolist()

  #Empty tensors to store mofd info and perturbed parameters.
  up_pert_p_vec = torch.zeros_like(p_vec).double()
  low_pert_p_vec = torch.zeros_like(p_vec).double()

  up_up_pert_p_vec = torch.zeros_like(p_vec).double()
  up_low_pert_p_vec = torch.zeros_like(p_vec).double()
  low_up_pert_p_vec = torch.zeros_like(p_vec).double()
  low_low_pert_p_vec = torch.zeros_like(p_vec).double()

  w = len(p_vec)
  hessian = torch.zeros((w,w)).double()

  #The number of iterations required to compute a triangular block.
  counter = int(w*(w+1)/2)

  for i in range(len(p_vec)):
    
    #Versions of the parameter vector to be perturbed.
    for j in range(len(p_vec)):
      up_pert_p_vec[j] = p_vec[j]
      low_pert_p_vec[j] = p_vec[j]

    for k in range(len(p_vec)):
      
      #Calculate the diagonal elements.
      if k == i:
        up_pert_p_vec[k] += h
        low_pert_p_vec[k] -= h

        a_up = up_pert_p_vec[:nels[0]].reshape(shapes[0])
        b_up = up_pert_p_vec[nels[0]:nels[1]].reshape(shapes[1])
        c_up = up_pert_p_vec[nels[1]:nels[2]].reshape(shapes[2])
        d_up = up_pert_p_vec[nels[2]:nels[3]].reshape(shapes[3])

        a_low = low_pert_p_vec[:nels[0]].reshape(shapes[0])
        b_low = low_pert_p_vec[nels[0]:nels[1]].reshape(shapes[1])
        c_low = low_pert_p_vec[nels[1]:nels[2]].reshape(shapes[2])
        d_low = low_pert_p_vec[nels[2]:nels[3]].reshape(shapes[3])

        neural_net_up = Network(a_up, b_up, c_up, d_up).to(device)
        pred_y_up = odeint(neural_net_up, true_y0.double(), t.double(), method=args.method)
        pert_loss_up = torch.mean(torch.abs(pred_y_up - true_y.double())).double()

        neural_net_low = Network(a_low, b_low, c_low, d_low).to(device)
        pred_y_low = odeint(neural_net_low, true_y0.double(), t.double(), method=args.method)
        pert_loss_low = torch.mean(torch.abs(pred_y_low - true_y.double())).double()
        
        grad2 = ((pert_loss_up - 2*base_loss + pert_loss_low)/(h**2)).double()
        hessian[k,k] = grad2

        if show_iters:
          counter -=1
          print('\r' + str(counter) + ' iterations remaining.', end = '')

      #Calculate the off-diagonal elements.
      if k > i:
        
        #Vectors to be perturbed (there are 4 of these).
        #They must be created individually for each k so that previous iterations do not affect the parameter values.
        for l in range(len(p_vec)):
          up_up_pert_p_vec[l] = p_vec[l]
          up_low_pert_p_vec[l] = p_vec[l]
          low_up_pert_p_vec[l] = p_vec[l]
          low_low_pert_p_vec[l] = p_vec[l]

        up_up_pert_p_vec[i] += h
        up_up_pert_p_vec[k] += h

        up_low_pert_p_vec[i] += h
        up_low_pert_p_vec[k] -= h

        low_up_pert_p_vec[i] -= h
        low_up_pert_p_vec[k] += h

        low_low_pert_p_vec[i] -= h
        low_low_pert_p_vec[k] -= h

        a_up_up = up_up_pert_p_vec[:nels[0]].reshape(shapes[0])
        b_up_up = up_up_pert_p_vec[nels[0]:nels[1]].reshape(shapes[1])
        c_up_up = up_up_pert_p_vec[nels[1]:nels[2]].reshape(shapes[2])
        d_up_up = up_up_pert_p_vec[nels[2]:nels[3]].reshape(shapes[3])

        a_up_low = up_low_pert_p_vec[:nels[0]].reshape(shapes[0])
        b_up_low = up_low_pert_p_vec[nels[0]:nels[1]].reshape(shapes[1])
        c_up_low = up_low_pert_p_vec[nels[1]:nels[2]].reshape(shapes[2])
        d_up_low = up_low_pert_p_vec[nels[2]:nels[3]].reshape(shapes[3])

        a_low_up = low_up_pert_p_vec[:nels[0]].reshape(shapes[0])
        b_low_up = low_up_pert_p_vec[nels[0]:nels[1]].reshape(shapes[1])
        c_low_up = low_up_pert_p_vec[nels[1]:nels[2]].reshape(shapes[2])
        d_low_up = low_up_pert_p_vec[nels[2]:nels[3]].reshape(shapes[3])

        a_low_low = low_low_pert_p_vec[:nels[0]].reshape(shapes[0])
        b_low_low = low_low_pert_p_vec[nels[0]:nels[1]].reshape(shapes[1])
        c_low_low = low_low_pert_p_vec[nels[1]:nels[2]].reshape(shapes[2])
        d_low_low = low_low_pert_p_vec[nels[2]:nels[3]].reshape(shapes[3])

        neural_net_up_up = Network(a_up_up, b_up_up, c_up_up, d_up_up).to(device)
        pred_y_up_up = odeint(neural_net_up_up, true_y0.double(), t.double(), method=args.method)
        pert_loss_up_up = torch.mean(torch.abs(pred_y_up_up - true_y.double())).double()

        neural_net_up_low = Network(a_up_low, b_up_low, c_up_low, d_up_low).to(device)
        pred_y_up_low = odeint(neural_net_up_low, true_y0.double(), t.double(), method=args.method)
        pert_loss_up_low = torch.mean(torch.abs(pred_y_up_low - true_y.double())).double()

        neural_net_low_up = Network(a_low_up, b_low_up, c_low_up, d_low_up).to(device)
        pred_y_low_up = odeint(neural_net_low_up, true_y0.double(), t.double(), method=args.method)
        pert_loss_low_up = torch.mean(torch.abs(pred_y_low_up - true_y.double())).double()

        neural_net_low_low = Network(a_low_low, b_low_low, c_low_low, d_low_low).to(device)
        pred_y_low_low = odeint(neural_net_low_low, true_y0.double(), t.double(), method=args.method)
        pert_loss_low_low = torch.mean(torch.abs(pred_y_low_low - true_y.double())).double()
        
        #MOFD formula to estimate second order gradient.
        grad2 = ((pert_loss_up_up - pert_loss_up_low - pert_loss_low_up + pert_loss_low_low)/(4*h**2)).double()
        hessian[i,k] = grad2
        hessian[k,i] = grad2

        if show_iters:
          counter -=1
          print('\r' + str(counter) + ' iterations remaining.', end='')

      else:
        pass

  return hessian

In [None]:
lib_hess = get_library_hessian(test_net)
print('Library Hessian is:')
print(lib_hess)

Library Hessian is:
tensor([[-3.3839e-02, -1.6081e-02, -1.9672e+00,  2.3338e-02],
        [-1.6081e-02, -7.6203e-03, -9.8616e-01, -2.9410e-04],
        [-1.9672e+00, -9.8616e-01, -7.4903e-03, -3.5142e-02],
        [ 2.3336e-02, -2.9491e-04, -3.5140e-02,  6.3032e-04]], device='cuda:0')


In [None]:
#Model set up and loss function.
double_net = test_net.double()
optimizer = optim.RMSprop(double_net.parameters(), lr=1e-3)
optimizer.zero_grad()
pred_y = odeint(double_net, true_y0.double(), t.double(), method=args.method).double()
base_loss = torch.mean(torch.abs(pred_y-true_y))    #64-bit float.

#Create vector of parameters.
param_tensors = double_net.parameters()
params_vec = torch.tensor([]).to(device)
for param in param_tensors:
  vec = torch.reshape(param, (-1,)).to(device)
  params_vec = torch.cat((params_vec, vec)) 

#Example of how to prepare shapes for input into MOFD hessian function.
shapes = []
for param in double_net.parameters():
  shapes.append(param.shape)

mofd_hess = get_mofd_hessian(params_vec, shapes, base_loss, show_iters=True)
print('')
print('-----------------------------------------------------')
print('MOFD Hessian: ')
print(mofd_hess)
print('-----------------------------------------------------')
difference = torch.sum(torch.abs(mofd_hess.to(device)-lib_hess.to(device)))
print('Difference is ' + str(difference.item()))

In [None]:
for param in test_net.parameters():
  print(param.dtype)

torch.float32
torch.float32
torch.float32
torch.float32


In [None]:
double_net = test_net.to(torch.float32)
double_net = test_net.to(torch.float64)


for param in test_net.parameters():
  print(param.dtype)

double_net = test_net.to(torch.float32)


for param in test_net.parameters():
  print(param.dtype)




torch.float64
torch.float64
torch.float64
torch.float64
torch.float32
torch.float32
torch.float32
torch.float32
