In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats
import ezmc


# Minimal Example

## Linear Regression

In [None]:
intercept = 1.
slope = .5
noise = .2
n = 100
true_parameters = [intercept, slope, noise]

x = np.random.normal(0, 1, n)
y = np.random.normal( intercept + slope *x, noise, n)
plt.scatter(x, y)

In [None]:
def get_prior_density(pars):
    intercept, slope, noise = pars
    dens = (stats.norm.logpdf(loc=0, scale=2, x=intercept) + 
            stats.norm.logpdf(loc=0, scale=2, x=slope) + 
            stats.norm.logpdf(loc=0, scale=2, x=noise))
    return dens
    
def get_likelihood(pars, x, y):
    intercept, slope, noise = pars
    ll = stats.norm.logpdf(loc = intercept + slope * x, scale=noise, x=y)
    return np.sum(ll)

def get_posterior_density(pars, x, y):
    intercept, slope, noise = pars
    prior = get_prior_density(pars)
    lik = get_likelihood(pars, x, y)
    if noise <= 0:
        return prior - 1e+10
    else:
        return prior + lik

def f(pars):
    ll = get_posterior_density(pars, x, y)
    return ll

def init():
    r = np.random.normal(0, 10, 3)
    while r[2] < 0:
        r = np.random.normal(0, 10, 3)
    return r

In [None]:
sampler = ezmc.MetropolisSampler(func=f, par_names=['intercept', 'slope', 'noise'],
                                proposal_sd=.05, noisy=False,
                                init_func=init)
sampler.sample_chains(10000)

In [None]:
chains = sampler.get_chains()
ezmc.viz.traceplot(chains)
plt.show()

In [None]:
results = sampler.get_results(burn_in=4000, thin=10)
ezmc.viz.traceplot(results, pars=sampler.par_names);

In [None]:
import arviz as az

In [None]:
def to_arviz(sampler, burn_in=0, thin=1):
    import arviz as az
    samples = sampler.get_results(burn_in=burn_in, thin=thin)
    nchains = len(set(samples['chain']))
    nsteps = len(set(samples['iter']))
    npars = len(sampler.par_names)
    par_dict = {}
    for k in sampler.par_names:
        X = samples.pivot_table(values=k, columns='iter', index='chain').values
        par_dict[k] = X
    posterior = az.dict_to_dataset(par_dict)
    return posterior

posterior = sampler.to_arviz(burn_in=3000, thin=10)

In [None]:
az.plot_trace(posterior)

In [None]:
az.plot_pair(posterior, kind='kde')

In [None]:
az.plot_pair(posterior, kind='scatter', plot_kwargs=dict(alpha=.1))

In [None]:
fig, axes = az.plot_forest(posterior, combined=True, figsize=(5, 3))
axes[0].vlines(0, *plt.ylim(), linestyle='dashed')