In [None]:
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt

# Parameters for the initialization
N = 300  # Network size
g1 = 1.5  # Initial weight variance parameter for the first kernel
g2 = 0.5  # Initial weight variance parameter for the second kernel

# Random key for JAX
key = random.PRNGKey(0)
subkey1, subkey2 = random.split(key)

# Initialize the first weight matrix W1 with Gaussian distribution N(0, g1^2 / N)
W1 = random.normal(subkey1, (N, N)) * (g1 / jnp.sqrt(N))

# Initialize the second weight matrix W2 with Gaussian distribution N(0, g2^2 / N)
W2 = random.normal(subkey2, (N, N)) * (g2 / jnp.sqrt(N))

# Step 1: Compute the Eigenvalues for both matrices
eigenvalues_W1 = jnp.linalg.eigvals(W1)
eigenvalues_W2 = jnp.linalg.eigvals(W2)

# Step 2: Calculate the leading eigenvalue and trace for W1
leading_eigenvalue_W1 = jnp.max(jnp.abs(eigenvalues_W1))
trace_W1 = jnp.sum(jnp.abs(eigenvalues_W1))

# Step 3: Calculate the leading eigenvalue and trace for W2
leading_eigenvalue_W2 = jnp.max(jnp.abs(eigenvalues_W2))
trace_W2 = jnp.sum(jnp.abs(eigenvalues_W2))

# Step 4: Calculate the Effective Rank for both matrices
effective_rank_W1 = trace_W1 / (leading_eigenvalue_W1 *N)
effective_rank_W2 = trace_W2 / (leading_eigenvalue_W2 * N)

# Output the effective ranks
print(f'Effective Rank of W1: {effective_rank_W1}')
print(f'Effective Rank of W2: {effective_rank_W2}')

# Plotting the eigenvalues and effective ranks for comparison
plt.figure(figsize=(14, 6))

# Plot eigenvalues for W1
plt.subplot(1, 3, 1)
plt.plot(jnp.sort(jnp.abs(eigenvalues_W1))[::-1], 'o-', label='Eigenvalues W1')
plt.title('Eigenvalues of W1')
plt.xlabel('Index')
plt.ylabel('Eigenvalue Magnitude')
plt.grid(True)

# Plot eigenvalues for W2
plt.subplot(1, 3, 2)
plt.plot(jnp.sort(jnp.abs(eigenvalues_W2))[::-1], 'o-', label='Eigenvalues W2')
plt.title('Eigenvalues of W2')
plt.xlabel('Index')
plt.ylabel('Eigenvalue Magnitude')
plt.grid(True)

# Plot effective ranks comparison
plt.subplot(1, 3, 3)
plt.bar(['W1', 'W2'], [effective_rank_W1, effective_rank_W2], color=['blue', 'orange'])
plt.title('Effective Rank Comparison')
plt.ylabel('Effective Rank')
plt.grid(True)

plt.tight_layout()
plt.show()
