In [1]:
import os 
import time
os.environ["JAX_ENABLE_X64"] = "True"
# os.environ["JAX_DISABLE_JIT"] = "True"
# os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax
import jax.numpy as jnp
from jax.experimental import sparse
import functools

import matplotlib.pyplot as plt

import numpy as np

In [2]:
res_dim = 10
key = jax.random.PRNGKey(0)
density = 0.5
wrkey1, wrkey2 = jax.random.split(key)
spec_rad = 0.8

N_nonzero = int(res_dim**2 * density)
wr_indices = jax.random.choice(
    wrkey1,
    res_dim**2,
    shape=(N_nonzero,),
)
wr_vals = jax.random.uniform(
    wrkey2, shape=N_nonzero, minval=-1, maxval=1
)
wr = jnp.zeros(res_dim * res_dim)
wr = wr.at[wr_indices].set(wr_vals)
wr = wr.reshape(res_dim, res_dim)
wr_dense = wr * (spec_rad / jnp.max(jnp.abs(jnp.linalg.eigvals(wr))))
wr_sparse = sparse.BCOO.fromdense(wr_dense)

In [3]:
# def max_eig_arnoldi_np(A, tol=1e-12, max_iters=1000, seed=0):
#     n = A.shape[0]
#     np.random.seed(seed)
#     b = np.random.rand(n)
#     b = b / np.linalg.norm(b)

#     V = np.zeros((n, max_iters + 1))
#     H = np.zeros((max_iters + 1, max_iters))
#     V[:, 0] = b

#     for j in range(max_iters):
#         w = A @ V[:, j]
        
#         for i in range(j + 1):
#             # H[i, j] = np.dot(V[:, i], w)
#             H[i, j] = V[:, i].conj().T @ w
#             w = w - H[i, j] * V[:, i]
        
#         H[j + 1, j] = np.linalg.norm(w)

        
#         V[:, j + 1] = w / H[j + 1, j]
        
#         # Compute eigenvalues of the current H (top-left (j+1)x(j+1) block)
#         eigvals = np.linalg.eigvals(H[:j+1, :j+1])
#         lambda_max = eigvals[np.argmax(np.abs(eigvals))]

#         # Optional early convergence check (change in largest Ritz value)
#         if j > 0:
#             delta = abs(lambda_max - prev_lambda)
#             if delta < tol:
#                 break
#         prev_lambda = lambda_max
    
#     print("Number of iterations / max_iters:", j + 1 , "/", max_iters)

#     return lambda_max


# A = np.random.randn(res_dim, res_dim)
# tol=1e-12
# max_iters=100
# seed=0
# print("Max eigenvalue (arnoldi):", np.abs(max_eig_arnoldi_np(A, tol=tol, max_iters=max_iters, seed=seed)))
# print("max eigenvalue (check):", np.max(np.abs(np.linalg.eigvals(A))))

In [4]:
### naive jax version without lax primitives
def max_eig_arnoldi_jax(A, tol=1e-12, max_iters=1000, seed=0):
    n = A.shape[0]
    key = jax.random.PRNGKey(seed)
    b = jax.random.uniform(key, shape=(n,))
    b = b / jnp.linalg.norm(b)

    V = jnp.zeros((n, max_iters + 1))
    H = jnp.zeros((max_iters + 1, max_iters))
    V = V.at[:, 0].set(b)

    lambda_max = 3

    for j in range(max_iters):
        w = A @ V[:, j]
        
        for i in range(j + 1):
            H = H.at[i, j].set(V[:, i].conj().T @ w)
            w = w - H[i, j] * V[:, i]
        
        H = H.at[j + 1, j].set(jnp.linalg.norm(w))
        if H[j + 1, j] < tol:
            break

        V = V.at[:, j + 1].set(w / H[j + 1, j])
        
        # Compute eigenvalues of the current H (top-left (j+1)x(j+1) block)
        eigvals = jnp.linalg.eigvals(H[:j+1, :j+1])
        lambda_max = eigvals[jnp.argmax(jnp.abs(eigvals))]

        # Optional early convergence check (change in largest Ritz value)
        if j > 0:
            delta = abs(lambda_max - prev_lambda)
            if delta < tol:
                break
        prev_lambda = lambda_max

    return lambda_max


