In [22]:
import diffrax
import jax
from jax.experimental import ode

In [24]:
PARAMS = (774.86, 0.25, 0.8, 6.44)

STEADY_PROD = 1.8803862513018528

STEADY_STATE = jax.numpy.array(
    [1.34432991e+02, 7.07000000e+02, 1.18701144e+03,
    3.95666872e+00, 4.49574232e+04, 1.55056740e+02,
    6.32017337e+02, 4.22182768e+02, 1.80125397e+03,
    6.63307283e+02, 7.28080320e+03], 
    dtype=jax.numpy.float64)

PROD_COEFFS = jax.numpy.array(
    [0.7, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 
    dtype=jax.numpy.float64)

MATRIX = jax.numpy.array([
    [-0.509, 0.009, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.508, -0.44, 0.068, 0.0, 0.0, 0.545, 0.0, 0.167, 0.002, 0.002, 0.0],
    [0.0, 0.121, -0.155, 12.0, 0.001, 0.0, 0.0, 0.003, 0.0, 0.0, 0.0],
    [0.0, 0.0, 4.4000e-02, -1.3333e+01, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.042, 1.333, -0.001, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.229, 0.0, 0.0, 0.0, -1.046, 0.0, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.136, -0.033, 0.0, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.364, 0.033, -0.183, 0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01, -0.002, 0.0, 0.0],
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003, 0.0, -0.002, 0.0],
    [0.0, 0.0, 3.333e-04, 0.0, 5.291e-06, 0.0, 0.0, 0.0, 0.0, 4.0e-04, -1.2340e-04]], 
    dtype=jax.numpy.float64)

DATA = jax.numpy.array(
    [[760.   , 761.   , 762.   , 763.   , 764.   , 765.   ,
      766.   , 767.   , 768.   , 769.   , 770.   , 771.   ,
      772.   , 773.   , 774.   , 775.   , 776.   , 777.   ,
      778.   , 779.   , 780.   , 781.   , 782.   , 783.   ,
      784.   , 785.   , 786.   , 787.   ],
     [-21.63 , -22.28 , -22.64 , -23.83 , -22.2  , -22.99 ,
      -20.73 , -21.59 , -25.32 , -25.6  , -25.7  , -24.   ,
      -23.73 , -21.91 , -23.44 ,  -9.335,  -6.46 ,  -9.7  ,
      -11.17 , -10.31 , -11.1  , -10.72 , -10.67 ,  -8.63 ,
       -9.68 ,  -9.31 , -12.33 , -14.44 ],
     [  1.8  ,   1.84 ,   2.02 ,   1.79 ,   1.77 ,   1.25 ,
        1.83 ,   1.79 ,   1.85 ,   1.705,   1.82 ,   1.72 ,
        1.26 ,   1.28 ,   1.73 ,   1.97 ,   1.3  ,   1.78 ,
        1.76 ,   1.33 ,   1.27 ,   1.735,   1.76 ,   1.74 ,
        1.76 ,   1.73 ,   1.77 ,   1.83 ]], 
    dtype=jax.numpy.float64)

In [25]:
@jax.jit 
def driving_term(t, args):
    start_time, duration, phase, area = jax.numpy.array(args)
    middle = start_time + duration / 2.
    height = area / duration

    gauss = height * \
        jax.numpy.exp(- ((t - middle) / (0.5 * duration)) ** 16.)
    sine = STEADY_PROD + 0.18 * STEADY_PROD *\
        jax.numpy.sin(2 * jax.numpy.pi / 11 * t + phase * 2 * jax.numpy.pi / 11)

    return (sine + gauss) * 3.747

In [26]:
@jax.jit
def jax_dydt(y, t, args, /, matrix=MATRIX, production=driving_term, 
                   prod_coeffs=PROD_COEFFS):
    ans = jax.numpy.matmul(matrix, y)
    production_rate_constant = production(t, args)
    production_term = prod_coeffs * production_rate_constant
    return ans + production_term

In [27]:
@jax.jit
def diffrax_dydt(t, y, args, /, matrix=MATRIX, production=driving_term, 
                 prod_coeffs=PROD_COEFFS):
    ans = jax.numpy.matmul(matrix, y)
    production_rate_constant = production(t, args)
    production_term = prod_coeffs * production_rate_constant
    return ans + production_term

In [28]:
time_out = jax.numpy.linspace(750, 800, 1000)

