In [1]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import time
import numpy as np
import matplotlib.pyplot as plt

DTYPE_FP64 = jnp.float64


# Generate matrices
def generate_test_matrix(key, size, condition="normal"):
    if condition == "normal":
        return jax.random.normal(key, (size, size), dtype=DTYPE_FP64) / jnp.sqrt(size)
    elif condition == "stiff":
        A = jax.random.normal(key, (size, size), dtype=DTYPE_FP64)
        A = A - 10 * jnp.eye(size)
        return A / jnp.sqrt(size)
    elif condition == "skew":
        A = jax.random.normal(key, (size, size), dtype=DTYPE_FP64)
        return (A - A.T) / (2 * jnp.sqrt(size))



# Padé [13/13] coefficients
alpha = jnp.array([
    64764752532480000, 32382376266240000, 7771770303897600,
    1187353796428800, 129060195264000, 10559470521600,
    670442572800, 33522128640, 1323241920,
    40840800, 960960, 16380,
    182, 1
], dtype=DTYPE_FP64)



# Scaling
def scaling(A):
    A_norm = jnp.linalg.norm(A, ord=1)
    theta_13 = 5.371920351148152
    s_f = jnp.maximum(0.0, jnp.ceil(jnp.log2(A_norm / theta_13)))
    s = int(float(s_f))
    return A / (2.0 ** s), s



# Naive Padé
def expm_pade_naive(A):
    I = jnp.eye(A.shape[0], dtype=A.dtype)
    A_scaled, s = scaling(A)

    U = jnp.zeros_like(A)
    V = jnp.zeros_like(A)
    Ak = I
    for k in range(14):
        if k > 0:
            Ak = Ak @ A_scaled
        if k % 2 == 0:
            V = V + alpha[k] * Ak
        else:
            U = U + alpha[k] * Ak

    R = jnp.linalg.solve(V - U, V + U)
    for _ in range(s):
        R = R @ R
    return R, s



# Horner method
def expm_pade_horner(A):
    I = jnp.eye(A.shape[0], dtype=A.dtype)
    A_scaled, s = scaling(A)

    # Horner in X = A^2
    beta  = [alpha[2*i]   for i in range(7)]  # even coefficients
    gamma = [alpha[2*i+1] for i in range(7)]  # odd coefficients
    X = A_scaled @ A_scaled

    # Even
    p_even = beta[-1] * I
    for coeff in reversed(beta[:-1]):
        p_even = X @ p_even + coeff * I
    # Odd
    p_odd = gamma[-1] * I
    for coeff in reversed(gamma[:-1]):
        p_odd = X @ p_odd + coeff * I

    pA = p_even + A_scaled @ p_odd  # numerator
    qA = p_even - A_scaled @ p_odd  # denominator

    R = jnp.linalg.solve(qA, pA)
    for _ in range(s):
        R = R @ R
    return R, s



# PS v=3
def expm_pade_ps_v3(A):
    I = jnp.eye(A.shape[0], dtype=A.dtype)
    A_scaled, s = scaling(A)

    beta  = [alpha[2*i]   for i in range(7)]
    gamma = [alpha[2*i+1] for i in range(7)]

    # A even powers
    X = A_scaled @ A_scaled
    X2 = X @ X
    X3 = X2 @ X
    X6 = X3 @ X3

    # Blocks
    E0 = beta[0]*I + beta[1]*X + beta[2]*X2
    E1 = beta[3]*I + beta[4]*X + beta[5]*X2
    E2 = beta[6]*I

    O0 = gamma[0]*I + gamma[1]*X + gamma[2]*X2
    O1 = gamma[3]*I + gamma[4]*X + gamma[5]*X2
    O2 = gamma[6]*I

    Ue = E0 + (E1 @ X3) + (E2 @ X6)
    So = O0 + (O1 @ X3) + (O2 @ X6)
    Uo = A_scaled @ So

    R = jnp.linalg.solve(Ue - Uo, Ue + Uo)
    for _ in range(s):
        R = R @ R
    return R, s



