While the code is written on my own after understanding the background, the logical structure (and some function names) of the entire implementation follows implementations found on the internet, especially msurtsukov's implementation on the Github repo "neural-ode"

In [2]:
#import modules
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch 
from torch import Tensor
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable

In [3]:
# lets check if cuda is available
use_cuda = torch.cuda.is_available()
print(use_cuda)

False


We can use any ODE solver we wish, but let us try the simple Euler's method first. Euler's method is good enough for simulating our spiral (archimedean spiral), but
when trying to simulate other dynamical systems we will use a better method like RK4 (Runge-Kutta) or Dopri5!

In [4]:
#ODE initial value problem solver - Euler's method 
def ode_solve(z0,t0,t1,f):
    h_max = 0.05
    n_steps = int(np.ceil(abs(t1 - t0) / h_max).max())
    h = (t1 - t0) / n_steps
    t = t0
    z = z0

    for i_step in range(n_steps):
        z = z + h * f(z, t)
        t = t + h

    return z

In [5]:
#example using Euler ODE solver
def f(z,t):
    return -z
ode_solve(1,0,1,f)

0.3584859224085422

In [6]:
#Class of parameterized dynamics
##Adjoints calculated with PyTorch
class ODEF(nn.Module):
    def __init__(self):
        super().__init__()
        self.t = None
    
    def forward(self, z):
        raise NotImplementedError('`forward` method must be implemented in subclass.')
        
    def flatten_parameters(self):
        return torch.cat([p.flatten() for p in self.parameters()])
    
    def forward_with_grad(self, z, t, grad_outputs):
        batch_size = z.shape[0]
        
        if self.t is None or not torch.allclose(t, self.t):
            self.t = t
            self.out = self.forward(z)
        
        a = grad_outputs
        adfdz, adfdt, *adfdp = torch.autograd.grad(
            (self.out,), (z,) + tuple(self.parameters()) + (t,),
            grad_outputs=(a), allow_unused=True, retain_graph=True
        )
        
        if adfdp is not None:
            adfdp = torch.cat([p_grad.flatten() for p_grad in adfdp]).unsqueeze(0)
            adfdp = adfdp.expand(batch_size, -1) / batch_size
        if adfdt is not None:
            adfdt = adfdt.expand(batch_size, 1) / batch_size
        
        return self.out, adfdz, adfdt, adfdp


In [7]:
##main adjoint method
#This adjoint method is almost the same as the one used by MSurtsukov, which inturn is pretty similar to the original authors implementation in torchdiffeq 
#In the original code, ODEadjoint implementation is written in cpp, which is supposedly more efficient. 

class ODEAdjoint(torch.autograd.Function):
    @staticmethod
    def forward(ctx, z0, t, flat_parameters, func):  ##forward pass
        assert isinstance(func, ODEF)
        bs, *z_shape = z0.shape
        time_len = t.shape[0]

        with torch.no_grad():
            z = torch.zeros((time_len, bs, *z_shape), device=z0.device, dtype=z0.dtype)
            z[0] = z0
            for i in range(time_len - 1):
                z0 = ode_solve(z0, t[i], t[i+1], func)
                z[i+1] = z0

        ctx.func = func
        ctx.save_for_backward(t, z.clone(), flat_parameters)
        return z

    @staticmethod  #make these methods static
    def backward(ctx, dLdz):    ##backward pass
        """
        dLdz shape: time_len, batch_size, *z_shape
        """
        func = ctx.func
        t, z, flat_parameters = ctx.saved_tensors
        time_len, bs, *z_shape = z.shape
        n_dim = np.prod(z_shape)
        n_params = flat_parameters.shape[0]

        # Dynamics of Aug system which are to be calculated backwards in time
        def augmented_dynamics(aug_z_i, t_i):
            z_i, a = aug_z_i[:, :n_dim], aug_z_i[:, n_dim:2*n_dim]  

            # Unflatten z and a
            z_i = z_i.view(bs, *z_shape)
            a = a.view(bs, *z_shape)
            with torch.set_grad_enabled(True):
                t_i = t_i.detach().requires_grad_(True)
                z_i = z_i.detach().requires_grad_(True)
                func_eval, adfdz, adfdt, adfdp = func.forward_with_grad(z_i, t_i, grad_outputs=a)  # bs, *z_shape
                adfdz = adfdz.to(z_i) if adfdz is not None else torch.zeros((bs, *z_shape), device=z_i.device, dtype=z_i.dtype)
                adfdp = adfdp.to(z_i) if adfdp is not None else torch.zeros((bs, n_params), device=z_i.device, dtype=z_i.dtype)
                adfdt = adfdt.to(z_i) if adfdt is not None else torch.zeros((bs, 1), device=z_i.device, dtype=z_i.dtype)

            # Flatten f and adfdz
            func_eval = func_eval.view(bs, n_dim)
            adfdz = adfdz.view(bs, n_dim) 
            return torch.cat((func_eval, -adfdz, -adfdp, -adfdt), dim=1)

        dLdz = dLdz.view(time_len, bs, n_dim)  # flatten dLdz for convenience
        with torch.no_grad():
            ## Create placeholders for output gradients
            # Prev computed backwards adjoints to be adjusted by direct gradients
            adj_z = torch.zeros((bs, n_dim), device=dLdz.device, dtype=dLdz.dtype)
            adj_p = torch.zeros((bs, n_params), device=dLdz.device, dtype=dLdz.dtype)
            # In contrast to z and p we need to return gradients for all times
            adj_t


