In [4]:
# # installations

# !pip install numpy
# !pip install scipy
# !pip install math
# !pip install scikit-learn
# !pip install pandas
# !pip install matplotlib

In [2]:
import importlib.util as import_util
import_util.find_spec('venv')

ModuleSpec(name='venv', loader=<_frozen_importlib_external.SourceFileLoader object at 0x1106c50d0>, origin='/opt/homebrew/Cellar/python@3.13/3.13.2/Frameworks/Python.framework/Versions/3.13/lib/python3.13/venv/__init__.py', submodule_search_locations=['/opt/homebrew/Cellar/python@3.13/3.13.2/Frameworks/Python.framework/Versions/3.13/lib/python3.13/venv'])

In [3]:
# imports

import numpy as np
import scipy
import math
from collections.abc import Sequence

# secondary imports

import pandas as pd
from sklearn.linear_model import Lasso
import matplotlib.pyplot as plt

In [20]:
# bayes VAMP

type prior = tuple[str, tuple]
# allowed priors: normal, rademacher, 3-point, laplace, bernoulli*normal
# params need to be: (tau2, None), None, (p, q), (theta, None), (p, tau2)

# compute f = E[beta | beta + N(0, 1/gamma_1) = r]
def f_condexp(r_1, gamma_1, prior, prior_params):
    
    # gaussian: beta_bar ~ N(0, tau^2)
    # -----------------------------------------------------------------------------------------
    if prior == "gaussian":
        tau2, _ = prior_params

        # closed form solution
        f_vals = (tau2 * gamma_1 / (1 + tau2 * gamma_1)) * r_1

    # rademacher: beta_bar ~ Uniform{-1, 1}
    # -----------------------------------------------------------------------------------------
    elif prior == "rademacher":
        # no prior params
        
        # closed form solution
        f_vals = np.tanh(gamma_1 * r_1)

    # three point: beta_bar ~ {-1 : p, 0 : q, 1 : (1 - p-q)}
    # -----------------------------------------------------------------------------------------
    elif prior == "three_point":
        theta_1, theta_2 = prior_params

        var = 1/max(gamma_1, 1e-12)
        
        # compute density for the derivative
        def marginal_density(r):
            return (
                theta_1 * scipy.stats.norm.pdf(r, loc = -1, scale = np.sqrt(var))
                + theta_2 * scipy.stats.norm.pdf(r, loc = 0, scale = np.sqrt(var))
                + (1 - theta_1 - theta_2) * scipy.stats.norm.pdf(r, loc = 1, scale = np.sqrt(var))
            )
        
        # compute derivative for numerator
        def density_derivative(r):
            return (
                -gamma_1 * (theta_1 * scipy.stats.norm.pdf(r, loc = -1, scale = np.sqrt(var)) * (r+1)
                + theta_2 * scipy.stats.norm.pdf(r, loc = 0, scale = np.sqrt(var)) * (r)
                + (1 - theta_1 - theta_2) * scipy.stats.norm.pdf(r, loc = 1, scale = np.sqrt(var)) * (r-1))
            )

        # use tweedie's formula 
        f_vals = r_1 + var * density_derivative(r_1) / (marginal_density(r_1) + 1e-8)


    # bernoulli/normal mix: beta_bar ~ bernoulli(theta) * normal(0, tau2)
    # -----------------------------------------------------------------------------------------
    elif prior == "bernoulli_gaussian":
        theta, tau2 = prior_params

        var = 1/max(gamma_1, 1e-16)
        
        # compute density for the derivative
        def marginal_density(r):
            return (1-theta) * scipy.stats.norm.pdf(r, 0, np.sqrt(var)) + theta * scipy.stats.norm.pdf(r, 0, np.sqrt(tau2 + var))
        
        # compute derivative for the numerator
        def density_derivative(r) :
            return -((1-theta) * (gamma_1 * r) * scipy.stats.norm.pdf(r, 0, np.sqrt(var)) + theta * (r / (tau2 + var) * scipy.stats.norm.pdf(r, 0, np.sqrt(tau2 + var))))
        
        # use tweedie's formula
        f_vals = r_1 + var * density_derivative(r_1) / (marginal_density(r_1) + 1e-8)

    else:
        raise ValueError("given prior is not yet supported")

    return f_vals

