In [31]:
# Numerical
from jax.numpy import exp, sin, pi, mean    # Arithematic
from jax.numpy import arange, array # Array creation routines
from jax.lax import scan 

# Miscellaneous
from functools import partial 
from ticktack import load_presaved_model

# ODEINTs
from ode import odeint as BS3
from jax.experimental.ode import odeint as DP5
from time import process_time

# Visualisation
from panda import DataFrame
from plotnine import ggplot, aes, theme_bw, geom_point

In [32]:
UNIT_FACTOR = 14.003242 / 6.022 * 5.11 * 31536. / 1.e5
PRODUCTION_RATE = 1.76

In [33]:
def odeint_linear(derivative, y0, timesteps, *args, **kwargs):
    step_size = timesteps[1] - timesteps[0]

    def step(y, time_value, step_size=None, derivative=None):
        next_time_value = y + derivative(y, time_value) * step_size
        return next_time_value, next_time_value

    step = partial(step, derivative=derivative, step_size=step_size)

    _, y_values = scan(f=step, init=y0, xs=timesteps) 
    return y_values

In [34]:
def production(t, *args, steady_state=PRODUCTION_RATE):
    start, duration, phase, area = array(list(args)).reshape(-1)

    middle = start + duration / 2.
    height = area / duration
    gaussian = height * exp(- ((t - middle) / (duration / 2)) ** 16.)
    sinusoid = 0.18 * steady_state * sin(2 * pi / 11 * t + phase)
    production = gaussian + sinusoid + steady_state
    
    return UNIT_FACTOR * production

In [35]:
carbon_box_model = load_presaved_model("Guttler14", production_rate_units="atoms/cm^2/s")
carbon_box_model.compile()

In [36]:
burn_in_time_out = arange(-225, 775)
burn_in_oversample = 1000
parameters = (774.86, 0.25, 0.8, 6.44)
growth_season = array([0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0])

In [37]:
# Running the burn-in
_, y_initial = carbon_box_model.run(burn_in_time_out, burn_in_oversample, \
    production, args=parameters, steady_state_production=PRODUCTION_RATE * UNIT_FACTOR)

In [38]:
time_out = arange(750, 800)

In [39]:
true_values, _ = carbon_box_model.run(time_out, burn_in_oversample, production,\
    y0=y_initial, args=parameters, steady_state_production=PRODUCTION_RATE * UNIT_FACTOR)
true_values = true_values[:, 1] # Selecting the troposphere

In [None]:
odeint = {"BS3": BS3, "DP5": DP5, "Linear": odeint_linear}

In [41]:
odeint_oversample_resids = {
    "Solver": [],
    "Oversample": [],
    "Time (s)": [],
    "Accuracy": []
}

In [42]:
for solver in odeint:
    for oversample in range(10, 100, 10):
        timer = process_time()
        solution, _ = carbon_box_model.run(time_out, oversample, \
            production, solver=odeint[solver], y0=y_initial, args=parameters, \
            steady_state_production=PRODUCTION_RATE * UNIT_FACTOR)
        timer = process_time() - timer

        # Calculating the residuals
        solution = solution[:, 1] # Selecting the troposphere values
        oversample_ratio = int(burn_in_oversample / oversample) # For finding comparative index
        local_true_value = true_values.at[::oversample_ratio]
        residuals = solution - oversample_ratio
        accuracy = mean(residuals)

        del solution
        del local_true_value
        del oversample_ratio
        del residuals

        odeint_oversample_resids["Solver"] = solver
        odeint_oversample_resids["Oversample"] = oversample
        odeint_oversample_resids["Time (s)"] = timer
        odeint_oversample_resids["Accuracy"] = accuracy

AttributeError: 'tuple' object has no attribute 'block_until_ready'

In [None]:
odeint_oversample_resids = DataFrame(odeint_oversample_resids)
(ggplot(odeint_oversample_resids, 
    aes(x="Time (s)", y="Accuracy", color="Oversample", group="Solver"))
    + theme_bw()
    + geom_point())