In [9]:
import diffrax
import jax
import ticktack

In [42]:
model = ticktack.load_presaved_model("Guttler15", production_rate_units="atoms/cm^2/s")
model.compile()

STEADY_PROD = model.equilibrate(target_C_14=707.)
STEADY_STATE = model.equilibrate(production_rate=STEADY_PROD)
PROD_COEFFS = model._production_coefficients
MATRIX = model._matrix
conv_prod_rate = model._convert_production_rate

del model

In [8]:
@jax.tree_util.Partial
@jax.jit
def production(t, args):
    start_time, duration, phase, area = jax.numpy.array(args)
    middle = start_time + duration / 2.
    height = area / duration

    gauss = height * jax.nunpy.exp(- ((t - middle) / (1. / 1.93516 * duration)) ** 16.)
    sine = 1.8803862513018528 + 0.18 * 1.8803862513018528 *\
        jax.numpy.sin(2 * jax.numpy.pi / 11 * t + phase * 2 * jax.numpy.pi / 11) 
        
    return sine + gauss

In [48]:
@jax.tree_util.Partial
@jax.jit
def derivative(t, y, args, /, matrix=MATRIX, prod=STEADY_PROD, production=production,\
    conv_rate=conv_prod_rate, prod_coeffs=PROD_COEFFS): 
    
    ans = jax.numpy.matmul(matrix, y) 
    production_rate_constant = production(t, args) - prod 
    production_rate_constant = conv_rate(production_rate_constant) 
    production_term = prod_coeffs * production_rate_constant 
    return ans + production_term 

In [96]:
@jax.tree_util.Partial
@jax.jit
def solve(y_initial, time, args, /, solver=diffrax.Dopri5(), dydx=derivative):
    term = diffrax.ODETerm(dydx)
    saveat = diffrax.SaveAt(ts=time)
    stepsize = diffrax.PIDController(rtol=1e-5, atol=1e-5)

    t0 = time.min()
    t1 = time.max()
    dt0 = time[1] - time[0] / 1000
    
    sol = diffrax.diffeqsolve(args=args, terms=term, solver=solver, t0=t0, t1=t1,\
        dt0=dt0, y0=y_initial, saveat=saveat, stepsize_controller=stepsize)
    
    return sol.ys

In [118]:
with open("miyake12.csv") as data:
    _ = next(data)  # String titles 
    data = jax.numpy.array([row.strip().split(" ") for row in data],\
        dtype=jax.numpy.float64)
    data = data.T

In [None]:
@jax.tree_util.Partial
@jax.jit
def log_likelihood(args, /, data=data, func=solve, y0=STEADY_STATE):
    solution = func(y0, data[0], args)
    chi_squared = (solution - data[1]) ** 2 / data[2] ** 2
    return - 0.5 * jax.numpy.sum(chi_squared)
    

In [125]:
jacobian = jax.jit(jax.jacrev(likelihood))
hessian = jax.jit(jax.jacfwd(jax.jacrev(likelihood)))

In [124]:
%%timeit
likelihood(1.0)

255 µs ± 4.91 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [128]:
%%timeit
jacobian(1.0)

372 µs ± 27 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [130]:
%%timeit
hessian(1.0)

430 µs ± 62.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
