In [None]:
# imports

import numpy as np
import scipy
import math
from collections.abc import Sequence

# secondary imports

import pandas as pd
from sklearn.linear_model import Lasso
import matplotlib.pyplot as plt

In [None]:
# BAYES VAMP STATE EVOLUTION

# MC method for 1-D gaussian, for discrete parts of beta
# 0, eta2_1
def gaussian_1dmc(h, mean, var, n_samples = 1000000, retrieve = False):
    sample_size = n_samples // 2
    
    Z_generated = np.random.normal(size = sample_size)
    Z_symmetric = np.concatenate([Z_generated, -Z_generated])
    
    X = mean + np.sqrt(max(var, 0.0)) * Z_symmetric

    f_vals = h(X)
    
    exp_estimate = np.mean(f_vals)
    
    mc_se = np.std(f_vals, ddof=1) / np.sqrt(f_vals.size)
    
    if retrieve:
        return exp_estimate, mc_se
    
    return exp_estimate


# MC method for 2-D gaussian, for continuous parts of beta
# entries are P_k, [continuous part] beta
# 0, eta2_1, 0, tau2
def gaussian_2dmc(h, mean_x, var_x, mean_y, var_y, n_samples = 1000000, retrieve = False):
    sample_size = n_samples // 2
    
    Z_x_gen = np.random.normal(size = sample_size)
    Z_y_gen = np.random.normal(size = sample_size)

    Z_x = np.concatenate([Z_x_gen, -Z_x_gen])
    Z_y = np.concatenate([Z_y_gen, -Z_y_gen])
    
    X = mean_x + np.sqrt(max(var_x, 0.0)) * Z_x
    Y = mean_y + np.sqrt(var_y) * Z_y

    # P_k, beta [continuous part]
    f_vals = h(X, Y)
    
    exp_estimate = np.mean(f_vals)
    
    mc_se = np.std(f_vals, ddof=1) / np.sqrt(f_vals.size)
    
    if retrieve:
        return exp_estimate, mc_se
    
    return exp_estimate



# compute f = E[beta | beta + N(0, 1/gamma_1) = (p_k + beta)]
def f_condexp(p_k, beta, gamma_1, prior_name, prior_params):

    inv_var = 1/max(gamma_1, 1e-16)
    
    # gaussian: beta_bar ~ N(0, tau^2)
    # -----------------------------------------------------------------------------------------
    if prior_name == "gaussian":
        tau2, _ = prior_params

        # closed form solution
        f_vals = (tau2 * gamma_1 / (1 + tau2 * gamma_1)) * (p_k + beta)


    # rademacher: beta_bar ~ Uniform{-1, 1}
    # -----------------------------------------------------------------------------------------
    elif prior_name == "rademacher":
        # no prior params
        
        # closed form solution
        f_vals = np.tanh(gamma_1 * (p_k + beta))


    # three point: beta_bar ~ {-1 : theta1, 0 : theta2, 1 : (1 - theta1 - theta2)}
    # -----------------------------------------------------------------------------------------
    elif prior_name == "three_point":
        theta_1, theta_2 = prior_params
        theta_3 = 1 - theta_1 - theta_2
        
        # compute density for the derivative
        def marginal_density(r):
            return (
                theta_1 * scipy.stats.norm.pdf(r, loc = -1, scale = np.sqrt(inv_var))
                + theta_2 * scipy.stats.norm.pdf(r, scale = np.sqrt(inv_var))
                + theta_3 * scipy.stats.norm.pdf(r, loc = 1, scale = np.sqrt(inv_var))
            )
        
        # compute derivative for numerator
        def density_derivative(r):
            return (
                -gamma_1 * (theta_1 * scipy.stats.norm.pdf(r, loc = -1, scale = np.sqrt(inv_var)) * (r + 1)
                + theta_2 * scipy.stats.norm.pdf(r, scale = np.sqrt(inv_var)) * (r)
                + theta_3 * scipy.stats.norm.pdf(r, loc = 1, scale = np.sqrt(inv_var)) * (r - 1))
            )

        # tweedie's formula 
        f_vals = (p_k + beta) + inv_var * density_derivative((p_k + beta)) / np.maximum(marginal_density((p_k + beta)), 1e-16)


    # bernoulli/gaussian mix: beta_bar ~ bernoulli(theta) * Normal(0, tau^2)
    # -----------------------------------------------------------------------------------------
    elif prior_name == "bernoulli_gaussian":
        theta, tau2 = prior_params

        # compute density for the derivative
        def marginal_density(r):
            return (1-theta) * scipy.stats.norm.pdf(r, 0, np.sqrt(inv_var)) + theta * scipy.stats.norm.pdf(r, 0, np.sqrt(tau2 + inv_var))
        
        # compute derivative for the numerator
        def density_derivative(r) :
            return -1 * ((1-theta) * scipy.stats.norm.pdf(r, 0, np.sqrt(inv_var)) * (gamma_1 * r) + theta * scipy.stats.norm.pdf(r, 0, np.sqrt(tau2 + inv_var)) * (r / (tau2 + inv_var)))
        
        # use tweedie's formula
        f_vals = (p_k + beta) + inv_var * density_derivative((p_k + beta)) / np.maximum(marginal_density((p_k + beta)), 1e-16)

    else:
        raise ValueError("given prior is not yet supported")

    return np.array(f_vals)



