In [95]:
import jax.experimental.ode as ode
import jax
import ticktack
import time

In [96]:
parameters = (774.86, 0.25, 0.8, 6.44)
model = ticktack.load_presaved_model("Guttler15", production_rate_units="atoms/cm^2/s")
model.compile()

STEADY_PROD = model.equilibrate(target_C_14=707.)
STEADY_STATE = model.equilibrate(production_rate=STEADY_PROD)
PROD_COEFFS = model._production_coefficients
MATRIX = model._matrix
GROWTH = model

del model

In [97]:
def profile(func, *args, **kwargs) -> tuple:
    time_sample = jax.numpy.zeros(10)  # Storing the trials
    
    for i in range(10):
        timer = time.process_time()             # Starting a timer 
        solution = func(*args, **kwargs)    # Running the model 
        timer = time.process_time() - timer     # Stopping the timer

        time_sample = time_sample.at[i].set(timer)  # Storing the timer 

    return {
        "average": float(jax.numpy.mean(time_sample.at[1:].get())), 
        "solution": solution
        }

In [98]:
@jax.tree_util.Partial
@jax.jit
def production(t, args):
    start_time, duration, phase, area = jax.numpy.array(args)
    middle = start_time + duration / 2.
    height = area / duration

    gauss = height * jax.numpy.exp(- ((t - middle) / (1. / 1.93516 * duration)) ** 16.)
    sine = 1.8803862513018528 + 0.18 * 1.8803862513018528 *\
        jax.numpy.sin(2 * jax.numpy.pi / 11 * t + phase * 2 * jax.numpy.pi / 11) 
        
    return (sine + gauss) * 3.747273140033743

In [99]:
@jax.tree_util.Partial
@jax.jit
def derivative(y, t, args, /, matrix=MATRIX, production=production, prod_coeffs=PROD_COEFFS): 
    
    ans = jax.numpy.matmul(matrix, y) 
    production_rate_constant = production(t, args)
    production_term = prod_coeffs * production_rate_constant 
    return ans + production_term 

In [100]:
@jax.tree_util.Partial
@jax.jit
def solve(y_initial, time, args, /, dydx=derivative):
    states = ode.odeint(dydx, y_initial, time, args)
    return states[:, 1]

In [101]:
with open("miyake12.csv") as data:
    _ = next(data)  # String titles 
    data = jax.numpy.array([row.strip().split(" ") for row in data],\
        dtype=jax.numpy.float64)
    data = data.T

In [102]:
@jax.tree_util.Partial
@jax.jit
def bin_data(data, kernel):    
    data = data.reshape(-1, kernel.shape[0])
    data *= kernel
    data = jax.numpy.sum(data, axis=1) / jax.numpy.sum(kernel)
    return data

In [103]:
growth = jax.numpy.array([0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=jax.numpy.float64)

In [104]:
def get_growth_kernel(oversample, /, growth=growth):
    kernel = jax.numpy.diag(growth) @ jax.numpy.ones((12, oversample // 12), dtype=jax.numpy.float64)
    kernel = kernel.flatten()
    return kernel

In [112]:
@jax.tree_util.Partial
@jax.jit
def log_likelihood(args, /, data=data, func=solve, bin=bin_data, y0=STEADY_STATE, time_out=None, kernel=None):

    solution = func(y0, time_out, args)
    solution = bin(solution, kernel)
    solution = (solution - STEADY_STATE[1]) / STEADY_STATE[1]
    solution += jax.numpy.mean(data[1][:4])
    chi_squared = (solution - data[1]) ** 2 / data[2] ** 2
    return - 0.5 * jax.numpy.sum(chi_squared)

In [122]:
jac_binned = jax.jit(jax.grad(log_likelihood))
hes_binned = jax.jit(jax.jacobian(jac_binned))

In [126]:
oversample = 48
time_out = jax.numpy.linspace(data[0].min(), data[0].max() + 2, (data[0].size) * oversample)
kernel = get_growth_kernel(oversample)

In [128]:
%%timeit
small_oversample = solve(STEADY_STATE, time_out, parameters)
small_oversample = bin_data(small_oversample, kernel)

254 µs ± 3.52 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [129]:
%%timeit
log_likelihood(parameters, kernel=kernel, time_out=time_out)

238 µs ± 75.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [130]:
%%timeit
jac_binned(parameters, kernel=kernel, time_out=time_out)

4.36 ms ± 706 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [132]:
%%timeit
hes_binned(parameters, kernel=kernel, time_out=time_out)

327 ms ± 10.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
