<div style="background-color: #008B8B; padding: 15px; border-radius: 5px; font-size: 28px; color: black; font-weight: bold;">
Basics on Neural Latent SDE
</div>

<div style="background-color: #008B8B; padding: 15px; border-radius: 5px; font-size: 20px; color: black; font-weight: bold;">
Biblio
</div>

Neural ODEs:

**Neural ODEs (https://arxiv.org/abs/1806.07366) (2019)** : introduction of the Neural ODE as the continuous-time limit of a ResNet stack. Presentation of the use of the adjoint sensitivity method. Seminal paper for Neural ODE.

**Latent ODEs for Irregularly-Sampled Time Series (https://arxiv.org/abs/1907.03907) (2019)** : Evolution of the Neural ODE model towards a Neural ODE RNN model, where the approximate posterior is built with a RNN on past observations.

Neural SDEs:

**SDE Matching: Scalable and Simulation-Free Training of Latent Stochastic Differential Equations (https://arxiv.org/abs/2502.02472 , 2025)** : good background section (#2) to explain Neural SDE. Propose a new method SDE matching, inspired by score and flow matching, vs the adjoint sensivity method. SDE matching is claimed to be more efficient to compute gradients and train latent SDEs.

**Scalable Gradients for Stochastic Differential Equations (https://arxiv.org/abs/2001.01328) (2020)** : generalization of the adjoint sensitivity method to SDEs. Combination with gradient-based stochastic variational inference for infinite-dimension VAEs.

**Neural SDEs (https://www.researchgate.net/publication/333418188_Neural_Stochastic_Differential_Equations) (2019)** : link between infinitely deep residual networks and solutions to stochastic differential equations

**Stable Neural SDEs in analyzing irregular time series data (https://arxiv.org/abs/2402.14989) (2025)** : points to the necessity of careful design of the drift and diffusion neural nets in latent SDEs. Introduces three latent SDEs models with performance guarantees.

**Generative Modeling of Neural Dynamics via Latent Stochastic Differential Equations (https://arxiv.org/abs/2412.12112) (2024)** : application of neural SDEs to a biological use case (brain activity). Details the model, architecture, ELBO/loss computation. Takes into account inputs/commands in the model. 

General/Misc:

**Efﬁcient gradient computation for dynamical models (https://www.fil.ion.ucl.ac.uk/~wpenny/publications/efficient_revised.pdf) (2014)** : summary of finite difference method, forward sensitivity method, adjoint sensitivity method, to compute gradients of a functional cost function. Applies to Neural ODEs training.

**Cyclical Annealing Schedule: A Simple Approach to Mitigating KL Vanishing (https://arxiv.org/abs/1903.10145) (2019)** : explanation of the posterior collapse/KL vanishing problem, introduces different KL annealing schedules for VAE training.


<div style="background-color: #008B8B; padding: 15px; border-radius: 5px; font-size: 20px; color: black; font-weight: bold;">
Code : torchsde library by Google Research
</div>

https://github.com/google-research/torchsde

[1] Xuechen Li, Ting-Kam Leonard Wong, Ricky T. Q. Chen, David Duvenaud. "Scalable Gradients for Stochastic Differential Equations". International Conference on Artificial Intelligence and Statistics. 2020. [arXiv]

[2] Patrick Kidger, James Foster, Xuechen Li, Harald Oberhauser, Terry Lyons. "Neural SDEs as Infinite-Dimensional GANs". International Conference on Machine Learning 2021. [arXiv]

[3] Patrick Kidger, James Foster, Xuechen Li, Terry Lyons. "Efficient and Accurate Gradients for Neural SDEs". 2021. [arXiv]

[4] Patrick Kidger, James Morrill, James Foster, Terry Lyons, "Neural Controlled Differential Equations for Irregular Time Series". Neural Information Processing Systems 2020. [arXiv]



<div style="background-color: #008B8B; padding: 15px; border-radius: 5px; font-size: 20px; color: black; font-weight: bold;">
Basic manipulations of the torchsde library
</div>

See also https://github.com/google-research/torchsde/blob/master/examples/demo.ipynb

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torchsde
from torchdiffeq import odeint, odeint_adjoint

# from mpl_toolkits.mplot3d import Axes3D
import timeit

In [2]:
def seed_everything(seed=42):
    """
    Set seed for reproducibility.
    """
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
seed_everything()

In [3]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    dtype = torch.cuda.FloatTensor
else:
    device = torch.device('cpu')
    dtype = torch.FloatTensor

print(f"Using {device}")

torch.set_default_dtype(torch.float32)

if device.type == 'cuda':
    print('GPU Name:', torch.cuda.get_device_name(0))
    print('Total GPU Memory:', round(torch.cuda.get_device_properties(0).total_memory/1024**3,1), 'GB')

Using cuda
GPU Name: NVIDIA GeForce RTX 3080 Ti
Total GPU Memory: 11.8 GB


<div style="background-color: #008B8B; padding: 15px; border-radius: 5px; font-size: 20px; color: black; font-weight: bold;">
Verifying the home made calculation of KL path : Ok à un facteur 2 près...
</div>

In [134]:
class LatentSDE(nn.Module):
    
    def __init__(self, prior_theta=10.0, posterior_theta=23.0, prior_mu=1.0, posterior_mu=2.0, sigma=0.5):
        super().__init__()
        self.noise_type="diagonal"
        self.sde_type="ito"
        
        # prior drift
        self.register_buffer("prior_theta", torch.tensor(prior_theta))
        self.register_buffer("prior_mu", torch.tensor(prior_mu))
        self.register_buffer("sigma", torch.tensor(sigma))
        
        # posterior drift
        self.posterior_theta = nn.Parameter(torch.tensor(posterior_theta), requires_grad=True)
        self.posterior_mu = nn.Parameter(torch.tensor(posterior_mu), requires_grad=True)
        
    # approx posterior drift
    def f(self,t,z):
        if t.dim()==0:
            t = torch.full_like(z, fill_value=t)
        return self.posterior_theta*(self.posterior_mu - z)
        
    # prior drift
    def h(self,t,z):
        if t.dim()==0:
            t = torch.full_like(z, fill_value=t)
        return self.prior_theta*(self.prior_mu - z)
        
    # shared diffusion
    def g(self,t,z):
        if t.dim()==0:
            return self.sigma.repeat(z.size(0), 1)
        else:
            return self.sigma * torch.ones((t.size(0),z.size(1))).to(device)
        
sde = LatentSDE().to(device)

In [135]:
LENGTH = 100
t_start = 0.0
t_end = 1.0
times = torch.linspace(t_start,t_end,LENGTH).to(device)
print(f"times : {times.shape}")

times : torch.Size([100])


In [136]:
# compute logp
K = 30
z0s = torch.zeros((K,1)).to(device)
print(f"z0s : {z0s.shape}")

# compute SDEs
zs, logqp = torchsde.sdeint(sde, z0s, times, method="euler", dt=1e-3, logqp=True)

# compute KL by averaging over batch, summing over time
kl_path_1 = logqp.mean(dim=1).sum()

# report
print(f"zs : {zs.shape}")
print(f"logqp : {logqp.shape}")
print(f"KL via logqp : {kl_path_1}")

z0s : torch.Size([30, 1])
zs : torch.Size([100, 30, 1])
logqp : torch.Size([99, 30])
KL via logqp : 279.9403381347656


In [137]:
# compute logqp home made
epsilon = 1e-6

# compute prior drifts
prior_drifts = sde.h(times, zs)
posterior_drifts = sde.f(times, zs)
diff = sde.g(times, zs).unsqueeze(-1)
diff2 = torch.where(diff.abs().detach() > epsilon, diff, torch.full_like(diff, fill_value=epsilon) * diff.sign())
deltas = torch.div(posterior_drifts - prior_drifts, diff2)**2
approx_int = torch.stack( [1/2 * (times[i+1]-times[i]) * (deltas[i,:,:] + deltas[i+1,:,:]) for i in range(times.shape[0]-1) ] )  # n_steps-1 x K x 1

print(f"Home made KL : {approx_int.shape}")
kl_path_2 = approx_int.mean(dim=1).sum()
print(f"KL via home made : {kl_path_2}")

Home made KL : torch.Size([99, 30, 1])
KL via home made : 558.3657836914062