# get f', the derivative of conditional expectation f
def f_derivative(p_k, beta, gamma_1, prior_name, prior_params) :

    inv_var = 1/max(gamma_1, 1e-16)

    # three point: beta_bar ~ {-1 : theta1, 0 : theta2, 1 : (1 - theta1 - theta2)}
    # -----------------------------------------------------------------------------------------
    if prior_name == "three_point":
        theta_1, theta_2 = prior_params
        theta_3 = 1 - theta_1 - theta_2

        # compute density for the derivative
        def marginal_density(r):
            return (
                theta_1 * scipy.stats.norm.pdf(r, loc = -1, scale = np.sqrt(inv_var))
                + theta_2 * scipy.stats.norm.pdf(r, scale = np.sqrt(inv_var))
                + theta_3 * scipy.stats.norm.pdf(r, loc = 1, scale = np.sqrt(inv_var))
            )
        
        # compute derivative for numerator
        def density_derivative(r):
            return (
                -gamma_1 * (theta_1 * scipy.stats.norm.pdf(r, loc = -1, scale = np.sqrt(inv_var)) * (r + 1)
                            + theta_2 * scipy.stats.norm.pdf(r, scale = np.sqrt(inv_var)) * (r)
                            + theta_3 * scipy.stats.norm.pdf(r, loc = 1, scale = np.sqrt(inv_var)) * (r - 1))
            )
        
        def density_2ndderivative(r) :
            return (
                -gamma_1 * (theta_1 * scipy.stats.norm.pdf(r, loc = -1, scale = np.sqrt(inv_var)) * (1 - gamma_1 * (r+1)**2) 
                            + theta_2 * scipy.stats.norm.pdf(r, scale = np.sqrt(inv_var)) * (1 - gamma_1 * r**2) 
                            + theta_3 * scipy.stats.norm.pdf(r, loc = 1, scale = np.sqrt(inv_var)) * (1 - gamma_1 * (r-1)**2))
            )
        
        f_prime_vals = 1 + inv_var * (density_2ndderivative(p_k + beta) * marginal_density(p_k + beta) - density_derivative(p_k + beta)**2) / np.maximum(marginal_density(p_k + beta)**2, 1e-16)


    # bernoulli/gaussian mix: beta_bar ~ bernoulli(theta) * Normal(0, tau^2)
    # -----------------------------------------------------------------------------------------
    if prior_name == "bernoulli_gaussian":
        theta, tau2 = prior_params

        # compute density for the derivative
        def marginal_density(r):
            return (1-theta) * scipy.stats.norm.pdf(r, 0, np.sqrt(inv_var)) + theta * scipy.stats.norm.pdf(r, 0, np.sqrt(tau2 + inv_var))
        
        # compute derivative for the numerator
        def density_derivative(r) :
            return ((1-theta) * scipy.stats.norm.pdf(r, 0, np.sqrt(inv_var)) * (-gamma_1 * r) 
                    + theta * scipy.stats.norm.pdf(r, 0, np.sqrt(tau2 + inv_var)) * (-r / (tau2 + inv_var)))
        
        def density_2ndderivative(r) :
            return (1 - theta) * scipy.stats.norm.pdf(r, 0, np.sqrt(inv_var)) * (-gamma_1 + (gamma_1 * r)**2) + theta * scipy.stats.norm.pdf(r, 0, np.sqrt(tau2 + inv_var)) * (-1/(tau2 + inv_var) + (r/(tau2 + inv_var))**2)
        
        f_prime_vals = 1 + inv_var * (density_2ndderivative(p_k + beta) * marginal_density(p_k + beta) - density_derivative(p_k + beta)**2) / np.maximum(marginal_density(p_k + beta)**2, 1e-16)

    
    # bernoulli/gaussian mix: beta_bar ~ bernoulli(theta) * Normal(0, tau^2)
    # -----------------------------------------------------------------------------------------
    if prior_name == "rademacher" :
        u = gamma_1 * (p_k + beta)
    
        log_denominator = np.logaddexp(u, -u)
        
        log_sech_u = np.log(2.0) - log_denominator
        sech_u = np.exp(log_sech_u)
        
        f_prime_vals = gamma_1 * sech_u

    return np.array(f_prime_vals)
        

