In [None]:
%%capture
"""
-----------------------------------------------------------------------------------------------------------------------------------------------------------
This file is used to compute the hessian during training for fitting simple 1D exponential ODEs.
It contains functionality to do so by using either "manual" or library-based approaches.
-----------------------------------------------------------------------------------------------------------------------------------------------------------
"""

In [None]:
%%capture
%%bash 
pip install torchdiffeq

In [None]:
import os
import argparse
import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--method', type=str, choices=['dopri5', 'adams'], default='dopri5')
parser.add_argument('--data_size', type=int, default=1000)
parser.add_argument('--batch_time', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=20)
parser.add_argument('--niters', type=int, default=100)
parser.add_argument('--test_freq', type=int, default=20)
parser.add_argument('--viz', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--adjoint', action='store_true')
parser.add_argument('--manual_hessian', action='store_true')
parser.add_argument('--library_hessian', action='store_true')
parser.add_argument('--hessian_freq', type=int, default=20)
args = parser.parse_args(args=[])

args.batch_size = 750
args.batch_time = 50
args.niters=6000
args.test_freq=100
args.library_hessian = True
args.manual_hessian = False
args.viz = True
args.hessian_freq = 100
args.method = 'dopri5'

In [None]:
#The technique only works when the adjoint method is not used. If it is used, the Hessian returned is a matrix of zeros.
adjoint = False

if adjoint == True:
    from torchdiffeq import odeint_adjoint as odeint
if adjoint == False:
    from torchdiffeq import odeint

device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

true_y0 = torch.tensor([2.]).to(device)
t_0, t_1 = 0., 2.
t = torch.linspace(t_0, t_1, args.data_size).to(device)

class Lambda(nn.Module):

    def forward(self, t, y):
        return torch.exp(t)

#The true solution defines an exponential.
with torch.no_grad():
    true_y = odeint(Lambda(), true_y0, t, method = args.method)

def get_batch():

    s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time, dtype=np.int64), args.batch_size, replace=False)) 
    batch_y0 = true_y[s]  # (M, D)    Random set of (100) values from true_y
    batch_t = t[:args.batch_time]  # (T)      #The first (10) values from t.
    batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0)  # (T, M, D)    Set of 20 lots of 10 sequential points in true_y.
    return batch_y0.to(device), batch_t.to(device), batch_y.to(device)

In [None]:
%%capture
def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)
        
if args.viz:
    makedirs('png')
    import matplotlib.pyplot as plt
    fig = plt.figure(figsize=(12, 4), facecolor='white')    

In [None]:
def visualize(true_y, pred_y, odefunc, itr):

  """
  This slightly altered version of the function visualize() seems to work fine. The only change is that I have moved the plt.figure() part of the code
  inside the function itself, i.e. I am creating a new figure environment for every figure, instead of editing the same environment multiple times.
  """

  if args.viz:

    fig = plt.figure(figsize=(12, 4), facecolor='white')  #facecolor is the background colour.
    plt.plot(t.cpu().numpy(), true_y.cpu().numpy(), 'g-', label='True_y')
    plt.plot(t.cpu().numpy(), pred_y.cpu().detach().numpy(), 'b--', label='Predicted y')
    plt.xlabel('t')
    plt.ylabel('y')
    plt.legend()
    
    #plt.savefig('png/{:03d}'.format(itr))
    plt.draw()
    plt.pause(0.001)
    plt.close()

In [None]:
class ODEFunc(nn.Module):
    """
    Defines a very simple neural net with a 1D latent space. It has 4 parameters (2 weights and 2 biases).
    There is a Tanh() activation function used on the hidden layer.
    """

    def __init__(self):
        super(ODEFunc, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(1, 1),
            nn.Tanh(),
            nn.Linear(1, 1),
        )

        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)

    def forward(self, t, y):
        return self.net(y)

func = ODEFunc().to(device)

In [None]:
"""
------------------------------------------------------------------------------------------------------------------------
This code can be used to compare the computationally-obtained gradient values to those obtained analytically.
See journal entry from 26/02 for more details.
------------------------------------------------------------------------------------------------------------------------
"""

output = odeint(func, true_y0, t)[-1]  
print(output.item())
target = true_y[-1] 
print(target.item())                             

optimizer = optim.RMSprop(func.parameters(), lr=1e-3) #func.parameters are the parameters to optimise.
optimizer.zero_grad()

