# Comparing Neural ODEs and Neural Conjugate Flows

Neural Conjugate Flows...

`torchdyn` implements out-of-the-box a variety of continuous-depth models. We will touch upon the following Neural ODE variants:

* **Vanilla** (depth-invariant)
* **Vanilla** (depth-variant)
* **Galerkin**
* **Data-controlled**


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Import the necessary libraries
import time 

import torch
import torch.nn as nn
import torch.utils.data as data
import torch.utils.benchmark as benchmark

import torchdyn
import numpy as np
import matplotlib.pyplot as plt

# from scipy.integrate import solve_ivp
from torchdyn.numerics import odeint

In [None]:
from torchdyn.core import NeuralODE
from torchdyn.nn import DataControl, DepthCat, Augmenter, GalLinear, Fourier
from torchdyn.datasets import *
from torchdyn.utils import *

In [None]:
from models import *
import utils.ode as ode

In [None]:
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Change to 64-bit
# torch.set_default_tensor_type(torch.DoubleTensor)
# torch.set_default_dtype(torch.float64)

# Clear GPU Memory
# torch.cuda.empty_cache()
torch.manual_seed(0)

In [None]:
# Define the Residue for the ODE
speed = 32.0
eq = ode.RescaledHodgkinHuxley(speed = speed)
# eq = ode.FitzHughNagumo(speed = speed)
eq_name = 'Hodgkin-Huxley'
order = eq.order
nt = 200
t = torch.linspace(0, 1, nt+1, requires_grad=False)
zero = torch.zeros(1, requires_grad=False)
x0 = [zero,eq.n_inf(zero),eq.m_inf(zero),eq.h_inf(zero)]
x0 = 10*torch.tensor(x0, requires_grad = False)

In [None]:
# Baseline Simulation with SciPy
sol = odeint(eq.f_solver, x0, t,solver = 'dopri5')[1]

plt.plot(speed*t,sol[:,0], label = 'X')
plt.plot(speed*t,sol[:,1], label = 'Y')
plt.plot(speed*t,sol[:,2], label = 'Z')
plt.plot(speed*t,sol[:,3], label = 'ZZ')
plt.title(eq_name)

In [None]:
# Baseline Simulation with SciPy
speed = 14.0
nt = 100
eq = ode.RescaledHodgkinHuxley(speed = speed)
x0 = torch.tensor(sol[-1],dtype=torch.float)
t = torch.linspace(0, 1, nt+1, requires_grad=False)

In [None]:

# sol = solve_ivp(eq.f_solver, [0.,1.], x0_np, t_eval=t)
# sol = solve_ivp(eq.f_solver, [0.,1.], x0_np, t_eval=t)
sol = odeint(eq.f_solver, x0, t,solver = 'dopri5')[1]
# sol_extrapolated = solve_ivp(eq.f_solver, [0.,2.], x0_np, t_eval=2*t)
sol_extrapolated = odeint(eq.f_solver, x0, 2*t,solver = 'dopri5')[1]

plt.plot(speed*t,sol[:,0], label = 'X')
plt.plot(speed*t,sol[:,1], label = 'Y')
plt.plot(speed*t,sol[:,2], label = 'Z')
plt.plot(speed*t,sol[:,3], label = 'ZZ')
plt.title(eq_name)

In [None]:
plt.plot(2*speed*t,sol_extrapolated[:,0], label = 'X')
plt.plot(2*speed*t,sol_extrapolated[:,1], label = 'Y')
plt.plot(2*speed*t,sol_extrapolated[:,2], label = 'Z')
plt.plot(2*speed*t,sol_extrapolated[:,3], label = 'ZZ')
plt.title(eq_name)

In [None]:
# Train datasets are a subset of the original solution
undersampling = 2
t_train = t[::undersampling]
y_train = sol[::undersampling].to(t).unsqueeze(1)

x_train = y_train[0]
train = data.TensorDataset(y_train)
trainloader = data.DataLoader(train, batch_size=len(t_train), shuffle=False)

In [None]:
experiments = 5
epochs = 2000

**Learner**

In [None]:
import torch.nn as nn
import pytorch_lightning as pl

