In [118]:
import jax 
import ticktack
import matplotlib.pyplot as pyplot

In [None]:
model = ticktack.load_presaved_model("Guttler15", production_rate_units="atoms/cm^2/s")
model.compile()

MATRIX = model._matrix
PROJECTION = model._production_coefficients

del model

In [7]:
def load(filename: str):
    """
    A custom `JAX` file loading protocol designed to be very quick and return a value that is `JAX` transformable. 
    
    Parameters:
        filename: String -> The file address of the data
    Returns:
        DeviceArray -> The data in column major order
    """
    with open(filename) as data:    # Opening the data file
        _ = next(data)              # Header row for the data 

        data = jax.numpy.array(
            [row.strip().split(" ") for row in data], 
            dtype=jax.numpy.float64
        )
        
        return data.T

So basically I want to have my matrix exponential and then multiply it by $e^{t}$ which will give me the time series for the model

In [69]:
time_out = jax.numpy.linspace(0, 10, 100)

In [114]:
matrix_exponential = jax.jit(jax.vmap(jax.scipy.linalg.expm))
multiply = jax.jit(jax.vmap(jax.numpy.multiply, in_axes=(None, 0)))
matrix_multiply = jax.jit(jax.vmap(jax.numpy.matmul, in_axes=(0, None)))

In [162]:
%%timeit
time_series = multiply(MATRIX, time_out)
state_transition_matrix = matrix_exponential(time_series)
analytic_solution = matrix_multiply(state_transition_matrix, PROJECTION)

2.56 ms ± 300 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


So above I have the first method of computation. Now where do I go from here? I need to do the eigendecomposition and then construct the fundamental matrix of the system from scratch using it.

In [152]:
eigenvals, eigenvecs = jax.numpy.linalg.eig(MATRIX)
eigenvecs, eigenvals = eigenvecs.real, eigenvals.real
inverse = jax.numpy.linalg.inv(eigenvecs) 

In [153]:
@jax.jit
@jax.tree_util.Partial(jax.vmap, out_axes=1)
def construct_fundamental_matrix_t(eigenvecs, eigenvals):
    return multiply(eigenvecs ,jax.numpy.exp(eigenvals * time_out))

In [161]:
%%timeit
partial_state_trans_mat = construct_fundamental_matrix_t(eigenvecs, eigenvals)
state_trans_mat = matrix_multiply(partial_state_trans_mat, inverse)
analytic_sol = matrix_multiply(state_trans_mat, PROJECTION)

55.3 µs ± 2.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


So I have now discovered the state transition matrix. The problem is I do not know how to compute this numerically. Ironically, this is what we were already doing, but I want to go faster. I have since discovered the fundamental matrix of the system which we can identify using the eigenfunctions of the system. I will have to compare these two methods based on the speed of the calculation after `jit` compilation

In [None]:
@jax.jit
def analytic_solution(time_out, /, matrix=MATRIX, projection=PROJECTION):
    """
    This is the analytic solution itself.
    Parameters: 
        
    Returns:
    """
    initial_position = 1.0 / jax.numpy.norm(projection) * projection



    impulse_solution = vmap_util(time_out)
    steady_solution = np.zeros((impulse_solution.shape))
    condition = (time_out > 774.86).reshape(-1, 1)
    
    return np.where(condition, impulse_solution, steady_solution)