In [18]:
import os 
import time
os.environ["JAX_ENABLE_X64"] = "True"
# os.environ["JAX_DISABLE_JIT"] = "True"
# os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax
import jax.numpy as jnp
from jax.experimental import sparse
import functools

import matplotlib.pyplot as plt

import numpy as np

In [19]:
res_dim = 10
key = jax.random.PRNGKey(0)
density = 0.5
wrkey1, wrkey2 = jax.random.split(key)
spec_rad = 0.8

N_nonzero = int(res_dim**2 * density)
wr_indices = jax.random.choice(
    wrkey1,
    res_dim**2,
    shape=(N_nonzero,),
)
wr_vals = jax.random.uniform(
    wrkey2, shape=N_nonzero, minval=-1, maxval=1
)
wr = jnp.zeros(res_dim * res_dim)
wr = wr.at[wr_indices].set(wr_vals)
wr = wr.reshape(res_dim, res_dim)
wr_dense = wr * (spec_rad / jnp.max(jnp.abs(jnp.linalg.eigvals(wr))))
wr_sparse = sparse.BCOO.fromdense(wr_dense)

In [15]:
# def max_eig_arnoldi_np(A, tol=1e-12, max_iters=1000, seed=0):
#     n = A.shape[0]
#     np.random.seed(seed)
#     b = np.random.rand(n)
#     b = b / np.linalg.norm(b)

#     V = np.zeros((n, max_iters + 1))
#     H = np.zeros((max_iters + 1, max_iters))
#     V[:, 0] = b

#     for j in range(max_iters):
#         w = A @ V[:, j]
        
#         for i in range(j + 1):
#             # H[i, j] = np.dot(V[:, i], w)
#             H[i, j] = V[:, i].conj().T @ w
#             w = w - H[i, j] * V[:, i]
        
#         H[j + 1, j] = np.linalg.norm(w)

        
#         V[:, j + 1] = w / H[j + 1, j]
        
#         # Compute eigenvalues of the current H (top-left (j+1)x(j+1) block)
#         eigvals = np.linalg.eigvals(H[:j+1, :j+1])
#         lambda_max = eigvals[np.argmax(np.abs(eigvals))]

#         # Optional early convergence check (change in largest Ritz value)
#         if j > 0:
#             delta = abs(lambda_max - prev_lambda)
#             if delta < tol:
#                 break
#         prev_lambda = lambda_max
    
#     print("Number of iterations / max_iters:", j + 1 , "/", max_iters)

#     return lambda_max


# A = np.random.randn(res_dim, res_dim)
# tol=1e-12
# max_iters=100
# seed=0
# print("Max eigenvalue (arnoldi):", np.abs(max_eig_arnoldi_np(A, tol=tol, max_iters=max_iters, seed=seed)))
# print("max eigenvalue (check):", np.max(np.abs(np.linalg.eigvals(A))))

In [16]:
### naive jax version without lax primitives
def max_eig_arnoldi_jax(A, tol=1e-12, max_iters=1000, seed=0):
    n = A.shape[0]
    key = jax.random.PRNGKey(seed)
    b = jax.random.uniform(key, shape=(n,))
    b = b / jnp.linalg.norm(b)

    V = jnp.zeros((n, max_iters + 1))
    H = jnp.zeros((max_iters + 1, max_iters))
    V = V.at[:, 0].set(b)

    lambda_max = 3

    for j in range(max_iters):
        w = A @ V[:, j]
        
        for i in range(j + 1):
            H = H.at[i, j].set(V[:, i].conj().T @ w)
            w = w - H[i, j] * V[:, i]
        
        H = H.at[j + 1, j].set(jnp.linalg.norm(w))
        if H[j + 1, j] < tol:
            break

        V = V.at[:, j + 1].set(w / H[j + 1, j])
        
        # Compute eigenvalues of the current H (top-left (j+1)x(j+1) block)
        eigvals = jnp.linalg.eigvals(H[:j+1, :j+1])
        lambda_max = eigvals[jnp.argmax(jnp.abs(eigvals))]

        # Optional early convergence check (change in largest Ritz value)
        if j > 0:
            delta = abs(lambda_max - prev_lambda)
            if delta < tol:
                break
        prev_lambda = lambda_max

    return lambda_max


