In [80]:
import diffrax
import jax
import ticktack
import matplotlib.pyplot as pyplot

In [94]:
model = ticktack.load_presaved_model("Guttler15", production_rate_units="atoms/cm^2/s")
model.compile()

STEADY_PROD = model.equilibrate(target_C_14=707.)
STEADY_STATE = model.equilibrate(production_rate=STEADY_PROD)
PROD_COEFFS = model._production_coefficients
MATRIX = model._matrix
GROWTH = model
del model

In [95]:
@jax.tree_util.Partial
@jax.jit
def production(t, args):
    start_time, duration, phase, area = jax.numpy.array(args)
    middle = start_time + duration / 2.
    height = area / duration

    gauss = height * jax.numpy.exp(- ((t - middle) / (1. / 1.93516 * duration)) ** 16.)
    sine = 1.8803862513018528 + 0.18 * 1.8803862513018528 *\
        jax.numpy.sin(2 * jax.numpy.pi / 11 * t + phase * 2 * jax.numpy.pi / 11) 
        
    return (sine + gauss) * 3.747273140033743

In [96]:
@jax.tree_util.Partial
@jax.jit
def derivative(t, y, args, /, matrix=MATRIX, production=production, prod_coeffs=PROD_COEFFS): 
    
    ans = jax.numpy.matmul(matrix, y) 
    production_rate_constant = production(t, args)
    production_term = prod_coeffs * production_rate_constant 
    return ans + production_term 

In [97]:
@jax.tree_util.Partial
@jax.jit
def solve(y_initial, time, args, /, solver=diffrax.Dopri5(), dydx=derivative):
    term = diffrax.ODETerm(dydx)
    saveat = diffrax.SaveAt(ts=time)
    stepsize = diffrax.PIDController(rtol=1e-5, atol=1e-5)

    t0 = time.min()
    t1 = time.max()
    dt0 = (time[1] - time[0]) / 1000
    
    sol = diffrax.diffeqsolve(args=args, terms=term, solver=solver, t0=t0, t1=t1,\
        dt0=dt0, y0=y_initial, saveat=saveat, stepsize_controller=stepsize)
    
    return sol.ys[:, 1]

In [221]:
with open("miyake12.csv") as data:
    _ = next(data)  # String titles 
    data = jax.numpy.array([row.strip().split(" ") for row in data],\
        dtype=jax.numpy.float64)
    data = data.T

In [325]:
# @jax.tree_util.Partial(jax.jit, static_argnums=(3))
def bin_data(data, time_out, growth, oversample):
    masked = jax.numpy.linspace(0, 1, oversample)
    # So the kernel represents the fraction of decimals between 0 and 1 that are less than the length of the growth season normalised
    kernel = (masked < jax.numpy.count_nonzero(growth)/12)
    
    # These top two return the index of the first one and the first zero
    first1 = jax.numpy.where(growth == 1, size=1)[0][0]
    first0 = jax.numpy.where(growth == 0, size=1)[0][0]

    # Here we zone onto the start of wrapped seasons
    all1s = jax.numpy.where(growth == 1, size=12)[0]
    after1 = jax.numpy.where(all1s > first0, all1s, 0)
    after1 = after1.at[jax.numpy.nonzero(after1, size=1)].get()[0]

    # difference between the wrapped and non-wrapped 
    num = jax.lax.sub(first1, after1)

    # if the season is not wrapped then return the index of the season start.
    # if the seasoon is wrapped return the start and the difference
    val = jax.lax.cond(num == 0, lambda: first1, lambda: after1)
    # Make sure that the tree doesn't grow all year around 
    shifted_index = jax.lax.cond(jax.numpy.all(growth == 1), lambda: 0, lambda: val)

    @jax.tree_util.Partial
    @jax.jit
    def fun(i, val, /, oversample=oversample, data=data, kernel=kernel, shifted_index=shifted_index):
        # So val is the array that is getting updated here
        translated_index = i * oversample + shifted_index * oversample // 12
        year_from_index = jax.lax.dynamic_slice(data, (translated_index,), (oversample,))
        selected_data = jax.numpy.multiply(year_from_index, kernel)
        year_mean = jax.numpy.array([jax.numpy.sum(selected_data) / (jax.numpy.sum(kernel))])

        # So this will update val 
        solution = jax.lax.dynamic_update_slice(val, year_mean, (i,))
        return solution 

    binned_data = jax.numpy.zeros((len(time_out),))
    binned_data = jax.lax.fori_loop(0, len(time_out), fun, binned_data)

    return binned_data
    

In [334]:
import jax.numpy as jnp
from jax.lax import cond, dynamic_slice, dynamic_update_slice, fori_loop
from functools import partial
from jax import jit

In [372]:
@partial(jit, static_argnums=(1))
def bin_data(data, time_oversample, time_out, growth):
    if data.ndim != 1:
        raise ValueError("Data is not one-dimensional! Data must be contained in one row. ")

    masked = jnp.linspace(0, 1, time_oversample)
    kernel = (masked < jnp.count_nonzero(growth)/12)
    
    @partial(jit)
    def _shifted_index_finder(seasons):
        first1 = jnp.where(seasons == 1, size=1)[0][0]
        first0 = jnp.where(seasons == 0, size=1)[0][0]
        all1s = jnp.where(seasons == 1, size=12)[0]
        after1 = jnp.where(all1s > first0, all1s, 0)
        after1 = after1.at[jnp.nonzero(after1, size=1)].get()[0]
        num = jax.lax.sub(first1, after1)
        val = cond(num == 0, lambda x: first1, lambda x : after1, num)
        act = cond(jnp.all(seasons == 1), lambda x: 0, lambda x: val, seasons)
        return act

    shifted_index = _shifted_index_finder(growth)

    binned_data = _rebin1D(time_out, shifted_index, time_oversample, kernel, data)
    return binned_data


@partial(jit, static_argnums=(2))
def _rebin1D(time_out, shifted_index, oversample, kernel, s):
    binned_data = jnp.zeros((len(time_out),))
    fun = lambda i, val: dynamic_update_slice(val, jnp.array([jnp.sum(jnp.multiply(dynamic_slice(
        s, (i * oversample + shifted_index * oversample // 12,), (oversample,)), kernel)) / (
                                                                    jnp.sum(kernel))]), (i,))
    binned_data = fori_loop(0, len(time_out), fun, binned_data)
    return binned_data

In [373]:
oversample = 1000
time_out = jax.numpy.linspace(data[0].min(), data[0].max(), oversample)
growth = jax.numpy.array([0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0], dtype=jax.numpy.float64)

In [374]:
solution = solve(STEADY_STATE, time_out, parameters)

In [375]:
bin_data(solution, oversample, time_out, growth)

DeviceArray([707.35638888, 707.35638888, 707.35638888, 707.35638888,
             707.35638888, 707.35638888, 707.35638888, 707.35638888,
             707.35638888, 707.35638888, 707.35638888, 707.35638888,
             707.35638888, 707.35638888, 707.35638888, 707.35638888,
             707.35638888, 707.35638888, 707.35638888, 707.35638888,
             707.35638888, 707.35638888, 707.35638888, 707.35638888,
             707.35638888, 707.35638888, 707.35638888, 707.35638888,
             707.35638888, 707.35638888, 707.35638888, 707.35638888,
             707.35638888, 707.35638888, 707.35638888, 707.35638888,
             707.35638888, 707.35638888, 707.35638888, 707.35638888,
             707.35638888, 707.35638888, 707.35638888, 707.35638888,
             707.35638888, 707.35638888, 707.35638888, 707.35638888,
             707.35638888, 707.35638888, 707.35638888, 707.35638888,
             707.35638888, 707.35638888, 707.35638888, 707.35638888,
             707.35638888, 707.356

In [144]:
@jax.tree_util.Partial
@jax.jit
def log_likelihood(args, /, data=data, func=solve, y0=STEADY_STATE):
    solution = func(y0, data[0], args)
    solution = (solution - STEADY_STATE[1]) / STEADY_STATE[1]
    solution += jax.numpy.mean(data[1][:4])
    chi_squared = (solution - data[1]) ** 2 / data[2] ** 2
    return - 0.5 * jax.numpy.sum(chi_squared)
    

In [145]:
parameters = (774.86, 0.25, 0.8, 6.44)

In [146]:
log_likelihood(parameters)

DeviceArray(-397.78775581, dtype=float64)

In [127]:
%%timeit
log_likelihood(parameters)

476 µs ± 49.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [147]:
jac = jax.jit(jax.jacobian(log_likelihood))

In [148]:
jac(parameters)

(DeviceArray(-1.00546032, dtype=float64, weak_type=True),
 DeviceArray(-0.48309617, dtype=float64, weak_type=True),
 DeviceArray(-0.02440878, dtype=float64, weak_type=True),
 DeviceArray(0.1201897, dtype=float64, weak_type=True))

In [143]:
%%timeit
jac(parameters)

4.77 ms ± 175 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [136]:
hes = jax.jit(jax.hessian(log_likelihood))

In [139]:
%%timeit
hes(parameters)

44.7 ms ± 2.55 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
