In [3]:
from typing import Union

import diffrax
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
import matplotlib.pyplot as plt
import optax

In [8]:
def lipswish(x):
    return 0.909 * jnn.silu(x)

In [9]:
class VectorField(eqx.Module):
    scale: Union[int, jnp.ndarray]
    mlp: eqx.nn.MLP
    def __init__(self, hidden_size, width_size, depth, scale, *, key, **kwargs):
        super().__init__(**kwargs)
        scale_key, mlp_key = jrandom.split(key)
        if scale:
            self.scale = jrandom.uniform(
                scale_key, (hidden_size,),minval=0.9,maxval=1.1)
        else:
            self.scale = 1
        self.mlp = eqx.nn.MLP(
            in_size=hidden_size + 1,
            out_size=hidden_size,
            width_size = width_size,
            depth=depth,
            activation=lipswish,
            final_activation=jnn.tanh,
            key=mlp.key
        )
        def __call__(self, t, y, args):
            return self.scale * self.mlp(jnp.concatenate([t[None],y]))

In [10]:
class ControlledVectorField(eqx.Module):
    scale: Union[int, jnp.ndarray]
    mlp: eqx.nn.MLP
    control_size: int
    hidden_size: int
    
    def __init__(self, control_size, hidden_size, depth, scale, *, key, **kwargs):
        super().__init__(**kwargs)
        scale_key, mlp_key = jrandom.split(key)
        if scale:
            self.scale = jrandom.uniform(
                scale_key,(hidden_size, control_size),minval=0.9, maxval=1.1)
        else:
            self.scale = 1
        self.mlp = eqx.nn.MLP(
            in_size=hidden_size+1,
            out_size = hidden_size*control_size,
            width_size=width_size,
            depth=depth,
            activation=lipswish,
            final_activation=jnn.tanh,
            key=mlp_key
        )
        self.control_size = control_size
        self.hidden_size = hidden_size
        
    def __call__(self, t, y, args):
        return self.scale*self.mlp(jnp.concatenate([t[None],y])).reshape(
        self.hidden_size,self.control_size)

In [12]:
class NeuralSDE(eqx.Module):
    initial: eqx.nn.MLP
    vf: VectorField
    cvf: ControlledVectorField
    readout: eqx.nn.Linear
    initial_noise_size: int
    noise_size: int
        
    def __init__(
        self,
        data_size,
        initial_noise_size,
        noise_size,
        hidden_size,
        width_size,
        depth,
        *,
        key,
        **kwargs,
    ):
        super().__init__(**kwargs)
        initial_key, vf_key, cvf_key, readout_key = jrandom.split(key,4)
        
        self.initial = eqx.nn.MLP(
            initial_noise_size, hidden_size, width_size, depth, key=initial_key
        )
        self.vf = VectorField(hidden_size, width_size, depth, scale=True, key=vf_key)
        self.cvf = ControlledVectorField(
            noise_size, hidden_size, width_size, depth, scale=True, key=cvf_key)
        self.readout = eqx.nn.Linear(hidden_size, data_size, key=readout_key)
        self.initial_noise_size = initial_noise_size
        self.noise_size = noise_size
        
    def __call__(slef, ts, *, key):
        t0 = ts[0]
        t1 = ts[-1]
        dt0 = 1.0
        init_key, bm_key = jrandom.split(key,2)
        init = jrandom.normal(init_key, (self.initial_noise_size,))
        control = diffrax.VirtualBrownianTree(
            t0=t0, t1=t1, tol=dt0/2, shape=(self.noise_size,),key=bm_key
        )
        vf = diffrax.ODETerm(self.vf)
        cvf = diffrax.ControlTerm(self.cvf, control)
        terms = diffrax.MultiTerm(vf, cvf)
        # ReversibleHeun is a cheap choice of SDE solver. We could also use Euler etc.
        solver = diffrax.ReversibleHeun()
        y0 = self.initial(init)
        saveat = diffrax.SaveAt(ts=ts)
        sol = diffrax.diffeqsolve(
            terms, solver, t0, t1, dt0, y0, saveat=saveat, max_steps=64
        )
        return jax.vmap(self.readout)(sol.ys)

In [None]:
class NeuralCDE(eqx.Module):
    initial: eqx.nn.MLP
    vf: VectorField
    cvf: ControlledVectorField
    readout: eqx.nn.Linear
        
    def __init__(self, data_size, hidden_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        initial_key, vf_key, cvf_key, readout_key = jrandom.split(key,4)
        
        self.initial = eqx.nn.MLP(
            data_size+1, hidden_size, width_size, depth, key=initial_key
        )
        self.vf = VectorField(hidden_size, width_size, depth, scale=True key=vf_key)
        self.cvf = ControlledVectorField(noise_size, hidden_size, width_size, depth, scale=True, key=cvf_key)
        self.readout = eqx.nn.Linear(hidden_size,1, key=readout_key)
    def __call__(slef, ts, *, key):
        t0 = ts[0]
        t1 = ts[-1]
        dt0 = 1.0
        init_key, bm_key = jrandom.split(key,2)
        control = diffrax.LinearInterpolation(ts,ys)
        vf = diffrax.ODETerm(self.vf)
        cvf = diffrax.ControlTerm(self.cvf, control)
        terms = diffrax.multiTerm(vf, cvf)
        solver = diffrax.ReversibleHeun()
        t0 = ts[0]
        t1 = ts[-1]
        dt0 = 1.0
        y0 = self.initial(init)
        saveat = diffrax.SaveAt(t0=True,t1=True)
        sol = diffrax.difffeqsolve(
            terms, solver, t0, t1, dt0, y0, saveat=saveat, max_steps=64
        )
        return jax.vmap(self.readout)(sol.ys)
    @eqx.filter_jit
    def clip_weights(self):
        leaves, treedef = jax.tree_flatten(
            self, is_leaf=lambda x: isinstance(x, eqx.nn.Linear)
        )
        new_leaves = []
        for leaf in leaves