From 4815daabe9db51eb5cbe8062678722d30077b873 Mon Sep 17 00:00:00 2001 From: Oriol Abril Date: Mon, 21 Oct 2019 23:01:48 +0200 Subject: [PATCH] fix #822 and add tests to prevent it (#823) --- arviz/tests/test_data_pymc.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/arviz/tests/test_data_pymc.py b/arviz/tests/test_data_pymc.py index eb17398af7..9c67cc680e 100644 --- a/arviz/tests/test_data_pymc.py +++ b/arviz/tests/test_data_pymc.py @@ -152,3 +152,29 @@ def test_constant_data(self): test_dict = {"posterior": ["beta"], "observed_data": ["obs"], "constant_data": ["x"]} fails = check_multiple_attrs(test_dict, inference_data) assert not fails + + def test_no_trace(self): + with pm.Model(): + x = pm.Data("x", [1.0, 2.0, 3.0]) + y = pm.Data("y", [1.0, 2.0, 3.0]) + beta = pm.Normal("beta", 0, 1) + obs = pm.Normal("obs", x * beta, 1, observed=y) # pylint: disable=unused-variable + trace = pm.sample(100, tune=100) + prior = pm.sample_prior_predictive() + posterior_predictive = pm.sample_posterior_predictive(trace) + + # Only prior + inference_data = from_pymc3(prior=prior) + test_dict = {"prior": ["beta", "obs"]} + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails + # Only posterior_predictive + inference_data = from_pymc3(posterior_predictive=posterior_predictive) + test_dict = {"posterior_predictive": ["obs"]} + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails + # Prior and posterior_predictive but no trace + inference_data = from_pymc3(prior=prior, posterior_predictive=posterior_predictive) + test_dict = {"prior": ["beta", "obs"], "posterior_predictive": ["obs"]} + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails