In [None]:
from jax.numpy import float64, zeros, mean, var 
from jax.experimental.ode import odeint as DP5
from ode import odeint as BS3
from pandas import read_csv
from time import process_time
from os import getcwd
from ticktack import load_presaved_model
from ticktack.fitting import SingleFitter

In [None]:
carbon_box_model = load_presaved_model("Guttler14", production_rate_units="atoms/cm^2/s")
fitting_object = SingleFitter(carbon_box_model, "Guttler14")

In [None]:
def odeint_linear(derivative, y0, timesteps, *args, **kwargs):
    step_size = timesteps[1] - timesteps[0]

    def step(y, time_value, step_size=None, derivative=None):
        next_time_value = y + derivative(y, time_value) * step_size
        return next_time_value, next_time_value

    step = partial(step, derivative=derivative, step_size=step_size)
    step = jit(step)

    _, y_values = scan(f=step, init=y0, xs=timesteps) 
    return y_values

In [None]:
def profile(func, args=[]) -> tuple:
    """
    Profiles the function, func, with the arguments, args, returning the speed and the variance of the speed.
    """
    time_sample = zeros(10)  # Storing the trials
    
    for i in range(10):
        timer = process_time()                          # starting a timer 
        func(*args).block_until_ready()                                     # Running the model 
        time_sample = time_sample.at[i].set(process_time() - timer)   # Stopping the timer and storing 

    return mean(time_sample), var(time_sample)

In [None]:
results = {
    "module": [],
    "time": [],
    "type": []
}

odeints = {BS3 : "ode", DP5 : "jax", odeint_linear: "linear"} # For translation from string to module 

dataset = read_csv(f"{getcwd()}/datasets/775AD/NH/Miyake12_Cedar.csv") # Importing the miyake data

for namespace in odeints:# Iterating through the namespaces of odeint function 
    mean, variance = profile(carbon_box_model.run, [time_out, 100, \
                production, odeint[solver], y_initial, parameters, \
                PRODUCTION_RATE * UNIT_FACTOR]) )  # Profiling the speed
    results["module"].append(odeints[namespace])
    results["time"].append(float(mean) * 1000)
    results["type"].append("odeint")

    # Single variable function 
    loss = partial(loss_function, namespace.odeint, dataset, dydx, u0)

    # The gradients exist below 
    def gradient(p): 
        if namespaces[namespace] == "ode":
            return jacfwd(loss)(p)         # Calculating the gradient 
        else:
            return jacrev(loss)(p)

    mean, variance = profile(gradient, [parameters])

    results["module"].append(namespaces[namespace])
    results["time"].append(float(mean) * 1000)
    results["type"].append("gradient")

    def hessians(p): 
        if namespaces[namespace] == "ode":
            return jacfwd(jacfwd(loss))(p) # Calculating the hessian 
        else:
            return jacrev(jacrev(loss))(p)

    mean, variance = profile(hessians, [parameters])

    results["module"].append(namespaces[namespace])
    results["time"].append(float(mean) * 1000)
    results["type"].append("hessian")