# A = np.random.randn(res_dim, res_dim)
A = jax.random.normal(key, shape=(res_dim, res_dim))
tol=1e-16
max_iters=100
seed=0
print("Max eigenvalue (arnoldi):", np.abs(max_eig_arnoldi_jax(A, tol=tol, max_iters=max_iters, seed=seed)))
print("max eigenvalue (check):", np.max(np.abs(np.linalg.eigvals(A))))

Max eigenvalue (arnoldi): 3.443780942023299
max eigenvalue (check): 3.443780942023293


In [5]:
import functools, jax, jax.numpy as jnp
from jax import lax

@functools.partial(jax.jit, static_argnames=["max_iters"])
def max_eig_arnoldi_lax1(A, tol=1e-12, max_iters=1000, seed=0):
    n   = A.shape[0]
    key = jax.random.PRNGKey(seed)
    v0  = jax.random.normal(key, (n,))
    v0  = v0 / jnp.linalg.norm(v0)

    V = jnp.zeros((n, max_iters + 1))
    H = jnp.zeros((max_iters + 1, max_iters))      # rectangular
    V = V.at[:, 0].set(v0)

    def body(carry):
        j, V, H, prev_lam, done = carry
        w = A @ V[:, j]

        # Gram–Schmidt
        def gs_step(i, val):
            w, H = val
            h = jnp.vdot(V[:, i], w)
            H = H.at[i, j].set(h)
            w = w - h * V[:, i]
            return (w, H)
        w, H = lax.fori_loop(0, j + 1, gs_step, (w, H))

        h_next = jnp.linalg.norm(w)
        H = H.at[j + 1, j].set(h_next)
        V = V.at[:, j + 1].set(jnp.where(h_next > 0, w / h_next, V[:, j + 1]))
        # ‹★› use the square upper block (shape max_iters × max_iters)
        eigvals = jnp.linalg.eigvals(H[:-1, :])     # now square & static
        lam_max = eigvals[jnp.argmax(jnp.abs(eigvals))]

        done = jnp.logical_or(done, jnp.abs(lam_max - prev_lam) < tol)
        return (j + 1, V, H, lam_max, done)

    def cond(carry):
        j, *_ = carry
        return j < max_iters

    init_state = (0, V, H, 0.0 + 0j, False)
    _, _, _, lambda_max, _ = lax.while_loop(cond, body, init_state)
    return lambda_max


A = jax.random.normal(key, shape=(res_dim, res_dim))
tol=1e-16
max_iters=100
seed=0
print("Max eigenvalue (arnoldi):", np.abs(max_eig_arnoldi_lax1(A, tol=tol, max_iters=max_iters, seed=seed)))
print("max eigenvalue (check):", np.max(np.abs(np.linalg.eigvals(A))))

Max eigenvalue (arnoldi): 3.443780942023302
max eigenvalue (check): 3.443780942023293


In [6]:
import jax
import jax.experimental
import jax.experimental.sparse
import jax.numpy as jnp
from jax import random, lax
from functools import partial


