In [1]:
import ticktack                 # For the compilation of the model
import ode                      # As a method of solving the odes
import numpy as np              # For the basic functions that need to be written
from jax import numpy as jnp    # annoying shit that I wish I didn't hav to deal with 
import jax.experimental.ode as jax  # Jax odeint 
import pandas as pd             # Becuase I want to 
from time import process_time   # For timing the different ode implementations

In [2]:
parameters = np.zeros(6)            # Empty array to hold the parameters of the prodcution function 
parameters[0] = 7.044873503263437   # Mean of the sinusoidal production 
parameters[1] = 0.18                # Amplitude of the sinusoidal production 
parameters[2] = 11.0                # Period of the sinusoidal production 
parameters[3] = 1.25                # Phase of the sinusoidal production 
parameters[4] = 120.05769867244142  # Height of the super-gaussian 
parameters[5] = 12.0                # Width of the super-gaussian 

In [3]:
projection = np.zeros(11)   # An array to hold the projection of the production function 
projection[0] = 0.7         # Proportion projection into the stratosphere       
projection[1] = 0.3         # Proportion projection into the troposphere

In [4]:
def production(t):    
    """
    The production function defined as the sum of sinusoidal and super-gaussian terms 
    """
    return parameters[0] * (1 + parameters[1] * \
        jnp.sin(2 * jnp.pi / parameters[2] * t + parameters[3])) + \
        parameters[4] * jnp.exp(- (parameters[5] * (t - 775)) ** 16)

In [5]:
cbm = ticktack.load_presaved_model("Guttler14", production_rate_units="atoms/cm^2/s")
cbm.compile()   # Constructing the transfer operator 



In [6]:
u0 = np.array([135.76261605786132, 709.7591911307035,
    1191.489526709938, 3.97158546945107, 45158.65854589925,
    155.54703228960028, 634.290736403387, 423.5954241095565,
    1808.3343542055652, 665.933052175064, 7348.751626918585])

In [7]:
def dydx(y, t):
    """
    The gradient of the carbon box model.
    """
    return cbm._matrix @ y + production(t) * projection

This is going very very well. I need to go plumbing again though. this means that I can look through the depths of the ode implementation and destroy the extra return improving my ability to do things

The code below is quite rough. I need to remove the global variable using returns from the profile function as a tuple.

In [46]:
results = {    # Storing the data using a dictionary 
    "solver" : [],              # The library containing the solver 
    "function" : [],            # The function that was getting profiled 
    "time" : [],                # The mean run time of 10 trials 
    "time variance" : []       # The variance of the run time of 10 trials
}

def profile(func, args=[]):
    time_sample = np.zeros(10)  # Storing the trials
    
    for i in range(10):
        timer = process_time()                  # starting a timer 
        func(*args)                             # Running the model 
        time_sample[i] = process_time() - timer # Stopping the timer and storing 

    return np.mean(time_sample), np.var(time_sample)

namespaces = {ode : "ode", jax : "jax"} # For translation from string to module 

for namespace in namespaces:# Iterating through the namespaces of odeint function 
    mean, variance = profile(
        namespace.odeint, [dydx, u0, np.arange(760.0, 790.0, 1)])   # Profiling the speed

    results["function"].append("odeint") # The function that was profiled 
    results["solver"].append(namespaces[namespace]) # Stroing implementation 
    results["time"].append(mean)        # Storing the mean run time 
    results["time variance"].append(variance)# Storing the variance of the mean

    # profile()


SyntaxError: unmatched ')' (<ipython-input-46-c1c4cb502e76>, line 26)

In [36]:
results = pd.DataFrame(results) # Switching to a pandas data frame
results

Unnamed: 0,solver,function,time,time variance
0,ode,odeint,0.001006,1.067806e-07
1,jax,odeint,0.000425,6.939772e-08


In [45]:
from jax import jvp, grad
import jax.numpy as np

jvp(np.sin, (np.pi,), (1.0,))
# cos = grad(np.sin)
# cos(3.14159)

(DeviceArray(1.2246468e-16, dtype=float64, weak_type=True),
 DeviceArray(-1., dtype=float64, weak_type=True))

In [None]:
import matplotlib.pyplot as plt # Plotting library for visualisations
plt.style.use("dark_background")

time_space = np.arange(760.0, 790.0)    # times of the sampling 

jax_solution = jax.odeint(dydx, u0, np.arange(760.0, 790.0))[:, 2]
ode_solution = ode.odeint(dydx, u0, np.arange(760.0, 790.0))[0][:, 2]

#! I'm not sure that this is the best way to go about it
# jax_solution = (jax_solution - np.median(jax_solution)) / jax_solution
# ode_solution = (ode_solution - np.median(ode_solution)) / ode_solution

plt.plot(time_space, jax_solution, label="jax") # Plotting the entire data 
plt.plot(time_space, ode_solution, label="ode")
plt.legend()
plt.show()                      # Showing the plot 