# computing E[f_k']
def compute_E_fprime_bayes(prior, gamma_1, eta2_1, precision):
    prior_name, prior_params = prior
    
    # gaussian: beta_bar ~ N(0, tau^2)
    # -----------------------------------------------------------------------------------------
    if prior_name == "gaussian":
        tau2, _ = prior_params
        return (tau2 * gamma_1) / (1 + tau2 * gamma_1)
    

    # three point: beta_bar ~ {-1 : theta1, 0 : theta2, 1 : (1 - theta1 - theta2)}
    # -----------------------------------------------------------------------------------------
    if prior_name == "three_point" :
        theta_1, theta_2 = prior_params
        theta_3 = 1 - theta_1 - theta_2

        def f_prime_part1(p_k) :
            return f_derivative(p_k, -1, gamma_1, prior_name, prior_params)
        
        def f_prime_part2(p_k) :
            return f_derivative(p_k, 0, gamma_1, prior_name, prior_params)
        
        def f_prime_part3(p_k) :
            return f_derivative(p_k, 1, gamma_1, prior_name, prior_params)
        
        integral_part1 = gaussian_1dmc(f_prime_part1, 0.0, eta2_1, n_samples = int(precision))
        integral_part2 = gaussian_1dmc(f_prime_part2, 0, eta2_1, n_samples = int(precision))
        integral_part3 = gaussian_1dmc(f_prime_part3, 0, eta2_1, n_samples = int(precision))

        print((integral_part1))

        return theta_1 * integral_part1 + theta_2 * integral_part2 + theta_3 * integral_part3
        

    # bernoulli/gaussian mix: beta_bar ~ bernoulli(theta) * Normal(0, tau^2)
    # -----------------------------------------------------------------------------------------
    if prior_name == "bernoulli_gaussian" :
        theta, tau2 = prior_params

        def f_prime_continuous(p_k, beta) :
            return f_derivative(p_k, beta, gamma_1, prior_name, prior_params)
        
        def f_prime_discrete(p_k) :
            return f_derivative(p_k, 0, gamma_1, prior_name, prior_params)
        
        integral_continuous = gaussian_2dmc(f_prime_continuous, 0, eta2_1, 0, tau2)
        integral_discrete = gaussian_1dmc(f_prime_discrete, 0, eta2_1)

        return theta * integral_continuous + (1 - theta) * integral_discrete
    

    # rademacher: beta_bar ~ Uniform{-1, 1}
    # -----------------------------------------------------------------------------------------
    if prior_name == "rademacher" :
        
        def f_prime_part1(p_k) :
            return f_derivative(p_k, -1, gamma_1, prior_name, prior_params)
        
        def f_prime_part2(p_k) :
            return f_derivative(p_k, 1, gamma_1, prior_name, prior_params)
        
        integral_part1 = gaussian_1dmc(f_prime_part1, 0, eta2_1)
        integral_part2 = gaussian_1dmc(f_prime_part2, 0, eta2_1)

        return 1/2 * (integral_part1 + integral_part2)


