In [1]:
import os 
import time
os.environ["JAX_ENABLE_X64"] = "True"
# os.environ["JAX_DISABLE_JIT"] = "True"
# os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax
import jax.numpy as jnp
from jax.experimental import sparse
import functools

import matplotlib.pyplot as plt

import numpy as np

In [2]:
res_dim = 10
key = jax.random.PRNGKey(0)
density = 0.5
wrkey1, wrkey2 = jax.random.split(key)
spec_rad = 0.8

N_nonzero = int(res_dim**2 * density)
wr_indices = jax.random.choice(
    wrkey1,
    res_dim**2,
    shape=(N_nonzero,),
)
wr_vals = jax.random.uniform(
    wrkey2, shape=N_nonzero, minval=-1, maxval=1
)
wr = jnp.zeros(res_dim * res_dim)
wr = wr.at[wr_indices].set(wr_vals)
wr = wr.reshape(res_dim, res_dim)
wr_dense = wr * (spec_rad / jnp.max(jnp.abs(jnp.linalg.eigvals(wr))))
wr_sparse = sparse.BCOO.fromdense(wr_dense)

In [3]:
def create_wr(res_dim, density):
    key = jax.random.PRNGKey(0)
    wrkey1, wrkey2 = jax.random.split(key)
    N_nonzero = int(res_dim**2 * density)
    wr_indices = jax.random.choice(
        wrkey1,
        res_dim**2,
        shape=(N_nonzero,),
    )
    wr_vals = jax.random.uniform(
        wrkey2, shape=N_nonzero, minval=-1, maxval=1
    )
    wr = jnp.zeros(res_dim * res_dim)
    wr = wr.at[wr_indices].set(wr_vals)
    wr = wr.reshape(res_dim, res_dim)
    return wr

In [4]:
# def max_eig_arnoldi_np(A, tol=1e-12, max_iters=1000, seed=0):
#     n = A.shape[0]
#     np.random.seed(seed)
#     b = np.random.rand(n)
#     b = b / np.linalg.norm(b)

#     V = np.zeros((n, max_iters + 1))
#     H = np.zeros((max_iters + 1, max_iters))
#     V[:, 0] = b

#     for j in range(max_iters):
#         w = A @ V[:, j]
        
#         for i in range(j + 1):
#             # H[i, j] = np.dot(V[:, i], w)
#             H[i, j] = V[:, i].conj().T @ w
#             w = w - H[i, j] * V[:, i]
        
#         H[j + 1, j] = np.linalg.norm(w)

        
#         V[:, j + 1] = w / H[j + 1, j]
        
#         # Compute eigenvalues of the current H (top-left (j+1)x(j+1) block)
#         eigvals = np.linalg.eigvals(H[:j+1, :j+1])
#         lambda_max = eigvals[np.argmax(np.abs(eigvals))]

#         # Optional early convergence check (change in largest Ritz value)
#         if j > 0:
#             delta = abs(lambda_max - prev_lambda)
#             if delta < tol:
#                 break
#         prev_lambda = lambda_max
    
#     print("Number of iterations / max_iters:", j + 1 , "/", max_iters)

#     return lambda_max


# A = np.random.randn(res_dim, res_dim)
# tol=1e-12
# max_iters=100
# seed=0
# print("Max eigenvalue (arnoldi):", np.abs(max_eig_arnoldi_np(A, tol=tol, max_iters=max_iters, seed=seed)))
# print("max eigenvalue (check):", np.max(np.abs(np.linalg.eigvals(A))))

In [5]:
### naive jax version without lax primitives
def max_eig_arnoldi_jax(A, tol=1e-12, max_iters=1000, seed=0):
    n = A.shape[0]
    key = jax.random.PRNGKey(seed)
    b = jax.random.uniform(key, shape=(n,))
    b = b / jnp.linalg.norm(b)

    V = jnp.zeros((n, max_iters + 1))
    H = jnp.zeros((max_iters + 1, max_iters))
    V = V.at[:, 0].set(b)

    lambda_max = 3

    for j in range(max_iters):
        w = A @ V[:, j]
        
        for i in range(j + 1):
            H = H.at[i, j].set(V[:, i].conj().T @ w)
            w = w - H[i, j] * V[:, i]
        
        H = H.at[j + 1, j].set(jnp.linalg.norm(w))
        if H[j + 1, j] < tol:
            break

        V = V.at[:, j + 1].set(w / H[j + 1, j])
        
        # Compute eigenvalues of the current H (top-left (j+1)x(j+1) block)
        eigvals = jnp.linalg.eigvals(H[:j+1, :j+1])
        lambda_max = eigvals[jnp.argmax(jnp.abs(eigvals))]

        # Optional early convergence check (change in largest Ritz value)
        if j > 0:
            delta = abs(lambda_max - prev_lambda)
            if delta < tol:
                break
        prev_lambda = lambda_max
    
    num_iters = j + 1
    return lambda_max, num_iters


