# MCMC for Bayesian Variable Selection

## Initialisation

In [None]:
import numpy as np
import random
import pandas as pd
# mport scipy
from scipy import special
from scipy import stats
import time
import math
import matplotlib.pyplot as plt
from tqdm import tqdm

## BVS_MCMC

In [None]:
class BVS_MCMC():
    def __init__ (self, X, y, g, 
                  h_exp_size = 5, h_type = 2, h = None, Z = None, 
                  model = "linear", prior_type = "ind", scale = True, ddof = 0):
        
        # computing n and p
        self.n, self.p = X.shape
        
        
        if isinstance(g, (int, float)) :
            self.g = g
            self.random_g = False
        elif g == "random":
            self.random_g = True
            self.log_prior_g = self.log_half_cauchy_wj
        
        
        self.model = model
        self.prior_type = prior_type
                
        
        
        if model == "linear":
            
            self.sample_now = self.linear_sample_now
            self.set_alg_par = self.linear_set_alg_par
            
            if Z is None:
                
                self.X = X - X.mean(axis = 0)
                self.y = y - y.mean()
            
                if scale:
                    self.X = self.X/self.X.std(axis = 0, ddof = ddof)
            
            else:
            
                Z = np.hstack((np.ones((self.n,1)), Z))
                
                self.y = (np.eye(self.n) - Z.dot(np.linalg.inv(Z.T.dot(Z))).dot(Z.T)).dot(y)
                self.X = (np.eye(self.n) - Z.dot(np.linalg.inv(Z.T.dot(Z))).dot(Z.T)).dot(X)
            
                if scale:
                    self.X = self.X/self.X.std(axis = 0, ddof = ddof)
                    
            self.yty = np.sum(self.y**2)
            self.ytX = y.T.dot(self.X)
            self.diag_XtX = np.sum(self.X**2, axis = 0)
            
            if prior_type == "g":
                self.log_llh = self.linear_log_llh_g
                self.compute_bf = self.linear_compute_bf_g
            elif prior_type == "ind":
                # if not self.random_g :
                #     self.diag_V = self.diag_XtX + 1/self.g
                self.log_llh = self.linear_log_llh_ind
                self.compute_bf = self.linear_compute_bf_ind
                
            # model prior with hyperparameter h
            # h is the highest parameter for setting the prior
            if h is not None:
                
                self.h = h
                
                if (type(h) == float):
                    self.h_exp = h # * self.p
                    self.h_odd = self.h_odd_1
                    self.log_m_prior = self.log_m_prior_1
                    self.h_til = self.h_til_1
            
                elif ((type(h) == list) & (len(h) == 2)):
                    self.h_exp = h[1]/(h[1] + h[2])
                    self.h_odd = self.h_odd_2
                    self.log_m_prior = self.log_m_prior_2
                    self.h_til = self.h_til_2
            
            # if h is not defined, using h_exp and h_type to indicate prior for model
            elif h is None:
                if h_type == 1:
                    self.h_exp = h_exp_size/self.p
                    self.h = self.h_exp
                    self.h_odd = self.h_odd_1
                    self.log_m_prior = self.log_m_prior_1
                    self.h_til = self.h_til_1
                    
                elif h_type == 2:
                    self.h_exp = h_exp_size/self.p
                    self.h = [1, (1-self.h_exp)/self.h_exp]
                    self.h_odd = self.h_odd_2
                    self.log_m_prior = self.log_m_prior_2
                    self.h_til = self.h_til_2
                
                
                
            
            
            
        
    # model prior (related functions)
    def h_odd_1 (self, I, *_):
        return(self.h/(1-self.h) if I else (1-self.h)/self.h)
        
    def log_m_prior_1 (self, gamma_par):
        gamma_par.log_mp = gamma_par.p_gam * (math.log(self.h) - math.log(1-self.h))
        
    def h_til_1 (self, gamma_par):
        gamma_par.h_til = self.h
        
            
            
    def h_odd_2 (self, I, gamma_par):
        return((gamma_par.p_gam+self.h[0])/(self.p-gamma_par.p_gam-1+self.h[1]) if I else (self.p-gamma_par.p_gam+self.h[1])/(gamma_par.p_gam-1+self.h[0]))
        
    def log_m_prior_2 (self, gamma_par):
        gamma_par.log_mp = special.betaln(gamma_par.p_gam+self.h[0], self.p-gamma_par.p_gam+self.h[1])
        
    def h_til_2 (self, gamma_par):
        h_til = (gamma_par.p_gam - gamma_par.gamma + self.h[0])/(self.p + self.h[0] + self.h[1] - 1)
        gamma_par.h_til = h_til
        
        
    
    
    # log-likelihood for linear model
    def linear_log_llh_g (self, gamma_par):


        if gamma_par.p_gam == 0:
            A = self.yty
            gamma_par.log_llh = - self.n * math.log(self.yty)/2
        else:
            X_gam = self.X[:,gamma_par.includes]
            V_gam = X_gam.T.dot(X_gam)
            gamma_par.inv_V_gam = np.linalg.inv(V_gam)

            ytX_gam = self.ytX[0, gamma_par.includes]

            ytXgFXgty = ytX_gam.dot(gamma_par.inv_V_gam).dot(ytX_gam.T) 

            A = self.yty - self.g/(1+self.g) * ytXgFXgty
            gamma_par.log_llh = - gamma_par.p_gam/2 * math.log(1+self.g) - self.n * math.log(A)/2
            
        


    def linear_log_llh_ind (self, gamma_par):
        

        if gamma_par.p_gam == 0:

            gamma_par.log_llh = - self.n * math.log(self.yty)/2

        else:
            X_gam = self.X[:,gamma_par.includes]
            V_gam = X_gam.T.dot(X_gam)
            V_gam[np.diag_indices(gamma_par.p_gam)] += 1/self.g
            gamma_par.inv_V_gam = np.linalg.inv(V_gam)
            
            ytX_gam = self.ytX[0, gamma_par.includes]
            ytXgFXgty = ytX_gam.dot(gamma_par.inv_V_gam).dot(ytX_gam.T)

            A = self.yty - ytXgFXgty

            sqrt_det_Vg = math.log(np.linalg.det(V_gam))/2 # (math.log(np.diag(L_Vg))).sum()

            gamma_par.log_llh = - gamma_par.p_gam/2 * math.log(self.g) - sqrt_det_Vg - self.n * math.log(A)/2
        
    
    class make_gamma_par():
        def __init__ (self, gamma):
        
            self.gamma = gamma
            self.p_gam = gamma.sum(dtype = int)
            self.includes = np.where(gamma)[0]
            self.excludes = np.where(gamma == 0)[0]
            self.log_llh = None
            self.log_mp = None
            self.inv_V_gam = None
            self.BF = None
            self.h_til = None
            # more variables
            
            
    def update_gamma_par(self, gamma_par):
        self.log_llh(gamma_par)
        self.log_m_prior(gamma_par)
        
    
    def update_llh_gamma_par(self, gamma_par):
        self.log_llh(gamma_par)
        
    
    def init_gamma_par (self):
        
        if self.gamma_init is None:
            gamma = np.random.random(self.p) < self.h_exp
        else:
            gamma = self.gamma_init
            
        gamma_par = self.make_gamma_par(gamma)
        
        self.update_gamma_par(gamma_par)
        
        return(gamma_par)               
    
    
    # with Jacobian for log-transformation
    def log_half_cauchy_wj(self, g):
        return(math.log(g) - math.log(1+g**2))
    
    
    def update_g(self, gamma_par):
        
        # temp_g = copy.deepcopy(self.g)
        temp_g = self.g
        
        # update g
        self.g *=  math.exp(math.exp(self.g_logsd[self.t]) * np.random.normal())
        
        # cases where g goes extreme
        if (self.g == 0) or (np.isinf(self.g)):
            g_acc_rate = 0
        else:
            
            # new_gamma_par = copy.deepcopy(gamma_par)  # **too slow**
            new_gamma_par = self.make_gamma_par(gamma_par.gamma.copy())
            new_gamma_par.log_mp = gamma_par.log_mp
            self.update_llh_gamma_par(new_gamma_par)
            
            g_acc_rate = min(1, math.exp(self.PT_temps[self.t] * (new_gamma_par.log_llh- gamma_par.log_llh) + self.log_prior_g(self.g)  - self.log_prior_g(temp_g)))
            # g_acc_rate = 1
            
        
        # accept or reject the new_g
        if np.random.random() < g_acc_rate:
            
            gamma_par = new_gamma_par
            
            # store acc_times after burn-in
            if self.epoch > self.N_burnin:
                self.g_acc_times[self.t] += 1
                
        else:
            self.g = temp_g
            
        
        # record g
        self.g_matrix[self.t, self.i] = self.g
        self.gs[self.t, self.i, self.epoch] = self.g
        
        # adapt g_logsd
        self.g_logsd[self.t] += self.epoch**self.g_adapt_phi * (g_acc_rate - self.g_target_acc)# /self.n_chain
        
        
        return(gamma_par)
    
        
    def linear_compute_bf_g(self, gamma_par):

        g_ratio = self.g/(self.g+1)
        inv_sqrtg1 = 1/math.sqrt(self.g+1)
        n_power = self.n/2


        if gamma_par.p_gam == 0:

            BF = (self.yty/(self.yty - g_ratio * self.ytX[0]**2 / self.diag_XtX))**n_power * inv_sqrtg1

        else:

            X_gam = self.X[:,gamma_par.includes]
            XgtX = X_gam.T.dot(self.X)

            ytXg = self.ytX[0, gamma_par.includes]
            ytXgFXgty = ytXg.dot(gamma_par.inv_V_gam).dot(ytXg.T) 

            A = self.yty - ytXgFXgty * g_ratio
            d_vec = 1 / (self.diag_XtX - np.einsum('ij,ji->i', XgtX.T.dot(gamma_par.inv_V_gam), XgtX))

            ytXFXtxj_vec = ytXg.dot(gamma_par.inv_V_gam).dot(XgtX) 

            tilda_A_vec = A - d_vec * (ytXFXtxj_vec - self.ytX)**2 * g_ratio

            BF = (A/tilda_A_vec.reshape(self.p))**n_power * inv_sqrtg1

            if gamma_par.p_gam == 1:

                BF[gamma_par.includes] = (self.yty/A)**n_power * inv_sqrtg1

            else:
                zj = -1
                for j in gamma_par.includes :

                    zj = zj + 1
                    A_ratio = 1 + (ytXg.dot(gamma_par.inv_V_gam[:,zj]))**2 * g_ratio / (A * gamma_par.inv_V_gam[zj,zj])
                    BF[j] = A_ratio**n_power * inv_sqrtg1

        gamma_par.BF = BF
            
    
    def linear_compute_bf_ind(self, gamma_par):

        inv_g = 1/self.g
        inv_sqrt_g = math.sqrt(inv_g)
        n_power = self.n/2

        diag_V = self.diag_XtX + 1/self.g

        if gamma_par.p_gam == 0:

            BF = (self.yty/(self.yty - self.ytX[0]**2 / diag_V))**n_power * np.sqrt(1/diag_V) * inv_sqrt_g 

        else:

            X_gam = self.X[:,gamma_par.includes]
            XgtX = X_gam.T.dot(self.X)

            ytXg = self.ytX[0, gamma_par.includes]
            ytXgFXgty = ytXg.dot(gamma_par.inv_V_gam).dot(ytXg.T) 

            A = self.yty - ytXgFXgty
            d_vec = 1 / (diag_V - np.einsum('ij,ji->i', XgtX.T.dot(gamma_par.inv_V_gam), XgtX))
            # the above is equivalent to
            # d_vec = 1 / (diag_V - (XgtX.T.dot(gamma_par.inv_V_gam).T * XgtX).sum(axis = 0))


            d_vec[gamma_par.includes] = 0

            ytXFXtxj_vec = ytXg.dot(gamma_par.inv_V_gam).dot(XgtX)
            tilda_A_vec = A - d_vec * (ytXFXtxj_vec - self.ytX)**2

            BF = np.sqrt(d_vec) * (A/tilda_A_vec.reshape(self.p))**n_power * inv_sqrt_g

            if gamma_par.p_gam == 1:
                BF[gamma_par.includes] = math.sqrt(gamma_par.inv_V_gam) * (self.yty/A)**n_power * inv_sqrt_g
            else:
                zj = -1
                for j in gamma_par.includes :

                    zj = zj + 1
                    A_ratio = (A + (ytXg.dot(gamma_par.inv_V_gam[:,zj]))**2/gamma_par.inv_V_gam[zj,zj])/A
                    BF[j] = math.sqrt(gamma_par.inv_V_gam[zj,zj]) * A_ratio**n_power * inv_sqrt_g


        gamma_par.BF = BF
    
    
    
    def update_temperatures (self):
        self.PT_temps = np.append(1, np.exp(-np.exp(self.PT_taus))).cumprod()
    
    
    def PT_swap_chain (self):
        # sample swap index
        swap_idx = np.random.randint(self.n_temp - 1)
        swap_acc_rate = math.exp((self.PT_temps[swap_idx] - self.PT_temps[swap_idx+1]) * (self.agg_gamma_par[swap_idx+1,self.i].log_llh - self.agg_gamma_par[swap_idx,self.i].log_llh))

        # accept or reject the swap
        if np.random.random() < swap_acc_rate:

            # swap gamma_par if accept
            temp_gamma_par = self.agg_gamma_par[swap_idx+1,self.i]
            self.agg_gamma_par[swap_idx+1,self.i] = self.agg_gamma_par[swap_idx,self.i]
            self.agg_gamma_par[swap_idx,self.i] = temp_gamma_par

            # store acc_times after burn-in
            if self.epoch > self.N_burnin:
                self.PT_acc_times += 1
    
    def PT_adapt_temperatures(self):
        self.PT_H *= self.PT_temps[:(self.n_temp-1)] - self.PT_temps[1:]
        self.PT_taus += self.epoch**self.PT_adapt_phi * (np.minimum(1, np.exp(self.PT_H)).mean(axis = 0) - self.PT_target_acc)
        self.update_temperatures()
        self.PT_temperatures[:, self.epoch] = self.PT_temps
    
    
    
    def linear_set_alg_par (self, sampler = "PARNI", N_iter = 500, N_burnin = 500, n_chain = 1, 
                            # parameters for adaptation
                            N_adapt_PIPs = None, N_rb = None, use_rb = True, kappa = 0.001, 
                            # parameters for PARNI
                            PARNI_omega_init = 0.5, PARNI_bal_fun = "hastings",
                            PARNI_omega_adapt = "RM", PARNI_adapt_phi = -0.7, PARNI_target_acc = 0.65,
                            # PARNI_omega_adapt = "KW", PARNI_adapt_phi = [-1, -0.5],
                            # parameters for ASI
                            ASI_adapt_phi = -0.7, ASI_target_acc = 0.234, ASI_zeta_init = 0.5,
                            # parameters for g
                            g_adapt_phi = -0.7, g_target_acc = 0.234, g_init = None,
                            # parameters for parallel-tempering 
                            n_temp = 1, PT_taus = None, PT_adapt_phi = -0.7, PT_target_acc = 0.234,
                            verbose = False, store_chains = False, gamma_init = None, f = None, *_):
        
        
        self.N_iter = N_iter
        self.N_burnin = N_burnin
        self.N_total = N_iter + N_burnin
        
        # need correcting
        self.use_rb = use_rb
        self.N_rb = N_burnin
        # self.N_adapt_PIPs = N_burnin
        
        
        self.n_chain = n_chain
        self.n_temp = n_temp
        
        if N_adapt_PIPs is None:
            self.N_adapt_PIPs = N_burnin
        
        
        
        self.sampler = sampler
        
        if sampler == "PARNI":
            if PARNI_bal_fun == "hastings":
                self.bal_fun = self.hastings
            elif PARNI_bal_fun == "barker":
                self.bal_fun = self.barker
            elif PARNI_bal_fun == "sqrt":
                self.bal_fun = math.sqrt
            else:
                self.bal_fun = PARNI_bal_fun
            
            self.propose = self.PARNI_propose
            self.adapt = True
            self.kappa = kappa
            self.logit_eps = 0.1/self.p
            
            if PARNI_omega_adapt == "RM":
                self.update_par = self.PARNI_update_par_rm
                self.adapt_phi = PARNI_adapt_phi
                self.target_acc = PARNI_target_acc
                self.PARNI_init_par_rm(PARNI_omega_init)
            
            else:
                self.update_par = self.PARNI_update_par_kw
                self.adapt_phi_a = PARNI_adapt_phi[0]
                self.adapt_phi_c = PARNI_adapt_phi[1]
                self.PARNI_init_par_kw(PARNI_omega_init)
                if n_chain == 1:
                    raise ValueError("Number of chains must be larger than 1 for Kiefer-Wolfowitz scheme.")
            
            
            
        elif sampler == "ASI":
            self.propose = self.ASI_propose
            self.update_par = self.ASI_update_par
            self.adapt = True
            self.kappa = kappa
            self.adapt_phi = ASI_adapt_phi
            self.target_acc = ASI_target_acc
            self.logit_eps = 0.1/self.p
            self.ASI_init_par(ASI_zeta_init)
            
            
        elif sampler == "ADS":
            self.propose = self.ADS_propose
            self.adapt = False
        
        
        if f is not None:
            self.eval_f = True
            self.f = f
            self.f_sum = np.zeros(self.f(np.zeros(self.p)).shape)
            self.fs = np.zeros((n_chain, self.N_total+1) + self.f_sum.shape)
        else:
            self.eval_f = False
            
        
        
        self.gamma_init = gamma_init
        
        self.verbose = verbose
        self.store_chains = store_chains
        
        
        # initialisations:
        # initialise the chain if self.store_chains is True
        if self.store_chains:
            self.chains = np.zeros(shape = (self.n_chain, self.N_total+1, self.p))
            
        # initialise vectors for log-posterior and model-size
        self.log_posts = np.zeros(shape = (self.n_temp, self.n_chain,  self.N_total+1))
        self.model_sizes = np.zeros(shape = (self.n_temp, self.n_chain,  self.N_total+1))
        
        # initialise acceptance times and ASJD
        self.acc_times = np.zeros(self.n_temp) 
        self.ASJD = np.zeros(self.n_temp) 
        
        
        # initialise PIPs
        self.estm_PIPs = np.zeros(shape = (self.n_temp, self.p))
        
        
        # initialisation for parallel tempering
        if (self.n_temp > 1):
            
            if PT_taus is not None:
                self.PT_taus = PT_taus
            else:
                self.PT_taus = -4 * np.ones(self.n_temp-1)
                
            # initialise temperature
            self.update_temperatures()
            self.PT_temperatures = np.zeros((self.n_temp, self.N_total+1))
            self.PT_temperatures[:,0] = self.PT_temps
            
            # initilise accemptance times for swapping
            self.PT_acc_times = 0
            
            # other quantities
            self.PT_target_acc = PT_target_acc
            self.PT_adapt_phi = PT_adapt_phi
            self.PT_H = np.zeros((self.n_chain, self.n_temp-1))
            
        else:
            self.PT_temps = np.ones(1)
        
        
        
        
        # ionitilise g
        if self.random_g:
            if isinstance(g_init, (int, float)):
                self.g_matrix = g_init * np.ones((self.n_temp, self.n_chain))
            else:
                self.g_matrix = stats.halfcauchy.rvs(size = (self.n_temp, self.n_chain))
            self.gs = np.zeros((self.n_temp, self.n_chain, self.N_total+1))
            self.g_logsd = np.ones(self.n_temp)
            self.g_acc_times = np.ones(self.n_temp)
            self.g_adapt_phi = g_adapt_phi
            self.g_target_acc = g_target_acc
            

        # initalise aggregate gamma_par (starting point)
        self.agg_gamma_par = {}
        
        for t in range(self.n_temp):
            
            # create a dictionary to store the current gamma
            # self.agg_gamma_par[t] = {}

            for i in range(self.n_chain):
                
                
                # assign g first
                if self.random_g:
                    self.g = self.g_matrix[t,i]
                    self.gs[t,i,0] = self.g
                
                self.agg_gamma_par[t,i] = self.init_gamma_par()
                
                
                if (t == 0) & self.store_chains:
                    self.chains[i,0] = self.agg_gamma_par[t,i].gamma
                
                if (t == 0) & self.eval_f:
                    self.fs[i,0] = self.f(self.agg_gamma_par[t,i].gamma)
                
                
                self.log_posts[t,i,0] =  self.agg_gamma_par[t,i].log_mp + self.PT_temps[t] * self.agg_gamma_par[t,i].log_llh
                self.model_sizes[t,i,0] = self.agg_gamma_par[t,i].p_gam
    
    
    
    
    def linear_sample_now(self):
        
        
        
        # time count 1
        time_count_1 = time.time()
        
        # initialise progress bar
        if self.verbose:
            pbar = tqdm(total = self.N_total, position=0, leave=True)
        
        for epoch in range(1,(self.N_total+1)):
            
            self.epoch = epoch
            
            if self.verbose:
                pbar.update(1)
                

            for i in range(self.n_chain):
                
                self.i = i
                    
                
                # swap between temperatures
                if self.n_temp > 1:
                    self.PT_swap_chain()
                

                for t in range(self.n_temp):
                    
                    self.t = t
                    
                    
                    # assign curr_gamma_par for current chain (i) and temperatures (t)
                    curr_gamma_par = self.agg_gamma_par[t,i]
                    
                    
                    if self.random_g:
                        
                        # assign g if random_g is True
                        self.g = self.g_matrix[t,i]
                        
                        # update g
                        self.agg_gamma_par[t,i] = curr_gamma_par = self.update_g(curr_gamma_par)
                        
                    
                    
                    prop_gamma_par, log_prop_odd, change = self.propose(curr_gamma_par)

                    curr_log_post = curr_gamma_par.log_mp + self.PT_temps[t] * curr_gamma_par.log_llh
                    prop_log_post = prop_gamma_par.log_mp + self.PT_temps[t] * prop_gamma_par.log_llh
                    
                    # accept probability
                    acc_rate = min(1, math.exp(prop_log_post - curr_log_post + log_prop_odd))
                    
                    # accept or reject the proposal
                    if np.random.random() < acc_rate:
                        # accecpt new model
                        curr_gamma_par = prop_gamma_par
                        self.agg_gamma_par[t,i] = prop_gamma_par
                        accepted = True
                    else:
                        accepted = False       
                    
                        
                        
                    if self.adapt:
                        self.update_par(curr_gamma_par, acc_rate, change)
                        # ...
                        
                        

                    # update
                    self.log_posts[t,i,epoch] = curr_gamma_par.log_mp + self.PT_temps[t] * curr_gamma_par.log_llh
                    self.model_sizes[t,i,epoch] = curr_gamma_par.p_gam

                    if epoch > self.N_burnin:
                        self.acc_times[t] += accepted
                        self.ASJD[t] += change * acc_rate
                        self.estm_PIPs[t] += curr_gamma_par.gamma

                    
                    
                    if t == 0:
                        if self.store_chains:
                            self.chains[i, epoch] = curr_gamma_par.gamma
                        if self.eval_f:
                            f_curr = self.f(curr_gamma_par.gamma)
                            self.fs[i,epoch] = f_curr
                            self.f_sum += f_curr
                    
                    
                    # update PT_H for adapting temperatures
                    if self.n_temp > 1:
                        if t == 0:
                            self.PT_H[i,t] = -curr_gamma_par.log_llh
                        elif t == (self.n_temp-1):
                            self.PT_H[i,t-1] += curr_gamma_par.log_llh
                        else:
                            self.PT_H[i,t] = -curr_gamma_par.log_llh
                            self.PT_H[i,t-1] += curr_gamma_par.log_llh
                        
                    
                    

            if self.n_temp > 1:
                
                # adapt temperatures
                self.PT_adapt_temperatures()
                
                
            
            # time count point 2 for burn-in period
            if epoch == (self.N_burnin):
                time_count_2 = time.time()


            
        time_count_3 = time.time()
        
        if self.verbose:
            pbar.close()
        
        # CPU times
        self.time_total = time_count_3 - time_count_1
        self.time_burnin = time_count_2 - time_count_1
        self.time_sample = time_count_3 - time_count_2
        
        # constant
        N_const = self.n_chain * self.N_iter
        # N_const_2 = self.n_chain * self.N_iter * self.n_temp
        
        self.ASJD /= N_const
        self.acc_times /= N_const
        self.estm_PIPs /= N_const
        
        if self.n_temp > 1:
            self.PT_acc_times/= N_const
        
        
        if self.eval_f:
            self.f_sum /= N_const
            
        if self.random_g:
            self.g_acc_times /= N_const
    
    
    
    
    
    def ADS_propose(self, gamma_par):
        
        if (0 < gamma_par.p_gam < self.p):
            
            if np.random.random() < 1/3:
                # swap
                del_index = gamma_par.includes[np.random.randint(gamma_par.p_gam)]
                add_index = gamma_par.excludes[np.random.randint(self.p - gamma_par.p_gam)]
                
                
                gamma_prop = gamma_par.gamma.copy()
                gamma_prop[del_index] -= 1
                gamma_prop[add_index] += 1
                
                log_prob_prop = 0
                log_prob_rev = 0
                
                change = 2
                
            elif np.random.random() < 1/2:
                # add
                add_index = gamma_par.excludes[np.random.randint(self.p - gamma_par.p_gam)] 
                
                gamma_prop = gamma_par.gamma.copy()
                gamma_prop[add_index] += 1
                
                log_prob_prop = -math.log(self.p - gamma_par.p_gam)
                log_prob_rev = -math.log(gamma_par.p_gam + 1)
                
                change = 1
                
            else:
                # delete
                del_index = gamma_par.includes[np.random.randint(gamma_par.p_gam)]
                
                gamma_prop = gamma_par.gamma.copy()
                gamma_prop[del_index] -= 1
                
                log_prob_prop = -math.log(gamma_par.p_gam)
                log_prob_rev = -math.log(self.p - gamma_par.p_gam + 1)
                
                change = 1
                
        elif gamma_par.p_gam == 0:
            # add
            add_index = gamma_par.excludes[np.random.randint(self.p - gamma_par.p_gam)] # 
            
            gamma_prop = gamma_par.gamma.copy()
            gamma_prop[add_index] += 1
                
            log_prob_prop = -math.log(self.p)
            log_prob_rev = 0
            
            change = 1
            
        elif gamma_par.p_gam == self.p:
            # delete
            del_index = gamma_par.includes[np.random.randint(gamma_par.p_gam)] 
            
            gamma_prop = gamma_par.gamma.copy()
            gamma_prop[del_index] -= 1
                
            log_prob_prop = -math.log(self.p)
            log_prob_rev = 0
            
            change = 1
        
        
        prop_gamma_par = self.make_gamma_par(gamma_prop)
        self.update_gamma_par(prop_gamma_par)
        
        return(prop_gamma_par, log_prob_rev - log_prob_prop, change)
    
    

    def AD_sample (self, new_sample, AD_prob, sample = None, prob_in_log = True):
        
        if new_sample:
            sample = np.where(np.random.random(self.p) < AD_prob)[0] 
    
        if prob_in_log:
            prob = np.log(AD_prob[sample]).sum()
        else:
            prob = AD_prob[sample].prob()
            
        return(sample, prob)
    
    
    
    
    
        
    def ASI_propose (self, gamma_par):
        
        AD_prob = (1-gamma_par.gamma) * self.A[self.t] + gamma_par.gamma * self.D[self.t]
        
        which_change, log_prob_prop = self.AD_sample(True, AD_prob, None, True)
        
        # propose new gamma
        gamma_prop = gamma_par.gamma.copy()
        gamma_prop[which_change] = 1 - gamma_prop[which_change]
        
        # AD_prob_prop = (1-gamma_prop) * self.A[self.t,:] + gamma_prop * self.D[self.t,:]
        # which_change, log_prob_rev = self.AD_sample(False, AD_prob_prop, which_change, True)
        AD_prob[which_change] = (1-gamma_prop[which_change]) * self.A[self.t,which_change] + gamma_prop[which_change] * self.D[self.t,which_change]
        which_change, log_prob_rev = self.AD_sample(False, AD_prob, which_change, True)
        
        prop_gamma_par = self.make_gamma_par(gamma_prop)
        self.update_gamma_par(prop_gamma_par)
        
        return(prop_gamma_par, log_prob_rev - log_prob_prop, which_change.size)
    
    
    
    def inv_logit_e (self, x) :
        if x < 0:
            y = (self.logit_eps + (1-self.logit_eps)*math.exp(x)) / (1 + math.exp(x))
        else:
            y = (self.logit_eps*math.exp(-x) + 1 - self.logit_eps) / (1 + math.exp(-x))

        return(y)
    
    def inv_logit_e_vec (self, x) :
        y = np.empty_like(x)
        y[x <= 0] = (self.logit_eps + (1-self.logit_eps)*np.exp(x[x <= 0])) / (1 + np.exp(x[x <= 0]))
        y[x > 0] = (self.logit_eps + (1-self.logit_eps)*np.exp(x[x > 0])) / (1 + np.exp(x[x > 0])) 
        return(y)
    
    
    def logit_e (self, y) :
        # math.log(y - self.logit_eps) - log(1 - y - self.logit_eps)
        return(math.log(y - self.logit_eps) - math.log(1 - y - self.logit_eps))
    
    


    
    def ASI_init_par (self, zeta_init):
        
        self.adapt_PIPs = np.zeros((self.n_temp,self.p))
        self.sum_adapt_PIPs = np.zeros((self.n_temp,self.p))
        
        self.A_til = np.ones((self.n_temp,self.p)) * (self.h_exp/(1-self.h_exp))
        self.D_til = np.ones((self.n_temp,self.p))
        
        self.A = self.A_til.copy()
        self.D = self.A_til.copy()
        
        self.temp_PIPs = np.zeros((self.n_temp,self.n_chain,self.p))
        self.temp_acc_rates = np.zeros((self.n_temp,self.n_chain))
        
        logitzeta_init = self.logit_e(zeta_init)
        self.logitzeta = np.ones(self.n_temp) * logitzeta_init
        self.zetas = np.zeros((self.n_temp,self.N_total+1))
        self.zetas[:,0] = zeta_init
    
    
    
    def ASI_update_par (self, gamma_par, acc_rate, JD):
        
        
            
        self.temp_acc_rates[self.t, self.i] = acc_rate
            
        if self.epoch <= self.N_adapt_PIPs:
            if self.epoch <= self.N_rb:
                if gamma_par.BF is None:
                    self.h_til(gamma_par)
                    self.compute_bf(gamma_par)
                    
                    
                    if self.t == 0:
                        self.temp_PIPs[self.t,self.i] = (gamma_par.h_til * gamma_par.BF)/(1 - gamma_par.h_til + gamma_par.h_til*gamma_par.BF)
                    else:
                        BF_temp = gamma_par.BF**self.PT_temps[self.t]
                        self.temp_PIPs[self.t,self.i] = (gamma_par.h_til * BF_temp)/(1 - gamma_par.h_til + gamma_par.h_til*BF_temp)
                        
            else:
                self.temp_PIPs[self.t, self.i] = gamma_par.gamma
        
        if self.i == (self.n_chain-1):
            
            if self.epoch <= self.N_adapt_PIPs:
                self.sum_adapt_PIPs[self.t] += self.temp_PIPs[self.t].sum(axis=0)
                self.adapt_PIPs[self.t] = self.kappa + (1-2*self.kappa)*self.sum_adapt_PIPs[self.t]/(self.epoch*self.n_chain)
                
                self.A_til[self.t] = np.minimum(1, self.adapt_PIPs[self.t]/(1-self.adapt_PIPs[self.t]))
                self.D_til[self.t] = np.minimum(1, (1-self.adapt_PIPs[self.t])/self.adapt_PIPs[self.t])
                
            self.logitzeta[self.t] += self.epoch**self.adapt_phi * (self.temp_acc_rates[self.t].mean() - self.target_acc)
            
            # need to check expected change
            zeta = self.inv_logit_e(self.logitzeta[self.t])
            self.zetas[self.t, self.epoch] = zeta
            
            self.A[self.t] = zeta * self.A_til[self.t]
            self.D[self.t] = zeta * self.D_til[self.t]
            
            
    def barker(self, x):
        return(x/(1+x))
    
    def hastings(self, x):
        return(min(1,x))
    
        
    def PARNI_propose (self, gamma_par):
        
        # update self.omega
        omega = self.omega[self.t, self.i]
        
        AD_prob = (1-gamma_par.gamma) * self.A[self.t] + gamma_par.gamma * self.D[self.t]
        
        k, k_prob = self.AD_sample(True, AD_prob, None, True)
        k = np.random.permutation(k)
        
        JD = 0
        
        prob_prop = 0
        rev_prob_prop = 0
        
        prod_bal_const = 0
        rev_prod_bal_const = 0
        
        temp_gamma_par = gamma_par
        log_post_curr = self.PT_temps[self.t] * temp_gamma_par.log_llh + temp_gamma_par.log_mp
        log_post_temp = log_post_curr
        
        
        for kj in k:
            
            temp_gamma_prop = temp_gamma_par.gamma.copy()
            temp_kj = temp_gamma_prop[kj]
            temp_gamma_prop[kj] = 1 - temp_kj
            
            temp_prop_gamma_par = self.make_gamma_par(temp_gamma_prop)
            self.update_gamma_par(temp_prop_gamma_par)
            
            
            log_post_temp_prop = self.PT_temps[self.t] * temp_prop_gamma_par.log_llh + temp_prop_gamma_par.log_mp
            post_prop_temp = math.exp(log_post_temp_prop - log_post_temp)
            
            
            mar_eff = self.adapt_PIPs[self.t, kj]
            odd_kj = (mar_eff/(1-mar_eff))**(2*temp_kj-1)
            
            change_prob = omega * self.bal_fun(post_prop_temp * odd_kj)
            keep_prob = (1-omega) * self.bal_fun(1)
            
            bal_const = change_prob + keep_prob
            change_prob /= bal_const
            keep_prob /= bal_const
            
            
            if np.random.random() < change_prob:
                
                
                rev_change_prob = omega * self.bal_fun(1 / post_prop_temp / odd_kj)
                rev_keep_prob = (1-omega) * self.bal_fun(1)
                
                rev_bal_const = rev_change_prob + rev_keep_prob
                rev_change_prob /= rev_bal_const
                rev_keep_prob /= rev_bal_const
                
                
                temp_gamma_par = temp_prop_gamma_par
                log_post_temp = log_post_temp_prop
                JD += 1
                
                
                prob_prop += math.log(change_prob)
                rev_prob_prop += math.log(rev_change_prob)
                
                
            else:
                rev_change_prob = change_prob
                rev_keep_prob = keep_prob
                rev_bal_const = bal_const
                
                prob_prop += math.log(keep_prob)
                rev_prob_prop += math.log(rev_keep_prob)
            
            
            prod_bal_const += math.log(bal_const)
            rev_prod_bal_const += math.log(rev_bal_const)
            
            
