# Donut in Stan

In [None]:
from timeit import default_timer as timer
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
import pickle
import numpy as np

import pystan as ps

from utils import generate_datasets, count_divergences, SEED

In [None]:
print(ps.__version__)

## 1. Model

In [None]:
# optionally
# from stan_model import *
# recompile_model()

In [None]:
sm = pickle.load(open('sm_donut.pkl', 'rb'))

## 2. Inference

### NUTS

In [None]:
def stan_nuts(model, n_samples, n_dim, seeds=SEED):
  """
  Runs Stan's NUTS algorithm
  """
  Y, C, R, r = generate_datasets(n_samples, n_dim, seeds)
  
  for seed, y in zip(seeds, Y):
    stan_data = dict(N=n_samples, D=n_dim, y=y)

    print(seed)
    start = timer()
    fit = model.sampling(data=stan_data, iter=2000, warmup=1000,
                            chains=4, control=dict(adapt_delta=0.99))
    end = timer()
    results = {'iters': 2000, 'warmup': 1000, 'divergences': int(count_divergences(fit)), 'time': end-start}
    with open('results/stan/nuts_{}d_{}'.format(n_dim, seed), 'wb') as f:
      pickle.dump(results, f)
  print('Done')

In [None]:
# Small dataset
n_samples = 1000
n_dim = 2
stan_nuts(sm, n_samples, n_dim)

In [None]:
# Big dataset
n_samples = 5000
n_dim = 5
stan_nuts(sm, n_samples, n_dim)

## VI

In [None]:
def stan_vi(model, n_samples, n_dim, seeds=SEED):
  """
  Runs Stan's ADVI algorithm (Meanfield approximation)
  """
  Y, C, R, r = generate_datasets(n_samples, n_dim, seeds)
  
  for seed, y in zip(seeds, Y):
    stan_data = dict(N=n_samples, D=n_dim, y=y)

    print(seed)
    iters = np.linspace(1000, 50000, 5).astype(int)
    for it in iters:
      start = timer()
      fit = model.vb(data=stan_data, algorithm='meanfield', iter=it, 
                        tol_rel_obj=0.0001, seed=seed, output_samples=1000)
      end = timer()
        
      e = fit['sampler_params']
      results = {'iters': it, 'tol': 0.0001, 'time': end-start, 'R': e[0], 'r': e[1]}
      for i in range(n_dim):
        results['C_{}'.format(i)] = e[2 + i]
          
      with open('results/stan/vi_{}d_{}.pkl'.format(n_dim, seed), 'ab') as f:
          pickle.dump(results, f)
  print('Done')

In [None]:
# Small dataset
n_samples = 1000
n_dim = 2
stan_vi(sm, n_samples, n_dim)

In [None]:
# Big dataset
n_samples = 5000
n_dim = 5
stan_vi(sm, n_samples, n_dim)