# A = np.random.randn(res_dim, res_dim)
A = jax.random.normal(key, shape=(res_dim, res_dim))
tol=1e-16
max_iters=100
seed=0
print("Max eigenvalue (arnoldi):", np.abs(max_eig_arnoldi_jax(A, tol=tol, max_iters=max_iters, seed=seed))[0])
print("max eigenvalue (check):", np.max(np.abs(np.linalg.eigvals(A))))

Max eigenvalue (arnoldi): 3.443780942023299
max eigenvalue (check): 3.443780942023293


In [6]:
@functools.partial(jax.jit, static_argnames=["max_iters"])
def arnoldi_jax(A, max_iters, tol=1e-12, seed=0):
    m = A.shape[0]
    n = min(max_iters, m)

    key = jax.random.PRNGKey(seed)
    b = jax.random.normal(key, (m,))
    q0 = b / jnp.linalg.norm(b)

    Q = jnp.zeros((m, n+1))
    H = jnp.zeros((n+1, n))

    Q = Q.at[:,0].set(q0)

    for k in range(n):
        v = A @ Q[:,k]
        for j in range(k+1):
            h_jk = jnp.dot(Q[:,j], v)
            H = H.at[j, k].set(h_jk)
            v = v - h_jk * Q[:,j]
        h_kplus1k = jnp.linalg.norm(v)
        # if h_kplus1k > 0:
        H = H.at[k+1, k].set(h_kplus1k)
        Q = Q.at[:,k+1].set(v / h_kplus1k)

    return Q[:,:n], H[:n,:n]


In [7]:
A = jax.random.normal(key, shape=(res_dim, res_dim))
Q, H = arnoldi_jax(A, max_iters = 15, tol = 1e-12)

In [8]:
jnp.max(jnp.abs(jnp.linalg.eigvals(A)))

Array(3.44378094, dtype=float64)

In [9]:
jnp.max(jnp.abs(jnp.linalg.eigvals(H)))

Array(3.44378094, dtype=float64)

In [10]:
@functools.partial(jax.jit, static_argnames=["max_iters"])
def arnoldi_jax2(A, max_iters, tol=1e-12, seed=0):
    # A is m x m; n is the size of the krylov basis
    m = A.shape[0]
    n = min(max_iters, m)

    # choose a random vector to start iterating on 
    key = jax.random.PRNGKey(seed)
    b = jax.random.normal(key, (m,))
    q0 = b / jnp.linalg.norm(b)

    # init krylov basis Q and hessenberg matrix H
    Q = jnp.zeros((m, n+1))
    H = jnp.zeros((n+1, n))
    Q = Q.at[:,0].set(q0)

    # modified gs step to form orth krylov basis
    def gs_step(carry, _):
        v, j, Q = carry
        h_jk = jnp.dot(Q[:,j], v)
        v = v - h_jk*Q[:,j]
        j = j+1
        carry = (v,j,Q)
        return carry, h_jk

    for k in range(n):
        # new candidate vector 
        v = A @ Q[:,k]
        
        # run modified gs
        final_carry, h_jk_vals = jax.lax.scan(gs_step, (v, 0, Q), length=k+1)
        v = final_carry[0] # orgogonalized candidate vector
        h_kplus1k = jnp.linalg.norm(v) # H subdiag value

        # fill in a column of H (everything down to one below diagonal)
        idxs = jnp.arange(k+2)
        H_vals = jnp.concatenate([h_jk_vals, jnp.array([h_kplus1k])])
        H = H.at[idxs, k].set(H_vals)

        # fill in Q 
        Q = Q.at[:,k+1].set(v / h_kplus1k)

    return Q[:,:n], H[:n,:n]


In [11]:
A = jax.random.normal(key, shape=(1000, 1000))
Q, H = arnoldi_jax2(A, max_iters = 200, tol = 1e-12)

