# One Way Normal Model in Stan

In [None]:
from timeit import default_timer as timer
import numpy as np
import pickle

import pystan

from utils import generate_datasets, count_divergences, SEED, I, SIGMA

## 1. Models

In [None]:
# optionally:
# from stan_model import *
# recompile_centered_model()
# recompile_non_centered_model()

In [None]:
# centered
def stan_model_c():
  return pickle.load(open('1wayN_centered.pkl', 'rb'))

# non-centered
def stan_model_nc():
  return pickle.load(open('1wayN_noncentered.pkl', 'rb'))

## 2. Inference

The configurations for NUTS taken from [1]

### NUTS - centered

In [None]:
def stan_nuts_c(seeds=SEED):
  """
  Runs Stan's NUTS algorithm for centered parameterization
  Default parameters: configuration from the paper
  """
  Y, theta = generate_datasets(seeds=seeds)
  sm_c = stan_model_c()
  for y,seed in zip(Y, seeds):
    print(seed)
    stan_data = dict(I=I, y=y, sigma=[SIGMA]*I)
    start = timer()
    fit = sm_c.sampling(data=stan_data, iter=100000, thin=1000, 
                                  warmup=5000, chains=4, seed=seed, 
                                  refresh=100000, control=dict(adapt_delta = 0.999))
    end = timer()
    e = fit.extract(permuted=False)[:,:,:2].mean(axis=1)
    results = {'mu': e[:,0], 'tau': e[:,1], 'time': end-start, 'iters': 100000, 'warmup': 5000, 'thin': 1000, 'divergences': int(count_divergences(fit))}
    with open('results/stan/nuts_c_{}.pkl'.format(seed), 'wb') as f:
      pickle.dump(results, f)
  print('Done')

### NUTS - non-centered

In [None]:
# default is nominal, other option baseline
def stan_nuts_nc(iters=50000, warmup=5000, adapt_delta=0.8, mode='nominal', seeds=SEED):
  """
  Runs Stan's NUTS algorithm for non-centered parameterization
  Default parameters: nominal configuration from the paper
  For baseline use: iters=100000, warmup=5000, adapt_delta=0.99, mode='baseline'
  """
  Y, theta = generate_datasets(seeds=seeds)
  sm_nc = stan_model_nc()
  for y,seed in zip(Y, seeds):
    print(seed)
    stan_data = dict(I=I, y=y, sigma=[SIGMA]*I)
    start = timer()
    fit = sm_c.sampling(data=stan_data, iter=iters, warmup=warmup, chains=4, seed=seed, 
                                  refresh=100000, control=dict(adapt_delta=adapt_delta))
    end = timer()
    e = fit.extract(permuted=False)[:,:,:2].mean(axis=1)
    results = {'mu': e[:,0], 'tau': e[:,1], 'time': end-start, 'iters': iters, 'warmup': warmup, 'divergences': int(count_divergences(fit))}
    with open('results/stan/nuts_nc_{}_{}.pkl'.format(mode, seed), 'wb') as f:
      pickle.dump(results, f)
  print('Done')

In [None]:
# centered
stan_nuts_c()

In [None]:
# nominal
stan_nuts_nc()

In [None]:
# baseline
stan_nuts_nc(iters=100000, warmup=5000, adapt_delta=0.99, mode='baseline', seeds=SEED)

## VI

In [None]:
def stan_vi(mode='c', seeds=SEED):
  """
  Runs Stan's ADVI algorithm (meanfield approximation) 
  param mode: if 'c' use the centered parameterization, if 'nc' the non-centered
  """
  if mode not in ['c', 'nc']:
    raise "Mode has to be 'c' for centered or 'nc' for non-centered model!"
    
  model = stan_model_c() if mode == 'c' else stan_model_nc()
  Y, theta = generate_datasets(seeds=seeds)
  
  for seed, y in zip(seeds, Y):
      stan_data = dict(I=I,  y=y, sigma=[SIGMA]*I)
      print(seed)
      # iters = ...(10000, 150000, 5)
      iters = np.linspace(50000, 500000, 5).astype(int)
      for it in iters:
        start = timer()
        fit = model.vb(data=stan_data, algorithm='meanfield', iter=it, 
                        tol_rel_obj=0.0001, seed=seed, output_samples=1000)
        end = timer()
        
        e = fit['sampler_params']
        results = {'iters': it, 'tol': 0.0001, 'time': end-start, 'mu': e[0], 'tau': e[1]}
          
        with open('results/stan/vi5_{}_{}.pkl'.format(mode, seed), 'ab') as f:
          pickle.dump(results, f)
  print('Done')

In [None]:
# centered 
stan_vi()

In [None]:
# non-centered
stan_vi(mode='nc')

### References

[1] Betancourt, Michael J. and Girolami, Mark. Hamiltonian Monte Carlo for Hierarchical Models. 2013.