ODENet:
    https://arxiv.org/abs/1806.07366

In [2]:
import torch
import torch.nn as nn
from torch.nn import init
import math as m

In [3]:
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [4]:
def runge_kutta(f, t0, t1, x0, h):
    t = t0
    x = x0
    n_steps = m.ceil((t0-t1).abs() / h)
    for i in range(n_steps):
        k1 = f(x, t)
        k2 = f(x + 0.5*h*k1, t + 0.5*h)
        k3 = f(x + 0.5*h*k2, t + 0.5*h)
        k4 = f(x + h*k3, t + h)

        x = x + h*(k1+2*k2 + 2*k3 + k4) / 6
        t = (t+h) if (t0 < t1) else (t-h)
    return x

In [5]:
#super class for function
class ODE_Function(torch.nn.Module):
    
    #return f and all function for backward (a*df/dz, a*df/dp, a*df/dt) 
    def forward_and_grads(self, z_in, t_in, grad_outputs):
        f = self.forward(z_in, t_in)
        bs = z_in.shape[0]
        a = grad_outputs

        z_in.requires_grad = True
        t_in.requires_grad = True

        adfdz, adfdt, *adfdp = torch.autograd.grad(
            (f,), (z_in, t_in) + tuple(self.parameters()), grad_outputs=(a), allow_unused=True, retain_graph=True )
        
        if adfdp is not None:
            for i in range(len(adfdp)):
                if adfdp[i] is None:
                    adfdp[i] = torch.Tensor(*list(self.parameters())[i].shape).to(z_in)
            adfdp = torch.cat([p_grads.flatten() for p_grads in adfdp])
            
        return f, adfdz, adfdp, adfdt
    
    def flat_params(self):
        params = []
        for p in self.parameters():
            params.append(p.flatten())
        return torch.cat(params)

In [6]:
class ODEAdjoint(torch.autograd.Function):
    @staticmethod
    def forward(ctx, z0, t, params, f, h):
        assert isinstance(f, ODE_Function)
        z_shape = z0.shape
        time_len = t.shape[0]

        with torch.no_grad():
            z = torch.Tensor(time_len, *z_shape).to(z0)
            z[0] = z0
            for i in range(time_len - 1):
                z[i+1] = runge_kutta(f, t[i], t[i+1], z[i], h)
      
        ctx.f = f
        ctx.save_for_backward(z.clone(), t, params)
        return z

    @staticmethod
    def backward(ctx, dLdz, h):
        f = ctx.f
        z, t, params = ctx.saved_tensors
        time_len, *z_shape = z.shape
        z_dim = np.prod(z_shape)
        params_dim = params.shape[0]
        
        #Right part of the augmented system
        def aug_dynamics(aug_v, t):
            '''
                dim aug_v = 2*z_dim + params_dim + 1
                aug_v = (z, a, params, t)
            '''
            z, a = aug_v[0:z_dim], aug_v[z_dim:2*z_dim]
            
            with torch.set_grad_enabled(True):

                f_eval, adfdz, adfdp, adfdt = f.forward_and_grads(z, t, a)
                adfdz = adfdz if adfdz is not None else torch.zeros(*z_shape)
                adfdp = adfdp if adfdp is not None else torch.zeros(params_dim)
                adfdt = adfdt if adfdt is not None else torch.zeros(1)

                adfdz = adfdz.to(z)
                adfdp = adfdp.to(z)
                adfdt = adfdt.to(z)

            return torch.cat((f_eval, -adfdz, -adfdp, -adfdt))
        
        dLdz = dLdz.view(time_len, z_dim)
        with torch.no_grad():
            adj_z = torch.zeros(z_dim).to(dLdz)
            adj_params = torch.zeros(params_dim).to(dLdz)
            adj_t = torch.zeros(time_len).to(dLdz)
            
            for i in range(time_len-1, 0, -1):
                f_i = f(z[i], t[i]).to(dLdz)
                dLdz_i = dLdz[i].to(dLdz)
                dLdt_i =torch.matmul( dLdz_i, f_i ).to(dLdz)
                
                adj_z += dLdz_i
                adj_t[i] -= dLdt_i
                z_i = z[i].to(dLdz)
                aug_v = torch.cat((z_i, adj_z, torch.zeros(params_dim).to(z), adj_t[i].unsqueeze(0)))
                
                aug_solution = runge_kutta(aug_dynamics, t[i], t[i-1], aug_v, h)
                
                adj_z[:] = aug_solution[z_dim:2*z_dim]
                adj_params[:] = aug_solution[2*z_dim:2*z_dim + params_dim]
                adj_t[i - 1] = aug_solution[2*z_dim + params_dim:]
                
                del aug_v, aug_solution
                
            dLdz_0 = dLdz[0]
            dLdt_0 = torch.matmul( dLdz_0, f_i)
            
            adj_z += dLdz_0
            adj_t[0] -= dLdt_0
        return adj_z, adj_t, adj_params, None, None

In [8]:
#Wrapper class of ODEAdjoint 
class ODELayer(torch.nn.Module):
    def __init__(self, f, h):
        super(ODELayer, self).__init__()
        assert isinstance(f, ODE_Function)
        self.f = f
        self.h = h
        
    def forward(self, z0, t=torch.Tensor([0., 1.])):
        t = t.to(z0)
        z = ODEAdjoint.apply(z0, t, self.f.flat_params(), self.f, self.h)       
        return z