# for f = E[beta | beta + N(0, 1/gamma_1) = r], compute f'(r) wrt r
def f_derivative(r_1, gamma_1, prior, prior_params):

    # gaussian: beta_bar ~ N(0, tau^2)
    # -----------------------------------------------------------------------------------------
    if prior == "gaussian" :
        tau2, _ = prior_params
        
        # full_like to facilitate return of a vector of values, although constant
        return np.full_like(r_1, (gamma_1 * tau2) / (gamma_1 * tau2 + 1))
    
    # rademacher: beta_bar ~ Uniform{-1, 1}
    # -----------------------------------------------------------------------------------------
    if prior == "rademacher":
        if gamma_1 > 128.0 :
            return np.zeros(len(r_1))
          
        return gamma_1 * (1/np.cosh(gamma_1 * r_1)) ** 2
    

    # three point: beta_bar ~ {-1 : p, 0 : q, 1 : (1 - p-q)}
    # -----------------------------------------------------------------------------------------
    elif prior == "three_point":
        theta_1, theta_2 = prior_params
        theta_3 = 1 - theta_1 - theta_2

        var = 1/max(gamma_1, 1e-12)
        
        # compute density for the derivative
        def marginal_density(r):
            return (
                theta_1 * scipy.stats.norm.pdf(r, loc = -1, scale = np.sqrt(var))
                + theta_2 * scipy.stats.norm.pdf(r, loc = 0, scale = np.sqrt(var))
                + (1 - theta_1 - theta_2) * scipy.stats.norm.pdf(r, loc = 1, scale = np.sqrt(var))
            )
        
        # compute derivative for numerator
        def density_derivative(r):
            return (
                -gamma_1 * (theta_1 * scipy.stats.norm.pdf(r, loc = -1, scale = np.sqrt(var)) * (r+1)
                + theta_2 * scipy.stats.norm.pdf(r, loc = 0, scale = np.sqrt(var)) * (r)
                + (1 - theta_1 - theta_2) * scipy.stats.norm.pdf(r, loc = 1, scale = np.sqrt(var)) * (r-1))
            )
        
        def density_2ndderivative(r) :
            return (
                -gamma_1 * (theta_1 * scipy.stats.norm.pdf(r, loc = -1, scale = np.sqrt(var)) * (1 - gamma_1 * (r+1)**2) 
                            + theta_2 * scipy.stats.norm.pdf(r, scale = np.sqrt(var)) * (1 - gamma_1 * r**2) 
                            + theta_3 * scipy.stats.norm.pdf(r, loc = 1, scale = np.sqrt(var)) * (1 - gamma_1 * (r-1)**2))
            )
        
        f_prime_vals = 1 + var * (density_2ndderivative(r_1) * marginal_density(r_1) - density_derivative(r_1)**2) / np.maximum(marginal_density(r_1)**2, 1e-16)

    # bernoulli/normal mix: beta_bar ~ bernoulli(theta) * normal(0, tau2)
    # -----------------------------------------------------------------------------------------
    elif prior == "bernoulli_gaussian":
        theta, tau2 = prior_params

        var = 1/max(gamma_1, 1e-16)
        
        # compute density for the derivative
        def marginal_density(r):
            return (1-theta) * scipy.stats.norm.pdf(r, 0, np.sqrt(var)) + theta * scipy.stats.norm.pdf(r, 0, np.sqrt(tau2 + var))
        
        # compute derivative for the numerator
        def density_derivative(r) :
            return -((1-theta) * (gamma_1 * r) * scipy.stats.norm.pdf(r, 0, np.sqrt(var)) + theta * (r / (tau2 + var) * scipy.stats.norm.pdf(r, 0, np.sqrt(tau2 + var))))
        
        def density_2ndderivative(r) :
            return (1 - theta) * scipy.stats.norm.pdf(r, 0, np.sqrt(var)) * (-gamma_1 + (gamma_1 * r)**2) + theta * scipy.stats.norm.pdf(r, 0, np.sqrt(tau2 + var)) * (-1/(tau2 + var) + (r/(tau2 + var))**2)
        
        f_prime_vals = 1 + var * (density_2ndderivative(r_1) * marginal_density(r_1) - density_derivative(r_1)**2) / np.maximum(marginal_density(r_1)**2, 1e-16)

    return f_prime_vals

