In [10]:
import diffrax
import jax

In [48]:
@jax.tree_util.Partial
def vector_field(t, y, args): 
    return - args * y

In [96]:
def construct_solver(derivative, y_initial, time, solver):
    term = diffrax.ODETerm(derivative)
    saveat = diffrax.SaveAt(ts=time)
    stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5)

    t0 = time.min()
    t1 = time.max()
    dt0 = time[1] - time[0]
    
    @jax.tree_util.Partial
    @jax.jit
    def solve(args, /, term=term, solver=solver, saveat=saveat,\
        stepsize=stepsize_controller, t0=t0, t1=t1, dt0=dt0, y0=y_initial):
        sol = diffrax.diffeqsolve(args=args, terms=term, solver=solver, t0=t0, t1=t1, dt0=dt0, y0=y0, \
            saveat=saveat, stepsize_controller=stepsize)
        return sol.ys
    
    return solve

In [118]:
def construct_log_likelihood(func):
    key = jax.random.PRNGKey(0)
    data = jax.random.uniform(key, (4,))
    error = jax.random.normal(key, (4,))

    @jax.tree_util.Partial
    @jax.jit
    def log_likelihood(args, /, data=data, error=error):
        solution = func(args)
        chi_squared = (solution - data) ** 2 / error ** 2
        return - 0.5 * jax.numpy.sum(chi_squared)
    
    return log_likelihood

In [119]:
solve = construct_solver(vector_field, 1, jax.numpy.arange(4.0), diffrax.Dopri5())
likelihood = construct_log_likelihood(solve)

In [125]:
jacobian = jax.jit(jax.jacrev(likelihood))
hessian = jax.jit(jax.jacfwd(jax.jacrev(likelihood)))

In [124]:
%%timeit
likelihood(1.0)

255 µs ± 4.91 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [128]:
%%timeit
jacobian(1.0)

372 µs ± 27 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [130]:
%%timeit
hessian(1.0)

430 µs ± 62.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