#         AD_prob = (1-temp_gamma_par.gamma) * self.A[self.t] + temp_gamma_par.gamma * self.D[self.t]
#         k, rev_k_prob = self.AD_sample(False, AD_prob, k, True)

#         print(JD, prod_bal_const - rev_prod_bal_const, 
#               log_post_temp - log_post_curr + rev_prob_prop - prob_prop + rev_k_prob - k_prob)
        
        # print(min(1,exp(prod_bal_const - rev_prod_bal_const)))
        
        return(temp_gamma_par, prod_bal_const - rev_prod_bal_const + log_post_curr - log_post_temp, JD)
        
        
        
    def PARNI_init_par_rm (self, omega_init):
        
        self.adapt_PIPs = np.ones((self.n_temp,self.p)) * self.h_exp 
        self.sum_adapt_PIPs = np.zeros((self.n_temp,self.p))
        
        self.A = np.ones((self.n_temp,self.p)) * (self.h_exp/(1-self.h_exp))
        self.D = np.ones((self.n_temp,self.p))
        
        self.temp_PIPs = np.zeros((self.n_temp,self.n_chain,self.p))
        self.temp_acc_rates = np.zeros((self.n_temp,self.n_chain))
        
        logitomega_init = self.logit_e(omega_init)
        self.logitomega = np.ones(self.n_temp) * logitomega_init
        
        self.omega = np.ones((self.n_temp, self.n_chain)) * omega_init
        self.omegas = np.zeros((self.n_temp,self.N_total+1))
        self.omegas[:,0] = omega_init
    
    
    
    def PARNI_update_par_rm (self, gamma_par, acc_rate, JD):
        
            
        self.temp_acc_rates[self.t, self.i] = acc_rate
            
        if self.epoch <= self.N_adapt_PIPs:
            if self.epoch <= self.N_rb:
                if gamma_par.BF is None:
                    self.h_til(gamma_par)
                    self.compute_bf(gamma_par)
                    
                    
                    if self.t == 0:
                        self.temp_PIPs[self.t,self.i] = (gamma_par.h_til * gamma_par.BF)/(1 - gamma_par.h_til + gamma_par.h_til*gamma_par.BF)
                    else:
                        BF_temp = gamma_par.BF**self.PT_temps[self.t]
                        self.temp_PIPs[self.t,self.i] = (gamma_par.h_til * BF_temp)/(1 - gamma_par.h_til + gamma_par.h_til*BF_temp)
                        
            else:
                self.temp_PIPs[self.t, self.i] = gamma_par.gamma
        
        if self.i == (self.n_chain-1):
            
            if self.epoch <= self.N_adapt_PIPs:
                self.sum_adapt_PIPs[self.t] += self.temp_PIPs[self.t].sum(axis=0)
                self.adapt_PIPs[self.t] = self.kappa + (1-2*self.kappa)*self.sum_adapt_PIPs[self.t]/(self.epoch*self.n_chain)
                
                self.A[self.t] = np.minimum(1, self.adapt_PIPs[self.t]/(1-self.adapt_PIPs[self.t]))
                self.D[self.t] = np.minimum(1, (1-self.adapt_PIPs[self.t])/self.adapt_PIPs[self.t])
                
            self.logitomega[self.t] += self.epoch**self.adapt_phi * (self.temp_acc_rates[self.t].mean() - self.target_acc)
            
            # need to check expected change
            omega = self.inv_logit_e(self.logitomega[self.t])
            self.omegas[self.t, self.epoch] = omega
            self.omega[self.t,:] = omega

            
    def PARNI_init_par_kw (self, omega_init):
        
        self.adapt_PIPs = np.ones((self.n_temp,self.p)) * self.h_exp
        self.sum_adapt_PIPs = np.zeros((self.n_temp,self.p))
        
        self.A = np.ones((self.n_temp,self.p)) * (self.h_exp/(1-self.h_exp))
        self.D = np.ones((self.n_temp,self.p))
        
        self.temp_PIPs = np.zeros((self.n_temp,self.n_chain,self.p))
        # self.temp_acc_rates = np.zeros((self.n_temp,self.n_chain))
        
        # phi_c_epoch = 1
        logitomega_init = self.logit_e(omega_init)
        self.logitomega = np.ones(self.n_temp) * logitomega_init
        omega = self.inv_logit_e_vec(logitomega_init + np.array([0, -1, 1]))
        
        
        self.n_pos = math.floor(self.n_chain/2)
        # self.n_neg = self.n_chain - self.n_pos
        # pos_idx = np.zeros(self.n_chain, dtype = bool)
        pos_idx = np.zeros(self.n_chain, dtype = bool)
        pos_idx[0:self.n_pos] = True
        
        self.pos_idx = np.random.permutation(pos_idx)
        
        
        self.omega = np.ones((self.n_temp, self.n_chain))
        self.omega[:,self.pos_idx] = omega[2]
        self.omega[:,~self.pos_idx] = omega[1]
        
        self.omegas = np.zeros((self.n_temp,self.N_total+1))
        self.omegas[:,0] = omega[0]
        
        self.ASJD_temp = np.zeros((self.n_temp, self.n_chain))
        
    
    
    def PARNI_update_par_kw (self, gamma_par, acc_rate, JD):
            
        # self.temp_acc_rates[self.t, self.i] = acc_rate
        self.ASJD_temp[self.t, self.i] = JD*acc_rate
        
        if self.epoch <= self.N_adapt_PIPs:
            if self.epoch <= self.N_rb:
                if gamma_par.BF is None:
                    self.h_til(gamma_par)
                    self.compute_bf(gamma_par)
                    
                    
                    if self.t == 0:
                        self.temp_PIPs[self.t,self.i] = (gamma_par.h_til * gamma_par.BF)/(1 - gamma_par.h_til + gamma_par.h_til*gamma_par.BF)
                    else:
                        BF_temp = gamma_par.BF**self.PT_temps[self.t]
                        self.temp_PIPs[self.t,self.i] = (gamma_par.h_til * BF_temp)/(1 - gamma_par.h_til + gamma_par.h_til*BF_temp)
                        
            else:
                self.temp_PIPs[self.t, self.i] = gamma_par.gamma
        
        if self.i == (self.n_chain-1):
            
            if self.epoch <= self.N_adapt_PIPs:
                self.sum_adapt_PIPs[self.t] += self.temp_PIPs[self.t].sum(axis=0)
                self.adapt_PIPs[self.t] = self.kappa + (1-2*self.kappa)/(self.epoch*self.n_chain)*self.sum_adapt_PIPs[self.t]
                
                self.A[self.t] = np.minimum(1, self.adapt_PIPs[self.t]/(1-self.adapt_PIPs[self.t]))
                self.D[self.t] = np.minimum(1, (1-self.adapt_PIPs[self.t])/self.adapt_PIPs[self.t])
            
            
            self.logitomega[self.t] += self.epoch**self.adapt_phi_a * (self.ASJD_temp[self.t,self.pos_idx].mean()-self.ASJD_temp[self.t,~self.pos_idx].mean())/(2*self.epoch**self.adapt_phi_c)
            omega = self.inv_logit_e_vec(self.logitomega[self.t] + np.array([0,-(self.epoch+1)**self.adapt_phi_c,(self.epoch+1)**self.adapt_phi_c]))
            self.omegas[self.t,self.epoch] = omega[0]
            
            if self.t == 0:
                self.new_pos_idx = np.random.permutation(self.pos_idx)
            
            self.omega[self.t, self.new_pos_idx] = omega[2]
            self.omega[self.t, ~self.new_pos_idx] = omega[1]
            
            if self.t == (self.n_temp-1):
                self.pos_idx = self.new_pos_idx
            
            
    # result visulisation
    def plot_temperatures (self):
        for t in range(self.n_temp):
            plt.plot(self.PT_temperatures[t])
            plt.ylim(0,1.1)
        plt.show()
        
        
    def plot_pips (self):
        fig, axs = plt.subplots(self.n_temp, figsize=(7, 4*self.n_temp))
        for t in range(self.n_temp):
            axs[t].bar(range(self.p), self.estm_PIPs[t])
            axs[t].set_ylim(0,1)
        fig.tight_layout()

