In [149]:
import diffrax
import jax
from jax.experimental import ode

In [150]:
PARAMS = jax.numpy.array([1.0, 1.0], dtype=jax.numpy.float64)
MATRIX = jax.numpy.array([[2.0, 1.0], [1.0, 2.0]], dtype=jax.numpy.float64)
Y_INIT = jax.numpy.array([0.0, 0.0], dtype=jax.numpy.float64)

  lax._check_user_dtype_supported(dtype, "array")


In [151]:
@jax.jit
def driving_term(t, args):
    return jax.numpy.array([jax.numpy.exp(t), jax.numpy.sin(t)]) * args

In [152]:
@jax.jit
def jax_ode(y, t, args):
    return MATRIX @ y + driving_term(t, args)

In [153]:
@diffrax.ODETerm
@jax.jit
def diffrax_ode(t, y, args):
    return MATRIX @ y + driving_term(t, args)

In [154]:
@jax.jit
def y(t):
    y_1 = 0.05 * jax.numpy.exp(3 * t) - 0.25 * jax.numpy.exp(t) + \
        0.2 * jax.numpy.cos(t) + 0.1 * jax.numpy.sin(t)
    y_2 = 0.05 * jax.numpy.exp(3 * t) + 0.25 * jax.numpy.exp(t) + \
        - 0.3 * jax.numpy.cos(t) + 0.4 * jax.numpy.sin(t)
    return jax.numpy.array([y_1, y_2])

In [155]:
time = jax.numpy.linspace(0, 1, 100)
data = y(time) + jax.random.normal(jax.random.PRNGKey(0), (2, 100))

In [156]:
@jax.jit
def jax_loss(args, /, dydx=jax_ode, y0=Y_INIT, t=time, data=data.T):
    solution = ode.odeint(dydx, y0, t, args)
    sum_squares = (solution - data) ** 2
    return jax.numpy.sum(sum_squares)

In [157]:
%%timeit
jax_loss(PARAMS)

The slowest run took 53.17 times longer than the fastest. This could mean that an intermediate result is being cached.
184 µs ± 390 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [158]:
step_size = diffrax.PIDController(rtol=1.4e-8, atol=1.4e-8)
times = diffrax.SaveAt(ts=time)
t_initial = jax_times.min()
t_final = jax_times.max()

adjoint_controller = diffrax.PIDController(
    norm=diffrax.adjoint_rms_seminorm, 
    rtol=1.4e-8, atol=1.4e-8)

grad = diffrax.BacksolveAdjoint(
    stepsize_controller=adjoint_controller,
    solver=diffrax.Dopri5())

solver = diffrax.Dopri5()

In [159]:
@jax.jit
def diffrax_loss(args, /, t0=t_initial, t1=t_final, dydx=diffrax_ode, adjoint=grad,
                steps=step_size, saves=times, y0=Y_INIT, solver=solver, data=data.T):
    solution = diffrax.diffeqsolve(
        t0=t0, t1=t1, dt0=None, terms=dydx, stepsize_controller=steps, 
        saveat=saves, adjoint=adjoint, solver=solver, args=args, y0=y0)
    sum_squares = (solution.ys - data) ** 2
    return jax.numpy.sum(sum_squares)

In [160]:
%%timeit
diffrax_loss(PARAMS)

219 µs ± 99.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [161]:
jax_grad = jax.jit(jax.grad(jax_loss))
diffrax_grad = jax.jit(jax.grad(diffrax_loss))

In [163]:
%%timeit
jax_grad(PARAMS)

220 µs ± 1.78 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [166]:
%%timeit
diffrax_grad(PARAMS)

2.34 ms ± 38.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
