In [None]:
"""
---------------------------------------------------------------------------------------------------------------------------------------------------------
This file is used to assess performance of "manual" hessian calculations with a loss containing the odeint() function from torchdiffeq.
---------------------------------------------------------------------------------------------------------------------------------------------------------
"""

In [None]:
%%bash 
pip install torchdiffeq

In [3]:
import torch
from torch import nn
import numpy as np
from torch.nn import Module
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt
import torch.optim as optim

In [4]:
device = torch.device('cuda:' + str(0) if torch.cuda.is_available() else 'cpu')

In [5]:
#Problem set up

true_y0 = torch.tensor([[2., 0.]]).to(device)
t = torch.linspace(0., 25., 1000).to(device)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]).to(device)

adjoint = False

if adjoint == True:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint

class Lambda(nn.Module):

    def forward(self, t, y):
        return torch.mm(y**3, true_A)

#It's not obvious, but this true_y solution defines a spiral in the x-y plane (I verified this computationally).
with torch.no_grad():
    true_y = odeint(Lambda(), true_y0, t, method='dopri5')    #Produces 1000 2D vectors, i.e. the evolution of y at 1000 points through time.
                                                              #This comes in the form of an array with shape [1000,1,2]            

In [6]:
batch_time = 10
batch_size = 20
data_size = 1000

def get_batch():

    #Generates a random list of integers in the range (data_size - batch_time), of length (batch_size).
    s = torch.from_numpy(np.random.choice(np.arange(data_size - batch_time, dtype=np.int64), batch_size, replace=False)) 

    #Creates the random batch. batch_y will be our ground truth when optimising the neural net.
    batch_y0 = true_y[s]  # (M, D)
    batch_t = t[:batch_time]  # (T)
    batch_y = torch.stack([true_y[s + i] for i in range(batch_time)], dim=0)  # (T, M, D)
    return batch_y0.to(device), batch_t.to(device), batch_y.to(device)

In [7]:
class ODEFunc(nn.Module):

    def __init__(self):
        super(ODEFunc, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(2, 50),
            nn.Tanh(),
            nn.Linear(50, 2),
        )

        for m in self.net.modules():
          if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0, std=0.1)
            nn.init.constant_(m.bias, val=0)

    def forward(self, t, y):
        return self.net(y**3)

func = ODEFunc()

In [8]:
class Network(nn.Module):

  def __init__(self, a, b, c, d):
    super(Network, self).__init__()
    self.a = a
    self.b = b
    self.c = c
    self.d = d

  def forward(self, t, y):
    x = F.linear(y**3, self.a, self.b)
    m = nn.Tanh()
    x = m(x)
    x = F.linear(x, self.c, self.d)
    return x


def get_loss_square(params_vector):

  a = params_vector[:100].reshape([50, 2])
  b = params_vector[100:150].reshape([50])
  c = params_vector[150:250].reshape([2, 50])
  d = params_vector[250:252].reshape([2])
  
  neural_net = Network(a, b, c, d).to(device)
  pred_y = odeint(neural_net, batch_y0, batch_t)
  
  loss = torch.mean(torch.abs(pred_y - batch_y))
  return loss

def get_hessian(net):

  param_tensors = net.parameters()
  params_vector = torch.tensor([]).to(device)
  for param in param_tensors:
    vec = torch.reshape(param, (-1,)).to(device)
    params_vector = torch.cat((params_vector, vec))

  hessian = torch.autograd.functional.hessian(get_loss_square, params_vector)
  return hessian

In [11]:
func = ODEFunc().to(device)
batch_y0, batch_t, batch_y = get_batch()

optimizer = optim.RMSprop(func.parameters(), lr=1e-3) #func.parameters are the parameters to optimise.
optimizer.zero_grad()

parameters = optimizer.param_groups[0]['params']

hessian = get_hessian(func)