In [12]:
jnp.max(jnp.abs(jnp.linalg.eigvals(A)))

Array(32.95245833, dtype=float64)

In [13]:
jnp.max(jnp.abs(jnp.linalg.eigvals(H)))

Array(32.95245848, dtype=float64)

In [14]:
@functools.partial(jax.jit, static_argnames=["max_iters"])
def arnoldi_jax3(A, max_iters, tol=1e-12, seed=0):
    # A is m x m; n is the size of the krylov basis
    m = A.shape[0]
    n = max_iters

    # choose a random vector to start iterating on 
    key = jax.random.PRNGKey(seed)
    b = jax.random.normal(key, (m,))
    q0 = b / jnp.linalg.norm(b)

    # init krylov basis Q and hessenberg matrix H
    Q = jnp.zeros((m, n+1), dtype=A.dtype)
    H = jnp.zeros((n+1, n), dtype=A.dtype)
    Q = Q.at[:,0].set(q0)

    # modified gs step to form orth krylov basis
    def gs_step(carry, j):
        v, Q, mask = carry
        # Only apply when j is less than or equal to k (mask is True)
        h_jk = jnp.where(mask[j], jnp.dot(Q[:,j], v), 0.0)
        v = jnp.where(mask[j], v - h_jk*Q[:,j], v)
        return (v, Q, mask), h_jk
    
    def arnoldi_step(carry, k):
        A, Q, H = carry

        # new candidate vector
        v = A @ Q[:, k]

        # Create a mask for valid indices (0 to k)
        idx_mask = jnp.arange(n+1) <= k
        
        # run modified gs with fixed-size loop and masking
        final_carry_gs, h_jk_vals = jax.lax.scan(
            gs_step, 
            (v, Q, idx_mask), 
            jnp.arange(n+1)  # fixed size scan
        )
        v = final_carry_gs[0]  # orthogonalized candidate vector
        
        # Calculate h_kplus1k
        h_kplus1k = jnp.linalg.norm(v)
        
        # Update H column k using a mask
        col_indices = jnp.arange(n+1)
        mask = col_indices <= k  # For the first k+1 elements
        
        # Apply h_jk_vals for the first k+1 elements using the mask
        H_col = jnp.where(mask, h_jk_vals, H[:, k])
        
        # Set the k+1 element to h_kplus1k
        H_col = H_col.at[k+1].set(h_kplus1k)
        
        # Update the k-th column of H
        H = H.at[:, k].set(H_col)

        # Update Q[:, k+1]
        Q = Q.at[:,k+1].set(v / (h_kplus1k + tol))  # Add tol to prevent division by zero

        return (A, Q, H), None

    final_carry_arnoldi, _ = jax.lax.scan(arnoldi_step, (A, Q, H), jnp.arange(n))
    _, Q, H = final_carry_arnoldi

    return Q[:,:n], H[:n,:n]

In [15]:
A = jax.random.normal(key, shape=(1000, 1000))
Q, H = arnoldi_jax3(A, max_iters = 200, tol = 1e-12)

In [16]:
jnp.max(jnp.abs(jnp.linalg.eigvals(A)))

Array(32.95245833, dtype=float64)

In [17]:
jnp.max(jnp.abs(jnp.linalg.eigvals(H)))

Array(32.95245848, dtype=float64)

In [None]:
@functools.partial(jax.jit, static_argnames=("max_iters",))
def arnoldi_fast(A, tol=1e-12, max_iters=300, seed=0):
    # A is m x m; n is the size of the krylov basis
    m = A.shape[0]
    n = min(max_iters, m)

    # choose a random vector to start iterating on 
    key = jax.random.PRNGKey(seed)
    q0  = jax.random.normal(key, (m,))
    q0  = q0 / jnp.linalg.norm(q0)

    # init krylov basis Q and hessenberg matrix H
    Q = jnp.zeros((m, n + 1), dtype=A.dtype)
    H = jnp.zeros((n + 1, n), dtype=A.dtype)
    Q = Q.at[:, 0].set(q0)

    # run arnoldi one arnoldi step for an entire column of H
    col_idx = jnp.arange(n + 1)
    def arnoldi_col_step(carry, k):
        Q, H = carry

        # new candidate vector
        v = A @ Q[:, k]                       

        # orthogonalize in a batch 
        h_full = jnp.dot(Q.T, v) # all inner products
        h_mask = (col_idx <= k)                 
        h = jnp.where(h_mask, h_full, 0)  # zeros beyond k
        v = v - Q @ h
        beta = jnp.linalg.norm(v)

        # build the whole column 
        h_col= h.at[k+1].set(beta) # subdiag is normed   
        H  = H.at[:, k].set(h_col)           

        Q  = Q.at[:, k+1].set(v / beta) 
        return (Q, H), None

    (Q, H), _ = jax.lax.scan(arnoldi_col_step, (Q, H), jnp.arange(n))
    return Q[:, :n], H[:n, :n]


