In [1]:
import diffrax
import jax
from jax.experimental import ode
jax.config.update("jax_enable_x64", True)

In [2]:
PARAMS = jax.numpy.array([1.0, 1.0], dtype=jax.numpy.float64)
MATRIX = jax.numpy.array([[2.0, 1.0], [1.0, 2.0]], dtype=jax.numpy.float64)
Y_INIT = jax.numpy.array([0.0, 0.0], dtype=jax.numpy.float64)
TIMES = jax.numpy.linspace(0, 1, 100)



In [3]:
def my_ode(t, y, args):
    return MATRIX @ y + jax.numpy.array([jax.numpy.exp(t), jax.numpy.sin(t)])

In [6]:
@jax.jit
def diffrax_loss(times):
    solution = diffrax.diffeqsolve(
        terms=diffrax.ODETerm(my_ode),
        solver=diffrax.Dopri5(),
        t0=TIMES.min(),
        t1=TIMES.max(),
        dt0=None,
        y0=Y_INIT,
        stepsize_controller=diffrax.PIDController(rtol=1.4e-8, atol=1.4e-8),
        saveat=diffrax.SaveAt(ts=times),
        adjoint=diffrax.BacksolveAdjoint(),
    )
    return solution.ys

In [22]:
other_test = diffrax_loss(TIMES)

In [9]:
def jax_ode(y, t, args):
    return MATRIX @ y + jax.numpy.array([jax.numpy.exp(t), jax.numpy.sin(t)])

In [18]:
@jax.jit
def jax_loss(time):
    solution = ode.odeint(jax_ode, Y_INIT, time, 0)
    return solution

In [21]:
test = jax_loss(TIMES)

In [15]:
step_size = diffrax.PIDController(rtol=1.4e-8, atol=1.4e-8)
times = diffrax.SaveAt(ts=time)
t_initial = time.min()
t_final = time.max()

adjoint_controller = diffrax.PIDController(
    norm=diffrax.adjoint_rms_seminorm, 
    rtol=1.4e-8, atol=1.4e-8)

grad = diffrax.BacksolveAdjoint(
    stepsize_controller=adjoint_controller,
    solver=diffrax.Dopri5())

solver = diffrax.Dopri5()

In [16]:
@jax.jit
def diffrax_loss(args, /, t0=t_initial, t1=t_final, dydx=diffrax_ode, adjoint=grad,
                steps=step_size, saves=times, y0=Y_INIT, solver=solver, data=data.T):
    solution = diffrax.diffeqsolve(
        t0=t0, t1=t1, dt0=None, terms=dydx, stepsize_controller=steps, 
        saveat=saves, adjoint=adjoint, solver=solver, args=args, y0=y0)
    sum_squares = (solution.ys - data) ** 2
    return jax.numpy.sum(sum_squares)

In [17]:
%%timeit
diffrax_loss(PARAMS)

122 µs ± 54 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [18]:
jax_grad = jax.jit(jax.grad(jax_loss))
diffrax_grad = jax.jit(jax.grad(diffrax_loss))

In [19]:
%%timeit
jax_grad(PARAMS)

155 µs ± 61.5 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [166]:
%%timeit
diffrax_grad(PARAMS)

2.34 ms ± 38.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