loss = torch.linalg.norm(output-target)**2
print('Loss is: ' + str(loss.item()))

loss.backward()

print('-------------------\n' + 'Gradients are:')
for param in func.parameters():
  print(param.grad.data.item())

In [None]:
class Network(nn.Module):

  def __init__(self, a, b, c, d):
    super(Network, self).__init__()
    self.a = a
    self.b = b
    self.c = c
    self.d = d

  def forward(self, t, y):
    x = F.linear(y, self.a, self.b)
    m = nn.Tanh()
    x = m(x)
    x = F.linear(x, self.c, self.d)
    return x


def get_loss_square(params_vector):

  a = params_vector[:1].reshape([1, 1])
  b = params_vector[1:2].reshape([1])
  c = params_vector[2:3].reshape([1, 1])
  d = params_vector[3:4].reshape([1])
  
  neural_net = Network(a, b, c, d).to(device)
  pred_y = odeint(neural_net, true_y0, t, method= args.method)
  loss = torch.mean(torch.abs(pred_y - true_y))
  return loss

def get_library_hessian(net):

  param_tensors = net.parameters()
  params_vector = torch.tensor([]).to(device)
  for param in param_tensors:
    vec = torch.reshape(param, (-1,)).to(device)
    params_vector = torch.cat((params_vector, vec))

  hessian = torch.autograd.functional.hessian(get_loss_square, params_vector)
  return hessian

In [None]:
def get_manual_hessian(grads, parameters):
  """
  Calculation of the Hessian using nested for loops.
  Inputs: 
    grads:      tuple of gradient tensors. Created using something like grads = torch.autograd.grad(loss, parameters, create_graph=True).
    parameters: list of parameter objects. Created using something like parameters = optimizer.param_groups[0]['params'].
  """
  start = time.time()                       #Begin timer.

  n_params = 0
  for param in parameters:
    n_params += torch.numel(param)
  grads2 = torch.zeros(n_params,n_params)             #Create an matrix of zeros thas has the same shape as the Hessian.

  y_counter = 0                             #y_direction refers to row number in the Hessian.

  for grad in grads:
      grad = torch.reshape(grad, [-1])                                  #Rearrange the gradient information into a vector.        

      for j, g in enumerate(grad):
        x_counter = 0                                                   #x_direction refers to column number in the Hessian.

        for l, param in enumerate(parameters):
          g2 = torch.autograd.grad(g, param, retain_graph=True)[0]      #Calculate the gradient of an element of the gradient wrt one layer's parameters.
          g2 = torch.reshape(g2, [-1])                                  #Reshape this into a vector.
          len = g2.shape[0]                       
          grads2[j+y_counter, x_counter:x_counter+len] = g2             #Indexing ensures that the second order derivatives are placed in the correct positions.
          x_counter += len

      grads2 = grads2.to(device)
      y_counter += grad.shape[0]
      print("Gradients calculated for row number " + str(y_counter) + ".")
  
  print('Time used was ', time.time() - start)

  return grads2