In [19]:
A = jax.random.normal(key, shape=(1000, 1000))
Q, H = arnoldi_fast(A, max_iters = 200, tol = 1e-12)

In [20]:
jnp.max(jnp.abs(jnp.linalg.eigvals(H)))

Array(32.95245848, dtype=float64)

In [21]:
jnp.max(jnp.abs(jnp.linalg.eigvals(A)))

Array(32.95245833, dtype=float64)

In [22]:
funcs_to_compare = [arnoldi_jax2, arnoldi_jax3, arnoldi_fast]
res_dim = 1000
num_trials = 10

times = []
for i,func in enumerate(funcs_to_compare):
    print("Running func:", func.__name__)
    key = jax.random.PRNGKey(i)
    # A = jax.random.normal(key, shape=(res_dim, res_dim)) # dense
    A = create_wr(res_dim=res_dim, density = 0.01)
    start_t = time.time()
    for j in range(num_trials):
        Q, H = func(A, max_iters = 150, tol = 1e-12)
        Q.block_until_ready()
        H.block_until_ready()
    times.append((time.time()-start_t)/num_trials)

Running func: arnoldi_jax2
Running func: arnoldi_jax3
Running func: arnoldi_fast


In [23]:
print(times)

[0.598164439201355, 0.05701522827148438, 0.05118889808654785]


In [24]:
a

NameError: name 'a' is not defined

In [None]:
@functools.partial(jax.jit, static_argnames=["max_iters"])
def max_eig_arnoldi_lax(A, tol=1e-12, max_iters=100, seed=0):
    '''
    Perform Arnoldi iteration to find the largest eigenvalue of a matrix A. 

    Args:
        A: The input matrix (n x n).
        tol: Tolerance for convergence.
        max_iters: Maximum number of iterations.
        seed: Random seed for initialization.
    Returns:
        The largest eigenvalue of A.
    '''
    n = A.shape[0]
    key = jax.random.PRNGKey(seed)
    v0 = jax.random.normal(key, (n,))
    v0 = v0 / jnp.linalg.norm(v0) # initial guess

    V = jnp.zeros((n, max_iters + 1)) # orthogonal krylov basis
    H = jnp.zeros((max_iters + 1, max_iters)) # hessnian matrix
    V = V.at[:, 0].set(v0)

    def body(carry):
        j, V, H, prev_lam, done = carry
        w = A @ V[:, j]

        # Gram–Schmidt
        def gs_step(i, val):
            w, H = val
            h = jnp.vdot(V[:, i], w)
            H = H.at[i, j].set(h)
            w = w - h * V[:, i]
            return (w, H)
        w, H = jax.lax.fori_loop(0, j + 1, gs_step, (w, H))

        h_next = jnp.linalg.norm(w)
        H = H.at[j + 1, j].set(h_next)
        V = V.at[:, j + 1].set(jnp.where(h_next > 0, w / h_next, V[:, j + 1]))
        eigvals = jnp.linalg.eigvals(H[:-1, :])     # now square & static
        lam_max = eigvals[jnp.argmax(jnp.abs(eigvals))]

        done = jnp.logical_or(done, jnp.abs(lam_max - prev_lam) < tol) # not actually used
        return (j + 1, V, H, lam_max, done)

    # check against max_iters
    def cond(carry):
        j, _, _, _, _ = carry
        return j < max_iters

    init_state = (0, V, H, 0.0 + 0j, False)
    _, _, _, lambda_max, _ = jax.lax.while_loop(cond, body, init_state)
    return lambda_max


A = jax.random.normal(key, shape=(res_dim, res_dim))
tol=1e-16
max_iters=100
seed=0
print("Max eigenvalue (arnoldi):", np.abs(max_eig_arnoldi_lax(A, tol=tol, max_iters=max_iters, seed=seed)))
print("max eigenvalue (check):", np.max(np.abs(np.linalg.eigvals(A))))