# A = np.random.randn(res_dim, res_dim)
A = jax.random.normal(key, shape=(res_dim, res_dim))
tol=1e-16
max_iters=100
seed=0
print("Max eigenvalue (arnoldi):", np.abs(max_eig_arnoldi_jax(A, tol=tol, max_iters=max_iters, seed=seed)))
print("max eigenvalue (check):", np.max(np.abs(np.linalg.eigvals(A))))

Max eigenvalue (arnoldi): 3.443780942023299
max eigenvalue (check): 3.443780942023293


In [17]:
@functools.partial(jax.jit, static_argnames=["max_iters"])
def max_eig_arnoldi_lax(A, tol=1e-12, max_iters=100, seed=0):
    '''
    Perform Arnoldi iteration to find the largest eigenvalue of a matrix A. 

    Args:
        A: The input matrix (n x n).
        tol: Tolerance for convergence.
        max_iters: Maximum number of iterations.
        seed: Random seed for initialization.
    Returns:
        The largest eigenvalue of A.
    '''
    n = A.shape[0]
    key = jax.random.PRNGKey(seed)
    v0 = jax.random.normal(key, (n,))
    v0 = v0 / jnp.linalg.norm(v0) # initial guess

    V = jnp.zeros((n, max_iters + 1)) # orthogonal krylov basis
    H = jnp.zeros((max_iters + 1, max_iters)) # hessnian matrix
    V = V.at[:, 0].set(v0)

    def body(carry):
        j, V, H, prev_lam, done = carry
        w = A @ V[:, j]

        # Gram–Schmidt
        def gs_step(i, val):
            w, H = val
            h = jnp.vdot(V[:, i], w)
            H = H.at[i, j].set(h)
            w = w - h * V[:, i]
            return (w, H)
        w, H = jax.lax.fori_loop(0, j + 1, gs_step, (w, H))

        h_next = jnp.linalg.norm(w)
        H = H.at[j + 1, j].set(h_next)
        V = V.at[:, j + 1].set(jnp.where(h_next > 0, w / h_next, V[:, j + 1]))
        eigvals = jnp.linalg.eigvals(H[:-1, :])     # now square & static
        lam_max = eigvals[jnp.argmax(jnp.abs(eigvals))]

        done = jnp.logical_or(done, jnp.abs(lam_max - prev_lam) < tol) # not actually used
        return (j + 1, V, H, lam_max, done)

    # check against max_iters
    def cond(carry):
        j, _, _, _, _ = carry
        return j < max_iters

    init_state = (0, V, H, 0.0 + 0j, False)
    _, _, _, lambda_max, _ = jax.lax.while_loop(cond, body, init_state)
    return lambda_max


A = jax.random.normal(key, shape=(res_dim, res_dim))
tol=1e-16
max_iters=100
seed=0
print("Max eigenvalue (arnoldi):", np.abs(max_eig_arnoldi_lax(A, tol=tol, max_iters=max_iters, seed=seed)))
print("max eigenvalue (check):", np.max(np.abs(np.linalg.eigvals(A))))

Max eigenvalue (arnoldi): 3.443780942023302
max eigenvalue (check): 3.443780942023293