In [8]:
class NeuralODE:
    def __init__(self, func):
        assert isinstance(func, ODEF)
        self.func = func

    def forward(self, z0, t=np.array([0., 1.]), return_whole_sequence=False):
        z0 = z0.astype(np.float32)
        t = t.astype(np.float32)
        z = ODEAdjoint.apply(torch.from_numpy(z0), torch.from_numpy(t), self.func.flatten_parameters(), self.func).numpy()
        if return_whole_sequence:
            return z
        else:
            return z[-1]


In [9]:
class LinearODEF(ODEF):
    def __init__(self, W):
        super(LinearODEF, self).__init__()
        self.W = W

    def forward(self, x, t):
        return np.matmul(x, self.W)


In [10]:
class SpiralFunctionExample(LinearODEF):
    def __init__(self):
        super(SpiralFunctionExample, self).__init__(Tensor([[-0.1, -1.], [1., -0.1]]))

In [11]:
class RandomLinearODEF(LinearODEF):
    def __init__(self):
        super(RandomLinearODEF, self).__init__(torch.randn(2, 2)/2.)

In [12]:
class TestODEF(ODEF):
    def __init__(self, A, B, x0):
        super(TestODEF, self).__init__()
        self.A = nn.Linear(2, 2, bias=False)
        self.A.weight = nn.Parameter(A)
        self.B = nn.Linear(2, 2, bias=False)
        self.B.weight = nn.Parameter(B)
        self.x0 = nn.Parameter(x0)

    def forward(self, x, t):
        xTx0 = torch.sum(x*self.x0, dim=1)
        sigmoid_xTx0 = torch.sigmoid(xTx0)
        A_diff = sigmoid_xTx0 * self.A(x - self.x0)
        B_diff = sigmoid_xTx0 * self.B(x + self.x0)
        dxdt = A_diff + B_diff
        return dxdt


In [13]:
class NNODEF(ODEF):
    def __init__(self, in_dim, hid_dim, time_invariant=False):
        super().__init__()
        self.time_invariant = time_invariant

        if time_invariant is True:
            self.lin1 = nn.Linear(in_dim, hid_dim)
        else:
            self.lin1 = nn.Linear(in_dim+1, hid_dim)
        self.lin2 = nn.Linear(hid_dim, hid_dim)
        self.lin3 = nn.Linear(hid_dim, in_dim)
        self.elu = nn.ELU(inplace=True)

    def forward(self, x, t):
        if not self.time_invariant:
            x = torch.cat((x, t), dim=-1)

        h = self.elu(self.lin1(x))
        h = self.elu(self.lin2(h))
        out = self.lin3(h)
        return out


In [14]:
def to_np(x):
    if isinstance(x, torch.Tensor):
        x = x.detach().cpu().numpy()
    return x

In [15]:
def plot_trajectories(obs=None, times=None, trajs=None, save=None, figsize=(16, 8)):
    fig, ax = plt.subplots(figsize=figsize)

    if obs is not None:
        if times is None:
            times = [None] * len(obs)
        for o, t in zip(obs, times):
            o, t = to_np(o), to_np(t)
            for b_i in range(o.shape[1]):
                ax.scatter(o[:, b_i, 0], o[:, b_i, 1], c=t[:, b_i, 0], cmap=cm.plasma)

    if trajs is not None: 
        for z in trajs:
            z = to_np(z)
            ax.plot(z[:, 0, 0], z[:, 0, 1], lw=1.5)
        if save is not None:
            plt.savefig(save)
    plt.show()


In [16]:
def conduct_experiment(ode_true, ode_trained, n_steps, name, plot_freq=10):
    # Create data
    z0 = torch.tensor([[0.6, 0.3]], requires_grad=True)

    t_max = 6.29*5
    n_points = 200

    index_np = np.arange(0, n_points, 1, dtype=np.int)
    index_np = np.hstack([index_np[:, None]])
    times_np = np.linspace(0, t_max, num=n_points)
    times_np = np.hstack([times_np[:, None]])

    times = torch.from_numpy(times_np[:, :, None]).to(z0)
    obs = ode_true(z0, times, return_whole_sequence=True).detach()
    obs = obs + torch.randn_like(obs) * 0.01

    # Get trajectory of random timespan
    min_delta_time = 1.0
    max_delta_time = 5.0
    max_points_num = 32

    def create_batch():
        t0 = np.random.uniform(0, t_max - max_delta_time)
        t1 = t0 + np.random.uniform(min_delta_time, max_delta_time)

        idx = sorted(np.random.permutation(index_np[(times_np > t0) & (times_np < t1)])[:max_points_num])

        obs_ = obs[idx]
        ts_ = times[idx]
        return obs_, ts_

    # Train Neural ODE
    optimizer = torch.optim.Adam(ode_trained.parameters(), lr=0.01)
    for i in range(n_steps):
        obs_, ts_ = create_batch()

        z_ = ode_trained(obs_[0], ts_, return_whole_sequence=True)
        loss = F.mse_loss(z_, obs_.detach())

        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()

        if i % plot_freq == 0:
            z_p = ode_trained(z0, times, return_whole_sequence=True)

            obs_np = to_np(obs)
            times_np = to_np(times)
            z_p_np = to_np(z_p)
            plot_trajectories(obs=[obs_np], times=[times_np], trajs=[z_p_np], save=f"/home/rmogalap/FP/{name}/{i}.png")
            clear_output(wait=True)


In [None]:
conduct_experiment(ode_true, ode_trained, 900, "linear")

In [None]:
##other spiral graph
func = TestODEF(Tensor([[-0.1, -0.5], [0.5, -0.1]]), Tensor([[0.2, 1.], [-1, 0.2]]), Tensor([[-1., 0.]]))
ode_true = NeuralODE(func)

func = NNODEF(2, 16, time_invariant=True)
ode_trained = NeuralODE(func)

In [None]:
conduct_experiment(ode_true, ode_trained, 3000, "comp", plot_freq=30)