# bvamp_tag
def vamp_bayes(X, y, prior_info : Sequence[prior], oracle_sigma2, max_iter = 100, tol = 1e-8, retrieve = False, verbose = False) :

    _, p = X.shape
    # delta_inv = p / n

    prior_name, prior_params = prior_info

    # initialization
    r_1_k = 0.01 * np.ones(p)
    gamma_1_k = 0.05

    # empty arrays to store iterates
    beta_hat = []
    r_1 = []
    r_2 = []

    # iterate
    for k in range(max_iter) :

        # add r_1 to storage
        r_1.append(r_1_k)

        # -----------------------------------------------------------------------------------------

        # updates part 1 --- f and f'
        
        # update beta_hat
        beta_hat_k = f_condexp(r_1_k, gamma_1_k, prior_name, prior_params)

        if (np.isnan(gamma_1_k)) :
            print('uh oh gamma_1k')
            return

        # update b
        b_k = np.mean(f_derivative(r_1_k, gamma_1_k, prior_name, prior_params))

        if (np.isnan(b_k)) :
            print('uh oh b_k')
            return
        
        # -----------------------------------------------------------------------------------------

        # parts where prior doesn't matter
        
        # update eta 1
        eta_1_k = gamma_1_k / (max(b_k, 1e-8))

        if (np.isnan(eta_1_k)) :
            print('uh oh eta_1k')
            return
        
        # update gamma 2
        gamma_2_k = eta_1_k - gamma_1_k

        if (np.isnan(gamma_2_k)) :
            print('uh oh gamma_2k')
            return

        # update r_2
        r_2_k = (1 / (gamma_2_k + 1e-8)) * (eta_1_k * beta_hat_k - gamma_1_k * r_1_k)

        # -----------------------------------------------------------------------------------------
        
        # append new updates for storage
        beta_hat.append(beta_hat_k)
        r_2.append(r_2_k)

        # convergence conditions
        # -----------------------------------------------------------------------------------------

        if k > 1 and (1/p * np.linalg.norm(beta_hat[-1] - beta_hat[-2])) < tol :
            if verbose :
                print("converged at iteration " + str(k))
            if retrieve :
                return beta_hat, r_1, r_2
            else :
                return beta_hat[-1]
            
        if k > 1 :
            diff1 = beta_hat_k + np.ones(p)
            diff2 = beta_hat_k - np.ones(p)
            
            deviations = np.min(np.array([diff1, diff2, beta_hat_k]), axis = 0)

            if np.mean(deviations) <= 1e-4 :
                if verbose :
                    print("converged at iteration " + str(k))
                if retrieve :
                    return beta_hat, r_1, r_2
                else :
                    return beta_hat[-1]
            
        # -----------------------------------------------------------------------------------------

        # updates part 2, continue
        
        # update c for gamma_1
        inv_arg = np.linalg.inv((1/oracle_sigma2) * X.T @ X + (gamma_2_k+1e-8) * np.eye(p))
        c_k = (1/p) * np.trace(gamma_2_k * inv_arg)

        if (np.isnan(c_k)) :
            print('uh oh c_k')
            return

        # update gamma_1
        gamma_1_k = gamma_2_k * (1/(np.clip(c_k, 1e-12, 1-1e-12)) - 1)

        if (np.isnan(gamma_1_k)) :
            print('uh oh gamma_1k')
            return

        # update r_1
        r_1_k = (1 / max(1-c_k, 1e-8)) * (inv_arg @ (1/oracle_sigma2 * X.T @ y + gamma_2_k * r_2_k) - c_k * r_2_k)

    if retrieve :
        return beta_hat, r_1, r_2
    
    if verbose :
        print("did not converge early")
        
    return beta_hat[-1]

In [26]:
# lvamp_tag
# VAMP LASSO

