In [89]:
import diffrax
import jax
from jax.experimental import ode
import matplotlib.pyplot as pyplot

In [91]:
PARAMS = jax.numpy.array([1.0, 1.0], dtype=jax.numpy.float64)
MATRIX = jax.numpy.array([[2.0, 1.0], [1.0, 2.0]], dtype=jax.numpy.float64)
Y_INIT = jax.numpy.array([0.0, 0.0], dtype=jax.numpy.float64)

In [92]:
@jax.jit
def driving_term(t, args):
    return jax.numpy.array([jax.numpy.exp(t), jax.numpy.sin(t)]) * args

In [93]:
@jax.jit
def jax_ode(y, t, args):
    return MATRIX @ y + driving_term(t, args)

In [94]:
@diffrax.ODETerm
@jax.jit
def diffrax_ode(t, y, args):
    return MATRIX @ y + driving_term(t, args)

In [95]:
@jax.jit
def y(t):
    y_1 = 0.05 * jax.numpy.exp(3 * t) - 0.25 * jax.numpy.exp(t) + \
        0.2 * jax.numpy.cos(t) + 0.1 * jax.numpy.sin(t)
    y_2 = 0.05 * jax.numpy.exp(3 * t) + 0.25 * jax.numpy.exp(t) + \
        - 0.3 * jax.numpy.cos(t) + 0.4 * jax.numpy.sin(t)
    return jax.numpy.array([y_1, y_2])

In [121]:
time = jax.numpy.linspace(0, 1, 100)
data = y(time) + jax.random.normal(jax.random.PRNGKey(0), (2, 100))

In [122]:
@jax.jit
def jax_loss(args):
    solution = ode.odeint(jax_ode, Y_INIT, time, args)
    sum_squares = (solution - data.T) ** 2
    return sum_squares

In [124]:
%%timeit
jax_loss(PARAMS)

15.6 µs ± 562 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [125]:
step_size = diffrax.PIDController(rtol=1.4e-8, atol=1.4e-8)
times = diffrax.SaveAt(ts=time)
t_initial = jax_times.min()
t_final = jax_times.max()

adjoint_controller = diffrax.PIDController(
    norm=diffrax.adjoint_rms_seminorm, 
    rtol=1.4e-8, atol=1.4e-8)

grad = diffrax.BacksolveAdjoint(
    stepsize_controller=adjoint_controller,
    solver=diffrax.Dopri5())

In [126]:
@jax.jit
def diffrax_loss(args):
    solution = diffrax.diffeqsolve(
        t0=t_initial, t1=t_final, dt0=None, terms=diffrax_ode, 
        stepsize_controller=step_size, saveat=times, adjoint=grad,
        solver=diffrax.Dopri5(), args=args, y0=Y_INIT)
    sum_squares = (solution.ys - data.T) ** 2
    return sum_squares

In [129]:
%%timeit
diffrax_loss(PARAMS)

The slowest run took 13.32 times longer than the fastest. This could mean that an intermediate result is being cached.
413 µs ± 519 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