In [None]:
fit = BVS_MCMC(X = X, y = y, Z = None, model = "linear", ddof = 1,
               g = 100, prior_type = "ind",
               h_exp_size = 5, h_type = 1, h = None, scale = True)

In [None]:
fit.set_alg_par (sampler = "PARNI", N_iter = 2000, N_burnin = 1000, 
                 PARNI_bal_fun = "hastings", 
                 PARNI_omega_adapt = "KW", PARNI_adapt_phi = [-1, -0.5],
                 ASI_zeta_init = 0.5,
                 n_chain = 25, n_temp = 5, verbose = True)

In [None]:
fit.sample_now()
print(fit.time_total)
print(fit.time_burnin)
print(fit.time_sample)

## Simulated datset generation 

In [None]:
def lrrsg(n, p, a = 0, beta0 = np.array([2,-3,2,2,-3,3,-2,3,-2,3]), 
          rho = 0, SNR = 1, sigma2 = 1, seed = None):
    
    
    if seed is not None:
        np.random.seed(seed)
    
    
    beta = np.zeros((p,1))
    b_n = beta0.size
    
    beta0 = beta0.reshape(b_n,1)
    beta[range(b_n),] = beta0
    
    beta = SNR * np.sqrt(sigma2 * np.log(p)/n) * beta
  
    b = np.sqrt(1-rho**2)
    
    X = np.random.normal(size = (n,p))
    
    for j in range(1,p):
        X[:,j] = rho * X[:,j-1] + b * X[:,j]
  

    # responses
    y = a + X.dot(beta) + np.random.normal(size = (n,1))
    
    return(beta, y, X)


In [None]:
beta,y,X = lrrsg(n = 10000, p = 5, beta0 = np.array([2,-3,2,2,-3]), rho = 0)