Skip to content

Commit

Permalink
increase test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Jun 8, 2020
1 parent b2d924e commit 3ced13b
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions arviz/tests/external_tests/test_data_pyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,13 @@ def test_inference_data_num_chains(self, predictions_data, chains):
nchains = inference_data.predictions.dims["chain"]
assert nchains == chains

def test_log_likelihood_warning(self):
@pytest.mark.parametrize("log_likelihood", [True, False])
def test_log_likelihood(self, log_likelihood):
"""Test behaviour when log likelihood cannot be retrieved.
If log_likelihood=True there is a warning to say log_likelihood group is skipped,
if log_likelihood=False there is no warning and log_likelihood is skipped.
"""
x = torch.randn((10, 2))
y = torch.randn(10)

Expand All @@ -239,8 +245,11 @@ def model_constant_data(x, y=None):
nuts_kernel = pyro.infer.NUTS(model_constant_data)
mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=10)
mcmc.run(x=x, y=y)
with pytest.warns(UserWarning, match="Could not get vectorized trace"):
inference_data = from_pyro(mcmc)
if log_likelihood:
with pytest.warns(UserWarning, match="Could not get vectorized trace"):
inference_data = from_pyro(mcmc, log_likelihood=log_likelihood)
else:
inference_data = from_pyro(mcmc, log_likelihood=log_likelihood)
test_dict = {
"posterior": ["beta"],
"sample_stats": ["diverging"],
Expand Down

0 comments on commit 3ced13b

Please sign in to comment.