Max eigenvalue (arnoldi): 3.443780942023302
max eigenvalue (check): 3.443780942023293


In [None]:
@functools.partial(jax.jit, static_argnames=('max_iters',))
def max_eig_arnoldi_lax_sparse(A, tol=1e-12, max_iters=300, seed=0):
    """
    Arnoldi method optimized for sparse matrices
    """
    n = A.shape[0]
    key = jax.random.PRNGKey(seed) 
    v0 = jax.random.normal(key, (n,))
    v0 = v0 / jnp.linalg.norm(v0) # TODO with parallel / ensemble initialize with prev eigenvector
    
    # Pre-allocate arrays
    V = jnp.zeros((n, max_iters + 1), dtype=A.dtype)
    V = V.at[:, 0].set(v0)
    H = jnp.zeros((max_iters, max_iters), dtype=A.dtype)
    
    
    def arnoldi_step(carry):
        j, V, H, prev_ritz, converged = carry
        
        # w = Av
        w = jax.experimental.sparse.bcoo_dot_general(A, V[:, j],  dimension_numbers=(((1,), (0,)), ((), ())))

        # Modified Gram–Schmidt
        def gs_body(i, inner):
            w, H = inner
            h_ij = jnp.vdot(V[:, i], w)
            H = H.at[i, j].set(h_ij)
            w = w - h_ij * V[:, i]
            return (w, H)

        w, H = jax.lax.fori_loop(0, j + 1, gs_body, (w, H))
        beta = jnp.linalg.norm(w)

        # Write subdiag of hessenberg if matrix not full 
        H = jax.lax.cond(
            (j + 1) < max_iters,
            lambda args: args[0].at[args[1] + 1, args[1]].set(args[2]),
            lambda args: args[0],
            (H, j, beta)
        )

        # Normalise and append new basis vector
        V = jax.lax.cond(
            beta > 0,
            lambda args: args[0].at[:, args[1] + 1].set(args[2] / args[3]),
            lambda args: args[0],
            (V, j, w, beta)
        )

        def get_max_eigenvalue(H, size):
            
            mask = jnp.arange(max_iters)[:, None] < size
            mask = mask & (jnp.arange(max_iters)[None, :] < size)
            
            H_masked = jnp.where(mask, H, 0.0)
            
            all_eigvals = jnp.linalg.eigvals(H_masked)
            
            return all_eigvals[jnp.argmax(jnp.abs(all_eigvals))]
        
        lambda_max = get_max_eigenvalue(H, j+1)
        
        converged = (beta < tol) | ((j > 0) & (jnp.abs(lambda_max - prev_ritz) < tol))
        
        return (j + 1, V, H, lambda_max, converged)

    def cond_fn(carry):
        j, _, _, _, conv = carry
        return (j < max_iters - 1) & (~conv)

    init = (0, V, H, 0.0, False)
    j_final, _, _, lambda_max, _ = jax.lax.while_loop(cond_fn, arnoldi_step, init)
    
    return lambda_max#, j_final

In [None]:
res_dim = 200
A = jax.random.normal(key, (res_dim, res_dim))
tol=1e-10
max_iters=100
seed=0

out = np.abs(max_eig_arnoldi_lax(A, tol=tol, max_iters=max_iters, seed=seed))
print("Max eigenvalue (arnoldi lax):", out)
# print("Number of iterations:", out[1])
print("max eigenvalue (check):", np.max(np.abs(np.linalg.eigvals(A))))

Max eigenvalue (arnoldi lax): 14.401025051328915
max eigenvalue (check): 14.401025083145015


In [None]:
res_dim = 1000
density = 0.01
num_trials = 10  
max_iters = 200 
tol = 1e-16     
seed = 0        

wr_dense = create_wr(res_dim, density)
wr_sparse = sparse.BCOO.fromdense(wr_dense)

start_time = time.time()
eig = jnp.max(jnp.abs(jnp.linalg.eigvals(wr_dense))).block_until_ready()
end_time = time.time()
print("Time taken (dense eigvals):", end_time - start_time)
print("Max eigenvalue (dense eigvals):", eig)
print("=============================================\n")

start_time = time.time()
eig = max_eig_arnoldi_jax(wr_dense, tol=1e-16, max_iters=max_iters, seed=seed).block_until_ready()
end_time = time.time()
print("Time taken (dense naive):", end_time - start_time)
print("Max eigenvalue (dense naive):", jnp.abs(eig), "\n")

