In [3]:
import jax 
import ticktack
import matplotlib.pyplot as pyplot

In [4]:
model = ticktack.load_presaved_model("Guttler15", production_rate_units="atoms/cm^2/s")
model.compile()

MATRIX = model._matrix
PROJECTION = model._production_coefficients

del model



In [5]:
def load(filename: str):
    """
    A custom `JAX` file loading protocol designed to be very quick and return a value that is `JAX` transformable. 
    
    Parameters:
        filename: String -> The file address of the data
    Returns:
        DeviceArray -> The data in column major order
    """
    with open(filename) as data:    # Opening the data file
        _ = next(data)              # Header row for the data 

        data = jax.numpy.array(
            [row.strip().split(" ") for row in data], 
            dtype=jax.numpy.float64
        )
        
        return data.T

So basically I want to have my matrix exponential and then multiply it by $e^{t}$ which will give me the time series for the model

In [6]:
time_out = jax.numpy.linspace(0, 10, 100)

In [57]:
multiply = jax.jit(jax.vmap(jax.numpy.multiply, in_axes=(None, 0)))
matrix_multiply = jax.jit(jax.vmap(jax.numpy.matmul, in_axes=(0, None)))

So above I have the first method of computation. Now where do I go from here? I need to do the eigendecomposition and then construct the fundamental matrix of the system from scratch using it.

In [42]:
def construct_analytic_template(matrix=MATRIX, projection=PROJECTION):
    eigenvals, eigenvecs = jax.numpy.linalg.eig(matrix)
    eigenvecs, eigenvals = eigenvecs.real, eigenvals.real
    inverse = jax.numpy.linalg.inv(eigenvecs) 

    @jax.jit
    @jax.vmap
    def fundamental_matrix(time_out, /, eigenvecs=eigenvecs, eigenvals=eigenvals):
        return multiply(eigenvecs, jax.numpy.exp(eigenvals * time_out))

    @jax.tree_util.Partial
    @jax.jit
    def analytic_template(time_out, /, inverse=inverse, projection=projection):

        state_trans_mat = fundamental_matrix(time_out)
        state_trans_mat = matrix_multiply(state_trans_mat, inverse)
        return matrix_multiply(state_trans_mat, projection)
    
    return analytic_template


In [61]:
eigenvals, eigenvecs = jax.numpy.linalg.eig(MATRIX)
eigenvecs, eigenvals = eigenvecs.real, eigenvals.real
inverse = jax.numpy.linalg.inv(eigenvecs) 

@jax.tree_util.Partial
@jax.jit
def fundamental_matrix(time_out, /, eigenvecs=eigenvecs, eigenvals=eigenvals):
    return multiply(eigenvecs, jax.numpy.exp(multiply(eigenvals, time_out)))

@jax.tree_util.Partial
@jax.jit
def analytic_template(time_out, /, inverse=inverse, projection=PROJECTION):

    state_trans_mat = fundamental_matrix(time_out)
    state_trans_mat = matrix_multiply(state_trans_mat, inverse)
    return matrix_multiply(state_trans_mat, projection)

So I have now discovered the state transition matrix. The problem is I do not know how to compute this numerically. Ironically, this is what we were already doing, but I want to go faster. I have since discovered the fundamental matrix of the system which we can identify using the eigenfunctions of the system. I will have to compare these two methods based on the speed of the calculation after `jit` compilation

In [None]:
@jax.jit
def analytic_solution(time_out, /, matrix=MATRIX, projection=PROJECTION):
    """
    This is the analytic solution itself.
    Parameters: 
        
    Returns:
    """
    initial_position = 1.0 / jax.numpy.norm(projection) * projection



    impulse_solution = vmap_util(time_out)
    steady_solution = np.zeros((impulse_solution.shape))
    condition = (time_out > 774.86).reshape(-1, 1)
    
    return np.where(condition, impulse_solution, steady_solution)