Skip to content

Commit

Permalink
Merge pull request #1 from JohannesBuchner/feature-moreutils
Browse files Browse the repository at this point in the history
extracting flat posterior
  • Loading branch information
JohannesBuchner committed Jun 4, 2020
2 parents 914edcf + 9d4d8d6 commit 9ae0ee0
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 10 deletions.
2 changes: 1 addition & 1 deletion stan_utility/__init__.py
@@ -1,4 +1,4 @@
from .utils import compile_model, compile_model_code, check_all_diagnostics, sample_model, plot_corner
from .utils import compile_model, compile_model_code, check_all_diagnostics, sample_model, plot_corner, get_flat_posterior
from . import cache

# from stan_utility.stan_generator import StanGenerator
Expand Down
26 changes: 18 additions & 8 deletions stan_utility/utils.py
@@ -1,10 +1,12 @@
import pystan
import pickle
import numpy
import os
import hashlib
import re
import warnings
import collections

import pystan
import arviz

from stan_utility.cache import get_path as get_path_of_cache
Expand Down Expand Up @@ -368,36 +370,44 @@ def sample_model(model, data, outprefix=None, **kwargs):

return fit

def get_flat_posterior(results):
la = results.posterior.data_vars
flat_posterior = collections.OrderedDict()
for k, v in la.items():
a = v.data
newshape = tuple([a.shape[0] * a.shape[1]] + list(a.shape)[2:])
flat_posterior[k] = v.data.reshape(newshape)
return flat_posterior

def plot_corner(samples, outprefix=None, **kwargs):
def plot_corner(results, outprefix=None, **kwargs):
"""
Store a simple corner plot in outprefix_corner.pdf, based on samples
extracted from fit.
Additional kwargs are passed to MCSamples.
"""
la = samples.posterior.data_vars
la = get_flat_posterior(results)
samples = []
paramnames = []
badlist = ['lp__']
badlist += [k for k in la.keys() if 'log' in k and k.replace('log', '') in la.keys()]

for k in sorted(la.keys()):
print('%20s: %.4f +- %.4f' % (k, la[k].mean(), la[k].std()))
if la[k].ndim == 2 and k not in badlist:
samples.append(la[k].data.flatten())
if k not in badlist and la[k].ndim == 2:
samples.append(la[k])
paramnames.append(k)

if len(samples) == 0:
arrays = [k for k in sorted(la.keys()) if la[k].ndim == 3 and la[k].shape[2] <= 20 and k not in badlist]
arrays = [k for k in la.keys() if la[k].ndim == 3 and la[k].shape[2] <= 20 and k not in badlist]
if len(arrays) != 1:
warnings.warn("no scalar variables found")
return

k = arrays[0]
# flatten across chains and column for each variable
samples = numpy.rollaxis(la[k].data, 2).reshape((la[k].shape[2], -1))
paramnames = ['%s[%d]' % (k, i + 1) for i in range(la[k].shape[2])]
samples = la[k]
paramnames = ['%s[%d]' % (k, i + 1) for i in range(la[k].shape[1])]

samples = numpy.transpose(samples)
import matplotlib.pyplot as plt
Expand Down
8 changes: 7 additions & 1 deletion tests/test_compile.py
Expand Up @@ -47,7 +47,7 @@ def test_compile_string():
)
if os.path.exists("mytest_fitfit.hdf5"):
os.unlink("mytest_fitfit.hdf5")
samples = stan_utility.sample_model(model, data, outprefix="mytest_fit", chains=1)
samples = stan_utility.sample_model(model, data, outprefix="mytest_fit", chains=2, iter=346)
assert os.path.exists("mytest_fitfit.hdf5")
os.unlink("mytest_fitfit.hdf5")

Expand All @@ -56,6 +56,12 @@ def test_compile_string():
stan_utility.plot_corner(samples, outprefix="mytest_fit")
assert os.path.exists("mytest_fit_corner.pdf")
os.unlink("mytest_fit_corner.pdf")

flat_samples = stan_utility.get_flat_posterior(samples)
assert set(flat_samples.keys()) == {"x", "y"}, flat_samples.keys()
assert flat_samples['x'].shape == (346,), flat_samples['x'].shape
assert flat_samples['y'].shape == (346, 10), flat_samples['y'].shape



if __name__ == '__main__':
Expand Down

0 comments on commit 9ae0ee0

Please sign in to comment.