# PS v=4
def expm_pade_ps_v4(A):
    I = jnp.eye(A.shape[0], dtype=A.dtype)
    A_scaled, s = scaling(A)

    beta  = [alpha[2*i]   for i in range(7)]
    gamma = [alpha[2*i+1] for i in range(7)]

    # A even powers
    X = A_scaled @ A_scaled
    X2 = X @ X
    X3 = X2 @ X
    X4 = X2 @ X2

    # Blocks
    E0 = beta[0]*I + beta[1]*X + beta[2]*X2 + beta[3]*X3
    E1 = beta[4]*I + beta[5]*X + beta[6]*X2

    O0 = gamma[0]*I + gamma[1]*X + gamma[2]*X2 + gamma[3]*X3
    O1 = gamma[4]*I + gamma[5]*X + gamma[6]*X2

    Ue = E0 + (E1 @ X4)
    So = O0 + (O1 @ X4)
    Uo = A_scaled @ So

    R = jnp.linalg.solve(Ue - Uo, Ue + Uo)
    for _ in range(s):
        R = R @ R
    return R, s



# Scipy expm for accuracy comparison
def scipy_expm(A):
    from scipy.linalg import expm
    return jnp.array(expm(np.array(A)))



# Benchmarking
def benchmark_method(method, A, num_runs=5000, warmup_runs=100):
    for _ in range(warmup_runs):
        _ = method(A)
    times = []
    for _ in range(num_runs):
        start = time.time()
        result = method(A)
        end = time.time()
        times.append(end - start)
    return result, np.mean(times), np.std(times)


# Compute relative error
def relative_error(result, reference):
    return jnp.linalg.norm(result - reference) / jnp.linalg.norm(reference)


# Experiments
def run_comparison():
    import jax
    import numpy as np

    key = jax.random.PRNGKey(42)
    sizes = [16, 32, 64, 128, 256]
    conditions = ["normal", "stiff", "skew"]

    methods = {
        "Naive Padé": expm_pade_naive,
        "Horner": expm_pade_horner,
        "PS ν=3": expm_pade_ps_v3,
        "PS ν=4": expm_pade_ps_v4,
    }

    results = {condition: {method: {} for method in methods} for condition in conditions}

    for size in sizes:
        for condition in conditions:
            key, subkey = jax.random.split(key)
            A = generate_test_matrix(subkey, size, condition)

            # Compute reference solution
            reference = scipy_expm(A)

            for method_name, method_func in methods.items():
                result, mean_time, std_time = benchmark_method(method_func, A)
                # Extract just the matrix result (ignore scaling parameter)
                if isinstance(result, tuple):
                    matrix_result = result[0]
                else:
                    matrix_result = result

                error = relative_error(matrix_result, reference)
                results[condition][method_name][size] = (mean_time * 1000, std_time * 1000, error)

    # Print results
    for condition in conditions:
        print(f"\nCondition: {condition}")
        for method_name in methods:
            print(f"  Method: {method_name}")
            for size in sizes:
                time, std, error = results[condition][method_name][size]
                print(f"    Size {size}x{size}: {time:.2f} ms ± {std:.2f} ms, rel_err: {error:.2e}")

In [2]:
run_comparison()


Condition: normal
  Method: Naive Padé
    Size 16x16: 10.01 ms ± 1.03 ms, rel_err: 3.39e-16
    Size 32x32: 9.93 ms ± 1.08 ms, rel_err: 5.36e-16
    Size 64x64: 10.13 ms ± 1.17 ms, rel_err: 6.52e-16
    Size 128x128: 10.16 ms ± 0.96 ms, rel_err: 1.42e-15
    Size 256x256: 10.18 ms ± 1.00 ms, rel_err: 1.66e-15
  Method: Horner
    Size 16x16: 8.71 ms ± 1.12 ms, rel_err: 3.24e-16
    Size 32x32: 8.94 ms ± 1.07 ms, rel_err: 5.32e-16
    Size 64x64: 8.90 ms ± 1.06 ms, rel_err: 6.28e-16
    Size 128x128: 8.86 ms ± 0.80 ms, rel_err: 1.45e-15
    Size 256x256: 9.11 ms ± 1.02 ms, rel_err: 1.57e-15
  Method: PS ν=3
    Size 16x16: 8.23 ms ± 1.01 ms, rel_err: 3.35e-16
    Size 32x32: 8.37 ms ± 1.02 ms, rel_err: 5.38e-16
    Size 64x64: 8.49 ms ± 1.08 ms, rel_err: 6.30e-16
    Size 128x128: 8.32 ms ± 0.94 ms, rel_err: 1.45e-15
    Size 256x256: 8.53 ms ± 1.00 ms, rel_err: 1.57e-15
  Method: PS ν=4
    Size 16x16: 7.84 ms ± 0.90 ms, rel_err: 3.25e-16
    Size 32x32: 8.06 ms ± 0.88 ms, rel_err: 5