In [None]:
%%bash 
pip install torchdiffeq

In [None]:
#Libraries
import os
import argparse
import time
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
#This is used to set the arguments that can be passed when the file is run from the command line.
parser = argparse.ArgumentParser()
parser.add_argument('--method', type=str, choices=['dopri5', 'adams'], default='dopri5')
parser.add_argument('--data_size', type=int, default=1000)
parser.add_argument('--batch_time', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=20)
parser.add_argument('--niters', type=int, default=100)
parser.add_argument('--test_freq', type=int, default=20)
parser.add_argument('--viz', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--adjoint', action='store_true')
args = parser.parse_args(args=[])

In [None]:
#Decides which ODE solver to use according to whether the adjoint method is required.
#autograd is just using the chain rule backprop through the network and each module used will (usually implicitly) define a way to do so.
#This means that importing odeint_adjoint ensures that backpropagation is later done in the by defining an adjoint state, etc.

if args.adjoint:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint

In [None]:
#A torch.device is an object representing the device on which a torch.Tensor is or will be allocated.
#torch.device is a class, and the line of code below creates an instance of that class.

device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

In [None]:
#torch.tensor is a class which constructs a tensor with the data that is input.
#torch.tensor.to() is a method which performs tensor dtype and/or device conversion.
#torch.linspace creates a 1 dimensional tensor of evenly spaced values.

true_y0 = torch.tensor([[2., 0.]]).to(device)
t = torch.linspace(0., 25., args.data_size).to(device)
true_A = torch.tensor([[-0.1, 2.0], [-2.0, -0.1]]).to(device)

In [None]:
#torch.mm performs matrix multiplication of the 2 inputs. 
#Lambda is the function that defines the ODE, i.e. dy/dt = Lambda(y). It is clearly non-linear.

class Lambda(nn.Module):

    def forward(self, t, y):
        return torch.mm(y**3, true_A)

In [None]:
#torch.no_grad() disables gradient calculation, i.e. it disables the autograd engine.
#This will reduce memory usage and speed up computations but you won’t be able to backprop.
#This is useful for all tensors that don't require gradients.
#odeint() solves an ode up to a time, t.

#It's not obvious, but this true_y solution defines a spiral in the x-y plane (I verified this computationally).
with torch.no_grad():
    true_y = odeint(Lambda(), true_y0, t, method='dopri5')

In [None]:
#torch.from_numpy() creates a tensor from a numpy ndarray.
#torch.stack concatenates a sequence of tensors along a new dimension.

def get_batch():

    #Generates a random list of integers in the range (args.data_size - args.batch_time), of length (args.batch_size).
    s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time, dtype=np.int64), args.batch_size, replace=False)) 

    #Creates the random batch. batch_y will be our ground truth when optimising the neural net.
    batch_y0 = true_y[s]  # (M, D)
    batch_t = t[:args.batch_time]  # (T)
    batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0)  # (T, M, D)
    return batch_y0.to(device), batch_t.to(device), batch_y.to(device)


In [None]:
#This is used to make a new directory if the results of the experiment are to be saved.

def makedirs(dirname):
    if not os.path.exists(dirname):
        os.makedirs(dirname)

In [None]:
if args.viz:
    makedirs('png')
    import matplotlib.pyplot as plt
    fig = plt.figure(figsize=(12, 4), facecolor='white')  #'facecolor is the background colour.
    ax_traj = fig.add_subplot(131, frameon=False)         #add axes to the figure as part of the subplot arrangement.
    ax_phase = fig.add_subplot(132, frameon=False)
    ax_vecfield = fig.add_subplot(133, frameon=False)
    #plt.show(block=False)