class Learner(pl.LightningModule):
    def __init__(self, x0:torch.Tensor, t_span:torch.Tensor, model:nn.Module, *,
                  trainloader: data.DataLoader = trainloader,
                  lr = 0.005):
        super().__init__()
        self.register_buffer('x0', x0)
        self.register_buffer('t_span', t_span)
        self.model = model
        # self.x0, self.t_span, self.model = x0, t_span, model
        self.trainloader = trainloader
        self.lr = lr
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x = self.x0
        y = batch[0]      
        y_hat = self.model(x, self.t_span)
        loss = nn.MSELoss()(y_hat, y)
        self.log("loss", loss, prog_bar=True, on_step=False, on_epoch=True, logger = False)
        return {'loss': loss}   
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)

    def train_dataloader(self):
        return self.trainloader

**Note:** In this notebook we will consider the depth domain $[0,1]$, i.e. $t\in[0,1]$. Note that, for most architectures in *static* settings (aka we do not deal with dynamic data) any other depth domain does not actually affect the expressiveness of Neural ODEs, since it can be seen as a rescaled/shifted version of $[0,1]$. Please note that, however, other choices of the depth domain can indeed affect the training phase

The depth domain can be accessed and modified through the `t_span` setting of `NeuralODE` instances.

## Vanilla Neural ODE (Depth-Invariant)

