# Gaussian Mixture Model in Stan

In [2]:
from timeit import default_timer as timer
import numpy as np
import pickle

import pystan

from utils import SEED, count_divergences

## 1. Data

In [4]:
def load_datasets(filename='gmm_6k.pkl', n=3):
  """
  :param filename: name of the pickle file 
  :param n: number of datasets to read (defults to 3)
  :return: list of loaded datasets in dict format
  """
  datasets = []
  with open(filename, 'rb') as f:
    for i in range(n):
      dataset = pickle.load(f)
      datasets.append(dataset)      
  return datasets

## 2. Model

In [20]:
# optionally
# from stan_model import recompile_model
# sm = recompile_model()

INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_ca775848777f6510647342598992bdde NOW.


In [3]:
with open('gmm_1d.pkl', 'rb') as f:
    sm = pickle.load(f)

## 3. Inference

### NUTS

In [5]:
def stan_nuts(filename='gmm_6k.pkl', n=3):
  """
  Runs Stan NUTS algorithm
  """
  datasets = load_datasets(filename, n)
  
  for dataset in datasets:
    stan_data = {'N': dataset['N'], 'K': dataset['K'], 'y': dataset['y'], 'alpha': [1]*dataset['K']}
    start = timer()
    fit = sm.sampling(data=stan_data, seed=dataset['seed'], iter=5000, warmup=2000, thin=10, chains=4)
    end = timer()
    K = dataset['K']
    e = fit.extract(permuted=False)[:,:,:3*K].mean(axis=1)
    mu = e[:,:K]
    sigma = e[:,K:2*K]
    w = e[:,2*K:3*K]
    results = {'mu': mu, 'sigma': sigma, 'w': w, 'time': end-start, 'iters': 5000, 'warmup': 2000, 'thin': 10, 'divergences': int(count_divergences(fit))}
    with open('results/stan/nuts_{}k_{}.pkl'.format(K, dataset['seed']), 'wb') as f:
      pickle.dump(results, f)
    fits.append(fit)

In [None]:
stan_nuts(filename='gmm_3k.pkl', n=3) 

In [None]:
stan_nuts(filename='gmm_6k.pkl', n=3) 

### ADVI

In [6]:
def stan_vi(filename='gmm_6k.pkl', n=3):
  """
  Runs Stan ADVI algorithm
  """
  datasets = load_datasets(filename, n)
    
  for dataset in datasets:
    stan_data = {'N': dataset['N'], 'K': dataset['K'], 'y': dataset['y'], 'alpha': [1]*dataset['K']}
    iters = np.linspace(500, 50000, 10).astype(int)
    for it in iters:
      start = timer()
      fit = sm.vb(data=stan_data, algorithm='meanfield', iter=it, 
                        tol_rel_obj=0.000001, seed=dataset['seed'], output_samples=1000)
      end = timer()
      K = dataset['K']
      e = fit['sampler_params']
      mu = np.array(e[:K])
      sigma = np.array(e[K:2*K])
      w = np.array(e[2*K:3*K])
      results = {'mu': mu, 'sigma': sigma, 'w': w, 'time': end-start, 'iters': it, 'tol': 0.000001}
      with open('results/stan/vi_{}k_{}.pkl'.format(K, dataset['seed']), 'ab') as f:
        pickle.dump(results, f)

In [None]:
stan_vi(filename='gmm_3k.pkl', n=3)

In [None]:
stan_vi(filename='gmm_6k.pkl', n=3)

## References

[1] Carpenter, Bob [Mixture models in Stan](http://andrewgelman.com/2017/08/21/mixture-models-stan-can-use-log\_mix)

[2] Lieu, Maggie [Multivariate Gaussian Mixture Model done properly](https://maggielieu.com/2017/03/21/multivariate-gaussian-mixture-model-done-properly)

[3] Betancourt, Michael [Identifying Bayesian Mixture Models](http://mc-stan.org/users/documentation/case-studies/identifying_mixture_models.html)