@functools.partial(jax.jit, static_argnames=('max_iters',))
def max_eig_arnoldi_lax(A, *, tol=1e-12, max_iters=300, seed=0):
    """
    Arnoldi method optimized for sparse matrices (JAX sparse.BCOO format)
    """
    n = A.shape[0]
    key = jax.random.PRNGKey(seed)
    v0 = jax.random.normal(key, (n,))
    v0 = v0 / jnp.linalg.norm(v0)
    
    # Pre-allocate arrays with correct types
    V = jnp.zeros((n, max_iters + 1), dtype=A.dtype)
    V = V.at[:, 0].set(v0)
    H = jnp.zeros((max_iters, max_iters), dtype=A.dtype)
    
    # Check if A is a sparse matrix
    is_sparse = True#isinstance(A, sparse.BCOO)
    
    def arnoldi_step(carry):
        j, V, H, prev_ritz, converged = carry
        
        # Use the appropriate matrix-vector product based on type
        if is_sparse:
        # w = A.coo_matvec(V[:, j])
            w = jax.experimental.sparse.bcoo_dot_general(A, V[:, j],  dimension_numbers=(((1,), (0,)), ((), ())))
        else:
            w = A @ V[:, j]

        # Modified Gram–Schmidt
        def gs_body(i, inner):
            w, H = inner
            h_ij = jnp.vdot(V[:, i], w)
            H = H.at[i, j].set(h_ij)
            w = w - h_ij * V[:, i]
            return (w, H)

        w, H = lax.fori_loop(0, j + 1, gs_body, (w, H))
        beta = jnp.linalg.norm(w)

        # Safely write the sub‑diagonal if we haven't filled the matrix
        H = lax.cond(
            (j + 1) < max_iters,
            lambda args: args[0].at[args[1] + 1, args[1]].set(args[2]),
            lambda args: args[0],
            (H, j, beta)
        )

        # Normalise and append new basis vector
        V = lax.cond(
            beta > 0,
            lambda args: args[0].at[:, args[1] + 1].set(args[2] / args[3]),
            lambda args: args[0],
            (V, j, w, beta)
        )

        # KEY CHANGE: Use custom eigenvalue computation with masking
        # Instead of H[:j+1, :j+1] which causes dynamic slicing error
        def get_max_eigenvalue(H, size):
            # Create a mask for the active submatrix
            mask = jnp.arange(max_iters)[:, None] < size
            mask = mask & (jnp.arange(max_iters)[None, :] < size)
            
            # Create a masked copy (zeros outside the j×j submatrix)
            H_masked = jnp.where(mask, H, 0.0)
            
            # Compute eigenvalues of the full matrix (most will be 0)
            all_eigvals = jnp.linalg.eigvals(H_masked)
            
            # Return the one with largest magnitude
            return all_eigvals[jnp.argmax(jnp.abs(all_eigvals))]
        
        lambda_max = get_max_eigenvalue(H, j+1)
        
        # Convergence tests
        converged = (beta < tol) | ((j > 0) & (jnp.abs(lambda_max - prev_ritz) < tol))
        
        return (j + 1, V, H, lambda_max, converged)

    def cond_fn(carry):
        j, _, _, _, conv = carry
        return (j < max_iters - 1) & (~conv)

    init = (0, V, H, 0.0, False)
    j_final, _, _, lambda_max, _ = lax.while_loop(cond_fn, arnoldi_step, init)
    
    return lambda_max#, j_final

In [7]:
res_dim = 200
A = jax.random.normal(key, (res_dim, res_dim))
tol=1e-10
max_iters=100
seed=0

out = np.abs(max_eig_arnoldi_lax(A, tol=tol, max_iters=max_iters, seed=seed))
print("Max eigenvalue (arnoldi lax):", out)
# print("Number of iterations:", out[1])
print("max eigenvalue (check):", np.max(np.abs(np.linalg.eigvals(A))))

Max eigenvalue (arnoldi lax): 14.401025039437345
max eigenvalue (check): 14.401025083145015


In [8]:
def create_wr(res_dim, density):
    key = jax.random.PRNGKey(0)
    wrkey1, wrkey2 = jax.random.split(key)
    N_nonzero = int(res_dim**2 * density)
    wr_indices = jax.random.choice(
        wrkey1,
        res_dim**2,
        shape=(N_nonzero,),
    )
    wr_vals = jax.random.uniform(
        wrkey2, shape=N_nonzero, minval=-1, maxval=1
    )
    wr = jnp.zeros(res_dim * res_dim)
    wr = wr.at[wr_indices].set(wr_vals)
    wr = wr.reshape(res_dim, res_dim)
    return wr

In [9]:
res_dim = 2000
density = 0.05

wr_dense = create_wr(res_dim, density)
wr_sparse = sparse.BCOO.fromdense(wr_dense)

start_time = time.time()
eig = jnp.max(jnp.abs(jnp.linalg.eigvals(wr_dense))).block_until_ready()
end_time = time.time()
print("Time taken (dense eigvals):", end_time - start_time)
print("Max eigenvalue (dense eigvals):", eig)
print("=============================================\n")

