In [2]:
import diffrax
import jax
import ticktack

PARAMS = (774.86, 0.25, 0.8, 6.44)

STEADY_PROD = 1.8803862513018528

STEADY_STATE = jax.numpy.array(
    [1.34432991e+02, 7.07000000e+02, 1.18701144e+03,
    3.95666872e+00, 4.49574232e+04, 1.55056740e+02,
    6.32017337e+02, 4.22182768e+02, 1.80125397e+03,
    6.63307283e+02, 7.28080320e+03], 
    dtype=jax.numpy.float64)

PROD_COEFFS = jax.numpy.array(
    [0.7, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 
    dtype=jax.numpy.float64)

MATRIX = jax.numpy.array([
    [-0.509, 0.009, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.508, -0.44, 0.068, 0.0, 0.0, 0.545, 0.0, 0.167, 0.002, 0.002, 0.0],
    [0.0, 0.121, -0.155, 12.0, 0.001, 0.0, 0.0, 0.003, 0.0, 0.0, 0.0],
    [0.0, 0.0, 4.4000e-02, -1.3333e+01, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.042, 1.333, -0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.229, 0.0, 0.0, 0.0, -1.046, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.136, -0.033, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.364, 0.033, -0.183, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01, -0.002, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003, 0.0, -0.002, 0.0],
    [0.0, 0.0, 3.333e-04, 0.0, 5.291e-06, 0.0, 0.0, 0.0, 0.0, 4.0e-04, -1.2340e-04]], 
    dtype=jax.numpy.float64)

@jax.jit 
def driving_term(t, args):
    start_time, duration, phase, area = jax.numpy.array(args)
    middle = start_time + duration / 2.
    height = area / duration

    gauss = height * \
        jax.numpy.exp(- ((t - middle) / (0.5 * duration)) ** 16.)
    sine = STEADY_PROD + 0.18 * STEADY_PROD *\
        jax.numpy.sin(2 * jax.numpy.pi / 11 * t + phase * 2 * jax.numpy.pi / 11)

    return (sine + gauss) * 3.747

@jax.jit
def jax_dydt(y, t, args, /, matrix=MATRIX, production=driving_term, 
                   prod_coeffs=PROD_COEFFS):
    ans = jax.numpy.matmul(matrix, y)
    production_rate_constant = production(t, args)
    production_term = prod_coeffs * production_rate_constant
    return ans + production_term

@jax.jit
def diffrax_dydt(t, y, args, /, matrix=MATRIX, production=driving_term, 
                 prod_coeffs=PROD_COEFFS):
    ans = jax.numpy.matmul(matrix, y)
    production_rate_constant = production(t, args)
    production_term = prod_coeffs * production_rate_constant
    return ans + production_term

time_out = jax.numpy.linspace(750, 800, 1000)

%%timeit
jax.experimental.ode.odeint(jax_dydt, STEADY_STATE, time_out, PARAMS)

term = diffrax.ODETerm(diffrax_dydt)
solver = diffrax.Bosh3()
step_size = diffrax.PIDController(rtol=1e-10, atol=1e-10)
save_time = diffrax.SaveAt(ts=time_out)

%%timeit
diffrax.diffeqsolve(args=PARAMS, terms=term, solver=solver, y0=STEADY_STATE,
                    t0=time_out.min(), t1=time_out.max(), dt0=0.01,
                    saveat=save_time, stepsize_controller=step_size, 
                    max_steps=10000)