In [1]:
import ticktack # For the compilation of the model
import ode      # As a method of solving the odes

from jax import jacfwd, jacrev          # Forward and reverse mode the autodiff
from jax import numpy as np             # Importing the jax numpy 
import jax.experimental.ode as jax_ode  # Jax odeint 

import pandas as pd             # Becuase I want to 
from time import process_time   # For timing the different ode implementations
from os import getcwd
from functools import partial
from plotnine import *
import matplotlib.pyplot as plt # Plotting library for visualisations
plt.style.use("dark_background")# Because dark mpde

In [2]:
#! Fix scope name shadowing 
parameters = np.array([ # Empty array to hold the parameters of the prodcution function 
    7.044873503263437,  # Mean of the sinusoidal production 
    0.18,               # Amplitude of the sinusoidal production 
    11.0,               # Period of the sinusoidal production 
    1.25,               # Phase of the sinusoidal production 
    120.05769867244142, # Height of the super-gaussian 
    12.0                # Width of the super-gaussian 
], dtype=np.float64)

INFO[2022-01-16 14:41:52,832]: Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
INFO[2022-01-16 14:41:52,832]: Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: "cuda". Available platform names are: Interpreter Host
INFO[2022-01-16 14:41:52,833]: Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.


In [3]:
projection = np.array([ # An array to hold the projection of the production function 
    0.7,    # Projection into the stratosphere 
    0.3,    # Projection into the troposphere 
    0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 
], dtype=np.float64)   

In [4]:
#! Type hints
def production(t, parameters):    
    """
    The production function defined as the sum of sinusoidal and super-gaussian terms 
    """
    return parameters[0] * (1 + parameters[1] * \
        np.sin(2 * np.pi / parameters[2] * t + parameters[3])) + \
        parameters[4] * np.exp(- (parameters[5] * (t - 775)) ** 16)

In [5]:
cbm = ticktack.load_presaved_model("Guttler14", production_rate_units="atoms/cm^2/s")
cbm.compile()   # Constructing the transfer operator 

In [6]:
u0 = np.array([   # Housing the initial C14 concentr    ations 
    135.76261605786132, # Initial stratosphere concentration
    709.7591911307035,  # Initial troposphere concentration
    1191.489526709938,  # Initial surface ocean concentration
    3.97158546945107,   # Initial surface biota concentration 
    45158.65854589925,  # Initial deep ocean concentration
    155.54703228960028, # Initial short-lived biota concentration
    634.290736403387,   # Initial long-lived biota concentration
    423.5954241095565,  # Initial litter concentration
    1808.3343542055652, # Initial soil concentration 
    665.933052175064,   # Initial peat concentration
    7348.751626918585   # Initial sedimentary sink concentration 
], dtype=np.float64)

In [7]:
#! Need to improve the doc strings here
#? Look into jit compilation 
def dydx(y, t, p):
    """
    The gradient of the carbon box model.
    """
    return cbm._matrix @ y + production(t, p) * projection

In [8]:
def loss_function(odeint, args, dataset):
    """
    The bayesian error of the loss function. 
    """
    d14c = np.array([*dataset["d14c"]], dtype=np.float64)   # JAX array for manipulation
    year = np.array([*dataset["year"]], dtype=np.float64)   # JAX array for the year
    sig_d14c = np.array([*dataset["sig_d14c"]], dtype=np.float64)# JAx array for manipulation 

    dydx, u0, params = args

    simulation = odeint(dydx, u0, year, params)[:, 2]   # Running the simulation 
    simulation = 1000 * (simulation - u0[2]) / u0[2]    # Deviations from initial 
    simulation += np.mean(d14c[1:4])    # Adding the offset 

    return - 0.5 * np.sum(((d14c - simulation) / sig_d14c) ** 2)   # Log likelhood as chi-squared 

In [9]:
def profile(func, args=[]) -> tuple:
    """
    Profiles the function, func, with the arguments, args, returning the speed and the variance of the speed.
    """
    time_sample = np.zeros(10)  # Storing the trials
    
    for i in range(10):
        timer = process_time()                          # starting a timer 
        func(*args).block_until_ready()                                     # Running the model 
        time_sample = time_sample.at[i].set(process_time() - timer)   # Stopping the timer and storing 

    return np.mean(time_sample), np.var(time_sample)

So the problem is that the `*args` list is does not have the correct time_series information in it. I also need to majorly look at the binning

In [17]:
results = {
    "module": [],
    "odeint": []
    # "grad": [],
    # "hessian": [] 
}

namespaces = {ode : "ode", jax_ode : "jax"} # For translation from string to module 

dataset = pd.read_csv(f"{getcwd()}/datasets/775AD/NH/Miyake12_Cedar.csv") # Importing the miyake data

for namespace in namespaces:# Iterating through the namespaces of odeint function 
    mean, variance = profile(
        namespace.odeint, 
        [dydx, u0, np.arange(750.0, 820.0, 1/12), parameters]
    )   # Profiling the speed
    results["module"].append(namespaces[namespace])
    results["odeint"].append(mean)

    # Look into the use of `@partial` from `functools` and generally fix this shit up.
    # loss = lambda p: loss_function(
    #     namespace.odeint, 
    #     [dydx, u0, np.arange(750.0, 820.0, 1/12), p]
    # )

    # #* The gradients exist below 
    # def gradient(p): 
    #     if namespaces[namespace] == "ode":
    #         return jacfwd(loss)(p)         # Calculating the gradient 
    #     else:
    #         return jacrev(loss)(p)

    # mean, variance = profile(gradient, [parameters])
    # results[namespaces[namespace]].append(mean)

    # def hessians(p): 
    #     if namespaces[namespace] == "ode":
    #         return jacfwd(jacfwd(loss))(p) # Calculating the hessian 
    #     else:
    #         return jacrev(jacrev(loss))(p)

    # mean, variance = profile(hessians, [parameters])
    # results[namespaces[namespace]].append(mean)

In [19]:
results = pd.DataFrame(results)  # Switching to a pandas data frame

In [23]:
results

Unnamed: 0,module,odeint
0,ode,0.0011163700000002
1,jax,0.0011241246999997


TypeError: no numeric data to plot