start_time = time.time()
eig = max_eig_arnoldi_jax(wr_dense, tol=1e-16, max_iters=max_iters, seed=seed).block_until_ready()
end_time = time.time()
print("Time taken (dense naive):", end_time - start_time)
print("Max eigenvalue (dense naive):", jnp.abs(eig), "\n")

start_time = time.time()
eig = max_eig_arnoldi_jax(wr_sparse, tol=1e-16, max_iters=max_iters, seed=seed).block_until_ready()
end_time = time.time()
print("Time taken (sparse naive):", end_time - start_time)
print("Max eigenvalue (sparse naive):", jnp.abs(eig), "\n")

start_time = time.time()
eig = max_eig_arnoldi_lax(wr_dense, tol=tol, max_iters=max_iters, seed=seed).block_until_ready()
end_time = time.time()
print("Time taken (dense lax):", end_time - start_time)
print("Max eigenvalue (dense lax):", jnp.abs(eig), "\n")

start_time = time.time()
eig = max_eig_arnoldi_lax(wr_sparse, tol=tol, max_iters=max_iters, seed=seed).block_until_ready()
end_time = time.time()
print("Time taken (sparse lax):", end_time - start_time)
print("Max eigenvalue (sparse lax):", jnp.abs(eig), "\n")

start_time = time.time()
eig = max_eig_arnoldi_lax1(wr_dense, tol=tol, max_iters=max_iters, seed=seed).block_until_ready()
end_time = time.time()
print("Time taken (dense lax1):", end_time - start_time)
print("Max eigenvalue (dense lax1):", jnp.abs(eig), "\n")

start_time = time.time()
eig = max_eig_arnoldi_lax1(wr_sparse, tol=tol, max_iters=max_iters, seed=seed).block_until_ready()
end_time = time.time()
print("Time taken (sparse lax1):", end_time - start_time)
print("Max eigenvalue (sparse lax1):", jnp.abs(eig), "\n")

Time taken (dense eigvals): 3.250822067260742
Max eigenvalue (dense eigvals): 5.766258653353951

Time taken (dense naive): 10.997761964797974
Max eigenvalue (dense naive): 5.7254152264007105 

Time taken (sparse naive): 4.024555683135986
Max eigenvalue (sparse naive): 5.725415226400742 

Time taken (dense lax): 0.44333767890930176
Max eigenvalue (dense lax): 5.758381695577947 

Time taken (sparse lax): 0.4044039249420166
Max eigenvalue (sparse lax): 5.7583816955779605 

Time taken (dense lax1): 0.48555850982666016
Max eigenvalue (dense lax1): 5.755736557235635 

Time taken (sparse lax1): 0.4642350673675537
Max eigenvalue (sparse lax1): 5.755736557235636 



In [10]:
# def max_eigval_power_iteration(A, seed=0, tol = 1e-12, max_iters=2000):
#     key = jax.random.PRNGKey(seed)
#     b_k = jax.random.normal(key, shape=(A.shape[0],)) 
#     b_k = b_k / jnp.linalg.norm(b_k)

#     def continue_iteration(params):
#         b_k, b_k_prev, iter_count = params
#         return jnp.logical_and(jnp.linalg.norm(b_k - b_k_prev) > tol, iter_count <= max_iters)
    
#     def power_iteration_step(params):
#         b_k, _, iter_count = params
#         b_k_next = A @ b_k
#         b_k_next = b_k_next / jnp.linalg.norm(b_k_next) 
#         iter_count += 1
#         return (b_k_next, b_k, iter_count)

#     out = jax.lax.while_loop(continue_iteration,power_iteration_step,(b_k, b_k+1, 0))
#     b_k_final, num_iters = out[0], out[2]
#     print("Number of iterations / total iters: ", num_iters-1, "/", max_iters)
#     max_eigval = jnp.dot(b_k_final.T, A @ b_k_final) / jnp.dot(b_k_final.T, b_k_final) # Rayleigh quotient
#     return max_eigval

In [11]:
# # test it out 
# wr_dense = create_wr(200, 0.02)
# wr_sparse = sparse.BCOO.fromdense(wr_dense)
# iters = 1000

