In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy
import sklearn as sk 
import jax 
import jax.numpy as jnp
import optax
from jax import vmap


In [None]:
def l1_sparsity(params):
    return jnp.linalg.norm(params, 1)

def l1_sparsity_activation(est_by_term):
    # assume that est_by_term is of shape (N, p)
    return jnp.linalg.norm(est_by_term, ord = 1, axis= 1)

def kld_sparsity_activation(est_by_term):
    # est_by_term: shape (N, p)
    N, p = est_by_term.shape

    # Normalize each row to form a distribution
    normed = est_by_term / (jnp.linalg.norm(est_by_term, axis=1, keepdims=True) + 1e-8)
    
    # Define the target (uniform) distribution
    uniform = jnp.ones((p,)) / p

    # Compute KL divergence for each row
    def row_kld(row):
        return jnp.sum(jax.scipy.special.kl_div(row, uniform))

    # vmap across rows
    kld_values = vmap(row_kld)(jnp.abs(normed))  # shape (N,)
    return kld_values


In [None]:
key = jax.random.PRNGKey(2)
X = jax.random.normal(key, (5, 10))
kld_sparsity_activation(X)

In [None]:
ests = jax.random.normal(jax.random.PRNGKey(3), (100, 4))
kld_sparsity_activation(ests)

In [None]:
import jax
import jax.numpy as jnp
import optax
import matplotlib.pyplot as plt
from jax import grad, jit, vmap
from jax.scipy.special import kl_div
from functools import partial


def generate_data(N=100, p=20, sparsity_level=0.3, seed=0):
    key = jax.random.PRNGKey(seed)
    num_nonzero = int(p * sparsity_level)
    beta_true = jnp.zeros(p)
    nonzero_indices = jax.random.choice(key, p, (num_nonzero,), replace=False)
    beta_true = beta_true.at[nonzero_indices].set(jax.random.normal(key, (num_nonzero,)))
    key, subkey = jax.random.split(key)
    X = jax.random.normal(subkey, (N, p))
    noise = 0.1 * jax.random.normal(subkey, (N,))
    y = X @ beta_true + noise
    return X, y, beta_true

def predict(X, beta):
    return X @ beta

def mse_loss(X, y, beta):
    return jnp.mean((predict(X, beta) - y) ** 2)

def l1_param(beta):
    return jnp.sum(jnp.abs(beta))

def l1_activation(X, beta):
    act = X * beta  # shape (N, p)
    return jnp.mean(jnp.linalg.norm(act, ord=1, axis=1))

def kl_activation(X, beta, eps=1e-8):
    act = jnp.abs(X * beta)
    normed = act / (jnp.linalg.norm(act, axis=1, keepdims=True) + eps)
    uniform = jnp.ones(X.shape[1]) / X.shape[1]
    def row_kl(row): return jnp.sum(kl_div(row, uniform))
    return jnp.mean(vmap(row_kl)(normed))

@partial(jit, static_argnames=["penalty_type"])
def loss_fn(beta, X, y, lam, penalty_type):
    if penalty_type == "l1_param":
        return mse_loss(X, y, beta) + lam * l1_param(beta)
    elif penalty_type == "l1_activation":
        return mse_loss(X, y, beta) + lam * l1_activation(X, beta)
    elif penalty_type == "kl_activation":
        return mse_loss(X, y, beta) + lam * kl_activation(X, beta)
    else:
        raise ValueError(f"Unknown penalty type: {penalty_type}")

def fit_model(X, y, lam, penalty_type, num_steps=1000, lr=1e-2):
    beta = jnp.zeros(X.shape[1])
    optimizer = optax.sgd(lr)
    opt_state = optimizer.init(beta)

    def step(beta, opt_state):
        loss, grads = jax.value_and_grad(loss_fn)(beta, X, y, lam, penalty_type)
        updates, opt_state = optimizer.update(grads, opt_state)
        beta = optax.apply_updates(beta, updates)
        return beta, opt_state

    for i in range(num_steps):
        beta, opt_state = step(beta, opt_state)
    return beta

def compare_sparsity_models(N=100, p=20, sparsity_level=0.3, lam=1.0):
    X, y, beta_true = generate_data(N, p, sparsity_level)

    beta_l1 = fit_model(X, y, lam, "l1_param")
    beta_l1act = fit_model(X, y, lam, "l1_activation")
    beta_klact = fit_model(X, y, lam, "kl_activation")

    return {
        "beta_true": beta_true,
        "beta_l1": beta_l1,
        "beta_l1_activation": beta_l1act,
        "beta_kl_activation": beta_klact
    }

def plot_coefficients(results):
    p = results["beta_true"].shape[0]
    labels = [f"$\\beta_{{{i}}}$" for i in range(p)]

    fig, ax = plt.subplots(figsize=(12, 5))
    width = 0.2
    x = jnp.arange(p)

    ax.bar(x - 1.5*width, results["beta_true"], width=width, label="True", color="black")
    ax.bar(x - 0.5*width, results["beta_l1"], width=width, label="L1 Weights", alpha=0.7)
    ax.bar(x + 0.5*width, results["beta_l1_activation"], width=width, label="L1 Activation", alpha=0.7)
    ax.bar(x + 1.5*width, results["beta_kl_activation"], width=width, label="KL Activation", alpha=0.7)

    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=45)
    ax.set_ylabel("Coefficient Value")
    ax.set_title("Comparison of Learned Coefficients with Different Sparsity Penalties")
    ax.legend()
    plt.tight_layout()
    plt.show()

# Run and plot
results = compare_sparsity_models()
plot_coefficients(results)

In [None]:
results = compare_sparsity_models()

for name, beta in results.items():
    print(f"{name}: {beta.round(2)}")

In [None]:
import numpy as np

def generate_sample():
    x = np.random.uniform(-1, 1, 2)
    y = np.random.uniform(-1, 1, 2)
    return np.kron(x, y)

A = np.vstack([generate_sample() for _ in range(4)])
rank = np.linalg.matrix_rank(A)
print("Rank of A:", rank)

for _ in range(500):
    A = np.vstack([generate_sample() for _ in range(4)])
    rank = np.linalg.matrix_rank(A)
    if rank != 4:
        print('rank is not 4')
        print(rank) 
print('done')