Skip to content

Commit

Permalink
Merge a778bff into cf63185
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Nov 14, 2019
2 parents cf63185 + a778bff commit b1c1625
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
17 changes: 8 additions & 9 deletions arviz/stats/stats.py
Expand Up @@ -485,14 +485,10 @@ def loo(data, pointwise=False, reff=None, scale="deviance"):
"""
inference_data = convert_to_inference_data(data)
for group in ("posterior", "sample_stats"):
if not hasattr(inference_data, group):
raise TypeError(
"Must be able to extract a {group} group from data!".format(group=group)
)
if not hasattr(inference_data, "sample_stats"):
raise TypeError("Must be able to extract a sample_stats group from data.")
if "log_likelihood" not in inference_data.sample_stats:
raise TypeError("Data must include log_likelihood in sample_stats")
posterior = inference_data.posterior
log_likelihood = inference_data.sample_stats.log_likelihood
log_likelihood = log_likelihood.stack(sample=("chain", "draw"))
shape = log_likelihood.shape
Expand All @@ -509,6 +505,9 @@ def loo(data, pointwise=False, reff=None, scale="deviance"):
raise TypeError('Valid scale values are "deviance", "log", "negative_log"')

if reff is None:
if not hasattr(inference_data, "posterior"):
raise TypeError("Must be able to extract a posterior group from data.")
posterior = inference_data.posterior
n_chains = len(posterior.chain)
if n_chains == 1:
reff = 1.0
Expand Down Expand Up @@ -907,12 +906,12 @@ def summary(

fmt_group = ("wide", "long", "xarray")
if not isinstance(fmt, str) or (fmt.lower() not in fmt_group):
raise TypeError("Invalid format: '{}'! Formatting options are: {}".format(fmt, fmt_group))
raise TypeError("Invalid format: '{}'. Formatting options are: {}".format(fmt, fmt_group))

unpack_order_group = ("C", "F")
if not isinstance(order, str) or (order.upper() not in unpack_order_group):
raise TypeError(
"Invalid order: '{}'! Unpacking options are: {}".format(order, unpack_order_group)
"Invalid order: '{}'. Unpacking options are: {}".format(order, unpack_order_group)
)

alpha = 1 - credible_interval
Expand Down Expand Up @@ -1148,7 +1147,7 @@ def waic(data, pointwise=False, scale="deviance"):
for group in ("sample_stats",):
if not hasattr(inference_data, group):
raise TypeError(
"Must be able to extract a {group} group from data!".format(group=group)
"Must be able to extract a {group} group from data.".format(group=group)
)
if "log_likelihood" not in inference_data.sample_stats:
raise TypeError("Data must include log_likelihood in sample_stats")
Expand Down
9 changes: 9 additions & 0 deletions arviz/tests/test_stats.py
Expand Up @@ -321,6 +321,15 @@ def test_loo_bad_scale(centered_eight):
loo(centered_eight, scale="bad_scale")


def test_loo_bad_no_posterior_reff(centered_eight):
loo(centered_eight, reff=None)
centered_eight = deepcopy(centered_eight)
del centered_eight.posterior
with pytest.raises(TypeError):
loo(centered_eight, reff=None)
loo(centered_eight, reff=0.7)


def test_loo_warning(centered_eight):
centered_eight = deepcopy(centered_eight)
# make one of the khats infinity
Expand Down

0 comments on commit b1c1625

Please sign in to comment.