# compute E[f_k^2]
def compute_E_f2_bayes(prior, gamma_1, eta2_1, precision):
    prior_name, prior_params = prior
    
    # gaussian: beta_bar ~ N(0, tau^2)
    # -----------------------------------------------------------------------------------------
    if prior_name == "gaussian":
        tau2, _ = prior_params
        kappa = (tau2*gamma_1) / (1 + tau2*gamma_1)

        return kappa**2 * eta2_1 + (1-kappa)**2 * tau2
    
    # bernoulli/gaussian mix: beta_bar ~ bernoulli(theta) * Normal(0, tau^2)
    # -----------------------------------------------------------------------------------------
    if prior_name == "bernoulli_gaussian" :
        theta, tau2 = prior_params

        def f2_continuous(p_k, beta) :
            return f_condexp(p_k, beta, gamma_1, prior_name, prior_params) ** 2
        
        def f2_discrete(p_k) :
            return f_condexp(p_k, 0, gamma_1, prior_name, prior_params) ** 2
        
        integral_continuous = gaussian_2dmc(f2_continuous, 0, eta2_1, 0, tau2)
        integral_discrete = gaussian_1dmc(f2_discrete, 0, eta2_1)
        
        return theta * integral_continuous + (1 - theta) * integral_discrete


    # rademacher: beta_bar ~ Uniform{-1, 1}
    # -----------------------------------------------------------------------------------------
    if prior_name == "rademacher" :
        
        def f2_part1(p_k) :
            return f_condexp(p_k, -1, gamma_1, prior_name, prior_params) ** 2
        
        def f2_part2(p_k) :
            return f_condexp(p_k, 1, gamma_1, prior_name, prior_params)
        
        integral_part1 = gaussian_1dmc(f2_part1, 0, eta2_1)
        integral_part2 = gaussian_1dmc(f2_part2, 0, eta2_1)

        return 1/2 * (integral_part1 + integral_part2)
    
    
    # three point: beta_bar ~ {-1 : theta1, 0 : theta2, 1 : (1 - theta1 - theta2)}
    # -----------------------------------------------------------------------------------------
    if prior_name == "three_point" :
        theta_1, theta_2 = prior_params
        theta_3 = 1 - theta_1 - theta_2

        def f2_part1(p_k) :
            return f_condexp(p_k, -1, gamma_1, prior_name, prior_params) ** 2
        
        def f2_part2(p_k) :
            return f_condexp(p_k, 0, gamma_1, prior_name, prior_params) ** 2
        
        def f2_part3(p_k) :
            return f_condexp(p_k, 1, gamma_1, prior_name, prior_params) ** 2
        
        integral_part1 = gaussian_1dmc(f2_part1, 0.0, eta2_1, n_samples = int(precision))
        integral_part2 = gaussian_1dmc(f2_part2, 0.0, eta2_1, n_samples = int(precision))
        integral_part3 = gaussian_1dmc(f2_part3, 0.0, eta2_1, n_samples = int(precision))

        return theta_1 * integral_part1 + theta_2 * integral_part2 + theta_3 * integral_part3
    

