In [15]:
import os 
import time
os.environ["JAX_ENABLE_X64"] = "True"
# os.environ["JAX_DISABLE_JIT"] = "True"
# os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax
import jax.numpy as jnp
from jax.experimental import sparse

import matplotlib.pyplot as plt

In [16]:
res_dim = 10
key = jax.random.PRNGKey(0)
density = 0.5
wrkey1, wrkey2 = jax.random.split(key)
spec_rad = 0.8

N_nonzero = int(res_dim**2 * density)
wr_indices = jax.random.choice(
    wrkey1,
    res_dim**2,
    shape=(N_nonzero,),
)
wr_vals = jax.random.uniform(
    wrkey2, shape=N_nonzero, minval=-1, maxval=1
)
wr = jnp.zeros(res_dim * res_dim)
wr = wr.at[wr_indices].set(wr_vals)
wr = wr.reshape(res_dim, res_dim)
wr_dense = wr * (spec_rad / jnp.max(jnp.abs(jnp.linalg.eigvals(wr))))
wr_sparse = sparse.BCOO.fromdense(wr_dense)

In [17]:
def create_wr(res_dim, density):
    key = jax.random.PRNGKey(0)
    wrkey1, wrkey2 = jax.random.split(key)
    N_nonzero = int(res_dim**2 * density)
    wr_indices = jax.random.choice(
        wrkey1,
        res_dim**2,
        shape=(N_nonzero,),
    )
    wr_vals = jax.random.uniform(
        wrkey2, shape=N_nonzero, minval=-1, maxval=1
    )
    wr = jnp.zeros(res_dim * res_dim)
    wr = wr.at[wr_indices].set(wr_vals)
    wr = wr.reshape(res_dim, res_dim)
    return wr

In [18]:
jnp.max(jnp.linalg.eigvals(wr))

Array(1.22398162+0.j, dtype=complex128)

In [49]:
def max_eigval_power_iteration(A, seed=0, tol = 1e-12, max_iters=2000):
    key = jax.random.PRNGKey(seed)
    b_k = jax.random.normal(key, shape=(A.shape[0],)) 
    b_k = b_k / jnp.linalg.norm(b_k)

    def continue_iteration(params):
        b_k, b_k_prev, iter_count = params
        return jnp.logical_and(jnp.linalg.norm(b_k - b_k_prev) > tol, iter_count <= max_iters)
    
    def power_iteration_step(params):
        b_k, _, iter_count = params
        b_k_next = A @ b_k
        b_k_next = b_k_next / jnp.linalg.norm(b_k_next) 
        iter_count += 1
        return (b_k_next, b_k, iter_count)

    out = jax.lax.while_loop(continue_iteration,power_iteration_step,(b_k, b_k+1, 0))
    b_k_final, num_iters = out[0], out[2]
    print("Number of iterations / total iters: ", num_iters-1, "/", max_iters)
    max_eigval = jnp.dot(b_k_final.T, A @ b_k_final) / jnp.dot(b_k_final.T, b_k_final) # Rayleigh quotient
    return max_eigval

In [50]:
# test it out 
wr_dense = create_wr(200, 0.02)
wr_sparse = sparse.BCOO.fromdense(wr_dense)
iters = 1000

print("Max eigval from jnp.linalg.eigvals: ", jnp.max(jnp.linalg.eigvals(wr_dense)))
print("Max eigval from power iteration: ", max_eigval_power_iteration(wr_dense,max_iters=10000, tol = 1e-10))
print("Max eigval from sparse power iteration: ", max_eigval_power_iteration(wr_sparse,max_iters=10000, tol = 1e-10))

Max eigval from jnp.linalg.eigvals:  (1.059319272073892+0j)
Number of iterations / total iters:  10000 / 10000
Max eigval from power iteration:  0.9931483664559289
Number of iterations / total iters:  10000 / 10000
Max eigval from sparse power iteration:  0.9931483664559291


