
# Linear regression example using Spike-Slab Prior

The model equation is y = ax + b with a, b being the model parameters, while the
likelihood model is based on a normal zero-mean additive model error distribution. The problem is solved via sampling using emcee.


In [1]:
import numpy as np
import matplotlib.pyplot as plt
import emcee
from scipy.stats import invgamma, halfcauchy, norm, bernoulli, uniform

In [2]:
def log_prior(theta):
    if theta[0] < 0.0 or theta[0] > 1.0:
        return -np.Inf
    return bernoulli.logpmf(0 if theta[0] < 0.5 else 1, p = 0.2)

""" def log_prior(theta):
    if theta[0] < 0.0 or theta[0] > 1.0:
        return -np.Inf
    elif theta[0] < 0.5:
        theta[0] = 0
        return bernoulli.logpmf(0, p = 0.2)
    elif theta[0] >= 0.5:
        theta[0] = 1
        return bernoulli.logpmf(1, p = 0.2) """
            
nwalkers = 2
ndim = 1
start_parameters = np.random.rand(nwalkers, ndim)
sampler = emcee.EnsembleSampler(
    nwalkers, ndim, log_prior)
sampler.run_mcmc(start_parameters, 5000, progress=True)

100%|██████████| 5000/5000 [00:01<00:00, 3146.26it/s]


State([[0.37813003]
 [0.17072287]], log_prob=[-0.22314355 -0.22314355], blobs=None, random_state=('MT19937', array([3091799997, 3480733788, 4005924982, 3998446630, 3841600198,
       3598582375, 1038229332, 1618922272, 1626168388, 4140504290,
       1059444225, 4211787647, 1920654897,  287462340,  647522092,
       2396386857, 2752171419, 3177510445, 2381707856,  397009677,
        607146868, 3155131131, 1248244075, 2650148511, 2219941060,
       2470686875,  179698518, 2251611515, 4070725569, 4266671140,
       3178590789, 2922908390, 2759954758,  457057901, 1720119931,
       2577904933, 1508485904, 3884453263,  503080566, 3280880856,
        847462329, 2869160164, 1442270807, 1792186315, 4003548826,
       1744661497, 3888277147, 3650541342, 3606917584, 3725891114,
        305569795, 2502566957,  280275641,  242384322,  697915293,
       1152149353,  149216035, 4029416359, 2611215133, 2058101577,
       3707553240,  574008936, 4004305488, 4086811151, 3838441415,
       1381671008,  

Plotting the trace of the chain

In [3]:
samples = sampler.get_chain(discard=1000, thin=15, flat=True)
sum(samples > 0.5)/len(samples)

array([0.17105263])

In [None]:
fig, axes = plt.subplots(ndim, figsize=(10, 7), sharex=True)
samples = sampler.get_chain()
labels = ["lmbda_a"] # Change the labels over here if changes in parameters are made in json file.
for i in range(ndim):
    ax = axes
    ax.plot(samples[:, :, i], "k", alpha=0.3)
    ax.set_xlim(0, len(samples))
    ax.set_ylabel(labels[i])
    ax.yaxis.set_label_coords(-0.1, 0.5)

axes.set_xlabel("step number")

Plotting the pair-plot of the chain.

In [None]:
flat_samples = sampler.get_chain(discard=500, thin=15, flat=True)
print(flat_samples.shape)
import corner
labels = ["lmbda_a"]  # Change the labels over here if changes in parameters are made in json file.
fig = corner.corner(
    flat_samples, labels=labels, truths=[1]
)