In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax.random as random

def compare_effective_ranks(W1, W2):
    """
    Compare the effective ranks of two weight matrices and generate comparison plots.

    Args:
        W1 (jax.numpy.ndarray): The first weight matrix.
        W2 (jax.numpy.ndarray): The second weight matrix.
    """
    
    # Step 1: Compute the Eigenvalues for both matrices
    eigenvalues_W1 = jnp.linalg.eigvals(W1)
    eigenvalues_W2 = jnp.linalg.eigvals(W2)

    # Step 2: Calculate the leading eigenvalue and trace for W1
    leading_eigenvalue_W1 = jnp.max(jnp.abs(eigenvalues_W1))
    trace_W1 = jnp.sum(jnp.abs(eigenvalues_W1))

    # Step 3: Calculate the leading eigenvalue and trace for W2
    leading_eigenvalue_W2 = jnp.max(jnp.abs(eigenvalues_W2))
    trace_W2 = jnp.sum(jnp.abs(eigenvalues_W2))

    # Step 4: Calculate the Effective Rank for both matrices
    effective_rank_W1 = trace_W1 / (leading_eigenvalue_W1  * jnp.shape(W1)[0])
    effective_rank_W2 = trace_W2 / (leading_eigenvalue_W2 * jnp.shape(W2)[0])

    # Output the effective ranks
    print(f'Effective Rank of W1: {effective_rank_W1}')
    print(f'Effective Rank of W2: {effective_rank_W2}')

    # Plotting the eigenvalues and effective ranks for comparison
    plt.figure(figsize=(14, 6))

    # Plot eigenvalues for W1
    plt.subplot(1, 3, 1)
    plt.plot(jnp.sort(jnp.abs(eigenvalues_W1))[::-1], 'o-', label='Eigenvalues W1')
    plt.title('Eigenvalues of W1')
    plt.xlabel('Index')
    plt.ylabel('Eigenvalue Magnitude')
    plt.grid(True)

    # Plot eigenvalues for W2
    plt.subplot(1, 3, 2)
    plt.plot(jnp.sort(jnp.abs(eigenvalues_W2))[::-1], 'o-', label='Eigenvalues W2')
    plt.title('Eigenvalues of W2')
    plt.xlabel('Index')
    plt.ylabel('Eigenvalue Magnitude')
    plt.grid(True)

    # Plot effective ranks comparison
    plt.subplot(1, 3, 3)
    plt.bar(['W1', 'W2'], [effective_rank_W1, effective_rank_W2], color=['blue', 'orange'])
    plt.title('Effective Rank Comparison')
    plt.ylabel('Effective Rank')
    plt.grid(True)

    plt.tight_layout()
    plt.show()

# Example usage
key = random.PRNGKey(0)
N = 300
g1 = 1.5
g2 = 0.5

# Generate two random weight matrices
W1 = random.normal(key, (N, N)) * (g1 / jnp.sqrt(N))
W2 = random.normal(key, (N, N)) * (g2 / jnp.sqrt(N))

# Compare their effective ranks
compare_effective_ranks(W1, W2)


In [2]:
import extra_initializers
n_rec = 100
grid_shape = (10,10)
key = random.PRNGKey(2150)
sigma = 0.012
subkey1, subkey2, subkey3 = random.split(key, 3)

In [None]:
cells_loc = extra_initializers.initialize_neurons_position(gridshape=grid_shape, key=subkey1, n_rec=n_rec)()
cells_loc.shape

In [4]:
M =extra_initializers.initialize_connectivity_mask(local_connectivity=True, gridshape=grid_shape,
                                                    neuron_indices=cells_loc, key=subkey2,
                                                    n_rec=n_rec, sigma=sigma
)()

In [5]:
from flax.linen import initializers
init = initializers.kaiming_normal()
W1 = extra_initializers.generalized_initializer(init_fn=init, gain=1.0,
                                                 avoid_self_recurrence=True, mask_connectivity=None)(key=subkey3, shape=(n_rec, n_rec))
W2 = extra_initializers.generalized_initializer(init_fn=init, gain=1.0,
                                                 avoid_self_recurrence=True, mask_connectivity=M)(key=subkey3, shape=(n_rec, n_rec))

In [None]:
compare_effective_ranks(W1, W2)

In [None]:
W2

In [None]:
W1