In [None]:
if __name__ == '__main__':
    """
    Executes the programme. This includes doing the following:

      - Trains the network;
      - Outputs the results in a series of png files (if desired);
      - Outputs hessian matrix information in list form.
    """

    ii = 0

    func = ODEFunc().to(device)
    
    optimizer = optim.RMSprop(func.parameters(), lr=1e-3) #func.parameters are the parameters to optimise.

    #Lists in which to store hessian data.
    #These will be lists of tuples like (iteration number, time, loss, hessian data).
    manual_hessian_data = []
    library_hessian_data = []
    loss_data = []

    for itr in range(1, args.niters + 1):
        optimizer.zero_grad()                                 
        batch_y0, batch_t, batch_y = get_batch()             
        pred_y = odeint(func, batch_y0, batch_t, method = args.method)
        loss = torch.mean(torch.abs(pred_y - batch_y))        
        loss.backward(create_graph=True)                                                                     
        
        if itr % args.hessian_freq == 0 or itr==1:
          if args.library_hessian:
            print('Obtaining library hessian...')
            library_start = time.time()
            library_hessian = get_library_hessian(func)                       #get hessian with library functions   
            library_end = time.time()
            print("Time taken for library-based approach was " + str(round(library_end-library_start,2)) + "s.")
            library_hessian_data.append((itr, library_end-library_start, loss.item(), library_hessian))

          if args.manual_hessian:
            
            print('Obtaining manual hessian...')
            manual_start = time.time()

            pred_y = odeint(func, true_y0, t, method = args.method)
            loss = torch.mean(torch.abs(pred_y - true_y))
            grads = torch.autograd.grad(loss, func.parameters(), create_graph=True)
            parameters = optimizer.param_groups[0]['params']
          
            manual_hessian = get_manual_hessian(grads, parameters)           #get hessian with manual approach.
            manual_end = time.time()
            print("Time taken for manual approach was " + str(round(manual_end-manual_start,2)) + "s.")
            manual_hessian_data.append((itr, manual_end-manual_start, loss.item(), manual_hessian))
      
        if itr % args.test_freq == 0:
          ii += 1       
          with torch.no_grad():
              pred_y = odeint(func, true_y0, t, method= args.method)
              loss = torch.mean(torch.abs(pred_y - true_y))
              loss_data.append((itr, loss.item()))
              print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
              visualize(true_y, pred_y, func, ii)
        """
        else:
          if itr % 50 == 0 or itr==1:
            if args.library_hessian:
              print('Obtaining library hessian...')
              library_start = time.time()
              library_hessian = get_library_hessian(func)                       #get hessian with library functions   
              library_end = time.time()
              print("Time taken for library-based approach was " + str(round(library_end-library_start,2)) + "s.")
              library_hessian_data.append((itr, library_end-library_start, loss.item(), library_hessian))
        
          if itr % 50 == 0:
            ii += 1       
            with torch.no_grad():
                pred_y = odeint(func, true_y0, t)
                loss = torch.mean(torch.abs(pred_y - true_y))
                loss_data.append((itr, loss.item()))
                print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
                visualize(true_y, pred_y, func, ii)
        """  
              

        optimizer.step()

In [None]:
#Create a plot of the loss curve.

itrs = []
data = []

for item in loss_data:
  itrs.append(item[0])
  data.append(item[1])

plt.figure(figsize=(10,8))
plt.rcParams.update({'font.size': 14})
plt.plot(itrs, data)
plt.title('Loss function for ' + str(args.method) + ' solver\nBatch Size, Time = ' + str(args.batch_size) + ', ' + str(args.batch_t) )
plt.xlabel('Iterations')
plt.ylabel('Loss')
#plt.savefig('/content/drive/MyDrive/colab_notebooks/calculating_hessians/testing_on_simple_nodes/exponential_curve/batch_size_investigation/batch_size_' 
 #           + str(args.batch_size) + '/loss_curve_' + str(args.batch_size) + '.png')
plt.show()

In [None]:
#Create histogram plots of the eigenvalue density.

for item in library_hessian_data:
  e, v = torch.symeig(item[3])
  plt.hist(e.cpu().numpy(), bins=150)
  plt.title("Iteration: " + str(item[0]))
  plt.xlabel('Eigenvalue')
  plt.ylabel('Density')
  #plt.savefig('/content/drive/MyDrive/colab_notebooks/calculating_hessians/testing_on_simple_nodes/exponential_curve/batch_size_investigation/batch_size_' 
              #+ str(args.batch_size) + '/eigenvalue_density_plots_' + str(args.batch_size) + '/eigenvalue_density_' + str(args.batch_size) + '_'
              #+ str(item[0]) + '.png')
  plt.show()

In [None]:
#Create a plot of the extremal eigenvalue through training.

itrs = []
values = []
for item in library_hessian_data:
  itrs.append(item[0])
  e, v = torch.symeig(item[3])
  value = max(e, key=abs)
  values.append(value)

plt.figure(figsize=(10,8))
plt.rcParams.update({'font.size': 14})
plt.plot(itrs, values)
plt.title('Extremal Eigenvalues for ' + str(args.method).title() + ' Solver\nBatch Size = ' + str(args.batch_size))
plt.xlabel('Iterations')
plt.ylabel('Extremal Eigenvalue')
#plt.savefig('/content/drive/MyDrive/colab_notebooks/calculating_hessians/testing_on_simple_nodes/exponential_curve/batch_size_investigation/batch_size_' 
           # + str(args.batch_size) + '/extremal_eigenvalues_' + str(args.batch_size) + '.png')
plt.show()