$$ \left\{
    \begin{aligned}
        \dot{z}(t) &= f(z(t), \theta)\\
        z(0) &= x\\
        \hat y & = z(1)
    \end{aligned}
    \right. \quad t\in[0,1]
$$

This model is the same used in [torchdyn quickstart](./00_quickstart.html) tutorial. The vector field is parametrized by a neural network $f$ with *static* parameters $\theta$ and taking as input only the state $h(s)$.

In [None]:

losses = []
losses_extrapolation = []
fit_times = []

for i in range(experiments):
        print(f'Experiment {i+1}')
        torch.manual_seed(i)
        width = 128
        f = CustomMLP(order,width,width,order)

        model = NeuralODE(f, sensitivity='autograd', solver='midpoint', interpolator=None, atol=1e-3, rtol=1e-3,return_t_eval=False)
        # train the Neural ODE
        learn = Learner(x_train, t_train, model, lr = 0.0025)
        trainer = pl.Trainer(max_epochs=epochs)
        
        start_time = time.time()
        trainer.fit(learn)
        end_time = time.time()
        fit_time = end_time - start_time
        fit_times.append(fit_time)
        
        trajectory = model(x0, t)
        trajectory = trajectory.squeeze(1).detach()
        loss_ground = nn.MSELoss()(trajectory, sol)       
        # loss_ground = np.mean((trajectory - sol)**2)
        losses.append(loss_ground)

        # Extrapolation
        trajectory_extrapolated = model(x0, 2*t)
        trajectory_extrapolated = trajectory_extrapolated.squeeze(1).detach()
        loss_extra = nn.MSELoss()(trajectory_extrapolated, sol_extrapolated)
        losses_extrapolation.append(loss_extra)


mean_loss = np.mean(losses)
std_loss = np.std(losses)
mean_extra = np.mean(losses_extrapolation)
std_extra = np.std(losses_extrapolation)
mean_fit_time = np.mean(fit_times)
std_fit_time = np.std(fit_times)

mean_loss, std_loss, mean_extra, std_extra, mean_fit_time, std_fit_time

In [None]:
model_node = model

**Plots**

In [None]:
trajectory = model(x_train, t)
trajectory = trajectory.detach().numpy()

In [None]:
# Baseline Simulation with SciPy

plt.plot(t,sol[:,0], label = 'X_baseline')
plt.plot(t,sol[:,1], label = 'Y_baseline')
plt.plot(t,trajectory[...,0], label = 'X')
plt.plot(t,trajectory[...,1], label = 'Y')
plt.title(eq_name)

In [None]:
trajectory_extrapolated = model(x_train, 2*t)
trajectory_extrapolated = trajectory_extrapolated.detach().numpy()

In [None]:
# Baseline Simulation with SciPy

plt.plot(2*t,sol_extrapolated[:,0], label = 'X_baseline')
plt.plot(2*t,sol_extrapolated[:,1], label = 'Y_baseline')
plt.plot(2*t,trajectory_extrapolated[...,0], label = 'X')
plt.plot(2*t,trajectory_extrapolated[...,1], label = 'Y')
plt.title(eq_name)

## Pseudo-Flows (MLP-PINNs)

In [None]:

losses = []
losses_extrapolation = []
fit_times = []

for i in range(experiments):
        print(f'Experiment {i+1}')
        torch.manual_seed(i)
        width = 128
        f = CustomMLP(order+1,width,width,order,fourier_feature=True)

        model = SemiFlow(f)
        # train the Neural ODE
        learn = Learner(x_train, t_train, model, lr = 0.0025)
        trainer = pl.Trainer(max_epochs=epochs)
        
        start_time = time.time()
        trainer.fit(learn)
        end_time = time.time()
        fit_time = end_time - start_time
        fit_times.append(fit_time)
        
        trajectory = model(x0, t)
        trajectory = trajectory.detach()  
        loss_ground = nn.MSELoss()(trajectory, sol)      
        losses.append(loss_ground)

        # Extrapolation
        trajectory_extrapolated = model(x0, 2*t)
        trajectory_extrapolated = trajectory_extrapolated.detach()
        loss_extra = nn.MSELoss()(trajectory_extrapolated, sol_extrapolated)
        losses_extrapolation.append(loss_extra)


mean_loss = np.mean(losses)
std_loss = np.std(losses)
mean_extra = np.mean(losses_extrapolation)
std_extra = np.std(losses_extrapolation)
mean_fit_time = np.mean(fit_times)
std_fit_time = np.std(fit_times)

mean_loss, std_loss, mean_extra, std_extra, mean_fit_time, std_fit_time

In [None]:
model_pseudo = model

**Plots**

In [None]:
trajectory = model(x_train, t)
trajectory = trajectory.detach().numpy()

In [None]:
# Baseline Simulation with SciPy

plt.plot(t,sol[:,0], label = 'X_baseline')
plt.plot(t,sol[:,1], label = 'Y_baseline')
plt.plot(t,trajectory[...,0], label = 'X')
plt.plot(t,trajectory[...,1], label = 'Y')
plt.title(eq_name)

In [None]:
trajectory_extrapolated = model(x_train, 2*t)
trajectory_extrapolated = trajectory_extrapolated.detach().numpy()

In [None]:
# Baseline Simulation with SciPy

plt.plot(2*t,sol_extrapolated[:,0], label = 'X_baseline')
plt.plot(2*t,sol_extrapolated[:,1], label = 'Y_baseline')
plt.plot(2*t,trajectory_extrapolated[...,0], label = 'X')
plt.plot(2*t,trajectory_extrapolated[...,1], label = 'Y')
plt.title(eq_name)

## Neural Conjugate Flows (No Topology)

In [None]:
twin_times = 1+1
twin_order = order*twin_times
twin_order

x_train_twinned = torch.cat([x_train]*twin_times, dim=-1)
y_train_twinned = torch.cat([y_train]*twin_times, dim=-1)
train_twinned = data.TensorDataset(y_train_twinned)
trainloader_twinned = data.DataLoader(train_twinned, batch_size=len(t_train), shuffle=False)

In [None]:
M0 = ncf_matrix_init(eq,x_train,pad_mode='twin',pad_times=twin_times)
# M0 = eq.jacobian(x_train)
M0, M0.shape


In [None]:
losses = []
losses_extrapolation = []
fit_times = []

for i in range(experiments):
        print(f'Experiment {i+1}')
        torch.manual_seed(i)
        width = 100


        Psi = LinearFlow(twin_order*twin_order, M0 = M0,omega_zero=1.0)
        s = CustomMLP(twin_order//2,width,width,twin_order//2)
        coupling = AdditiveCouplingLayer(s)
        s2 = CustomMLP(twin_order//2,width,width,twin_order//2)
        coupling2 = AdditiveCouplingLayer(s2, orientation='skew')

        layers = [coupling, coupling2]

        model = NeuralConjugate(layers, Psi, pad='no')
        # train the Neural ODE
        learn = Learner(x_train_twinned, t_train, model, lr = 0.0025,trainloader=trainloader_twinned)
        trainer = pl.Trainer(max_epochs=epochs)
        
        start_time = time.time()
        trainer.fit(learn)
        end_time = time.time()
        fit_time = end_time - start_time
        fit_times.append(fit_time)
        
        trajectory = model(x_train_twinned, t)
        trajectory = trajectory[...,:4].detach()
        loss_ground = nn.MSELoss()(trajectory, sol)     
        losses.append(loss_ground)

        # Extrapolation
        trajectory_extrapolated = model(x_train_twinned, 2*t)
        trajectory_extrapolated = trajectory_extrapolated[...,:4].detach()
        loss_extra = nn.MSELoss()(trajectory_extrapolated, sol_extrapolated)
        losses_extrapolation.append(loss_extra)


mean_loss = np.mean(losses)
std_loss = np.std(losses)
mean_extra = np.mean(losses_extrapolation)
std_extra = np.std(losses_extrapolation)
mean_fit_time = np.mean(fit_times)
std_fit_time = np.std(fit_times)

mean_loss, std_loss, mean_extra, std_extra, mean_fit_time, std_fit_time

In [None]:
# torch.manual_seed(1)

# order = 2*twin_times
# width = 32

# Psi = LinearFlow(order*order, M0 = M0, lie_algebra='skew_symmetric',omega_zero=1.0)


# o = OrthogonalLayer(order,order,bias = True)
# s = CustomMLP(order//2,width,width,order//2)
# coupling = AdditiveCouplingLayer(s)
# s2 = CustomMLP(order//2,width,width,order//2)
# coupling2 = AdditiveCouplingLayer(s2, orientation='skew')

# # layers = [o,coupling, coupling2]
# layers = [coupling, coupling2]
# # layers = []

# model = NeuralConjugate(layers, Psi, pad='no')
# # model = NeuralConjugate(layers, Psi)

In [None]:
# plt.plot(2*t,Psi(x_train_twinned,2*t)[...,0].detach().numpy(), label = 'X_baseline')
# plt.plot(2*t,Psi(x_train_twinned,2*t)[...,1].detach().numpy(), label = 'X_baseline')
# plt.plot(2*t,Psi(x_train_twinned,2*t)[...,2].detach().numpy(), label = 'X_baseline')
# plt.plot(2*t,Psi(x_train_twinned,2*t)[...,3].detach().numpy(), label = 'X_baseline')
# # x_train_twinned.shape
# # hmm = x_train_twinned.unsqueeze(0)
# # hmm = hmm.transpose(1,2)

In [None]:
# # train the Neural ODE
# learn = Learner(x_train_twinned, t_train, model, trainloader = trainloader_twinned, lr = 0.0025)
# # learn = Learner(x_train, t_train, model)
# trainer = pl.Trainer(max_epochs=2000)
# trainer.fit(learn)

**Plots**

In [None]:
trajectory = model(x_train_twinned, t)
# _ , trajectory = model(x_train, t)
trajectory = trajectory.detach()

In [None]:
# Baseline Simulation with SciPy

plt.plot(t,sol[:,0], label = 'X_baseline')
plt.plot(t,sol[:,1], label = 'Y_baseline')
plt.plot(t,trajectory[...,0], label = 'X')
plt.plot(t,trajectory[...,1], label = 'Y')
plt.title(eq_name)

In [None]:
trajectory_extrapolated = model(x_train_twinned, 2*t)
# _ , trajectory_extrapolated = model(x_train, 2*t)
trajectory_extrapolated = trajectory_extrapolated.detach()

In [None]:
# Baseline Simulation with SciPy
extra = 3

plt.plot(2*t,sol_extrapolated[:,0], label = 'X_baseline')
# plt.plot(2*t,sol_extrapolated[1], label = 'Y_baseline')
plt.plot(2*t,trajectory_extrapolated[...,0], label = 'X')
# plt.plot(2*t,trajectory_extrapolated[...,1], label = 'Y')
# plt.plot(2*t,trajectory_extrapolated[...,2], label = 'X')
# plt.plot(2*t,trajectory_extrapolated[...,3], label = 'Y')
plt.title(eq_name)

## Neural Conjugate Flows (With Topology)

In [None]:
twin_times = 1+1

x_train_twinned = torch.cat([x_train]*twin_times, dim=-1)
y_train_twinned = torch.cat([y_train]*twin_times, dim=-1)
train_twinned = data.TensorDataset(y_train_twinned)
trainloader_twinned = data.DataLoader(train_twinned, batch_size=len(t_train), shuffle=False)

In [None]:
M0 = ncf_matrix_init(eq,x_train,pad_mode='twin',pad_times=twin_times)
# M0 = eq.jacobian(x_train)
M0.shape


In [None]:

losses = []
losses_extrapolation = []
fit_times = []

twin_order = order*twin_times
twin_order


In [None]:

for i in range(experiments):
        print(f'Experiment {i+1}')
        torch.manual_seed(i)
        width = 90

        Psi = LinearFlow(twin_order*twin_order, M0 = M0,omega_zero=5)

        s = CustomMLP(twin_order//2,width,width,twin_order//2)
        coupling = AdditiveCouplingLayer(s)
        s2 = CustomMLP(twin_order//2,width,width,twin_order//2)
        coupling2 = AdditiveCouplingLayer(s2, orientation='skew')

        layers = [coupling, coupling2]

        model = NeuralConjugate(layers, Psi, pad='no')
        # train the Neural ODE
        learn = Learner(x_train_twinned, t_train, model, lr = 0.0025,trainloader=trainloader_twinned)
        trainer = pl.Trainer(max_epochs=epochs)
        
        start_time = time.time()
        trainer.fit(learn)
        end_time = time.time()
        fit_time = end_time - start_time
        fit_times.append(fit_time)
        
        trajectory = model(x_train_twinned, t)
        trajectory = trajectory.squeeze(1)[...,:4].detach()
        loss_ground = nn.MSELoss()(trajectory, sol)    
        losses.append(loss_ground)

        # Extrapolation
        trajectory_extrapolated = model(x_train_twinned, 2*t)
        trajectory_extrapolated = trajectory_extrapolated[...,:4].detach()
        loss_extra = nn.MSELoss()(trajectory_extrapolated, sol_extrapolated)
        losses_extrapolation.append(loss_extra)


mean_loss = np.mean(losses)
std_loss = np.std(losses)
mean_extra = np.mean(losses_extrapolation)
std_extra = np.std(losses_extrapolation)
mean_fit_time = np.mean(fit_times)
std_fit_time = np.std(fit_times)

mean_loss, std_loss, mean_extra, std_extra, mean_fit_time, std_fit_time

**Plots**

In [None]:
trajectory = model(x_train_twinned, t)
# _ , trajectory = model(x_train, t)
trajectory = trajectory.detach().numpy()

In [None]:
# Baseline Simulation with SciPy

plt.plot(t,sol[...,0], label = 'X_baseline')
plt.plot(t,sol[...,1], label = 'Y_baseline')
plt.plot(t,trajectory[...,0], label = 'X')
plt.plot(t,trajectory[...,1], label = 'Y')
plt.title(eq_name)

In [None]:
trajectory_extrapolated = model(x_train_twinned, 2*t)
# _ , trajectory_extrapolated = model(x_train, 2*t)
trajectory_extrapolated = trajectory_extrapolated.detach().numpy()

In [None]:
# Baseline Simulation with SciPy
extra = 3

plt.plot(2*t,sol_extrapolated[...,0], label = 'X_baseline')
# plt.plot(2*t,sol_extrapolated[1], label = 'Y_baseline')
plt.plot(2*t,trajectory_extrapolated[...,0], label = 'X')
# plt.plot(2*t,trajectory_extrapolated[...,1], label = 'Y')
# plt.plot(2*t,trajectory_extrapolated[...,2], label = 'X')
# plt.plot(2*t,trajectory_extrapolated[...,3], label = 'Y')
plt.title(eq_name)

In [None]:
mlp_extrapolated = model_pseudo(x_train, 2*t)
mlp_extrapolated = mlp_extrapolated.detach().numpy()
node_extrapolated = model_node(x_train, 2*t)
node_extrapolated = node_extrapolated.detach().numpy()

In [None]:
t = t.numpy()
fig, ax = plt.subplots()
ax.plot(2*speed*t,10*sol_extrapolated[...,0], label = 'X_baseline')
ax.plot(2*speed*t,10*mlp_extrapolated[...,0], label = 'X', marker = "+",markevery=3, linestyle = '--')
ax.plot(2*speed*t,10*node_extrapolated[...,0], label = 'X', marker = "x", markevery=3, linestyle = '-.')
ax.plot(2*speed*t,10*trajectory_extrapolated[...,0], label = 'X', marker = "o", markevery=3, linestyle = '-.')
lims = [-0,70]
ax.set_ylim(lims)
ax.set_xlim([0,28])
# ax.set_aspect(3/4)
ax.vlines(14.0,*lims, linestyles='dashed', colors='gray')
plt.title("Hodgkin-Huxley Neuron Model")
plt.legend(["Baseline","MLP","NODE","NCF", "Training Limit"])
plt.xlabel("Time (ms)")
plt.ylabel("Neuron Potential (mV)")