def vamp_lasso(X, y, reg_lambda = 1, oracle_sigma2 = 1, max_iter = 100, tol = 1e-8, verbose = False, retrieve = False) :

    # soft thresholding denoiser for lasso
    def soft_threshold(x, threshold):
        return np.sign(x) * np.maximum(np.abs(x) - threshold, 0)

    def subgradient_soft_threshold(r_1, threshold):
        grad = np.zeros_like(r_1)
        grad[np.abs(r_1) > threshold] = 1.0
        grad[np.isclose(np.abs(r_1), threshold)] = 0.75
        
        return grad

    _, p = X.shape

    # initialization
    r_1 = np.random.normal(loc = 0, scale = 1, size = p)
    gamma_1 = 1

    beta_hat_store = []
    r_1_store = []
    r_2_store = []

    # iterate
    for _ in range(max_iter) :

        r_1_store.append(r_1)
        
        # update beta_hat
        beta_hat = soft_threshold(r_1, reg_lambda / (gamma_1+1e-8))

        beta_hat_store.append(beta_hat)

        # check divergence
        if np.isnan(np.linalg.norm(beta_hat)) :
            print("divergence")
            return beta_hat

        # update b
        b = np.mean(subgradient_soft_threshold(r_1, reg_lambda / (gamma_1+1e-8)))

        # update eta 1
        eta_1 = gamma_1 / (b+1e-8)
        
        # update gamma 2
        gamma_2 = eta_1 - gamma_1

        # update r_2
        r_2 = (1 /(gamma_2+1e-12)) * (eta_1*beta_hat - gamma_1*r_1)

        r_2_store.append(r_2)

        # update c for gamma_1
        inv_arg = np.linalg.inv((1/oracle_sigma2) * X.T @ X + gamma_2 * np.eye(p))
        c = (1/p) * np.trace(gamma_2 * inv_arg)

        # update gamma_1
        gamma_1 = gamma_2 * (1/(c + 1e-8) - 1)

        # update r_1
        r_1 = (1 / (1-c + 1e-8)) * (inv_arg @ (1/oracle_sigma2 * X.T @ y + gamma_2*r_2) - c*r_2)

    if retrieve :
        return beta_hat_store, r_1_store, r_2_store
    
    return beta_hat_store[-1]

In [25]:
# rvamp_tag
# VAMP RIDGE

def vamp_ridge(X, y, reg_lambda = 1, oracle_sigma2 = 1, max_iter = 100, tol = 1e-8, verbose = False) :

    # soft thresholding denoiser for ridge
    def l2_shrinkage(x, threshold):
        return (1.0 / (1.0 + threshold)) * x

    def gradient_shrinkage(x, threshold):
        scale = 1.0 / (1.0 + threshold)
        return np.full_like(x, scale)

    n, p = X.shape
    # delta_inv = p / n

    # initialization
    r_1 = np.random.normal(loc = 0, scale = 1, size = p)
    gamma_1 = 1.0

    # iterate
    for k in range(max_iter) :
        
        # update beta_hat
        beta_hat = l2_shrinkage(r_1, reg_lambda / (gamma_1+1e-12))

        # check divergence
        if np.isnan(np.linalg.norm(beta_hat)) :
            print("divergence")
            return beta_hat

        # update b
        b = np.mean(gradient_shrinkage(r_1, reg_lambda / (gamma_1+1e-12)))

        # update eta 1
        eta_1 = gamma_1 / (b+1e-12)
        
        # update gamma 2
        gamma_2 = eta_1 - gamma_1

        # update r_2
        r_2 = (1/ (gamma_2+1e-12)) * (eta_1 * beta_hat - gamma_1 * r_1)

        # update c for gamma_1
        inv_arg = np.linalg.inv((1/oracle_sigma2) * X.T @ X + (gamma_2+1e-12) * np.eye(p))
        c = (1/p) * np.trace(gamma_2 * inv_arg)

        # update gamma_1
        gamma_1 = gamma_2 * (1/(c+1e-12) - 1)

        # update r_1
        r_1 = (1 / (1-c + 1e-12)) * (inv_arg @ (1/oracle_sigma2 * X.T @ y + gamma_2*r_2) - c*r_2)

    return beta_hat

In [6]:
# this cell is the data generation helpers

def generate_beta(p, mean = 0.0, var = 1.0, prior = "gaussian") :
    if prior == "gaussian" :
        beta = np.random.normal(loc = 0, scale = math.sqrt(var), size = p)

    elif prior == "sparse_gaussian" :
        print("not done yet 5")
        
    elif prior == "other" :
        print("not done yet 3")
    return beta



def generate_gaussian_design(n, p, symmetry = "asymmetric") :
    if symmetry == "symmetric" :
        if n != p :
            print("symmetric matrices need to have same row and column dimensions")
        else :
            print("not done yet 1")

    elif symmetry == "asymmetric" :
        matrix = np.random.normal(loc = 0, scale = 1 / math.sqrt(n), size = (n, p))
        
    else :
        print("invalid matrix type")
        
    return matrix



def generate_rri_design(n, p, k = -1, method = "lnn") :
    if method == "lnn" :
        if k == -1 :
            k = p
            
        X_1 = np.random.normal(loc = 0, scale = 1, size = (n, k))
        X_2 = np.random.normal(loc = 0, scale = 1, size = (k, k))
        X_3 = np.random.normal(loc = 0, scale = 1, size = (k, k))
        X_4 = np.random.normal(loc = 0, scale = 1, size = (k, p))
        
        X = (1 / math.sqrt(n)) * X_1 @ X_2 @ X_3 @ X_4

    elif method == "heavy_tail" :
        df = 3
        mean = np.zeros(p)
        scale = np.eye(p)
        X = scipy.stats.multivariate_t.rvs(loc = mean, shape = scale, df = df, size = n)

    elif method == "spiked" :
        X, V, W, signal, noise = generate_spiked(n, p)
    
    else :
        print("invalid matrix type")

    return X