In [51]:
def max_eigval_arnoldi(A, seed=0, tol = 1e-12, max_iters=2000):
    key = jax.random.PRNGKey(seed)
    b_k = jax.random.normal(key, shape=(A.shape[0],)) 
    b_k = b_k / jnp.linalg.norm(b_k)

    def continue_iteration(params):
        b_k, b_k_prev, iter_count = params
        return jnp.logical_and(jnp.linalg.norm(b_k - b_k_prev) > tol, iter_count <= max_iters)
    
    def arnoldi_step(params):
        b_k, _, iter_count = params
        b_k_next = A @ b_k
        b_k_next = b_k_next / jnp.linalg.norm(b_k_next) 
        iter_count += 1
        return (b_k_next, b_k, iter_count)

    out = jax.lax.while_loop(continue_iteration,arnoldi_step,(b_k, b_k+1, 0))
    b_k_final, num_iters = out[0], out[2]
    print("Number of iterations / total iters: ", num_iters-1, "/", max_iters)
    max_eigval = jnp.dot(b_k_final.T, A @ b_k_final) / jnp.dot(b_k_final.T, b_k_final) # Rayleigh quotient
    return max_eigval

# test it out
wr_dense = create_wr(200, 0.02)
wr_sparse = sparse.BCOO.fromdense(wr_dense)
iters = 1000
print("Max eigval from jnp.linalg.eigvals: ", jnp.max(jnp.linalg.eigvals(wr_dense)))
print("Max eigval from arnoldi: ", max_eigval_arnoldi(wr_dense,max_iters=10000, tol = 1e-10))
print("Max eigval from sparse arnoldi: ", max_eigval_arnoldi(wr_sparse,max_iters=10000, tol = 1e-10))

Max eigval from jnp.linalg.eigvals:  (1.059319272073892+0j)
Number of iterations / total iters:  10000 / 10000
Max eigval from arnoldi:  0.9931483664559289
Number of iterations / total iters:  10000 / 10000
Max eigval from sparse arnoldi:  0.9931483664559291


# Timing tests

In [7]:
# res_dims = jnp.arange(100, 1001, 100)
# densities = jnp.arange(0.01, 0.21, 0.01)
res_dims = [100, 200, 300]
densities = [0.01, 0.05]
num_trials = 10
results = jnp.zeros((len(res_dims), len(densities), num_trials))

In [None]:
### Dense power iteration
for i, res_dim in enumerate(res_dims):
    for j, density in enumerate(densities):
        for k in range(num_trials):
            wr = create_wr(res_dim, density)
            start_t = time.time()
            _ = max_eigval_power_iteration(wr).block_until_ready()
            end_t = time.time()
            results = results.at[i, j, k].set(end_t - start_t)

# average out the trials
avg_results_dense_pi = jnp.mean(results, axis=2)
std_results_dense_pi = jnp.std(results, axis=2)

2001
2001
2001


In [None]:
### Sparse power iteration
for i, res_dim in enumerate(res_dims):
    for j, density in enumerate(densities):
        for k in range(num_trials):
            wr = create_wr(res_dim, density)
            wr_sparse = sparse.BCOO.fromdense(wr)
            start_t = time.time()
            _ = max_eigval_power_iteration(wr_sparse).block_until_ready()
            end_t = time.time()
            results = results.at[i, j, k].set(end_t - start_t)

# average out the trials
avg_results_sparse_pi = jnp.mean(results, axis=2)
std_results_sparse_pi = jnp.std(results, axis=2)

In [None]:
### Dense jnp.linalg.eigvals
for i, res_dim in enumerate(res_dims):
    for j, density in enumerate(densities):
        for k in range(num_trials):
            wr = create_wr(res_dim, density)
            start_t = time.time()
            _ = jnp.max(jnp.linalg.eigvals(wr)).block_until_ready()
            end_t = time.time()
            results = results.at[i, j, k].set(end_t - start_t)

# average out the trials
avg_results_dense_eig = jnp.mean(results, axis=2)
std_results_dense_eig = jnp.std(results, axis=2)