start_time = time.time()
eig = max_eig_arnoldi_jax(wr_sparse, tol=1e-16, max_iters=max_iters, seed=seed).block_until_ready()
end_time = time.time()
print("Time taken (sparse naive):", end_time - start_time)
print("Max eigenvalue (sparse naive):", jnp.abs(eig), "\n")

start_time = time.time()
eig = max_eig_arnoldi_lax(wr_dense, tol=tol, max_iters=max_iters, seed=seed).block_until_ready()
end_time = time.time()
print("Time taken (dense lax):", end_time - start_time)
print("Max eigenvalue (dense lax):", jnp.abs(eig), "\n")

start_time = time.time()
eig = max_eig_arnoldi_lax(wr_sparse, tol=tol, max_iters=max_iters, seed=seed).block_until_ready()
end_time = time.time()
print("Time taken (sparse lax):", end_time - start_time)
print("Max eigenvalue (sparse lax):", jnp.abs(eig), "\n")

Time taken (dense eigvals): 4.45522665977478
Max eigenvalue (dense eigvals): 1.8291688631675567

Time taken (dense naive): 32.54138946533203
Max eigenvalue (dense naive): 1.829168894290524 

Time taken (sparse naive): 16.880793571472168
Max eigenvalue (sparse naive): 1.8291688942905129 

Time taken (dense lax): 5.0487377643585205
Max eigenvalue (dense lax): 1.829168834769205 

Time taken (sparse lax): 2.065659284591675
Max eigenvalue (sparse lax): 1.829168834769199 



In [None]:
# --- Timing jnp.linalg.eigvals (Dense) ---
times_dense_eigvals = []
for _ in range(num_trials):
    start_time = time.time()
    _ = jnp.max(jnp.abs(jnp.linalg.eigvals(wr_dense))).block_until_ready()
    end_time = time.time()
    times_dense_eigvals.append(end_time - start_time)
avg_time_dense_eigvals = np.mean(times_dense_eigvals)
print(f"Average time taken (dense eigvals over {num_trials} trials): {avg_time_dense_eigvals:.6f} seconds")
print("=============================================\n")

# --- Timing max_eig_arnoldi_jax (Dense) ---
times_dense_naive = []
for _ in range(num_trials):
    start_time = time.time()
    _ = max_eig_arnoldi_jax(wr_dense, tol=1e-16, max_iters=max_iters, seed=seed).block_until_ready()
    end_time = time.time()
    times_dense_naive.append(end_time - start_time)
avg_time_dense_naive = np.mean(times_dense_naive)
print(f"Average time taken (dense naive over {num_trials} trials): {avg_time_dense_naive:.6f} seconds\n")

# --- Timing max_eig_arnoldi_jax (Sparse) ---
times_sparse_naive = []
for _ in range(num_trials):
    start_time = time.time()
    _ = max_eig_arnoldi_jax(wr_sparse, tol=1e-16, max_iters=max_iters, seed=seed).block_until_ready()
    end_time = time.time()
    times_sparse_naive.append(end_time - start_time)
avg_time_sparse_naive = np.mean(times_sparse_naive)
print(f"Average time taken (sparse naive over {num_trials} trials): {avg_time_sparse_naive:.6f} seconds\n")

# --- Timing max_eig_arnoldi_lax (Dense) ---
times_dense_lax = []
for _ in range(num_trials):
    start_time = time.time()
    _ = max_eig_arnoldi_lax(wr_dense, tol=tol, max_iters=max_iters, seed=seed).block_until_ready()
    end_time = time.time()
    times_dense_lax.append(end_time - start_time)
avg_time_dense_lax = np.mean(times_dense_lax)
print(f"Average time taken (dense lax over {num_trials} trials): {avg_time_dense_lax:.6f} seconds\n")

# --- Timing max_eig_arnoldi_lax (Sparse) ---
times_sparse_lax = []
for _ in range(num_trials):
    start_time = time.time()
    _ = max_eig_arnoldi_lax(wr_sparse, tol=tol, max_iters=max_iters, seed=seed).block_until_ready()
    end_time = time.time()
    times_sparse_lax.append(end_time - start_time)
avg_time_sparse_lax = np.mean(times_sparse_lax)
print(f"Average time taken (sparse lax over {num_trials} trials): {avg_time_sparse_lax:.6f} seconds\n")

Average time taken (dense eigvals over 10 trials): 0.826949 seconds