def generate_response(design_matrix, signal, noise_var) :
    n, _ = design_matrix.shape
    error = np.random.normal(loc = 0, scale = math.sqrt(noise_var), size = n)

    return design_matrix @ signal + error



# helper for generating spiked matrix
def haar_orthonormal_columns(n, m):
    G = np.random.normal(size = (n, m))
    Q, R = np.linalg.qr(G)
    
    signs = np.sign(np.diag(R))
    signs[signs == 0] = 1.0
    Q = Q * signs[np.newaxis, :]
    return Q

# helper for spiked matrix generation
def generate_spiked(n, p, m = 50, alpha = 10.0, seed = None):
    if m > min(n, p):
        raise ValueError("m must be <= min(n, p)")
    
    V = haar_orthonormal_columns(n, m) 
    W = haar_orthonormal_columns(p, m) 

    signal = alpha * (V @ W.T)               
    noise = (1.0/math.sqrt(n)) * np.random.normal(size = (n, p))

    X = signal + noise

    return X, V, W, signal, noise

In [18]:
# check for phase transition in 3-point for bayes VAMP.
# tagged1

delta = 0.8

p = 500
n = int(delta * p)

sigmas2 = [0.01] # [0.001, 0.0025, 0.005, 0.0075, 0.01, 0.0125, 0.015, 0.0175, 0.02, 0.0225, 0.025, 0.0275, 0.03]
tau2 = 1

theta1 = 0.2
theta2 = 0.4

mse_bvamp = []
mse_lvamp = []

for sigma2 in sigmas2 :
    errors_bvamp = []
    errors_lvamp = []

    for _ in range(20) :

        # use bernoulli * normal
        beta = np.random.choice((-1, 0, 1), p = (theta1, theta2, 1 - theta1 - theta2), size = p)#np.random.binomial(1, theta, size = p) * np.random.normal(0, scale = np.sqrt(tau2), size = p)
        prior_info = ("three_point", (theta1, theta2))

        X = generate_rri_design(n, p, method = "lnn")
        y = generate_response(X, beta, sigma2)
        
        lambda_lvamp = 0.1

        estimate_bvamp = vamp_bayes(X, y, prior_info, oracle_sigma2 = sigma2)

        if not np.all(np.isnan(np.linalg.norm(estimate_bvamp - beta))) :
            errors_bvamp.append(1/p * np.linalg.norm(estimate_bvamp - beta) ** 2)
    
    mse_bvamp.append(np.median(errors_bvamp))


In [19]:
mse_bvamp

[np.float64(0.01561818968878742)]

In [23]:
# check for phase transition in 3-point for bayes VAMP.
# tagged1

delta = 0.8

p = 500
n = int(delta * p)

sigmas2 = [0.01] # [0.001, 0.0025, 0.005, 0.0075, 0.01, 0.0125, 0.015, 0.0175, 0.02, 0.0225, 0.025, 0.0275, 0.03]
tau2 = 1

theta = 0.3

theta1 = 0.2
theta2 = 0.4

mse_bvamp = []
mse_lvamp = []

for sigma2 in sigmas2 :
    errors_bvamp = []
    errors_lvamp = []

    for _ in range(20) :

        # use bernoulli * normal
        beta = np.random.binomial(1, theta, size = p) * np.random.normal(0, scale = np.sqrt(tau2), size = p) # np.random.choice((-1, 0, 1), p = (theta1, theta2, 1 - theta1 - theta2), size = p)#
        prior_info = ("bernoulli_gaussian", (theta, tau2))

        X = generate_rri_design(n, p, method = "lnn")
        y = generate_response(X, beta, sigma2)
        
        lambda_lvamp = 0.1

        estimate_bvamp = vamp_bayes(X, y, prior_info, oracle_sigma2 = sigma2)

        if not np.all(np.isnan(np.linalg.norm(estimate_bvamp - beta))) :
            errors_bvamp.append(1/p * np.linalg.norm(estimate_bvamp - beta) ** 2)
    
    mse_bvamp.append(np.median(errors_bvamp))


In [24]:
mse_bvamp

[np.float64(0.007306725264257463)]