In the model 
$$y\sim N(X\beta,\sigma^2I)$$
$$\beta|\gamma\sim N(0,D_\gamma)$$
where $$D_\gamma = diag((1-\gamma_j)\sigma_0^2 + \gamma_j\sigma^2_1)$$
$$p(\gamma_j = 1) = \pi_1$$
$$\sigma^2\sim IG(a,b)$$

The SSVS iteratively samples:

1. 
$$\beta\sim N(\frac{A}{\sigma^2} X^T y, A)$$ 
where 
$$A = (X^TX/\sigma^2 + D_\gamma^{-1})^{-1}$$

2. $$\sigma^2\sim IG(n/2 + a, ||y-X\beta||^2/2 + b)$$

3. $$p(\gamma_j = 1) = \frac{p_1}{p_0+p_1}$$
 where 

 $$p_1 = \pi_1 p(\beta|\gamma_j=1,\gamma_{-j})$$
 $$p_0 = \pi_0 p(\beta|\gamma_j=0,\gamma_{-j})$$



In [127]:
import numpy as np
from scipy.stats import invgamma  
from scipy.stats import multivariate_normal
from scipy.stats import bernoulli
import pandas as pd
def ssvs(X,y,pi0,var0,var1,ig_a=0.01,ig_b=0.01,n_burnin=500,n_post=10000,printevery = 10):
    n,p = X.shape
    XtX = np.dot(X.T,X)
    Xty = np.dot(X.T,y)
    beta_draws = np.zeros((n_burnin + n_post,p))
    gamma_draws = np.zeros((n_burnin + n_post,p))
    sigma2_draws = np.zeros(n_burnin + n_post)
    gamma = np.zeros(p)
    y = np.squeeze(y)
    sigma2 = np.var(y)
    
    for i in range(n_burnin + n_post):
        if(i%printevery==0):
            print(' '.join(['drawing sample',str(i)]))
        # sample beta
        d_inv = 1/(gamma*var1 + (1-gamma)*var0)
        A = np.linalg.inv(XtX/sigma2 + np.diag(d_inv))
        beta = np.random.multivariate_normal(np.squeeze(np.matmul(A/sigma2,Xty)),A)
        sigma2 = invgamma.rvs(1,n/2+ig_a,sum((y-np.matmul(X,beta))**2)/2+ig_b)
        for j in range(p):
            d1 = (gamma*var1 + (1-gamma)*var0)
            d1[j] = var1
            d0 = (gamma*var1 + (1-gamma)*var0)
            d0[j] = var0
            p1 = (1-pi0) * multivariate_normal.pdf(beta, mean=np.zeros(p), cov=np.diag(d1))
            p0 = pi0 * multivariate_normal.pdf(beta, mean=np.zeros(p), cov=np.diag(d0))
            gamma[j] = bernoulli.rvs(p1/(p1+p0), size=1)
        beta_draws[i,:] = beta
        gamma_draws[i,:] = gamma
        sigma2_draws[i] = sigma2
    return {'beta':beta_draws[n_burnin:,:],'gamma':gamma_draws[n_burnin:,:],'sigma2':sigma2_draws[n_burnin:]}
        


In [128]:
def gen_data(n,p,pi0,var0,var1,sigma2,XI = True):
    beta0 = np.random.normal(0,np.sqrt(var0),p)
    beta1 = np.random.normal(0,np.sqrt(var1),p)
    gamma = bernoulli.rvs(1-pi0, size=p)
    beta = beta1 * gamma + (1-gamma) * beta0
    if XI:
        y = beta + np.random.normal(0,np.sqrt(sigma2),p)
        return {'beta':beta,'y':y,'X':np.diag(np.ones(p)),'gamma':gamma}
    else:
        X = np.random.normal(size=(n,p))
        beta = np.reshape(beta,(p,1))
        y = np.matmul(X,beta) + np.reshape(np.random.normal(0,np.sqrt(sigma2),n),(n,1))
        return {'beta':beta,'y':y,'X':X,'gamma':gamma}

In [131]:
datax = gen_data(100,10,0.5,0.01,5,1,XI=False)

array([-0.02022971, -0.05262649,  2.16381827,  0.06471874,  1.45040414,
       -0.23270781, -0.09633481,  0.26639295,  4.5215936 ,  0.03435869])

In [119]:
fit = ssvs(datax['X'],datax['y'],0.5,0.01,5,
ig_a=0.01,ig_b=0.01,n_burnin=1000,
n_post=2000,
printevery = 100)

drawing sample 0
drawing sample 100
drawing sample 200
drawing sample 300
drawing sample 400
drawing sample 500
drawing sample 600
drawing sample 700
drawing sample 800
drawing sample 900
drawing sample 1000
drawing sample 1100
drawing sample 1200
drawing sample 1300
drawing sample 1400
drawing sample 1500
drawing sample 1600
drawing sample 1700
drawing sample 1800
drawing sample 1900
drawing sample 2000
drawing sample 2100
drawing sample 2200
drawing sample 2300
drawing sample 2400
drawing sample 2500
drawing sample 2600
drawing sample 2700
drawing sample 2800
drawing sample 2900


In [132]:
upper_ci = np.apply_along_axis(np.quantile,0,fit,q=0.975)
lower_ci = np.apply_along_axis(np.quantile,0,fit,q=0.025)
pm = np.apply_along_axis(np.mean,0,fit)
res = pd.DataFrame({'true_beta':np.squeeze(datax['beta']),
'ols': np.squeeze(np.matmul(np.linalg.inv(np.matmul(datax['X'].T,datax['X'])),np.matmul(datax['X'].T,datax['y']))),
'posterior_mean':pm,
'lower':lower_ci,
'upper':upper_ci})
res

Unnamed: 0,true_beta,ols,posterior_mean,lower,upper
0,0.134602,-0.02023,-0.132063,-3.702432,3.225506
1,-0.055601,-0.052626,-0.05741,-3.195973,2.924375
2,2.150908,2.163818,0.027142,-3.261504,3.228945
3,0.010992,0.064719,-0.010777,-3.420915,3.189246
4,1.424295,1.450404,0.011625,-3.058356,3.155593
5,-0.074215,-0.232708,0.259813,-2.623448,3.757211
6,0.054716,-0.096335,0.066065,-3.094691,3.313081
7,0.175528,0.266393,-0.061471,-3.25349,3.183751
8,4.443326,4.521594,-0.55858,-4.44558,2.868801
9,-0.090402,0.034359,0.008629,-3.363321,3.311434
