In [6]:
import numpy as np
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

use_cuda = torch.cuda.is_available()
print("Cuda Available?  ", use_cuda)


Cuda Available?   True


In [7]:
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns

In [8]:
import math

In [9]:
def euler_ode_solver(z0, t0, t1, f):
    """
    Simplest Euler ODE Solver
    z0: value at intial state
    t0: intial state
    t1: final state to be calculated
    f: derivatve of a function with paramters z and p(aka parameter) with respect to t i.e f=dz/dt
    """
    h_max = .05 # Random Smallest Possible difference 
    n_steps = math.ceil((abs(t1-t0)/h_max).max().item()) # Based on distance between t1 and t0, number of steps to be taken with step length of h_max

    h = (t1 - t0)/n_steps # Approximated Smallest Possible Difference
    t = t0
    z = z0

    for _ in range(n_steps):
        z = z + h*f(z, t) # Euler Method: z_1 = z_0 + h*f(z, t)
        t = t + h # Updating t0 to reach t1 with step size of h
    
    # When we reach t1 from t0, its possible an ODE solver also makes the z0 to reach z1(the output)
    return z

In [10]:
class BaseODESolver(nn.Module):
    """
    Base Class for Parameters based ODE Solver
    """

    def forward_with_grad(self, z, t, grad_ouptuts):
        """
        A custome method(not a method of nn.Module)
        Helps with calculation of coefficients required for reverse-mode automatic differentiation
        This will calculate the following:
        - Vector field: df/dz
        - Rate of change of output with change in parameters: df/dp
        - Rate of change of output with change in layers: df/dt
        """
        batch_size = z.size(0)

        out = self.forward(z, t)
        a = grad_ouptuts

        # Calculate change in output with respect to z, t, p
        ## here a is a jacobian matrix, which is used for vector-jacobian multiplication
        adfdz, adfdt, *adfdp = torch.autograd.grad(
            (out,), (z, t) + tuple(self.parameters()),
            grad_ouptuts=(a),
            allow_unused=True, retain_graph=True
        )
        
        if adfdp is not None:
            adfdp = torch.cat([p_grad.flatten() for p_grad in adfdp]).unsqueeze(0)
            adfdp = adfdp.expand(batch_size, -1) / batch_size

        if adfdt is not None:
            adfdt = adfdt.expand(batch_size, 1) / batch_size
        return out, adfdz, adfdt, adfdp
    
    def flatten_parameters(self):
        p_shapes = []
        flat_parameters = []
        for p in self.parameters:
            p_shapes.append(p.size())
            flat_parameters.append(p.flatten())
        return torch.cat(flat_parameters)

The code below encapsulates the forward and backward passes of a *Neural ODE*. We need to separate it from the main `torch.nn.Module` because a custom backward function cannot be implemented inside a Module, but it can be implemented inside `torch.autograd.Function`. This is a simple workaround.

This function is fundamental to the entire Neural ODE method.