In [None]:
def visualize(true_y, pred_y, odefunc, itr):
    """
    This function was created by the authors and is used to visualise our solution. It didn't quite seem to work for me, so I made a slight modification with 
    the function visualize_3().
    """

    if args.viz:

        ax_traj.cla()                       #Clear subplots.
        ax_traj.set_title('Trajectories')
        ax_traj.set_xlabel('t')
        ax_traj.set_ylabel('x,y')
        ax_traj.plot(t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 0], t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 1], 'g-')
        ax_traj.plot(t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 0], '--', t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 1], 'b--')
        ax_traj.set_xlim(t.cpu().min(), t.cpu().max())
        ax_traj.set_ylim(-2, 2)
        #ax_traj.legend()

        ax_phase.cla()
        ax_phase.set_title('Phase Portrait')
        ax_phase.set_xlabel('x')
        ax_phase.set_ylabel('y')
        ax_phase.plot(true_y.cpu().numpy()[:, 0, 0], true_y.cpu().numpy()[:, 0, 1], 'g-')
        ax_phase.plot(pred_y.cpu().numpy()[:, 0, 0], pred_y.cpu().numpy()[:, 0, 1], 'b--')
        ax_phase.set_xlim(-2, 2)
        ax_phase.set_ylim(-2, 2)

        ax_vecfield.cla()
        ax_vecfield.set_title('Learned Vector Field')
        ax_vecfield.set_xlabel('x')
        ax_vecfield.set_ylabel('y')

        y, x = np.mgrid[-2:2:21j, -2:2:21j]
        dydt = odefunc(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2)).to(device)).cpu().detach().numpy()
        mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)
        dydt = (dydt / mag)
        dydt = dydt.reshape(21, 21, 2)

        ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black")
        ax_vecfield.set_xlim(-2, 2)
        ax_vecfield.set_ylim(-2, 2)

        fig.tight_layout()
        plt.savefig('png/{:03d}'.format(itr))
        plt.draw()
        plt.pause(0.001)

In [None]:
def visualize_2(true_y, pred_y, odefunc, itr):

  """
  I made this function to aid help me understand why the visualize() function wasn't working. It isn't particularly useful for anything!
  """

  if args.viz:
    makedirs('png_2')
    import matplotlib.pyplot as plt
    fig, ax = plt.subplots()
    ax.plot(true_y.cpu().numpy()[:, 0, 0], true_y.cpu().numpy()[:, 0, 1], 'g-')
    ax.plot(pred_y.cpu().numpy()[:, 0, 0], pred_y.cpu().numpy()[:, 0, 1], 'b--')
    ax.set_xlim(-2, 2)
    ax.set_ylim(-2, 2)
    ax.set(xlabel='x', ylabel='y', title='Phase Portrait')
    plt.savefig('png_2/{:03d}'.format(itr))
    plt.close()
    




In [None]:
def visualize_3(true_y, pred_y, odefunc, itr):

  """
  This slightly altered version of the function visualize() seems to work fine. The only change is that I have moved the plt.figure() part of the code
  inside the function itself, i.e. I am creating a new figure environment for every figure, instead of editing the same environment multiple times.
  """

  if args.viz:

    fig = plt.figure(figsize=(12, 4), facecolor='white')  #'facecolor is the background colour.
    ax_traj = fig.add_subplot(131, frameon=False)         #add axes to the figure as part of the subplot arrangement.
    ax_phase = fig.add_subplot(132, frameon=False)
    ax_vecfield = fig.add_subplot(133, frameon=False)

    ax_traj.set_title('Trajectories')
    ax_traj.set_xlabel('t')
    ax_traj.set_ylabel('x,y')
    ax_traj.plot(t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 0], t.cpu().numpy(), true_y.cpu().numpy()[:, 0, 1], 'g-')
    ax_traj.plot(t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 0], '--', t.cpu().numpy(), pred_y.cpu().numpy()[:, 0, 1], 'b--')
    ax_traj.set_xlim(t.cpu().min(), t.cpu().max())
    ax_traj.set_ylim(-2, 2)

    ax_phase.set_title('Phase Portrait')
    ax_phase.set_xlabel('x')
    ax_phase.set_ylabel('y')
    ax_phase.plot(true_y.cpu().numpy()[:, 0, 0], true_y.cpu().numpy()[:, 0, 1], 'g-')
    ax_phase.plot(pred_y.cpu().numpy()[:, 0, 0], pred_y.cpu().numpy()[:, 0, 1], 'b--')
    ax_phase.set_xlim(-2, 2)
    ax_phase.set_ylim(-2, 2)

    ax_vecfield.set_title('Learned Vector Field')
    ax_vecfield.set_xlabel('x')
    ax_vecfield.set_ylabel('y')

    y, x = np.mgrid[-2:2:21j, -2:2:21j]
    dydt = odefunc(0, torch.Tensor(np.stack([x, y], -1).reshape(21 * 21, 2)).to(device)).cpu().detach().numpy()
    mag = np.sqrt(dydt[:, 0]**2 + dydt[:, 1]**2).reshape(-1, 1)
    dydt = (dydt / mag)
    dydt = dydt.reshape(21, 21, 2)

    ax_vecfield.streamplot(x, y, dydt[:, :, 0], dydt[:, :, 1], color="black")
    ax_vecfield.set_xlim(-2, 2)
    ax_vecfield.set_ylim(-2, 2)

    fig.tight_layout()
    plt.savefig('png/{:03d}'.format(itr))
    plt.draw()
    plt.pause(0.001)
    plt.close()

