Skip to content

Commit

Permalink
Merge 293b2e3 into 46b52c5
Browse files Browse the repository at this point in the history
  • Loading branch information
ahartikainen committed Oct 13, 2018
2 parents 46b52c5 + 293b2e3 commit a07e716
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 4 deletions.
28 changes: 28 additions & 0 deletions arviz/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,31 @@ def load_cached_models(draws, chains):
with open(path, "rb") as buff:
models[library.__name__] = pickle.load(buff)
return models


def pystan_extract_unpermuted(fit, var_names=None):
"""Extract PyStan samples unpermuted.
Function return everything as a float.
"""
if var_names is None:
var_names = fit.model_pars
extract = fit.extract(var_names, permuted=False)
if not isinstance(extract, dict):
extract_permuted = fit.extract(var_names, permuted=True)
permutation_order = fit.sim["permutation"]
ary_order = []
for order in permutation_order:
order = np.argsort(order) + len(ary_order)
ary_order.extend(list(order))
nchain = fit.sim["chains"]
extract = {}
for key, ary in extract_permuted.items():
ary = np.asarray(ary)[ary_order]
if ary.shape:
ary_shape = ary.shape[1:]
else:
ary_shape = ary.shape
ary = ary.reshape((-1, nchain, *ary_shape), order="F")
extract[key] = ary
return extract
41 changes: 37 additions & 4 deletions arviz/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from_pystan,
from_emcee,
)
from .helpers import eight_schools_params, load_cached_models, BaseArvizTest
from .helpers import (
eight_schools_params,
load_cached_models,
BaseArvizTest,
pystan_extract_unpermuted,
)


class TestNumpyToDataArray:
Expand Down Expand Up @@ -156,7 +161,7 @@ def setup_class(cls):
cls.data = eight_schools_params()
cls.draws, cls.chains = 500, 2
_, stan_fit = load_cached_models(cls.draws, cls.chains)["pystan"]
stan_dict = stan_fit.extract(stan_fit.model_pars, permuted=False)
stan_dict = pystan_extract_unpermuted(stan_fit)
cls.obj = {}
for name, vals in stan_dict.items():
if name not in {"y_hat", "log_lik"}: # extra vars
Expand Down Expand Up @@ -256,8 +261,12 @@ def setup_class(cls):
cls.model, cls.obj = load_cached_models(cls.draws, cls.chains)["pystan"]

def get_inference_data(self):
"""log_likelihood as a var."""
prior = pystan_extract_unpermuted(self.obj)
prior = {"theta_test": prior["theta"]}
return from_pystan(
fit=self.obj,
prior=prior,
posterior_predictive="y_hat",
observed_data=["y"],
log_likelihood="log_lik",
Expand All @@ -272,13 +281,14 @@ def get_inference_data(self):
)

def get_inference_data2(self):
"""log_likelihood as a ndarray."""
# dictionary
observed_data = {"y_hat": self.data["y"]}
# ndarray
log_likelihood = self.obj.extract("log_lik", permuted=False)["log_lik"]
log_likelihood = pystan_extract_unpermuted(self.obj, "log_lik")["log_lik"]
return from_pystan(
fit=self.obj,
posterior_predictive="y_hat",
posterior_predictive=["y_hat"],
observed_data=observed_data,
log_likelihood=log_likelihood,
coords={"school": np.arange(self.data["J"])},
Expand All @@ -291,17 +301,40 @@ def get_inference_data2(self):
},
)

def get_inference_data3(self):
"""log_likelihood as a ndarray."""
# ndarray
log_likelihood = pystan_extract_unpermuted(self.obj, "log_lik")["log_lik"]
return from_pystan(
fit=self.obj,
posterior_predictive=["y_hat"],
observed_data=["y"],
log_likelihood=log_likelihood,
coords={"school": np.arange(self.data["J"])},
dims={
"theta": ["school"],
"y": ["school"],
"log_lik": ["school"],
"y_hat": ["school"],
"theta_tilde": ["school"],
},
)

def test_sampler_stats(self):
inference_data = self.get_inference_data()
assert hasattr(inference_data, "sample_stats")

def test_inference_data(self):
inference_data1 = self.get_inference_data()
inference_data2 = self.get_inference_data2()
inference_data3 = self.get_inference_data3()
assert hasattr(inference_data1.sample_stats, "log_likelihood")
assert hasattr(inference_data1.prior, "theta_test")
assert hasattr(inference_data1.observed_data, "y")
assert hasattr(inference_data2.sample_stats, "log_likelihood")
assert hasattr(inference_data2.observed_data, "y_hat")
assert hasattr(inference_data3.sample_stats, "log_likelihood")
assert hasattr(inference_data3.observed_data, "y")


class TestCmdStanNetCDFUtils(BaseArvizTest):
Expand Down

0 comments on commit a07e716

Please sign in to comment.