In [29]:
term = diffrax.ODETerm(diffrax_dydt)
solver = diffrax.Dopri5()
step_size = diffrax.PIDController(rtol=1.4e-8, atol=1.4e-8)
save_time = diffrax.SaveAt(ts=DATA[0])
adjoint_controller = diffrax.PIDController(norm=diffrax.adjoint_rms_seminorm, 
    rtol=1.4e-8, atol=1.4e-8)
adjoint = diffrax.BacksolveAdjoint(stepsize_controller=adjoint_controller,
    solver=diffrax.Dopri5())

In [52]:
@jax.jit
def jax_solve(y_initial, time, args, /, dydx=jax_dydt):
    states = ode.odeint(dydx, y_initial, time, args, atol=1e-10, rtol=1e-10)
    return states[:, 1]

In [30]:
@jax.jit
def diffrax_loss(args):
    solution = diffrax.diffeqsolve(args=args, terms=term, solver=solver, 
        y0=STEADY_STATE, t0=DATA[0].min(), t1=DATA[0].max(), dt0=None, 
        saveat=save_time, stepsize_controller=step_size, adjoint=adjoint)
    chi_sq = jax.numpy.sum((solution.ys[:, 1] - DATA[1]) ** 2 / DATA[2] ** 2)
    return -0.5 * chi_sq

In [55]:
@jax.jit 
def jax_loss(args, /, y0=STEADY_STATE, time=DATA[0], dc14=DATA[1], sig_dc14=DATA[2]):
    solution = jax_solve(y0, time, args)
    chi_sq = jax.numpy.sum((solution - dc14) ** 2 / sig_dc14 ** 2)
    return -0.5 * chi_sq

In [42]:
diffrax_jac = jax.jit(jax.grad(diffrax_loss))
jax_jac = jax.jit(jax.grad(jax_loss))

In [35]:
%%timeit
diffrax_jac(PARAMS)

3.14 ms ± 1.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [47]:
%%timeit
jax_jac(PARAMS)

648 ms ± 43.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [50]:
%%timeit
ode.odeint(jax_dydt, STEADY_STATE, DATA[0], PARAMS)

2.13 ms ± 30.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [59]:
%%timeit
jax_loss(PARAMS)

2.08 s ± 41.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [64]:
jax_solve(STEADY_STATE, DATA[0], PARAMS)

DeviceArray([707.     , 696.49384, 674.52985, 649.49756, 624.86725,
             601.99567, 581.2823 , 562.8076 , 546.5007 , 532.2323 ,
             519.79645, 508.92395, 499.25262, 490.41992, 482.0909 ,
             469.04684, 468.5465 , 460.33243, 450.65308, 441.07117,
             432.33908, 424.71466, 418.18393, 412.53754, 407.47372,
             402.70587, 398.0082 , 393.28372], dtype=float32)

In [67]:
diffrax.diffeqsolve(args=PARAMS, terms=term, solver=solver, 
        y0=STEADY_STATE, t0=DATA[0].min(), t1=DATA[0].max(), dt0=None, 
        saveat=save_time, stepsize_controller=step_size, adjoint=adjoint)

Solution(t0=DeviceArray(760., dtype=float32), t1=DeviceArray(787., dtype=float32), ts=DeviceArray([760., 761., 762., 763., 764., 765., 766., 767., 768., 769.,
             770., 771., 772., 773., 774., 775., 776., 777., 778., 779.,
             780., 781., 782., 783., 784., 785., 786., 787.],            dtype=float32, weak_type=True), ys=DeviceArray([[1.3443298e+02, 7.0700000e+02, 1.1870115e+03, 3.9566686e+00,
              4.4957422e+04, 1.5505675e+02, 6.3201733e+02, 4.2218277e+02,
              1.8012540e+03, 6.6330731e+02, 7.2808032e+03],
             [9.0297478e+01, 6.9649414e+02, 1.1817042e+03, 3.9011424e+00,
              4.4967418e+04, 1.5421484e+02, 6.3220599e+02, 4.2212173e+02,
              1.8018723e+03, 6.6324738e+02, 7.2808032e+03],
             [6.3622334e+01, 6.7453247e+02, 1.1749722e+03, 3.8793766e+00,
              4.4977137e+04, 1.5133296e+02, 6.3214374e+02, 4.2145432e+02,
              1.8024868e+03, 6.6318658e+02, 7.2808008e+03],
             [4.7192314e+01, 6.49497