# helper for generating spiked matrix
def haar_orthonormal_columns(n, m):
    G = np.random.normal(size = (n, m))
    Q, R = np.linalg.qr(G)
    
    signs = np.sign(np.diag(R))
    signs[signs == 0] = 1.0
    Q = Q * signs[np.newaxis, :]
    return Q

# helper for spiked matrix generation
def generate_spiked(n, p, m = 50, alpha = 10.0, seed = None):
    if m > min(n, p):
        raise ValueError("m must be <= min(n, p)")
    
    V = haar_orthonormal_columns(n, m) 
    W = haar_orthonormal_columns(p, m) 

    signal = alpha * (V @ W.T)               
    noise = (1.0/math.sqrt(n)) * np.random.normal(size = (n, p))

    X = signal + noise

    return X, V, W, signal, noise

def generate_rri_design(n, p, k = -1, method = "lnn") :
    if method == "lnn" :
        if k == -1 :
            k = p
            
        X_1 = np.random.normal(loc = 0, scale = 1, size = (n, k))
        X_2 = np.random.normal(loc = 0, scale = 1, size = (k, k))
        X_3 = np.random.normal(loc = 0, scale = 1, size = (k, k))
        X_4 = np.random.normal(loc = 0, scale = 1, size = (k, p))
        
        X = (1 / math.sqrt(n)) * X_1 @ X_2 @ X_3 @ X_4

    elif method == "heavy_tail" :
        df = 3
        mean = np.zeros(p)
        scale = np.eye(p)
        X = scipy.stats.multivariate_t.rvs(loc = mean, shape = scale, df = df, size = n)

    elif method == "spiked" :
        X, V, W, signal, noise = generate_spiked(n, p)
    
    else :
        print("invalid matrix type")

    return X


def sample_S(rri_type, n, p, num_samples = 10) :

    sample = np.empty()

    for _ in range(num_samples) :
        X = generate_rri_design(n, p, method = rri_type)

        U, Sigma, V = np.linalg.svd(X)

        rank = len(Sigma)

        if rank < n :
            s = np.hstack(Sigma, np.zeros(n - rank))
        else :
            s = Sigma[:n]

        sample = np.hstack([sample, s])

    return sample



def monte_carlo_S(function, sample, retrieve = False) :

    f_vals = function(sample)

    exp_estimate = f_vals.mean()

    mc_se = f_vals.std(ddof=1) / np.sqrt(f_vals.size)

    if retrieve :
        return exp_estimate, mc_se

    return exp_estimate


# MC method for 1-D gaussian, for discrete parts of beta
def gaussian_1dmc(h, mean, var, n_samples = 1000000, retrieve = False):
    sample_size = n_samples // 2
    
    Z_generated = np.random.normal(size = sample_size)
    Z_symmetric = np.concatenate([Z_generated, -Z_generated])
    
    X = mean + np.sqrt(max(var, 0.0)) * Z_symmetric

    f_vals = h(X)
    
    exp_estimate = f_vals.mean()
    
    mc_se = f_vals.std(ddof=1) / np.sqrt(f_vals.size)
    
    if retrieve:
        return exp_estimate, mc_se
    
    return exp_estimate



def compute_alpha(rri_type, n, p, sigma2, gamma_2bar) :

    ## use sigma^2 and gamma_2bar to determine sample size

    sample_count = 1e6

    num_samples = sample_count // n

    sample = sample_S(rri_type, n, p, num_samples)

    def function(S) :
        return sigma2 * gamma_2bar / (S**2 + sigma2 * gamma_2bar)
    
    alpha = monte_carlo_S(function, sample, retrieve = False)

    return alpha



def compute_omega_1(rri_type, n, p, sigma2, gamma_2bar) :

    ## use sigma^2 and gamma_2bar to determine sample size

    sample_count = 1e6

    num_samples = sample_count // n

    sample = sample_S(rri_type, n, p, num_samples)

    def function(S) :
        return (S / (S**2 + sigma2 * gamma_2bar)) ** 2
    
    omega_1 = monte_carlo_S(function, sample, retrieve = False)
    
    return omega_1