In [None]:
@functools.partial(jax.jit, static_argnames=('max_iters',))
def max_eig_arnoldi_lax_sparse(A, tol=1e-12, max_iters=300, seed=0):
    """
    Arnoldi method optimized for sparse matrices
    """
    n = A.shape[0]
    key = jax.random.PRNGKey(seed) 
    v0 = jax.random.normal(key, (n,))
    v0 = v0 / jnp.linalg.norm(v0) # TODO with parallel / ensemble initialize with prev eigenvector
    
    # Pre-allocate arrays
    V = jnp.zeros((n, max_iters + 1), dtype=A.dtype)
    V = V.at[:, 0].set(v0)
    H = jnp.zeros((max_iters, max_iters), dtype=A.dtype)
    
    
    def arnoldi_step(carry):
        j, V, H, prev_ritz, converged = carry
        
        # w = Av
        w = jax.experimental.sparse.bcoo_dot_general(A, V[:, j],  dimension_numbers=(((1,), (0,)), ((), ())))

        # Modified Gram–Schmidt
        def gs_body(i, inner):
            w, H = inner
            h_ij = jnp.vdot(V[:, i], w)
            H = H.at[i, j].set(h_ij)
            w = w - h_ij * V[:, i]
            return (w, H)

        w, H = jax.lax.fori_loop(0, j + 1, gs_body, (w, H))
        beta = jnp.linalg.norm(w)

        # Write subdiag of hessenberg if matrix not full 
        H = jax.lax.cond(
            (j + 1) < max_iters,
            lambda args: args[0].at[args[1] + 1, args[1]].set(args[2]),
            lambda args: args[0],
            (H, j, beta)
        )

        # Normalise and append new basis vector
        V = jax.lax.cond(
            beta > 0,
            lambda args: args[0].at[:, args[1] + 1].set(args[2] / args[3]),
            lambda args: args[0],
            (V, j, w, beta)
        )

        def get_max_eigenvalue(H, size):
            
            mask = jnp.arange(max_iters)[:, None] < size
            mask = mask & (jnp.arange(max_iters)[None, :] < size)
            
            H_masked = jnp.where(mask, H, 0.0)
            
            all_eigvals = jnp.linalg.eigvals(H_masked)
            
            return all_eigvals[jnp.argmax(jnp.abs(all_eigvals))]
        
        lambda_max = get_max_eigenvalue(H, j+1)
        
        converged = (beta < tol) | ((j > 0) & (jnp.abs(lambda_max - prev_ritz) < tol))
        
        return (j + 1, V, H, lambda_max, converged)

    def cond_fn(carry):
        j, _, _, _, conv = carry
        return (j < max_iters - 1) & (~conv)

    init = (0, V, H, 0.0, False)
    j_final, _, _, lambda_max, _ = jax.lax.while_loop(cond_fn, arnoldi_step, init)
    
    return lambda_max#, j_final

In [9]:
res_dim = 200
A = jax.random.normal(key, (res_dim, res_dim))
tol=1e-10
max_iters=100
seed=0

out = np.abs(max_eig_arnoldi_lax(A, tol=tol, max_iters=max_iters, seed=seed))
print("Max eigenvalue (arnoldi lax):", out)
# print("Number of iterations:", out[1])
print("max eigenvalue (check):", np.max(np.abs(np.linalg.eigvals(A))))

Max eigenvalue (arnoldi lax): 14.401025039437345
max eigenvalue (check): 14.401025083145015


In [10]:
def create_wr(res_dim, density):
    key = jax.random.PRNGKey(0)
    wrkey1, wrkey2 = jax.random.split(key)
    N_nonzero = int(res_dim**2 * density)
    wr_indices = jax.random.choice(
        wrkey1,
        res_dim**2,
        shape=(N_nonzero,),
    )
    wr_vals = jax.random.uniform(
        wrkey2, shape=N_nonzero, minval=-1, maxval=1
    )
    wr = jnp.zeros(res_dim * res_dim)
    wr = wr.at[wr_indices].set(wr_vals)
    wr = wr.reshape(res_dim, res_dim)
    return wr

In [None]:
res_dim = 1000
density = 0.01
num_trials = 10  
max_iters = 200 
tol = 1e-16     
seed = 0        

wr_dense = create_wr(res_dim, density)
wr_sparse = sparse.BCOO.fromdense(wr_dense)

start_time = time.time()
eig = jnp.max(jnp.abs(jnp.linalg.eigvals(wr_dense))).block_until_ready()
end_time = time.time()
print("Time taken (dense eigvals):", end_time - start_time)
print("Max eigenvalue (dense eigvals):", eig)
print("=============================================\n")

start_time = time.time()
eig = max_eig_arnoldi_jax(wr_dense, tol=1e-16, max_iters=max_iters, seed=seed).block_until_ready()
end_time = time.time()
print("Time taken (dense naive):", end_time - start_time)
print("Max eigenvalue (dense naive):", jnp.abs(eig), "\n")

start_time = time.time()
eig = max_eig_arnoldi_jax(wr_sparse, tol=1e-16, max_iters=max_iters, seed=seed).block_until_ready()
end_time = time.time()
print("Time taken (sparse naive):", end_time - start_time)
print("Max eigenvalue (sparse naive):", jnp.abs(eig), "\n")

