From a0ca4a8af507652fbd12704a0553b233eac4f0de Mon Sep 17 00:00:00 2001 From: Johannes Buchner Date: Thu, 4 Jun 2020 17:52:30 +0200 Subject: [PATCH 1/3] add function for getting the flattened posterior --- stan_utility/__init__.py | 2 +- stan_utility/utils.py | 27 +++++++++++++++++++-------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/stan_utility/__init__.py b/stan_utility/__init__.py index e51b747..574516d 100644 --- a/stan_utility/__init__.py +++ b/stan_utility/__init__.py @@ -1,4 +1,4 @@ -from .utils import compile_model, compile_model_code, check_all_diagnostics, sample_model, plot_corner +from .utils import compile_model, compile_model_code, check_all_diagnostics, sample_model, plot_corner, get_flat_posterior from . import cache # from stan_utility.stan_generator import StanGenerator diff --git a/stan_utility/utils.py b/stan_utility/utils.py index 1dc3545..39e41a6 100644 --- a/stan_utility/utils.py +++ b/stan_utility/utils.py @@ -1,10 +1,12 @@ -import pystan import pickle import numpy import os import hashlib import re import warnings +import collections + +import pystan import arviz from stan_utility.cache import get_path as get_path_of_cache @@ -368,15 +370,24 @@ def sample_model(model, data, outprefix=None, **kwargs): return fit - -def plot_corner(samples, outprefix=None, **kwargs): +def get_flat_posterior(results): + la = results.posterior.data_vars + flat_posterior = collections.OrderedDict() + for k, v in la.items(): + a = v.data + b = numpy.rollaxis(a, -1) + newshape = tuple([b.shape[0] * b.shape[1]] + list(b.shape)[2:]) + flat_posterior[k] = v.data.transpose().reshape(newshape) + return flat_posterior + +def plot_corner(results, outprefix=None, **kwargs): """ Store a simple corner plot in outprefix_corner.pdf, based on samples extracted from fit. Additional kwargs are passed to MCSamples. """ - la = samples.posterior.data_vars + la = get_flat_posterior(results) samples = [] paramnames = [] badlist = ['lp__'] @@ -384,19 +395,19 @@ def plot_corner(samples, outprefix=None, **kwargs): for k in sorted(la.keys()): print('%20s: %.4f +- %.4f' % (k, la[k].mean(), la[k].std())) - if la[k].ndim == 2 and k not in badlist: - samples.append(la[k].data.flatten()) + if k not in badlist and la[k].ndim == 2: + samples.append(la[k]) paramnames.append(k) if len(samples) == 0: - arrays = [k for k in sorted(la.keys()) if la[k].ndim == 3 and la[k].shape[2] <= 20 and k not in badlist] + arrays = [k for k in la.keys() if la[k].ndim == 3 and la[k].shape[2] <= 20 and k not in badlist] if len(arrays) != 1: warnings.warn("no scalar variables found") return k = arrays[0] # flatten across chains and column for each variable - samples = numpy.rollaxis(la[k].data, 2).reshape((la[k].shape[2], -1)) + samples = numpy.rollaxis(la[k], 2).reshape((la[k].shape[2], -1)) paramnames = ['%s[%d]' % (k, i + 1) for i in range(la[k].shape[2])] samples = numpy.transpose(samples) From c8ee78bafc584d46d977095bde1a3585e5c3ff79 Mon Sep 17 00:00:00 2001 From: Johannes Buchner Date: Thu, 4 Jun 2020 18:12:30 +0200 Subject: [PATCH 2/3] fix shapes returned by get_flat_posterior to be chain iter first --- stan_utility/utils.py | 5 ++--- tests/test_compile.py | 8 +++++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/stan_utility/utils.py b/stan_utility/utils.py index 39e41a6..7dfe7c5 100644 --- a/stan_utility/utils.py +++ b/stan_utility/utils.py @@ -375,9 +375,8 @@ def get_flat_posterior(results): flat_posterior = collections.OrderedDict() for k, v in la.items(): a = v.data - b = numpy.rollaxis(a, -1) - newshape = tuple([b.shape[0] * b.shape[1]] + list(b.shape)[2:]) - flat_posterior[k] = v.data.transpose().reshape(newshape) + newshape = tuple([a.shape[0] * a.shape[1]] + list(a.shape)[2:]) + flat_posterior[k] = v.data.reshape(newshape) return flat_posterior def plot_corner(results, outprefix=None, **kwargs): diff --git a/tests/test_compile.py b/tests/test_compile.py index f7d1f00..3d3f0b5 100644 --- a/tests/test_compile.py +++ b/tests/test_compile.py @@ -47,7 +47,7 @@ def test_compile_string(): ) if os.path.exists("mytest_fitfit.hdf5"): os.unlink("mytest_fitfit.hdf5") - samples = stan_utility.sample_model(model, data, outprefix="mytest_fit", chains=1) + samples = stan_utility.sample_model(model, data, outprefix="mytest_fit", chains=2, iter=346) assert os.path.exists("mytest_fitfit.hdf5") os.unlink("mytest_fitfit.hdf5") @@ -56,6 +56,12 @@ def test_compile_string(): stan_utility.plot_corner(samples, outprefix="mytest_fit") assert os.path.exists("mytest_fit_corner.pdf") os.unlink("mytest_fit_corner.pdf") + + flat_samples = stan_utility.get_flat_posterior(samples) + assert set(flat_samples.keys()) == {"x", "y"}, flat_samples.keys() + assert flat_samples['x'].shape == (346,), flat_samples['x'].shape + assert flat_samples['y'].shape == (346, 10), flat_samples['y'].shape + if __name__ == '__main__': From 9d4d8d68c314bc0c83bce49a6f63ef7fd47c80b1 Mon Sep 17 00:00:00 2001 From: Johannes Buchner Date: Thu, 4 Jun 2020 18:18:57 +0200 Subject: [PATCH 3/3] avoid double transpose --- stan_utility/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stan_utility/utils.py b/stan_utility/utils.py index 7dfe7c5..3848bf4 100644 --- a/stan_utility/utils.py +++ b/stan_utility/utils.py @@ -406,8 +406,8 @@ def plot_corner(results, outprefix=None, **kwargs): k = arrays[0] # flatten across chains and column for each variable - samples = numpy.rollaxis(la[k], 2).reshape((la[k].shape[2], -1)) - paramnames = ['%s[%d]' % (k, i + 1) for i in range(la[k].shape[2])] + samples = la[k] + paramnames = ['%s[%d]' % (k, i + 1) for i in range(la[k].shape[1])] samples = numpy.transpose(samples) import matplotlib.pyplot as plt