def compute_omega_2(rri_type, n, p, sigma2, gamma_2bar) :

    ## use sigma^2 and gamma_2bar to determine sample size

    sample_count = 1e6

    num_samples = sample_count // n

    sample = sample_S(rri_type, n, p, num_samples)

    def function(S) :
        return (sigma2 * gamma_2bar / (S**2 + sigma2 * gamma_2bar)) ** 2
    
    omega_2 = monte_carlo_S(function, sample, retrieve = False)
    
    return omega_2



type prior = tuple[str, tuple]
type design = tuple[str, tuple]

# allowed priors: normal, rademacher, 3-point, laplace, bernoulli*gaussian
# params need to be: (tau2, None), None, (p, q), (theta, None), (p, tau2)
def vamp_bayes_se(design : Sequence[design], prior_info : Sequence[prior], n, p, error_var : float, iter = 100, precision = 1e8, retrieve = False) :

    # keep the SE params in arrays
    b_bar = []
    kappa2_2 = []
    gamma_2bar = []
    c_bar = []
    kappa2_1 = []
    gamma_1bar = []

    # initializations
    sigma2 = error_var
    rri_type, svd = design              
    s, _ = svd # limiting empirical dist of singular values

    kappa2_1k = 0.1
    gamma_1bar_k = 1
    
    # -----------------------------------------------------------------------------------------
    
    for k in range(iter) :
        # append the initializations or updates to their respective storage arrays
        gamma_1bar.append(gamma_1bar_k)
        kappa2_1.append(kappa2_1k)

        # -----------------------------------------------------------------------------------------

        # updates part 1

        # update (b_bar)_k via function to compute the expectation of f'
        b_bar_k = compute_E_fprime_bayes(prior_info, gamma_1bar_k, kappa2_1k, precision)

        # compute expectation of f^2
        E_f2 = compute_E_f2_bayes(prior_info, gamma_1bar_k, kappa2_1k, precision)

        C2_1 = (1 / (1 - b_bar_k)) ** 2
        kappa2_2k = C2_1 * (E_f2 - b_bar_k**2 * kappa2_1k)

        # update (gamma_bar)_2 k
        gamma_2bar_k = gamma_1bar_k * (1/max(b_bar_k, 1e-16) - 1)


        ## here, use RRI type to sample, then compute alpha_k and omega_1k, omega_2k vals
        # update c_bar_k

        ## compute alpha_k
        alpha_k = compute_alpha(rri_type, n, p, sigma2, gamma_2bar_k)

        c_bar_k = alpha_k - c_bar_k * kappa2_2k

        # -----------------------------------------------------------------------------------------

        # interlude before updating the next round
        # append the new state evolution params to the vectors
        b_bar.append(b_bar_k)
        kappa2_1.append(kappa2_2k)
        gamma_2bar.append(gamma_2bar_k)
        c_bar.append(c_bar_k)
        
        # -----------------------------------------------------------------------------------------

        # updates part 2
        # now continue updating next round

        # update kappa^2_{1, k + 1}
        omega_1k = compute_omega_1(rri_type, n, p, sigma2, gamma_2bar_k)
        omega_2k = compute_omega_2(rri_type, n, p, sigma2, gamma_2bar)

        C2_2 = (1 / (1-c_bar_k)) ** 2

        kappa2_1k = C2_2 * (omega_1k * sigma2 + omega_2k * kappa2_2k - c_bar_k * kappa2_2k)
        
        # update (gamma_bar)_1 k+1
        gamma_1bar_k = gamma_2bar_k * (1/c_bar_k - 1)
    
    if retrieve :
        return gamma_1bar, gamma_2bar, kappa2_1, kappa2_2, b_bar, c_bar
    
    else :
        return gamma_1bar[-1], kappa2_1[-1]    