# HMC for a Simple Dirichlet Process Model via Stan

In [6]:
import stan as pystan
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

We will mimic Dirichlet process based Gaussian mixture model using Stan. Since Stan doesn’t provide the Dirichlet Process prior, we mimic it via finite mixture model. The stick-breaking process is achived inside of the stan code.

The data $y$ is a mixture of $y_1$, $y_2$ and $y_3$, where $y_{1}\sim\mathcal{N}(-3,0.5^2)$, $y_{2}\sim\mathcal{N}(0,0.75^2)$, and $y_{3}\sim\mathcal{N}(3,1^2)$, and the mixing rate is $\pi=(0.1, 0.5, 0.4)$. 

![title](data_plot.png)

# Write Model in Stan

A Stan model requires at least three blocks, for each of data, parameters, and the model. The data block specifies the types and dimensions of the data that will be used for sampling, and the parameter block specifies the relevant parameters. The distribution statement goes in the model block.

If no prior is defined, Stan uses default priors with the specifications uniform(-infinity, +infinity). You can restrict priors using upper or lower when declaring the parameters (i.e. lower = 0> to make sure a parameter is positive).

In [7]:
# simple DP example, via Truncated Stick-breaking process

model = """
data{
  int<lower=0> C;//maximum num of cludter
  int<lower=0> N;//data num
  real y[N];
}

parameters {
  real mu_cl[C]; //cluster mean
  real <lower=0,upper=1> v[C];
  real<lower=0> sigma_cl[C]; // error scale
  real<lower=0> alpha; // hyper prior DP(alpha,base)
}

transformed parameters{
  simplex [C] pi;
  pi[1] = v[1];

  for(j in 2:(C-1)){
      pi[j]= v[j]*(1-v[j-1])*pi[j-1]/v[j-1]; 
  }
  pi[C]=1-sum(pi[1:(C-1)]); // to make a simplex.
}

model {
  real a = 1.0;
  real b = 1.0;
  real ps[C];
  sigma_cl ~ inv_gamma(a,b);
  mu_cl ~ normal(0,10);
  alpha ~ gamma(6,1);
  v ~ beta(1,alpha);
  
  for(i in 1:N){
    for(c in 1:C){
      ps[c]=log(pi[c])+normal_lpdf(y[i]|mu_cl[c],sigma_cl[c]);
    }
    target += log_sum_exp(ps);
  }

}
"""

# HMC Sampling

In [8]:
# Read in data
data = np.loadtxt('dat.dat',delimiter=' ',skiprows=1)

C = 10 # truncation point for stick-breaking process
y = data[:,1]
print(len(y))

# Put data in a dictionary
stanData = {'C': C, 'N': len(y), 'y': y}

5000


In [9]:
# Compile the model
sm = pystan.StanModel(model_code=model)

# Train the model and generate samples
fit = sm.sampling(data=stanData, iter=1000, chains=1, warmup=500, thin=1, seed=101)

AttributeError: module 'stan' has no attribute 'StanModel'

# Results

In [None]:
print(fit)

Converged lp__ allows greater confidence that the whole sampling process has converged, but the value itself isn't particularly important.

n_eff is the effective sample size, which because of correlation between samples, can be significantly lower than the nominal amount of samples generated. The effect of autocorrelation can be mitigated by thinning the Markov chains.

Rhat is the Gelman-Rubin convergence statistic, a measure of Markov chain convergence, and corresponds to the scale factor of variance reduction that could be observed if sampling were allowed to continue forever. So if Rhat is approximately 1, you would expect to see no decrease in sampling variance regardless of how long you continue to iterate, and so the Markov chain is likely (but not guaranteed) to have converged.

In [None]:
# Extracting traces
pi = fit['pi']
mu = fit['mu_cl']
lp = fit['lp__']

# Plotting Posteriors

In [None]:
# Define a function that plots the trace and posterior distribution for a given parameter
def plot_trace(param, param_name='parameter'):
    """Plot the trace and posterior of a parameter."""
    
    # Summary statistics
    mean = np.mean(param)
    median = np.median(param)
    cred_min, cred_max = np.percentile(param, 2.5), np.percentile(param, 97.5)
    
    # Plotting
    plt.subplot(2,1,1)
    plt.plot(param)
    plt.xlabel('samples')
    plt.ylabel(param_name)
    plt.axhline(mean, color='r', lw=2, linestyle='--')
    plt.axhline(median, color='c', lw=2, linestyle='--')
    plt.axhline(cred_min, linestyle=':', color='k', alpha=0.2)
    plt.axhline(cred_max, linestyle=':', color='k', alpha=0.2)
    plt.title('Trace and Posterior Distribution for {}'.format(param_name))

    plt.subplot(2,1,2)
    plt.hist(param, 30, density=True); sns.kdeplot(param, shade=True)
    plt.xlabel(param_name)
    plt.ylabel('density')
    plt.axvline(mean, color='r', lw=2, linestyle='--',label='mean')
    plt.axvline(median, color='c', lw=2, linestyle='--',label='median')
    plt.axvline(cred_min, linestyle=':', color='k', alpha=0.2, label=r'95\% CI')
    plt.axvline(cred_max, linestyle=':', color='k', alpha=0.2)
    
    plt.gcf().tight_layout()
    plt.legend()

In [None]:
# Detect number of clusters
plt.boxplot(pi)
plt.show()

In [None]:
plot_trace(mu[:,0], r'$\mu[1]$') 
plt.show()
plot_trace(mu[:,2], r'$\mu[3]$') 
plt.show()
plot_trace(mu[:,3], r'$\mu[4]$') 
plt.show()
plot_trace(lp, r'lp\_\_') 
plt.show()