start_time = time.time()
eig = max_eig_arnoldi_lax(wr_dense, tol=tol, max_iters=max_iters, seed=seed).block_until_ready()
end_time = time.time()
print("Time taken (dense lax):", end_time - start_time)
print("Max eigenvalue (dense lax):", jnp.abs(eig), "\n")

start_time = time.time()
eig = max_eig_arnoldi_lax(wr_sparse, tol=tol, max_iters=max_iters, seed=seed).block_until_ready()
end_time = time.time()
print("Time taken (sparse lax):", end_time - start_time)
print("Max eigenvalue (sparse lax):", jnp.abs(eig), "\n")

Time taken (dense eigvals): 0.8503098487854004
Max eigenvalue (dense eigvals): 1.8291688631675567

Time taken (dense naive): 32.61605787277222
Max eigenvalue (dense naive): 1.829168894290524 

Time taken (sparse naive): 17.40448546409607
Max eigenvalue (sparse naive): 1.8291688942905129 

Time taken (dense lax): 1.6839098930358887
Max eigenvalue (dense lax): 1.8291688617096102 

Time taken (sparse lax): 1.658437728881836
Max eigenvalue (sparse lax): 1.829168861709604 

Time taken (dense lax1): 3.0642952919006348
Max eigenvalue (dense lax1): 1.829168834769205 

Time taken (sparse lax1): 5.257415294647217
Max eigenvalue (sparse lax1): 1.829168834769199 



In [12]:
# --- Timing jnp.linalg.eigvals (Dense) ---
times_dense_eigvals = []
for _ in range(num_trials):
    start_time = time.time()
    _ = jnp.max(jnp.abs(jnp.linalg.eigvals(wr_dense))).block_until_ready()
    end_time = time.time()
    times_dense_eigvals.append(end_time - start_time)
avg_time_dense_eigvals = np.mean(times_dense_eigvals)
print(f"Average time taken (dense eigvals over {num_trials} trials): {avg_time_dense_eigvals:.6f} seconds")
print("=============================================\n")

# --- Timing max_eig_arnoldi_jax (Dense) ---
times_dense_naive = []
for _ in range(num_trials):
    start_time = time.time()
    _ = max_eig_arnoldi_jax(wr_dense, tol=1e-16, max_iters=max_iters, seed=seed).block_until_ready()
    end_time = time.time()
    times_dense_naive.append(end_time - start_time)
avg_time_dense_naive = np.mean(times_dense_naive)
print(f"Average time taken (dense naive over {num_trials} trials): {avg_time_dense_naive:.6f} seconds\n")

# --- Timing max_eig_arnoldi_jax (Sparse) ---
times_sparse_naive = []
for _ in range(num_trials):
    start_time = time.time()
    _ = max_eig_arnoldi_jax(wr_sparse, tol=1e-16, max_iters=max_iters, seed=seed).block_until_ready()
    end_time = time.time()
    times_sparse_naive.append(end_time - start_time)
avg_time_sparse_naive = np.mean(times_sparse_naive)
print(f"Average time taken (sparse naive over {num_trials} trials): {avg_time_sparse_naive:.6f} seconds\n")

# --- Timing max_eig_arnoldi_lax (Dense) ---
times_dense_lax = []
for _ in range(num_trials):
    start_time = time.time()
    _ = max_eig_arnoldi_lax(wr_dense, tol=tol, max_iters=max_iters, seed=seed).block_until_ready()
    end_time = time.time()
    times_dense_lax.append(end_time - start_time)
avg_time_dense_lax = np.mean(times_dense_lax)
print(f"Average time taken (dense lax over {num_trials} trials): {avg_time_dense_lax:.6f} seconds\n")

# --- Timing max_eig_arnoldi_lax (Sparse) ---
times_sparse_lax = []
for _ in range(num_trials):
    start_time = time.time()
    _ = max_eig_arnoldi_lax(wr_sparse, tol=tol, max_iters=max_iters, seed=seed).block_until_ready()
    end_time = time.time()
    times_sparse_lax.append(end_time - start_time)
avg_time_sparse_lax = np.mean(times_sparse_lax)
print(f"Average time taken (sparse lax over {num_trials} trials): {avg_time_sparse_lax:.6f} seconds\n")

NameError: name 'num_trials' is not defined