
# Linear regression example using Spike-Slab Prior

The model equation is y = ax + b with a, b being the model parameters, while the
likelihood model is based on a normal zero-mean additive model error distribution. The problem is solved via sampling using emcee.


In [1]:
import numpy as np
import matplotlib.pyplot as plt
import emcee
from scipy.stats import invgamma, halfcauchy, norm, bernoulli, uniform

In [2]:
def log_prior(theta):
    if theta[0] < 0.0 or theta[0] > 1.0:
        return -np.Inf
    return bernoulli.logpmf(0 if theta[0] < 0.5 else 1, p = 0.2)

""" def log_prior(theta):
    if theta[0] < 0.0 or theta[0] > 1.0:
        return -np.Inf
    elif theta[0] < 0.5:
        theta[0] = 0
        return bernoulli.logpmf(0, p = 0.2)
    elif theta[0] >= 0.5:
        theta[0] = 1
        return bernoulli.logpmf(1, p = 0.2) """
            
nwalkers = 2
ndim = 1
start_parameters = np.random.rand(nwalkers, ndim)
sampler = emcee.EnsembleSampler(
    nwalkers, ndim, log_prior)
sampler.run_mcmc(start_parameters, 5000, progress=True)

  0%|          | 0/5000 [00:00<?, ?it/s]

100%|██████████| 5000/5000 [00:01<00:00, 3137.46it/s]


State([[0.25580334]
 [0.79779069]], log_prob=[-0.22314355 -1.60943791], blobs=None, random_state=('MT19937', array([ 511415355, 2976889677,  691957307, 2562479105, 1029262931,
       2078597897, 4171473786,  133695670, 1457771619, 1152820194,
       1206562975, 2457282709,   39838820,  171601938, 3665507867,
        895036996, 2411484048, 3615228695,   40791962, 2889811943,
       2343972694, 1068386204, 2880080642,  753670122, 3129724787,
        650397100,  289981738,  333740544, 3796513899, 1320463754,
       2298522561,  278228596, 3689952013, 3360586661, 1070688814,
       2014230620, 3685019848, 2828668468, 1039753614, 3806630384,
       1320851311, 1619371235, 3864147487, 1473589626, 3326832103,
        431328244, 1786498715,  309233622,  567381042, 1410276307,
        876133221, 1366744474, 3428543899, 2812428033, 3718393368,
       3018675732, 3200479158, 4175083794, 3917272861, 3621382036,
       1149812343, 2729342129, 3052457255, 4201390690, 3373871837,
       3660256001, 1

Plotting the trace of the chain

In [3]:
samples = sampler.get_chain(discard=1000, thin=15, flat=True)
sum(samples > 0.5)/len(samples)

array([0.2424812])

In [None]:
fig, axes = plt.subplots(ndim, figsize=(10, 7), sharex=True)
samples = sampler.get_chain()
labels = ["lmbda_a"] # Change the labels over here if changes in parameters are made in json file.
for i in range(ndim):
    ax = axes
    ax.plot(samples[:, :, i], "k", alpha=0.3)
    ax.set_xlim(0, len(samples))
    ax.set_ylabel(labels[i])
    ax.yaxis.set_label_coords(-0.1, 0.5)

axes.set_xlabel("step number")

Plotting the pair-plot of the chain.

In [None]:
flat_samples = sampler.get_chain(discard=500, thin=15, flat=True)
print(flat_samples.shape)
import corner
labels = ["lmbda_a"]  # Change the labels over here if changes in parameters are made in json file.
fig = corner.corner(
    flat_samples, labels=labels, truths=[1]
)