In [11]:
class BackPropagation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, z0, t, flat_parameters, ode):
        assert isinstance(ode, BaseODESolver)
        bs, *z_shape = z0.size()
        time_steps = t.size(0)

        with torch.no_grad():
            z = torch.zeros(time_steps, bs, *z_shape).to(z0)
            z[0] = z0

            for t_i in range(time_steps - 1):
                z0 = euler_ode_solver(z, t[t_i], t[t_i+1], ode)
                z[t_i + 1] = z0
        
        ctx.func = ode
        ctx.save_for_backward(t, z.clone(), flat_parameters)
        return z
    
    @staticmethod
    def backward(ctx, dLdz):
        """
        dLdz shape: time_steps, batch_size, *z_shape
        """
        ode = ctx.func
        t, z, flat_parameters = ctx.saved_tensors

        time_steps, bs, *z_shape = z.size()
        n_dim = np.prod(z_shape)
        n_params = flat_parameters.size(0)

        # Dynamics of augmented system to be calculated backwards in time
        def augmented_dynamics(aug_z_i, t_i):
            """
            Slices of Temporal tensors
            aug_z_i: bs, n_dim*2 + n_params + 1
            t_i: bs, 1
            """

            z_i, a = aug_z_i[:, :n_dim], aug_z_i[:, n_dim:2*n_dim]

            # Squeeze z and a
            z_i = z_i.view(bs, *z_shape)
            a = a.view(bs, *z_shape)
            with torch.set_grad_enabled(True):
                t_i = t_i.detach().requires_grad(True)
                z_i = z_i.detach().requires_grad(True)

                ode_output, adfdz, adfdt, adfdp = ode.forward_with_grad(z_i, t_i, grad_ouptuts=a)
                adfdz = adfdz.to(z_i) if adfdz is not None else torch.zeros(bs, *z_shape).to(z_i)
                adfdp = adfdp.to(z_i) if adfdp is not None else torch.zeros(bs, n_params).to(z_i)
                adfdt = adfdt.to(z_i) if adfdt is not None else torch.zeros(bs, 1).to(z_i)

            ode_output = ode_output.view(bs, n_dim)
            adfdz = adfdz.view(bs, n_dim)
            return torch.cat((ode_output, -adfdz, -adfdp, -adfdt), dim=1)
        
        dLdz = dLdz.view(time_steps, bs, n_dim)
        with torch.no_grad():
            # Create Placeholder for each gradient, with respect to their size
            adj_z = torch.zeros(bs, n_dim).to(dLdz)
            adj_p = torch.zeros(bs, n_params).to(dLdz)
            adj_t = torch.zeros(time_steps, bs, 1).to(dLdz)

            for i_t in range(time_steps-1, 0, -1):
                z_i = z[i_t]
                t_i = t[i_t]
                out_i = ode(z_i, t_i).view(bs, n_dim)

                # Compute direct gradients
                dLdz_i = dLdz[i_t]
                dLdt_i = torch.bmm(
                    torch.transpose(
                        dLdz_i.unsqueeze(-1), 
                        1,
                        2
                    ), out_i.unsqueeze(-1)
                )[:, 0]

                # Adjusting adjoints with direct gradients
                adj_z += dLdz_i
                adj_t[i_t] = adj_t[i_t] - dLdt_i

                # Concatenate augmented Variable
                aug_z = torch.cat(
                    (z_i.view(bs, n_dim), adj_z, torch.zeros(bs, n_params).to(z), adj_t[i_t]),
                    dim=-1
                )

                # Solve augmented system backwards
                aug_ans = euler_ode_solver(aug_z, t_i, t[i_t-1], augmented_dynamics)

                # Unpack solved backwards augmented system
                adj_z[:] = aug_ans[:, n_dim:2*n_dim]
                adj_p[:] += aug_ans[:, 2*n_dim:2*n_dim + n_params]
                adj_t[i_t-1] = aug_ans[:, 2*n_dim+n_params:]

                del aug_z, aug_ans

            ## Adjust 0 time adjoint with direct gradients
            # Compute direct gradients
            dLdz_0 = dLdz[0]
            dLdt = torch.bmm(
                torch.transpose(dLdz_0.unsqueeze(-1), 1, 2),
                out_i.unsqueeze(-1)
            )[:, 0]

            # Adjust adjoints
            adj_z += dLdz_0
            adj_t[0] = adj_t[0] - dLdt_0
        return adj_z.view(bs, *z_shape), adj_t, adj_p, None



Create a Class to use Backpropagation with NeuralODE using **nn.Module** for convenience.

In [12]:
class NeuralODE(nn.Module):
    def __init__(self, ode):
        super(NeuralODE, self).__init__()
        assert isinstance(ode, BaseODESolver)
        self.ode = ode

    def forward(self, z0, t=Tensor([0., 1.]), return_whole_sequence=False):
        t = t.to(z0)
        z = BackPropagation.apply(z0, t, self.ode.flatten_parameters(), self.ode)

        return z if return_whole_sequence else z[-1]