# print("Max eigval from jnp.linalg.eigvals: ", jnp.max(jnp.linalg.eigvals(wr_dense)))
# print("Max eigval from power iteration: ", max_eigval_power_iteration(wr_dense,max_iters=10000, tol = 1e-10))
# print("Max eigval from sparse power iteration: ", max_eigval_power_iteration(wr_sparse,max_iters=10000, tol = 1e-10))

In [12]:
# def max_eigval_arnoldi(A, seed=0, tol = 1e-12, max_iters=2000):
#     key = jax.random.PRNGKey(seed)
#     b_k = jax.random.normal(key, shape=(A.shape[0],)) 
#     b_k = b_k / jnp.linalg.norm(b_k)

#     def continue_iteration(params):
#         b_k, b_k_prev, iter_count = params
#         return jnp.logical_and(jnp.linalg.norm(b_k - b_k_prev) > tol, iter_count <= max_iters)
    
#     def arnoldi_step(params):
#         b_k, _, iter_count = params
#         b_k_next = A @ b_k
#         b_k_next = b_k_next / jnp.linalg.norm(b_k_next) 
#         iter_count += 1
#         return (b_k_next, b_k, iter_count)

#     out = jax.lax.while_loop(continue_iteration,arnoldi_step,(b_k, b_k+1, 0))
#     b_k_final, num_iters = out[0], out[2]
#     print("Number of iterations / total iters: ", num_iters-1, "/", max_iters)
#     max_eigval = jnp.dot(b_k_final.T, A @ b_k_final) / jnp.dot(b_k_final.T, b_k_final) # Rayleigh quotient
#     return max_eigval

# # test it out
# wr_dense = create_wr(200, 0.02)
# wr_sparse = sparse.BCOO.fromdense(wr_dense)
# iters = 1000
# print("Max eigval from jnp.linalg.eigvals: ", jnp.max(jnp.linalg.eigvals(wr_dense)))
# print("Max eigval from arnoldi: ", max_eigval_arnoldi(wr_dense,max_iters=10000, tol = 1e-10))
# print("Max eigval from sparse arnoldi: ", max_eigval_arnoldi(wr_sparse,max_iters=10000, tol = 1e-10))

# Timing tests

In [13]:
# res_dims = jnp.arange(100, 1001, 100)
# densities = jnp.arange(0.01, 0.21, 0.01)
res_dims = [100, 200, 300]
densities = [0.01, 0.05]
num_trials = 10
results = jnp.zeros((len(res_dims), len(densities), num_trials))

In [14]:
### Dense arnoldi
for i, res_dim in enumerate(res_dims):
    for j, density in enumerate(densities):
        for k in range(num_trials):
            wr = create_wr(res_dim, density)
            start_t = time.time()
            _ = max_eigval_power_iteration(wr).block_until_ready()
            end_t = time.time()
            results = results.at[i, j, k].set(end_t - start_t)

# average out the trials
avg_results_dense_pi = jnp.mean(results, axis=2)
std_results_dense_pi = jnp.std(results, axis=2)

NameError: name 'max_eigval_power_iteration' is not defined

In [None]:
### Sparse power iteration
for i, res_dim in enumerate(res_dims):
    for j, density in enumerate(densities):
        for k in range(num_trials):
            wr = create_wr(res_dim, density)
            wr_sparse = sparse.BCOO.fromdense(wr)
            start_t = time.time()
            _ = max_eigval_power_iteration(wr_sparse).block_until_ready()
            end_t = time.time()
            results = results.at[i, j, k].set(end_t - start_t)

# average out the trials
avg_results_sparse_pi = jnp.mean(results, axis=2)
std_results_sparse_pi = jnp.std(results, axis=2)

In [None]:
### Dense jnp.linalg.eigvals
for i, res_dim in enumerate(res_dims):
    for j, density in enumerate(densities):
        for k in range(num_trials):
            wr = create_wr(res_dim, density)
            start_t = time.time()
            _ = jnp.max(jnp.linalg.eigvals(wr)).block_until_ready()
            end_t = time.time()
            results = results.at[i, j, k].set(end_t - start_t)

# average out the trials
avg_results_dense_eig = jnp.mean(results, axis=2)
std_results_dense_eig = jnp.std(results, axis=2)