In [None]:
"""
------------------------------------------------------------------------------------------------------------------------------------------------------
It is from this point on that the details are not specific to this example. The method relates to how to implement a NODE and is therefore more 
important to understand.
------------------------------------------------------------------------------------------------------------------------------------------------------
"""

In [None]:
class ODEFunc(nn.Module):

    def __init__(self):
        super(ODEFunc, self).__init__()

        #Define a very simple neural network architecture with 1 hidden layer.
        self.net = nn.Sequential(
            nn.Linear(2, 50),
            nn.Tanh(),
            nn.Linear(50, 2),
        )

        #Initialise the weights and biases for the linear layers.
        #The isinstance functions checks that the first input is an instance or subclass of the second argument.

        for m in self.net.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0, std=0.1)
                nn.init.constant_(m.bias, val=0)

    #The forward function defines how the data is passed through the neural net. 
    #In particular, it is called when you apply the neural net to an input variable.
    #We act the net on y**3 such that it is only learning to represent the matrix (see class Lambda)

    def forward(self, t, y):
        return self.net(y**3)

In [None]:
class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val

In [None]:
if __name__ == '__main__':

    ii = 0

    func = ODEFunc().to(device)
    
    optimizer = optim.RMSprop(func.parameters(), lr=1e-3) #func.parameters are the parameters to optimise.
    end = time.time()

    time_meter = RunningAverageMeter(0.97)
    
    loss_meter = RunningAverageMeter(0.97)

    for itr in range(1, args.niters + 1):
        optimizer.zero_grad()                                 #Clears x.grad for every parameter x. It’s important to call this before loss.backward(), otherwise you’ll accumulate the gradients from multiple passes.
        batch_y0, batch_t, batch_y = get_batch()              #Data that we will use.
        pred_y = odeint(func, batch_y0, batch_t).to(device)   #Calculate output values from the NODE system. In other words, we compute a forward pass.
        loss = torch.mean(torch.abs(pred_y - batch_y))        #Calculate the loss.
        loss.backward()                                       #Calculates the gradient of the loss surface. These are accumulated into x.grad for every parameter x.
        optimizer.step()                                      #Updates the value of x using the gradient x.grad.  

        time_meter.update(time.time() - end)
        loss_meter.update(loss.item())

        #This essentially prints the loss, etc. at regular intervals. It does so by evaluating the predicted values over all time steps of the ODE
        #instead of just a batch sample. I assume this means it gives the exact value of the loss at that stage.

        if itr % args.test_freq == 0:
            with torch.no_grad():
                pred_y = odeint(func, true_y0, t)
                loss = torch.mean(torch.abs(pred_y - true_y))
                print('Iter {:04d} | Total Loss {:.6f}'.format(itr, loss.item()))
                visualize_3(true_y, pred_y, func, ii)
                ii += 1

        end = time.time()