In [1]:
import diffrax
import jax
from jax.experimental import ode
jax.config.update("jax_enable_x64", True)

In [2]:
PARAMS = jax.numpy.array([1.0, 1.0], dtype=jax.numpy.float64)
MATRIX = jax.numpy.array([[2.0, 1.0], [1.0, 2.0]], dtype=jax.numpy.float64)
Y_INIT = jax.numpy.array([0.0, 0.0], dtype=jax.numpy.float64)
TIMES = jax.numpy.linspace(0, 1, 100)



In [41]:
def my_ode(t, y, args):
    return MATRIX @ y + jax.numpy.array([jax.numpy.exp(t), jax.numpy.sin(t)])

In [42]:
@jax.jit
def diffrax_loss(times):
    solution = diffrax.diffeqsolve(
        terms=diffrax.ODETerm(my_ode),
        solver=diffrax.Dopri5(),
        t0=TIMES.min(),
        t1=TIMES.max(),
        dt0=None,
        y0=Y_INIT,
        stepsize_controller=diffrax.PIDController(rtol=1.4e-8, atol=1.4e-8),
        saveat=diffrax.SaveAt(ts=times),
        adjoint=diffrax.BacksolveAdjoint(),
    )
    return solution.ys

In [47]:
%%timeit
other_test = diffrax_loss(TIMES)

105 µs ± 9.69 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [9]:
def jax_ode(y, t, args):
    return MATRIX @ y + jax.numpy.array([jax.numpy.exp(t), jax.numpy.sin(t)])

In [44]:
@jax.jit
def jax_loss(times):
    _my_ode = lambda y, t, args: my_ode(t, y, 0)
    solution = ode.odeint(_my_ode, Y_INIT, times, 0)
    return jax.numpy.sum(solution)

In [18]:
@jax.jit
def jax_loss(time):
    solution = ode.odeint(jax_ode, Y_INIT, time, 0)
    return solution

In [46]:
%%timeit
test = jax_loss(TIMES)

22.3 µs ± 696 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
