In [136]:
import diffrax
import jax
from jax.experimental import ode
import ticktack

In [137]:
PARAMS = (774.86, 0.25, 0.8, 6.44)
model = ticktack.load_presaved_model(
    "Guttler15", production_rate_units="atoms/cm^2/s")
model.compile()

STEADY_PROD = model.equilibrate(target_C_14=707.)
STEADY_STATE = model.equilibrate(production_rate=STEADY_PROD)
PROD_COEFFS = model._production_coefficients
MATRIX = model._matrix

del model

In [138]:
@jax.tree_util.Partial
@jax.jit 
def driving_term(t, args):
    start_time, duration, phase, area = jax.numpy.array(args)
    middle = start_time + duration / 2.
    height = area / duration

    gauss = height * \
        jax.numpy.exp(- ((t - middle) / (0.5 * duration)) ** 16.)
    sine = STEADY_PROD + 0.18 * STEADY_PROD *\
        jax.numpy.sin(2 * jax.numpy.pi / 11 * t + phase * 2 * jax.numpy.pi / 11)

    return (sine + gauss) * 3.747

In [139]:
@jax.tree_util.Partial
@jax.jit
def jax_dydt(y, t, args, /, matrix=MATRIX, production=driving_term, 
        prod_coeffs=PROD_COEFFS):
    ans = jax.numpy.matmul(matrix, y)
    production_rate_constant = production(t, args)
    production_term = prod_coeffs * production_rate_constant
    return ans + production_term

In [154]:
@jax.tree_util.Partial
@jax.jit
def diffrax_dydt(t, y, args, /, matrix=MATRIX, production=driving_term, 
                 prod_coeffs=PROD_COEFFS):
    ans = jax.numpy.matmul(matrix, y)
    production_rate_constant = production(t, args)
    production_term = prod_coeffs * production_rate_constant
    return ans + production_term

In [155]:
time_out = jax.numpy.linspace(750, 800, 1000)

In [156]:
term = diffrax.ODETerm(diffrax_dydt)
solver = diffrax.Dopri5()
step_size = diffrax.PIDController(rtol=1.4e-8, atol=1.4e-8)
save_time = diffrax.SaveAt(ts=DATA[0])
adjoint_controller = diffrax.PIDController(norm=diffrax.adjoint_rms_seminorm, 
    rtol=1.4e-8, atol=1.4e-8)
adjoint = diffrax.BacksolveAdjoint(stepsize_controller=adjoint_controller,
    solver=diffrax.Dopri5())

In [157]:
@jax.jit
def diffrax_loss(args):
    solution = diffrax.diffeqsolve(args=args, terms=term, solver=solver, 
        y0=STEADY_STATE, t0=DATA[0].min(), t1=DATA[0].max(), dt0=None, 
        saveat=save_time, stepsize_controller=step_size, adjoint=adjoint)
    chi_sq = jax.numpy.sum((solution.ys[:, 1] - DATA[1]) ** 2 / DATA[2] ** 2)
    return -0.5 * chi_sq

In [160]:
diffrax_jac = jax.jit(jax.grad(diffrax_loss))
jax_jac = jax.jit(jax.grad(jax_loss))

In [161]:
%%timeit
diffrax_jac(PARAMS)

TypeError: body_fun output and input must have identical types, got
_State(y='ShapedArray(float64[11])', tprev='DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', tnext='ShapedArray(float64[])', made_jump='ShapedArray(bool[])', solver_state='ShapedArray(float64[11])', controller_state=('ShapedArray(bool[])', 'ShapedArray(bool[])', 'ShapedArray(float64[])', 'ShapedArray(float64[], weak_type=True)'), result='ShapedArray(int64[], weak_type=True)', num_steps='ShapedArray(int64[], weak_type=True)', num_accepted_steps='ShapedArray(int64[], weak_type=True)', num_rejected_steps='ShapedArray(int64[], weak_type=True)', saveat_ts_index='ShapedArray(int64[], weak_type=True)', ts='ShapedArray(float64[28], weak_type=True)', ys='ShapedArray(float64[28,11], weak_type=True)', save_index='ShapedArray(int64[], weak_type=True)', dense_ts=None, dense_infos=None, dense_save_index=None).

In [129]:
%%timeit
jax_jac(PARAMS)

915 ms ± 38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [109]:
%%timeit
legal_ode(jax_dydt, STEADY_STATE, DATA[0], PARAMS)

1.22 ms ± 56.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


So at this point I know that ode.odeint is faster when removed, but is recompiling every run for some fucking reason. I believe the recompilation is happening because I am calling a `jit`'ed function from within another `jit`'ed function. This behaviour can be avoided either using `static_argnums` or `jax.tree_util.Partial`.

I need to communicate with Ben and discuss the implications of these findings on the structure. So nothing came back at this point I am ready to riot. 