In [None]:
# Numerical
from jax.numpy import mean, var, exp, newaxis   # Arithematic
from jax.numpy import array, zeros, arange  # Array creation routines
from jax.numpy import sum as vsum
from jax.numpy.linalg import solve, eig, norm
from jax.lax import scan 
from jax import jit, jacrev, jacfwd, vmap

# Miscellaneous
from functools import partial 
from time import process_time
from ticktack import load_presaved_model
from ticktack.fitting import SingleFitter

# ODEINTs
from ticktack.bogacki_shampine import odeint as BS3
from ticktack.dormand_prince import odeint as DP5

# Visualisation
from pandas import DataFrame
from plotnine import *

In [None]:
def construct_analytic_solution():
    """
    This creates a closure environment that pre-calculates the eigenvalues of the transfer matrix and the coefficients that produce the approriate initial state.
    """
    cbm = load_presaved_model("Guttler14", production_rate_units="atoms/cm^2/s")
    cbm.compile()

    eigenvalues, eigenvectors = eig(cbm._matrix)
    soln = solve(eigenvectors, cbm._production_coefficients)
    # soln = 1 / norm(soln) * soln    # Normalising the impulse 

    def analytic_solution(time, start, area, coeffs=None, eigenvals=None, eigenvecs=None):
        """
        This is the analytic solution itself.
        
        Parameters: 
            time: Array -> The time values that the solution is to be evaluated at 
            coeffs: Array -> The coefficients associated with a particular initial condition. 
            eigenvals: Array -> The eigenvalues of the transfer matrix used to construct the linear superposition of the eigenfunctions
            eigenvecs: Array[Array] -> The eigenvectors of the transfer matrix used to construct the linear superposition of the eigenfunctions.
        
        Returns:
            Array -> The analytic solution evaluated at times
        """
        @jit
        def sum_util(coeff, eigenval, eigenvec, /, time=time, start=start, area=area):
            return area * coeff * exp(eigenval * (time - start))[:, newaxis] * eigenvec

        # Need to use jax.numpy.sum and jax.vmap here
        vsum(vmap(sum_util)(coeffs, eigenvals, eigenvecs.T), axis=0)

    impulse_response = partial(analytic_solution, coeffs=soln,
        eigenvals=eigenvalues, eigenvecs=eigenvectors)

    return jit(impulse_response)

In [None]:
analytic